Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bert-based models crash #49

Open
lambdavi opened this issue Jan 7, 2024 · 4 comments
Open

Bert-based models crash #49

lambdavi opened this issue Jan 7, 2024 · 4 comments

Comments

@lambdavi
Copy link

lambdavi commented Jan 7, 2024

Hi there. Thanks for the great library!

I have one issue regarding the usage of Bert-based models. I trained different models finetuning them on my custom dataset (roberta, luke, deberta, xlm-roberta etc)

I tried to do the same using the same code but I get an error (also using your code from the getting started part of the documentation).

I am using a dataset with this format:
{"tokens": ["(7)", "On", "specific", "query", "by", "the", "Bench", "about", "an", "entry", "of", "Rs.", "1,31,37,500", "on", "deposit", "side", "of", "Hongkong", "Bank", "account", "of", "which", "a", "photo", "copy", "is", "appearing", "at", "p.", "40", "of", "assessee's", "paper", "book,", "learned", "authorised", "representative", "submitted", "that", "it", "was", "related", "to", "loan", "from", "broker,", "Rahul", "&", "Co.", "on", "the", "basis", "of", "his", "submission", "a", "necessary", "mark", "is", "put", "by", "us", "on", "that", "photo", "copy."], "ner_tags": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 7, 21, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 7, 21, 21, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]}

And I load it with this script:

from datasets import load_dataset, Dataset, DatasetDict
def load_legal_ner():
    ret = {}
    for split_name in ['TRAIN', 'DEV']:
        data = []
        with open(f"./data/NER_{split_name}/NER_{split_name}_ALL_OT.jsonl", 'r') as reader:
            for line in reader:
                data.append(json.loads(line))
        ret[split_name.lower()] = Dataset.from_list(data)
    return DatasetDict(ret)

For every other model, it works perfectly. But if I try to use a bert-based model (e.g. bert-base-uncased, bert-base-cased, legal-bert etc) it crashes returning different errors, but always linked to the forward method (sometimes is related to the normalization layer, sometimes about matmul).

This is the traceback:

Cell In[8], line 28
     20 trainer = Trainer(
     21     model=model,
     22     args=args,
     23     train_dataset=dataset["train"],
     24     eval_dataset=dataset["dev"],
     25 )
     27 # Training is really simple using our Trainer!
---> 28 trainer.train()

File /opt/conda/lib/python3.10/site-packages/transformers/trainer.py:1537, in Trainer.train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
   1535         hf_hub_utils.enable_progress_bars()
   1536 else:
-> 1537     return inner_training_loop(
   1538         args=args,
   1539         resume_from_checkpoint=resume_from_checkpoint,
   1540         trial=trial,
   1541         ignore_keys_for_eval=ignore_keys_for_eval,
   1542     )

File /opt/conda/lib/python3.10/site-packages/transformers/trainer.py:1854, in Trainer._inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)
   1851     self.control = self.callback_handler.on_step_begin(args, self.state, self.control)
   1853 with self.accelerator.accumulate(model):
-> 1854     tr_loss_step = self.training_step(model, inputs)
   1856 if (
   1857     args.logging_nan_inf_filter
   1858     and not is_torch_tpu_available()
   1859     and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step))
   1860 ):
   1861     # if loss is nan or inf simply add the average of previous logged losses
   1862     tr_loss += tr_loss / (1 + self.state.global_step - self._globalstep_last_logged)

File /opt/conda/lib/python3.10/site-packages/transformers/trainer.py:2723, in Trainer.training_step(self, model, inputs)
   2720     return loss_mb.reduce_mean().detach().to(self.args.device)
   2722 with self.compute_loss_context_manager():
-> 2723     loss = self.compute_loss(model, inputs)
   2725 if self.args.n_gpu > 1:
   2726     loss = loss.mean()  # mean() to average on multi-gpu parallel training

File /opt/conda/lib/python3.10/site-packages/transformers/trainer.py:2746, in Trainer.compute_loss(self, model, inputs, return_outputs)
   2744 else:
   2745     labels = None
-> 2746 outputs = model(**inputs)
   2747 # Save past state if it exists
   2748 # TODO: this needs to be fixed and made cleaner later.
   2749 if self.args.past_index >= 0:

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File /opt/conda/lib/python3.10/site-packages/span_marker/modeling.py:153, in SpanMarkerModel.forward(self, input_ids, attention_mask, position_ids, start_marker_indices, num_marker_pairs, labels, num_words, document_ids, sentence_ids, **kwargs)
    136 """Forward call of the SpanMarkerModel.
    137 
    138 Args:
   (...)
    150     SpanMarkerOutput: The output dataclass.
    151 """
    152 token_type_ids = torch.zeros_like(input_ids)
--> 153 outputs = self.encoder(
    154     input_ids,
    155     attention_mask=attention_mask,
    156     token_type_ids=token_type_ids,
    157     position_ids=position_ids,
    158 )
    159 last_hidden_state = outputs[0]
    160 last_hidden_state = self.dropout(last_hidden_state)

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File /opt/conda/lib/python3.10/site-packages/transformers/models/bert/modeling_bert.py:1013, in BertModel.forward(self, input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, encoder_hidden_states, encoder_attention_mask, past_key_values, use_cache, output_attentions, output_hidden_states, return_dict)
   1004 head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
   1006 embedding_output = self.embeddings(
   1007     input_ids=input_ids,
   1008     position_ids=position_ids,
   (...)
   1011     past_key_values_length=past_key_values_length,
   1012 )
-> 1013 encoder_outputs = self.encoder(
   1014     embedding_output,
   1015     attention_mask=extended_attention_mask,
   1016     head_mask=head_mask,
   1017     encoder_hidden_states=encoder_hidden_states,
   1018     encoder_attention_mask=encoder_extended_attention_mask,
   1019     past_key_values=past_key_values,
   1020     use_cache=use_cache,
   1021     output_attentions=output_attentions,
   1022     output_hidden_states=output_hidden_states,
   1023     return_dict=return_dict,
   1024 )
   1025 sequence_output = encoder_outputs[0]
   1026 pooled_output = self.pooler(sequence_output) if self.pooler is not None else None

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File /opt/conda/lib/python3.10/site-packages/transformers/models/bert/modeling_bert.py:607, in BertEncoder.forward(self, hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, past_key_values, use_cache, output_attentions, output_hidden_states, return_dict)
    596     layer_outputs = self._gradient_checkpointing_func(
    597         layer_module.__call__,
    598         hidden_states,
   (...)
    604         output_attentions,
    605     )
    606 else:
--> 607     layer_outputs = layer_module(
    608         hidden_states,
    609         attention_mask,
    610         layer_head_mask,
    611         encoder_hidden_states,
    612         encoder_attention_mask,
    613         past_key_value,
    614         output_attentions,
    615     )
    617 hidden_states = layer_outputs[0]
    618 if use_cache:

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File /opt/conda/lib/python3.10/site-packages/transformers/models/bert/modeling_bert.py:497, in BertLayer.forward(self, hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, past_key_value, output_attentions)
    485 def forward(
    486     self,
    487     hidden_states: torch.Tensor,
   (...)
    494 ) -> Tuple[torch.Tensor]:
    495     # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
    496     self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
--> 497     self_attention_outputs = self.attention(
    498         hidden_states,
    499         attention_mask,
    500         head_mask,
    501         output_attentions=output_attentions,
    502         past_key_value=self_attn_past_key_value,
    503     )
    504     attention_output = self_attention_outputs[0]
    506     # if decoder, the last output is tuple of self-attn cache

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File /opt/conda/lib/python3.10/site-packages/transformers/models/bert/modeling_bert.py:436, in BertAttention.forward(self, hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, past_key_value, output_attentions)
    417 def forward(
    418     self,
    419     hidden_states: torch.Tensor,
   (...)
    425     output_attentions: Optional[bool] = False,
    426 ) -> Tuple[torch.Tensor]:
    427     self_outputs = self.self(
    428         hidden_states,
    429         attention_mask,
   (...)
    434         output_attentions,
    435     )
--> 436     attention_output = self.output(self_outputs[0], hidden_states)
    437     outputs = (attention_output,) + self_outputs[1:]  # add attentions if we output them
    438     return outputs

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File /opt/conda/lib/python3.10/site-packages/transformers/models/bert/modeling_bert.py:386, in BertSelfOutput.forward(self, hidden_states, input_tensor)
    385 def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
--> 386     hidden_states = self.dense(hidden_states)
    387     hidden_states = self.dropout(hidden_states)
    388     hidden_states = self.LayerNorm(hidden_states + input_tensor)

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/linear.py:114, in Linear.forward(self, input)
    113 def forward(self, input: Tensor) -> Tensor:
--> 114     return F.linear(input, self.weight, self.bias)

RuntimeError: CUDA error: CUBLAS_STATUS_EXECUTION_FAILED when calling cublasLtMatmul with transpose_mat1 1 transpose_mat2 0 m 768 n 3072 k 768 mat1_ld 768 mat2_ld 768 result_ld 768 abcType 0 computeType 68 scaleType 0

Or this is another traceback (same code):

RuntimeError                              Traceback (most recent call last)
Cell In[16], line 148
    139 trainer = Trainer(
    140     model=model,
    141     args=args,
   (...)
    144     compute_metrics=compute_f1
    145 )
    147 # Training is really simple using our Trainer!
--> 148 trainer.train()
    150 # ... and so is evaluating!
    151 metrics = trainer.evaluate()

File /opt/conda/lib/python3.10/site-packages/transformers/trainer.py:1537, in Trainer.train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
   1535         hf_hub_utils.enable_progress_bars()
   1536 else:
-> 1537     return inner_training_loop(
   1538         args=args,
   1539         resume_from_checkpoint=resume_from_checkpoint,
   1540         trial=trial,
   1541         ignore_keys_for_eval=ignore_keys_for_eval,
   1542     )

File /opt/conda/lib/python3.10/site-packages/transformers/trainer.py:1854, in Trainer._inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)
   1851     self.control = self.callback_handler.on_step_begin(args, self.state, self.control)
   1853 with self.accelerator.accumulate(model):
-> 1854     tr_loss_step = self.training_step(model, inputs)
   1856 if (
   1857     args.logging_nan_inf_filter
   1858     and not is_torch_tpu_available()
   1859     and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step))
   1860 ):
   1861     # if loss is nan or inf simply add the average of previous logged losses
   1862     tr_loss += tr_loss / (1 + self.state.global_step - self._globalstep_last_logged)

File /opt/conda/lib/python3.10/site-packages/transformers/trainer.py:2723, in Trainer.training_step(self, model, inputs)
   2720     return loss_mb.reduce_mean().detach().to(self.args.device)
   2722 with self.compute_loss_context_manager():
-> 2723     loss = self.compute_loss(model, inputs)
   2725 if self.args.n_gpu > 1:
   2726     loss = loss.mean()  # mean() to average on multi-gpu parallel training

File /opt/conda/lib/python3.10/site-packages/transformers/trainer.py:2746, in Trainer.compute_loss(self, model, inputs, return_outputs)
   2744 else:
   2745     labels = None
-> 2746 outputs = model(**inputs)
   2747 # Save past state if it exists
   2748 # TODO: this needs to be fixed and made cleaner later.
   2749 if self.args.past_index >= 0:

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File /opt/conda/lib/python3.10/site-packages/span_marker/modeling.py:153, in SpanMarkerModel.forward(self, input_ids, attention_mask, position_ids, start_marker_indices, num_marker_pairs, labels, num_words, document_ids, sentence_ids, **kwargs)
    136 """Forward call of the SpanMarkerModel.
    137 
    138 Args:
   (...)
    150     SpanMarkerOutput: The output dataclass.
    151 """
    152 token_type_ids = torch.zeros_like(input_ids)
--> 153 outputs = self.encoder(
    154     input_ids,
    155     attention_mask=attention_mask,
    156     token_type_ids=token_type_ids,
    157     position_ids=position_ids,
    158 )
    159 last_hidden_state = outputs[0]
    160 last_hidden_state = self.dropout(last_hidden_state)

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File /opt/conda/lib/python3.10/site-packages/transformers/models/bert/modeling_bert.py:1013, in BertModel.forward(self, input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, encoder_hidden_states, encoder_attention_mask, past_key_values, use_cache, output_attentions, output_hidden_states, return_dict)
   1004 head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
   1006 embedding_output = self.embeddings(
   1007     input_ids=input_ids,
   1008     position_ids=position_ids,
   (...)
   1011     past_key_values_length=past_key_values_length,
   1012 )
-> 1013 encoder_outputs = self.encoder(
   1014     embedding_output,
   1015     attention_mask=extended_attention_mask,
   1016     head_mask=head_mask,
   1017     encoder_hidden_states=encoder_hidden_states,
   1018     encoder_attention_mask=encoder_extended_attention_mask,
   1019     past_key_values=past_key_values,
   1020     use_cache=use_cache,
   1021     output_attentions=output_attentions,
   1022     output_hidden_states=output_hidden_states,
   1023     return_dict=return_dict,
   1024 )
   1025 sequence_output = encoder_outputs[0]
   1026 pooled_output = self.pooler(sequence_output) if self.pooler is not None else None

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File /opt/conda/lib/python3.10/site-packages/transformers/models/bert/modeling_bert.py:607, in BertEncoder.forward(self, hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, past_key_values, use_cache, output_attentions, output_hidden_states, return_dict)
    596     layer_outputs = self._gradient_checkpointing_func(
    597         layer_module.__call__,
    598         hidden_states,
   (...)
    604         output_attentions,
    605     )
    606 else:
--> 607     layer_outputs = layer_module(
    608         hidden_states,
    609         attention_mask,
    610         layer_head_mask,
    611         encoder_hidden_states,
    612         encoder_attention_mask,
    613         past_key_value,
    614         output_attentions,
    615     )
    617 hidden_states = layer_outputs[0]
    618 if use_cache:

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File /opt/conda/lib/python3.10/site-packages/transformers/models/bert/modeling_bert.py:539, in BertLayer.forward(self, hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, past_key_value, output_attentions)
    536     cross_attn_present_key_value = cross_attention_outputs[-1]
    537     present_key_value = present_key_value + cross_attn_present_key_value
--> 539 layer_output = apply_chunking_to_forward(
    540     self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
    541 )
    542 outputs = (layer_output,) + outputs
    544 # if decoder, return the attn key/values as the last output

File /opt/conda/lib/python3.10/site-packages/transformers/pytorch_utils.py:242, in apply_chunking_to_forward(forward_fn, chunk_size, chunk_dim, *input_tensors)
    239     # concatenate output at same dimension
    240     return torch.cat(output_chunks, dim=chunk_dim)
--> 242 return forward_fn(*input_tensors)

File /opt/conda/lib/python3.10/site-packages/transformers/models/bert/modeling_bert.py:552, in BertLayer.feed_forward_chunk(self, attention_output)
    550 def feed_forward_chunk(self, attention_output):
    551     intermediate_output = self.intermediate(attention_output)
--> 552     layer_output = self.output(intermediate_output, attention_output)
    553     return layer_output

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File /opt/conda/lib/python3.10/site-packages/transformers/models/bert/modeling_bert.py:466, in BertOutput.forward(self, hidden_states, input_tensor)
    464 hidden_states = self.dense(hidden_states)
    465 hidden_states = self.dropout(hidden_states)
--> 466 hidden_states = self.LayerNorm(hidden_states + input_tensor)
    467 return hidden_states

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/normalization.py:190, in LayerNorm.forward(self, input)
    189 def forward(self, input: Tensor) -> Tensor:
--> 190     return F.layer_norm(
    191         input, self.normalized_shape, self.weight, self.bias, self.eps)

File /opt/conda/lib/python3.10/site-packages/torch/nn/functional.py:2515, in layer_norm(input, normalized_shape, weight, bias, eps)
   2511 if has_torch_function_variadic(input, weight, bias):
   2512     return handle_torch_function(
   2513         layer_norm, (input, weight, bias), input, normalized_shape, weight=weight, bias=bias, eps=eps
   2514     )
-> 2515 return torch.layer_norm(input, normalized_shape, weight, bias, eps, torch.backends.cudnn.enabled)

RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions

Sometimes also this one:

/usr/local/src/pytorch/aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [313,0,0], thread: [0,0,0] Assertion `srcIndex < srcSelectDimSize` failed.

I know that probably is not much to work on. Let me know if you have any advice for me.


transformers==4.36.0
span-marker==1.5.0
torch==2.0.0

@tomaarsen
Copy link
Owner

tomaarsen commented Jan 9, 2024

Hello!

Thanks for listing the relevant versions. It's strange that e.g. XLM-RoBERTa do work, but BERT doesn't. That's understandably annoying - as the BERT models are quite good still. I'm not able to reproduce this I'm afraid.

I have two suggestions:

  1. Could you try and run inference on a model already trained with bert-base-cased, e.g. https://huggingface.co/tomaarsen/span-marker-bert-base-fewnerd-fine-super?)
  2. Could you try incrementing your torch version? X.0.0 is always a potentially buggy release, often followed with patch releases (e.g. 2.0.1).

Edit: I've not been able to reproduce this even with torch==2.0.0

  • Tom Aarsen

@moloti
Copy link

moloti commented Jan 22, 2024

Hi there,
I have the same issue.
The error occurs here:
last_hidden_state[i, start_marker_indices[i] : end_marker_indices[i]]

And then I get an Assertion 'srcIndex < srcSelectDimSize' failed which seems to an out-of-bounds indexing issue during tensor slicing in the CUDA kernel.

I wonder if this error is connected to the document_level context.

Model Config:
model = SpanMarkerModel.from_pretrained(
model_name,
labels=labels,
# SpanMarker hyperparameters:
model_max_length=512,
marker_max_length=128,
entity_max_length=8,
max_prev_context=2,
max_next_context=2,
# Model card arguments
model_card_data=SpanMarkerModelCardData(
model_id=model_id,
encoder_id=model_name,
dataset_name=dataset_name,
dataset_id=dataset_name,
license="other",
language="en",
),
)

# Prepare the 🤗 transformers training arguments
args = TrainingArguments(
    output_dir="models/span_marker_biolinkbert_base_1000",
    # Training Hyperparameters:
    learning_rate=1e-5,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    num_train_epochs=3,
    weight_decay=0.01,
    warmup_ratio=0.1,
    bf16=True,  # Replace `bf16` with `fp16` if your hardware can't use bf16.
    # Other Training parameters
    logging_first_step=True,
    logging_steps=50,
    evaluation_strategy="steps",
    save_strategy="steps",
    eval_steps=100,
    save_total_limit=2,
    dataloader_num_workers=2,
)

@moloti
Copy link

moloti commented Jan 22, 2024

Update:
It works if the max context length is set:
max_prev_context=2,
max_next_context=2

@Definelymes
Copy link

I still have issue with above and
max_prev_context=2,
max_next_context=2

Hasn't solve my issue.
I was able to finetune xlm-roberta, but bert is getting failing.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants