-
Notifications
You must be signed in to change notification settings - Fork 285
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update I-JEPA example to use timm #1614
Comments
Hey @guarin I wanted to know if I can contribute to this feature request. I did go through the examples file and looking into the similar mae.py examples, I did jot some changes. I wanted to ask if I am heading in the right direction. If so, I can create a PR based on the feature request and you guys can take a look. Here is what I have for now: # More headers
# from lightly.models.modules.ijepa import IJEPABackbone, IJEPAPredictor
from lightly.models.modules import IJEPAPredictorTIMM, MaskedVisionTransformerTIMM
from lightly.transforms.ijepa_transform import IJEPATransform
class IJEPA(nn.Module):
def __init__(self, vit_encoder, vit_predictor, momentum_scheduler):
super().__init__()
#self.encoder = IJEPABackbone.from_vit(vit_encoder)
self.encoder = MaskedVisionTransformerTIMM(vit=vit_encoder)
"""
self.predictor = IJEPAPredictor.from_vit_encoder(
vit_predictor.encoder,
(vit_predictor.image_size // vit_predictor.patch_size) ** 2,
)
"""
self.predictor = IJEPAPredictorTIMM(
num_patches=vit_predictor.patch_embed.num_patches,
depth=vit_predictor.depth,
mlp_dim=vit_predictor.embed_dim,
predictor_embed_dim=384, # Official VisionTransformerPredictor Class
num_heads=vit_predictor.num_heads, # Official implementation uses 12
qkv_bias=vit_predictor.qkv_bias,
mlp_ratio=vit_predictor.mlp_ratio,
drop_path_rate=vit_predictor.drop_path_rate,
proj_drop_rate=vit_predictor.proj_drop_rate,
attn_drop_rate=vit_predictor.attn_drop_rate,
)
self.target_encoder = copy.deepcopy(self.encoder)
self.momentum_scheduler = momentum_scheduler
def forward_target(self, imgs, masks_enc, masks_pred):
with torch.no_grad():
h = self.target_encoder(images=imgs)
h = F.layer_norm(h, (h.size(-1),)) # normalize over feature-dim
B = len(h)
# -- create targets (masked regions of h)
h = utils.apply_masks(h, masks_pred)
h = utils.repeat_interleave_batch(h, B, repeat=len(masks_enc))
return h
def forward_context(self, imgs, masks_enc, masks_pred):
z = self.encoder(imgs, masks_enc)
z = self.predictor(z, masks_enc, masks_pred)
return z
def forward(self, imgs, masks_enc, masks_pred):
z = self.forward_context(imgs, masks_enc, masks_pred)
h = self.forward_target(imgs, masks_enc, masks_pred)
return z, h
def update_target_encoder(
self,
):
with torch.no_grad():
m = next(self.momentum_scheduler)
for param_q, param_k in zip(
self.encoder.parameters(), self.target_encoder.parameters()
):
param_k.data.mul_(m).add_((1.0 - m) * param_q.detach().data)
collator = IJEPAMaskCollator(
input_size=(224, 224),
patch_size=32,
)
transform = IJEPATransform()
# we ignore object detection annotations by setting target_transform to return 0
# or create a dataset from a folder containing images or videos:
# dataset = LightlyDataset("path/to/folder")
def target_transform(t):
return 0
dataset = torchvision.datasets.VOCDetection(
"datasets/pascal_voc",
download=True,
transform=transform,
target_transform=target_transform,
)
data_loader = torch.utils.data.DataLoader(
dataset, collate_fn=collator, batch_size=10, persistent_workers=False
)
ema = (0.996, 1.0)
ipe_scale = 1.0
ipe = len(data_loader)
num_epochs = 10
momentum_scheduler = (
ema[0] + i * (ema[1] - ema[0]) / (ipe * num_epochs * ipe_scale)
for i in range(int(ipe * num_epochs * ipe_scale) + 1)
)
# vit_for_predictor = torchvision.models.vit_b_32(pretrained=False) # Remove torchvision dependency
vit_for_predictor = vit_base_patch32_224()
#vit_for_embedder = torchvision.models.vit_b_32(pretrained=False) # Remove torchvision dependency
vit_for_embedder = vit_base_patch32_224()
model = IJEPA(vit_for_embedder, vit_for_predictor, momentum_scheduler)
# More lines below this Also, why is there a discrepancy between IJEPAPredictorTIMM and the official VisionTransformerPredictor (https://github.com/facebookresearch/ijepa/blob/main/src/models/vision_transformer.py#L445) with the usage of |
I-JEPA is compatible with timm since #1612
We should update the I-JEPA examples to use timm instead of torchvision:
examples/pytorch/ijepa.py
examples/pytorch_lightning/ijepa.py
examples/pytorch_lightning_distributed/ijepa.py
The text was updated successfully, but these errors were encountered: