Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Training with 96x96 input #100

Open
rezendegabriel opened this issue Jan 10, 2023 · 1 comment
Open

Training with 96x96 input #100

rezendegabriel opened this issue Jan 10, 2023 · 1 comment

Comments

@rezendegabriel
Copy link

rezendegabriel commented Jan 10, 2023

Hello,

I am trying to adapt the code to run 96x96 inputs, more specifically the BigGAN.py script, by adding arch[96] into the functions def G_arch and def D_arch. However, I get the following error when compiling D_input = torch.cat([img for img in [G_z, x] if img is not None], 0) line in the def G_D function:

RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 0. Got 96 and 64 in dimension 2 at /pytorch/aten/src/THC/generic/THCTensorMath.cu:71

What I could investigate so far are the G_z and x tensors. They have different dimensions, torch.Size([50, 3, 64, 64]) and torch.Size([50, 3, 96, 96]), respectively. Therefore, concatenation cannot be done. I think that G_z tensor should also be in the form of x tensor. I would appreciate any help.

arch[96] = {'in_channels':  [int(ch * item) for item in [16, 16, 8, 4]],
            'out_channels': [int(ch * item) for item in [16,  8, 4, 2]],
            'upsample': [True] * 4,
            'resolution': [12, 24, 48, 96],
            'attention': {2**i + 2**(i-1): (2**i + 2**(i-1) in [int(item) for item in attention.split('_')]) for i in range(3, 7)}}
.
.
.
arch[96] = {'in_channels':  [3] + [int(ch*item) for item in [1, 2, 4, 8]],
            'out_channels': [int(item * ch) for item in [1, 2, 4, 8, 16]],
            'downsample': [True] * 4 + [False],
            'resolution': [48, 24, 12, 6, 6],
            'attention': {2**i + 2**(i-1): 2**i + 2**(i-1) in [int(item) for item in attention.split('_')] for i in range(2, 7)}}
@rezendegabriel
Copy link
Author

I solved the problem by changing the bottom_width variable's default to 6 in the Generator class.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant