Skip to content

Commit

Permalink
[Enhance] Extract update loss
Browse files Browse the repository at this point in the history
  • Loading branch information
fanqiNO1 committed Mar 6, 2024
1 parent 4d76d6c commit bcfc424
Showing 1 changed file with 31 additions and 30 deletions.
61 changes: 31 additions & 30 deletions mmengine/runner/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,21 +401,7 @@ def run_iter(self, idx, data_batch: Sequence[dict]):
with autocast(enabled=self.fp16):
outputs = self.runner.model.val_step(data_batch)

if isinstance(outputs[-1],
BaseDataElement) and outputs[-1].keys() == ['loss']:
loss = outputs[-1].loss # type: ignore
outputs = outputs[:-1]
else:
loss = dict()
# get val loss and avoid breaking change
for loss_name, loss_value in loss.items():
if loss_name not in self.val_loss:
self.val_loss[loss_name] = HistoryBuffer()
if isinstance(loss_value, torch.Tensor):
self.val_loss[loss_name].update(loss_value.item())
elif is_list_of(loss_value, torch.Tensor):
for loss_value_i in loss_value:
self.val_loss[loss_name].update(loss_value_i.item())
outputs, self.val_loss = _update_losses(outputs, self.val_loss)

self.evaluator.process(data_samples=outputs, data_batch=data_batch)
self.runner.call_hook(
Expand Down Expand Up @@ -498,21 +484,7 @@ def run_iter(self, idx, data_batch: Sequence[dict]) -> None:
with autocast(enabled=self.fp16):
outputs = self.runner.model.test_step(data_batch)

if isinstance(outputs[-1],
BaseDataElement) and outputs[-1].keys() == ['loss']:
loss = outputs[-1].loss # type: ignore
outputs = outputs[:-1]
else:
loss = dict()
# get val loss and avoid breaking change
for loss_name, loss_value in loss.items():
if loss_name not in self.test_loss:
self.test_loss[loss_name] = HistoryBuffer()
if isinstance(loss_value, torch.Tensor):
self.test_loss[loss_name].update(loss_value.item())
elif is_list_of(loss_value, torch.Tensor):
for loss_value_i in loss_value:
self.test_loss[loss_name].update(loss_value_i.item())
outputs, self.test_loss = _update_losses(outputs, self.test_loss)

self.evaluator.process(data_samples=outputs, data_batch=data_batch)
self.runner.call_hook(
Expand Down Expand Up @@ -545,3 +517,32 @@ def _parse_losses(losses: Dict[str, HistoryBuffer],

loss_dict[f'{stage}_loss'] = all_loss
return loss_dict


def _update_losses(outputs: list, losses: dict) -> Tuple[list, dict]:
"""Update and record the losses of the network.
Args:
outputs (list): The outputs of the network.
losses (dict): The losses of the network.
Returns:
list: The updated outputs of the network.
dict: The updated losses of the network.
"""
if isinstance(outputs[-1],
BaseDataElement) and outputs[-1].keys() == ['loss']:
loss = outputs[-1].loss # type: ignore
outputs = outputs[:-1]
else:
loss = dict()

for loss_name, loss_value in loss.items():
if loss_name not in losses:
losses[loss_name] = HistoryBuffer()
if isinstance(loss_value, torch.Tensor):
losses[loss_name].update(loss_value.item())
elif is_list_of(loss_value, torch.Tensor):
for loss_value_i in loss_value:
losses[loss_name].update(loss_value_i.item())
return outputs, losses

0 comments on commit bcfc424

Please sign in to comment.