-
Notifications
You must be signed in to change notification settings - Fork 442
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
The model requires num_beams
, although it is not needed in the example
#105
Comments
Hey @LEv145 , thank you for bringing that up! The The parameters can be found in the |
Thanks it works! Load checkpoint from /mnt/store/models/rugpt3xl/mp_rank_00_model_states.pt
Model Loaded
Traceback (most recent call last):
File "/mnt/store/tests/test_rugpt3xl.py", line 29, in <module>
main()
File "/mnt/store/tests/test_rugpt3xl.py", line 19, in main
result = gpt.generate(
File "/opt/ru-gpts/src/xl_wrapper.py", line 244, in generate
return list(map(self.tokenizer.decode, res.tolist()))
AttributeError: 'NoneType' object has no attribute 'tolist' Codeimport os
import sys
sys.path.append("/opt/ru-gpts/")
os.environ["USE_DEEPSPEED"] = "1"
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "5000"
from src.xl_wrapper import RuGPT3XL
def main():
gpt = RuGPT3XL.from_pretrained(
"sberbank-ai/rugpt3xl",
weights_path="/mnt/store/models/rugpt3xl/mp_rank_00_model_states.pt",
seq_len=512,
)
result = gpt.generate(
"Кто был президентом США в 2020? ",
max_length=50,
num_beams=5,
early_stopping=True,
)
print(result)
if __name__ == "__main__":
main() |
I have the same problem while generating text with the model. |
Ubuntu 20.04
pytorch==1.11.0a0+17540c5c
NVIDIA CUDA 11.6.0
TensorRT 8.2.3
transformers==4.26.1
apex
NVIDIA/apex@0c8400a or (qywu/apex@798a36c with patch_amp_state.py
)deepspeed==0.8.0
triton==1.0.0
timm==0.3.2
Code:
Error:
I don't know what
num_beams
does and how to make it work, but I would be happy to helpPip freeze
The text was updated successfully, but these errors were encountered: