We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Hi, thanks for your excellent work. I don't know why others haven't bring this up but when I tried to run the small demo in the readme for the Mamab2:
from mamba_ssm import Mamba2 batch, length, dim = 2, 64, 16 x = torch.randn(batch, length, dim).to("cuda") model = Mamba2( # This module uses roughly 3 * expand * d_model^2 parameters d_model=32, # Model dimension d_model d_state=64, # SSM state expansion factor, typically 64 or 128 d_conv=4, # Local convolution width expand=2, # Block expansion factor ).to("cuda") y = model(x) assert y.shape == x.shape
I got RuntimeError: mat1 and mat2 shapes cannot be multiplied (128x16 and 32x385)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (128x16 and 32x385)
Any idea what happend? Thanks
The text was updated successfully, but these errors were encountered:
Perhaps dim should be equal to d_model
Sorry, something went wrong.
Hi, thanks for your excellent work. I don't know why others haven't bring this up but when I tried to run the small demo in the readme for the Mamab2: from mamba_ssm import Mamba2 batch, length, dim = 2, 64, 16 x = torch.randn(batch, length, dim).to("cuda") model = Mamba2( # This module uses roughly 3 * expand * d_model^2 parameters d_model=32, # Model dimension d_model d_state=64, # SSM state expansion factor, typically 64 or 128 d_conv=4, # Local convolution width expand=2, # Block expansion factor ).to("cuda") y = model(x) assert y.shape == x.shape I got RuntimeError: mat1 and mat2 shapes cannot be multiplied (128x16 and 32x385) Any idea what happend? Thanks
Change arg headdim to 16 and d_model to 16
No branches or pull requests
Hi, thanks for your excellent work. I don't know why others haven't bring this up but when I tried to run the small demo in the readme for the Mamab2:
I got
RuntimeError: mat1 and mat2 shapes cannot be multiplied (128x16 and 32x385)
Any idea what happend? Thanks
The text was updated successfully, but these errors were encountered: