From bcfc424817d59f29d65fb176825c09e352a6a0b7 Mon Sep 17 00:00:00 2001 From: fanqiNO1 <1848839264@qq.com> Date: Wed, 6 Mar 2024 14:35:40 +0800 Subject: [PATCH] [Enhance] Extract update loss --- mmengine/runner/loops.py | 61 ++++++++++++++++++++-------------------- 1 file changed, 31 insertions(+), 30 deletions(-) diff --git a/mmengine/runner/loops.py b/mmengine/runner/loops.py index 4f60c3328f..bc80bd3b7c 100644 --- a/mmengine/runner/loops.py +++ b/mmengine/runner/loops.py @@ -401,21 +401,7 @@ def run_iter(self, idx, data_batch: Sequence[dict]): with autocast(enabled=self.fp16): outputs = self.runner.model.val_step(data_batch) - if isinstance(outputs[-1], - BaseDataElement) and outputs[-1].keys() == ['loss']: - loss = outputs[-1].loss # type: ignore - outputs = outputs[:-1] - else: - loss = dict() - # get val loss and avoid breaking change - for loss_name, loss_value in loss.items(): - if loss_name not in self.val_loss: - self.val_loss[loss_name] = HistoryBuffer() - if isinstance(loss_value, torch.Tensor): - self.val_loss[loss_name].update(loss_value.item()) - elif is_list_of(loss_value, torch.Tensor): - for loss_value_i in loss_value: - self.val_loss[loss_name].update(loss_value_i.item()) + outputs, self.val_loss = _update_losses(outputs, self.val_loss) self.evaluator.process(data_samples=outputs, data_batch=data_batch) self.runner.call_hook( @@ -498,21 +484,7 @@ def run_iter(self, idx, data_batch: Sequence[dict]) -> None: with autocast(enabled=self.fp16): outputs = self.runner.model.test_step(data_batch) - if isinstance(outputs[-1], - BaseDataElement) and outputs[-1].keys() == ['loss']: - loss = outputs[-1].loss # type: ignore - outputs = outputs[:-1] - else: - loss = dict() - # get val loss and avoid breaking change - for loss_name, loss_value in loss.items(): - if loss_name not in self.test_loss: - self.test_loss[loss_name] = HistoryBuffer() - if isinstance(loss_value, torch.Tensor): - self.test_loss[loss_name].update(loss_value.item()) - elif is_list_of(loss_value, torch.Tensor): - for loss_value_i in loss_value: - self.test_loss[loss_name].update(loss_value_i.item()) + outputs, self.test_loss = _update_losses(outputs, self.test_loss) self.evaluator.process(data_samples=outputs, data_batch=data_batch) self.runner.call_hook( @@ -545,3 +517,32 @@ def _parse_losses(losses: Dict[str, HistoryBuffer], loss_dict[f'{stage}_loss'] = all_loss return loss_dict + + +def _update_losses(outputs: list, losses: dict) -> Tuple[list, dict]: + """Update and record the losses of the network. + + Args: + outputs (list): The outputs of the network. + losses (dict): The losses of the network. + + Returns: + list: The updated outputs of the network. + dict: The updated losses of the network. + """ + if isinstance(outputs[-1], + BaseDataElement) and outputs[-1].keys() == ['loss']: + loss = outputs[-1].loss # type: ignore + outputs = outputs[:-1] + else: + loss = dict() + + for loss_name, loss_value in loss.items(): + if loss_name not in losses: + losses[loss_name] = HistoryBuffer() + if isinstance(loss_value, torch.Tensor): + losses[loss_name].update(loss_value.item()) + elif is_list_of(loss_value, torch.Tensor): + for loss_value_i in loss_value: + losses[loss_name].update(loss_value_i.item()) + return outputs, losses