Skip to content

Commit

Permalink
IMPORTANT CHANGE: Normalize rep hinge loss to number of hits of inter…
Browse files Browse the repository at this point in the history
…est (#475)

* IMPORTANT CHANGE: Normalize rep hinge loss to n hits oi

* Bump version
  • Loading branch information
klieret authored Dec 21, 2023
1 parent 960f83f commit 18a9ce5
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 4 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@
This changelog mostly collects important changes to the models that are not
backward compatible but result in different results.

## 23.12.1

### Breaking changes

* #475 (changed normalization of repulsive hinge loss for GC)

## 23.12.0

### Breaking changes
Expand Down
20 changes: 18 additions & 2 deletions src/gnn_tracking/metrics/losses/metric_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def _hinge_loss_components(
r_emb_hinge: float,
p_attr: float,
p_rep: float,
n_hits_oi: int,
) -> tuple[T, T]:
eps = 1e-9

Expand All @@ -27,8 +28,14 @@ def _hinge_loss_components(
v_att = torch.sum(torch.pow(dists_att, p_attr)) / norm_att

dists_rep = norm(x[rep_edges[0]] - x[rep_edges[1]], dim=-1)
norm_rep = rep_edges.shape[1] + eps
v_rep = r_emb_hinge - torch.sum(torch.pow(dists_rep, p_rep)) / norm_rep
# There is no "good" way to normalize this: The naive way would be
# to normalize to the number of repulsive edges, but this number
# gets smaller and smaller as the training progresses, making the objective
# increasingly harder.
# The maximal number of edges that can be in the radius graph is proportional
# to the number of hits of interest, so we normalize by this number.
norm_rep = n_hits_oi + eps
v_rep = torch.sum(r_emb_hinge - torch.pow(dists_rep, p_rep)) / norm_rep

return v_att, v_rep

Expand Down Expand Up @@ -100,6 +107,8 @@ def forward(
pt_thld=self.hparams.pt_thld,
max_eta=self.hparams.max_eta,
)
# oi = of interest
n_hits_oi = mask.sum()
att_edges, rep_edges = self._get_edges(
x=x,
batch=batch,
Expand All @@ -114,6 +123,7 @@ def forward(
r_emb_hinge=self.hparams.r_emb,
p_attr=self.hparams.p_attr,
p_rep=self.hparams.p_rep,
n_hits_oi=n_hits_oi,
)
losses = {
"attractive": attr,
Expand All @@ -123,7 +133,13 @@ def forward(
"attractive": 1.0,
"repulsive": self.hparams.lw_repulsive,
}
extra = {
"n_hits_oi": n_hits_oi,
"n_edges_att": att_edges.shape[1],
"n_edges_rep": rep_edges.shape[1],
}
return MultiLossFctReturn(
loss_dct=losses,
weight_dct=weights,
extra_metrics=extra,
)
2 changes: 1 addition & 1 deletion src/gnn_tracking/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
23.12.0
23.12.1
3 changes: 2 additions & 1 deletion tests/test_losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

T: TypeAlias = torch.Tensor

# Ignore print statements
# ruff: noqa: T201


Expand Down Expand Up @@ -192,7 +193,7 @@ def get_ml_loss(loss_fct: Callable, td: MockData) -> dict[str, float]:

def test_hinge_loss():
assert get_ml_loss(GraphConstructionHingeEmbeddingLoss(), td1) == approx(
{"attractive": 0.7307405975481213, "repulsive": 0.34612957938781874}
{"attractive": 0.7307405975481213, "repulsive": 11.076146539572338}
)


Expand Down

0 comments on commit 18a9ce5

Please sign in to comment.