Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update I-JEPA example to use timm #1614

Open
3 tasks
guarin opened this issue Jul 29, 2024 · 1 comment
Open
3 tasks

Update I-JEPA example to use timm #1614

guarin opened this issue Jul 29, 2024 · 1 comment
Labels

Comments

@guarin
Copy link
Contributor

guarin commented Jul 29, 2024

I-JEPA is compatible with timm since #1612

We should update the I-JEPA examples to use timm instead of torchvision:

  • Update examples/pytorch/ijepa.py
  • Update examples/pytorch_lightning/ijepa.py
  • Update examples/pytorch_lightning_distributed/ijepa.py
@Shrinidhibhat87
Copy link

Shrinidhibhat87 commented Oct 18, 2024

Hey @guarin I wanted to know if I can contribute to this feature request. I did go through the examples file and looking into the similar mae.py examples, I did jot some changes. I wanted to ask if I am heading in the right direction.

If so, I can create a PR based on the feature request and you guys can take a look.

Here is what I have for now:

# More headers
# from lightly.models.modules.ijepa import IJEPABackbone, IJEPAPredictor
from lightly.models.modules import IJEPAPredictorTIMM, MaskedVisionTransformerTIMM
from lightly.transforms.ijepa_transform import IJEPATransform


class IJEPA(nn.Module):
    def __init__(self, vit_encoder, vit_predictor, momentum_scheduler):
        super().__init__()
        #self.encoder = IJEPABackbone.from_vit(vit_encoder)
        self.encoder = MaskedVisionTransformerTIMM(vit=vit_encoder)
        """
        self.predictor = IJEPAPredictor.from_vit_encoder(
            vit_predictor.encoder,
            (vit_predictor.image_size // vit_predictor.patch_size) ** 2,
        )        
        """
        self.predictor = IJEPAPredictorTIMM(
            num_patches=vit_predictor.patch_embed.num_patches,
            depth=vit_predictor.depth,
            mlp_dim=vit_predictor.embed_dim,
            predictor_embed_dim=384, # Official VisionTransformerPredictor Class
            num_heads=vit_predictor.num_heads, # Official implementation uses 12
            qkv_bias=vit_predictor.qkv_bias,
            mlp_ratio=vit_predictor.mlp_ratio,
            drop_path_rate=vit_predictor.drop_path_rate,
            proj_drop_rate=vit_predictor.proj_drop_rate,
            attn_drop_rate=vit_predictor.attn_drop_rate,
        )
        self.target_encoder = copy.deepcopy(self.encoder)
        self.momentum_scheduler = momentum_scheduler

    def forward_target(self, imgs, masks_enc, masks_pred):
        with torch.no_grad():
            h = self.target_encoder(images=imgs)
            h = F.layer_norm(h, (h.size(-1),))  # normalize over feature-dim
            B = len(h)
            # -- create targets (masked regions of h)
            h = utils.apply_masks(h, masks_pred)
            h = utils.repeat_interleave_batch(h, B, repeat=len(masks_enc))
            return h

    def forward_context(self, imgs, masks_enc, masks_pred):
        z = self.encoder(imgs, masks_enc)
        z = self.predictor(z, masks_enc, masks_pred)
        return z

    def forward(self, imgs, masks_enc, masks_pred):
        z = self.forward_context(imgs, masks_enc, masks_pred)
        h = self.forward_target(imgs, masks_enc, masks_pred)
        return z, h

    def update_target_encoder(
        self,
    ):
        with torch.no_grad():
            m = next(self.momentum_scheduler)
            for param_q, param_k in zip(
                self.encoder.parameters(), self.target_encoder.parameters()
            ):
                param_k.data.mul_(m).add_((1.0 - m) * param_q.detach().data)


collator = IJEPAMaskCollator(
    input_size=(224, 224),
    patch_size=32,
)

transform = IJEPATransform()

# we ignore object detection annotations by setting target_transform to return 0
# or create a dataset from a folder containing images or videos:
# dataset = LightlyDataset("path/to/folder")


def target_transform(t):
    return 0


dataset = torchvision.datasets.VOCDetection(
    "datasets/pascal_voc",
    download=True,
    transform=transform,
    target_transform=target_transform,
)
data_loader = torch.utils.data.DataLoader(
    dataset, collate_fn=collator, batch_size=10, persistent_workers=False
)

ema = (0.996, 1.0)
ipe_scale = 1.0
ipe = len(data_loader)
num_epochs = 10
momentum_scheduler = (
    ema[0] + i * (ema[1] - ema[0]) / (ipe * num_epochs * ipe_scale)
    for i in range(int(ipe * num_epochs * ipe_scale) + 1)
)

# vit_for_predictor = torchvision.models.vit_b_32(pretrained=False) # Remove torchvision dependency
vit_for_predictor = vit_base_patch32_224()
#vit_for_embedder = torchvision.models.vit_b_32(pretrained=False) # Remove torchvision dependency
vit_for_embedder = vit_base_patch32_224()
model = IJEPA(vit_for_embedder, vit_for_predictor, momentum_scheduler)

# More lines below this

Also, why is there a discrepancy between IJEPAPredictorTIMM and the official VisionTransformerPredictor (https://github.com/facebookresearch/ijepa/blob/main/src/models/vision_transformer.py#L445) with the usage of mlp_dim and embed_dim. Do correct me if I am wrong, but I believe this is inconsistent to the original implementation.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants