Skip to content

Commit

Permalink
fix: dice loss update
Browse files Browse the repository at this point in the history
  • Loading branch information
sithu31296 committed Sep 24, 2021
1 parent 9798e02 commit efe0e4d
Showing 1 changed file with 14 additions and 13 deletions.
27 changes: 14 additions & 13 deletions utils/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,20 +58,21 @@ def __init__(self, delta: float = 0.5, aux_weights: list = [1, 0.4]):
self.delta = delta
self.aux_weights = aux_weights

def _forward(self, preds: Tensor, targets: Tensor) -> Tensor:
# preds in shape [B, C, H, W] and targets in shape [B, C, H, W]
if preds.shape[-2:] != targets.shape[-2:]:
preds = F.interpolate(preds, size=targets.shape[2:], mode='bilinear', align_corners=False)
def _forward(self, preds: Tensor, labels: Tensor) -> Tensor:
# preds in shape [B, C, H, W] and labels in shape [B, H, W]
if preds.shape[-2:] != labels.shape[-2:]:
preds = F.interpolate(preds, size=labels.shape[1:], mode='bilinear', align_corners=False)

tp = torch.sum(targets*preds, dim=(2, 3))
fn = torch.sum(targets*(1-preds), dim=(2, 3))
fp = torch.sum((1-targets)*preds, dim=(2, 3))
num_classes = preds.shape[1]
labels = F.one_hot(labels, num_classes).permute(0, 3, 1, 2)
tp = torch.sum(labels*preds, dim=(2, 3))
fn = torch.sum(labels*(1-preds), dim=(2, 3))
fp = torch.sum((1-labels)*preds, dim=(2, 3))

dice_score = (tp + 1e-6) / (tp + self.delta * fn + (1 - self.delta) * fp + 1e-6)
dice_score = torch.sum(1-dice_score, dim=-1)
dice_score = torch.sum(1 - dice_score, dim=-1)

# adjust loss to account for number of classes
dice_score = dice_score / targets.shape[1]
dice_score = dice_score / num_classes
return dice_score.mean()

def forward(self, preds, targets: Tensor) -> Tensor:
Expand All @@ -94,8 +95,8 @@ def get_loss(loss_fn_name: str = 'ce', ignore_label: int = 255, cls_weights: Ten


if __name__ == '__main__':
pred = [torch.randint(0, 19, (2, 19, 224, 224), dtype=torch.float) for _ in range(2)]
label = torch.randint(0, 19, (2, 224, 224), dtype=torch.long)
loss_fn = OhemCrossEntropy(thresh=0.7)
pred = torch.randint(0, 19, (2, 19, 480, 640), dtype=torch.float)
label = torch.randint(0, 19, (2, 480, 640), dtype=torch.long)
loss_fn = Dice()
y = loss_fn(pred, label)
print(y)

0 comments on commit efe0e4d

Please sign in to comment.