Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Confusion about PyTorch checkpoint for online BootsTAPIR #125

Open
weirenorweiren opened this issue Oct 30, 2024 · 0 comments
Open

Confusion about PyTorch checkpoint for online BootsTAPIR #125

weirenorweiren opened this issue Oct 30, 2024 · 0 comments

Comments

@weirenorweiren
Copy link

I realize the PyTorch checkpoints for online BootsTAPIR in torch_causal_tapir_demo.ipynb and README are not the same. The former is https://storage.googleapis.com/dm-tapnet/causal_bootstapir_checkpoint.pt but the later is https://storage.googleapis.com/dm-tapnet/bootstap/causal_bootstapir_checkpoint.pt with an extra directory of bootstap. The one for the demo works okay but there are runtime problems for the checkpoint in README. I have listed all problems below. Could you please clarify on the following questions?

  1. Which checkpoint is the right one to use? Now only the one in demo works but I am not sure whether that's the right one.
  2. If the one in README is the right one, how to make it work?
  3. Do you have a PyTorch checkpoint for online TAPIR?

pyramid_level=1

wget -P F:\Wei\tapnet\tapnet\checkpoints https://storage.googleapis.com/dm-tapnet/bootstap/causal_bootstapir_checkpoint.pt
model = tapir_model.TAPIR(pyramid_level=1, use_casual_conv=True)
model.load_state_dict(torch.load("tapnet/checkpoints/causal_bootstapir_checkpoint.pt"))

Traceback (most recent call last):
  File "F:\Wei\tapnet\tapnet\live_demo_thorlab_camera_torch.py", line 17, in <module>
    model.load_state_dict(torch.load("tapnet/checkpoints/causal_bootstapir_checkpoint.pt"))
  File "C:\Users\NOCB\anaconda3\envs\tapnet_torch\Lib\site-packages\torch\nn\modules\module.py", line 2215, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for TAPIR:
        size mismatch for extra_convs.blocks.0.conv.weight: copying a param with shape torch.Size([1024, 768, 3, 3]) from checkpoint, the shape in current model is torch.Size([1024, 256, 3, 3]).
        size mismatch for extra_convs.blocks.1.conv.weight: copying a param with shape torch.Size([1024, 768, 3, 3]) from checkpoint, the shape in current model is torch.Size([1024, 256, 3, 3]).
        size mismatch for extra_convs.blocks.2.conv.weight: copying a param with shape torch.Size([1024, 768, 3, 3]) from checkpoint, the shape in current model is torch.Size([1024, 256, 3, 3]).
        size mismatch for extra_convs.blocks.3.conv.weight: copying a param with shape torch.Size([1024, 768, 3, 3]) from checkpoint, the shape in current model is torch.Size([1024, 256, 3, 3]).
        size mismatch for extra_convs.blocks.4.conv.weight: copying a param with shape torch.Size([1024, 768, 3, 3]) from checkpoint, the shape in current model is torch.Size([1024, 256, 3, 3]).

pyramid_level=0

wget -P F:\Wei\tapnet\tapnet\checkpoints https://storage.googleapis.com/dm-tapnet/bootstap/causal_bootstapir_checkpoint.pt
model = tapir_model.TAPIR(pyramid_level=0, use_casual_conv=True)
model.load_state_dict(torch.load("tapnet/checkpoints/causal_bootstapir_checkpoint.pt"))

Traceback (most recent call last):
  File "F:\Wei\tapnet\tapnet\live_demo_thorlab_camera_torch.py", line 17, in <module>
    model.load_state_dict(torch.load("tapnet/checkpoints/causal_bootstapir_checkpoint.pt"))
  File "C:\Users\NOCB\anaconda3\envs\tapnet_torch\Lib\site-packages\torch\nn\modules\module.py", line 2215, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for TAPIR:
        size mismatch for torch_pips_mixer.linear.weight: copying a param with shape torch.Size([512, 535]) from checkpoint, the shape in current model is torch.Size([512, 486]).
        size mismatch for extra_convs.blocks.0.conv.weight: copying a param with shape torch.Size([1024, 768, 3, 3]) from checkpoint, the shape in current model is torch.Size([1024, 256, 3, 3]).
        size mismatch for extra_convs.blocks.1.conv.weight: copying a param with shape torch.Size([1024, 768, 3, 3]) from checkpoint, the shape in current model is torch.Size([1024, 256, 3, 3]).
        size mismatch for extra_convs.blocks.2.conv.weight: copying a param with shape torch.Size([1024, 768, 3, 3]) from checkpoint, the shape in current model is torch.Size([1024, 256, 3, 3]).
        size mismatch for extra_convs.blocks.3.conv.weight: copying a param with shape torch.Size([1024, 768, 3, 3]) from checkpoint, the shape in current model is torch.Size([1024, 256, 3, 3]).
        size mismatch for extra_convs.blocks.4.conv.weight: copying a param with shape torch.Size([1024, 768, 3, 3]) from checkpoint, the shape in current model is torch.Size([1024, 256, 3, 3]).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant