diff --git a/CHANGELOG.md b/CHANGELOG.md index 0488f261..8e139ae1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,12 @@ This changelog mostly collects important changes to the models that are not backward compatible but result in different results. +## 23.12.1 + +### Breaking changes + +* #475 (changed normalization of repulsive hinge loss for GC) + ## 23.12.0 ### Breaking changes diff --git a/src/gnn_tracking/metrics/losses/metric_learning.py b/src/gnn_tracking/metrics/losses/metric_learning.py index 649fe387..a5172fee 100644 --- a/src/gnn_tracking/metrics/losses/metric_learning.py +++ b/src/gnn_tracking/metrics/losses/metric_learning.py @@ -19,6 +19,7 @@ def _hinge_loss_components( r_emb_hinge: float, p_attr: float, p_rep: float, + n_hits_oi: int, ) -> tuple[T, T]: eps = 1e-9 @@ -27,8 +28,14 @@ def _hinge_loss_components( v_att = torch.sum(torch.pow(dists_att, p_attr)) / norm_att dists_rep = norm(x[rep_edges[0]] - x[rep_edges[1]], dim=-1) - norm_rep = rep_edges.shape[1] + eps - v_rep = r_emb_hinge - torch.sum(torch.pow(dists_rep, p_rep)) / norm_rep + # There is no "good" way to normalize this: The naive way would be + # to normalize to the number of repulsive edges, but this number + # gets smaller and smaller as the training progresses, making the objective + # increasingly harder. + # The maximal number of edges that can be in the radius graph is proportional + # to the number of hits of interest, so we normalize by this number. + norm_rep = n_hits_oi + eps + v_rep = torch.sum(r_emb_hinge - torch.pow(dists_rep, p_rep)) / norm_rep return v_att, v_rep @@ -100,6 +107,8 @@ def forward( pt_thld=self.hparams.pt_thld, max_eta=self.hparams.max_eta, ) + # oi = of interest + n_hits_oi = mask.sum() att_edges, rep_edges = self._get_edges( x=x, batch=batch, @@ -114,6 +123,7 @@ def forward( r_emb_hinge=self.hparams.r_emb, p_attr=self.hparams.p_attr, p_rep=self.hparams.p_rep, + n_hits_oi=n_hits_oi, ) losses = { "attractive": attr, @@ -123,7 +133,13 @@ def forward( "attractive": 1.0, "repulsive": self.hparams.lw_repulsive, } + extra = { + "n_hits_oi": n_hits_oi, + "n_edges_att": att_edges.shape[1], + "n_edges_rep": rep_edges.shape[1], + } return MultiLossFctReturn( loss_dct=losses, weight_dct=weights, + extra_metrics=extra, ) diff --git a/src/gnn_tracking/version.txt b/src/gnn_tracking/version.txt index 34291177..e42e273b 100644 --- a/src/gnn_tracking/version.txt +++ b/src/gnn_tracking/version.txt @@ -1 +1 @@ -23.12.0 +23.12.1 diff --git a/tests/test_losses.py b/tests/test_losses.py index 56b649fd..e84ff823 100644 --- a/tests/test_losses.py +++ b/tests/test_losses.py @@ -25,6 +25,7 @@ T: TypeAlias = torch.Tensor +# Ignore print statements # ruff: noqa: T201 @@ -192,7 +193,7 @@ def get_ml_loss(loss_fct: Callable, td: MockData) -> dict[str, float]: def test_hinge_loss(): assert get_ml_loss(GraphConstructionHingeEmbeddingLoss(), td1) == approx( - {"attractive": 0.7307405975481213, "repulsive": 0.34612957938781874} + {"attractive": 0.7307405975481213, "repulsive": 11.076146539572338} )