Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support calculating loss during validation #1503

Merged
merged 16 commits into from
May 17, 2024
63 changes: 63 additions & 0 deletions mmengine/runner/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from mmengine.evaluator import Evaluator
from mmengine.logging import print_log
from mmengine.registry import LOOPS
from mmengine.structures import BaseDataElement
from mmengine.utils import is_list_of
from .amp import autocast
from .base_loop import BaseLoop
from .utils import calc_dynamic_intervals
Expand Down Expand Up @@ -361,17 +363,32 @@ def __init__(self,
logger='current',
level=logging.WARNING)
self.fp16 = fp16
self.val_loss: Dict[str, list] = dict()

def run(self) -> dict:
"""Launch validation."""
self.runner.call_hook('before_val')
self.runner.call_hook('before_val_epoch')
self.runner.model.eval()

# clear val loss
self.val_loss.clear()
for idx, data_batch in enumerate(self.dataloader):
self.run_iter(idx, data_batch)

# compute metrics
metrics = self.evaluator.evaluate(len(self.dataloader.dataset))

if self.val_loss:
# get val loss and save to metrics
val_loss = 0
for loss_name, loss_value in self.val_loss.items():
avg_loss = sum(loss_value) / len(loss_value)
metrics[loss_name] = avg_loss
if 'loss' in loss_name:
val_loss += avg_loss # type: ignore
metrics['val_loss'] = val_loss

self.runner.call_hook('after_val_epoch', metrics=metrics)
self.runner.call_hook('after_val')
return metrics
Expand All @@ -389,6 +406,21 @@ def run_iter(self, idx, data_batch: Sequence[dict]):
# outputs should be sequence of BaseDataElement
with autocast(enabled=self.fp16):
outputs = self.runner.model.val_step(data_batch)
if isinstance(outputs[-1],
BaseDataElement) and outputs[-1].keys() == ['loss']:
loss = outputs[-1].loss # type: ignore
outputs = outputs[:-1]
else:
loss = dict()
# get val loss and avoid breaking change
for loss_name, loss_value in loss.items():
if loss_name not in self.val_loss:
self.val_loss[loss_name] = []
if isinstance(loss_value, torch.Tensor):
self.val_loss[loss_name].append(loss_value.item())
elif is_list_of(loss_value, torch.Tensor):
self.val_loss[loss_name].extend([v.item() for v in loss_value])

self.evaluator.process(data_samples=outputs, data_batch=data_batch)
self.runner.call_hook(
'after_val_iter',
Expand Down Expand Up @@ -433,17 +465,32 @@ def __init__(self,
logger='current',
level=logging.WARNING)
self.fp16 = fp16
self.test_loss: Dict[str, list] = dict()

def run(self) -> dict:
"""Launch test."""
self.runner.call_hook('before_test')
self.runner.call_hook('before_test_epoch')
self.runner.model.eval()

# clear test loss
self.test_loss.clear()
for idx, data_batch in enumerate(self.dataloader):
self.run_iter(idx, data_batch)

# compute metrics
metrics = self.evaluator.evaluate(len(self.dataloader.dataset))

if self.test_loss:
# get test loss and save to metrics
test_loss = 0
for loss_name, loss_value in self.test_loss.items():
avg_loss = sum(loss_value) / len(loss_value)
metrics[loss_name] = avg_loss
if 'loss' in loss_name:
test_loss += avg_loss # type: ignore
metrics['test_loss'] = test_loss

self.runner.call_hook('after_test_epoch', metrics=metrics)
self.runner.call_hook('after_test')
return metrics
Expand All @@ -460,6 +507,22 @@ def run_iter(self, idx, data_batch: Sequence[dict]) -> None:
# predictions should be sequence of BaseDataElement
with autocast(enabled=self.fp16):
outputs = self.runner.model.test_step(data_batch)
if isinstance(outputs[-1],
BaseDataElement) and outputs[-1].keys() == ['loss']:
loss = outputs[-1].loss # type: ignore
outputs = outputs[:-1]
else:
loss = dict()
# get val loss and avoid breaking change
for loss_name, loss_value in loss.items():
if loss_name not in self.test_loss:
self.test_loss[loss_name] = []
if isinstance(loss_value, torch.Tensor):
self.test_loss[loss_name].append(loss_value.item())
elif is_list_of(loss_value, torch.Tensor):
self.test_loss[loss_name].extend(
[v.item() for v in loss_value])

self.evaluator.process(data_samples=outputs, data_batch=data_batch)
self.runner.call_hook(
'after_test_iter',
Expand Down
Loading