You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
TorchMetrics v0.9 is now out, and it brings significant changes to how the forward method works. This blog post goes over these improvements and how they affect both users of TorchMetrics and users that implement custom metrics. TorchMetrics v0.9 also includes several new metrics and bug fixes.
Since the beginning of TorchMetrics, Forward has served the dual purpose of calculating the metric on the current batch and accumulating in a global state. Internally, this was achieved by calling update twice: one for each purpose, which meant repeating the same computation. However, for many metrics, calling update twice is unnecessary to achieve both the local batch statistics and accumulating globally because the global statistics are simple reductions of the local batch states.
In v0.9, we have finally implemented a logic that can take advantage of this and will only call update once before making a simple reduction. As you can see in the figure below, this can lead to a single call of forward being 2x faster in v0.9 compared to v0.8 of the same metric.
With the improvements to forward, many metrics have become significantly faster (up to 2x)
It should be noted that this change mainly benefits metrics (for example, confusionmatrix) where calling update is expensive.
We went through all existing metrics in TorchMetrics and enabled this feature for all appropriate metrics, which was almost 95% of all metrics. We want to stress that if you are using metrics from TorchMetrics, nothing has changed to the API, and no code changes are necessary.
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
Highligths
TorchMetrics v0.9 is now out, and it brings significant changes to how the forward method works. This blog post goes over these improvements and how they affect both users of TorchMetrics and users that implement custom metrics. TorchMetrics v0.9 also includes several new metrics and bug fixes.
Blog: TorchMetrics v0.9 — Faster forward
The Story of the Forward Method
Since the beginning of TorchMetrics, Forward has served the dual purpose of calculating the metric on the current batch and accumulating in a global state. Internally, this was achieved by calling update twice: one for each purpose, which meant repeating the same computation. However, for many metrics, calling update twice is unnecessary to achieve both the local batch statistics and accumulating globally because the global statistics are simple reductions of the local batch states.
In v0.9, we have finally implemented a logic that can take advantage of this and will only call update once before making a simple reduction. As you can see in the figure below, this can lead to a single call of forward being 2x faster in v0.9 compared to v0.8 of the same metric.
With the improvements to forward, many metrics have become significantly faster (up to 2x)
It should be noted that this change mainly benefits metrics (for example,
confusionmatrix
) where calling update is expensive.We went through all existing metrics in TorchMetrics and enabled this feature for all appropriate metrics, which was almost 95% of all metrics. We want to stress that if you are using metrics from TorchMetrics, nothing has changed to the API, and no code changes are necessary.
[0.9.0] - 2022-05-31
Added
RetrievalPrecisionRecallCurve
andRetrievalRecallAtFixedPrecision
to retrieval package (Metric retrieval recall at precision #951)full_state_update
that determinesforward
should callupdate
once or twice (Refactor/remove double forward #984,forward upgrade utils #1033)Dice
to classification package (Dice score as metric #1021)segm
as IOU for mean average precision (IOU with segm masks and MAP for instance segment. #822)Changed
reduction
argument toaverage
in Jaccard score and added additional options (added micro average option for torch metrics #874)Removed
compute_on_step
argument (Removed Deprecatedcompute_on_step
from Classification #962, Remove deprecatedcompute_on_step
in Regression #967, remove deprecated compute_on_step from torchmetrics/image #979 ,Removecompute_on_step
from aggregation and tests #990, Removecompute_on_step
from wrappers #991, remove compute_on_step in torchmetrics/retrieval #993, Removecompute_on_step
from detection #1005, Removecompute_on_step
from text #1004, Removed Deprecated compute_on_step from audio #1007)Fixed
dict
for a few metrics (Fixed non-empty state dict for a few metrics #1012)torch.double
support in stat score metrics (Double support for stat score metrics #1023)FID
calculation for non-equal size real and fake input (Fix FID computation for non equal size #1028)KLDivergence
could outputNan
(Fix nan in KL Divergence #1030)mdmc_average
inAccuracy
(Fixed default value for mdmc_average in Accuracy as per documentation #1036)MetricCollection
(Fix metric collection missing update of property #1052)Contributors
@Borda, @burglarhobbit, @charlielito, @gianscarpe, @MrShevan, @phaseolud, @razmikmelikbekyan, @SkafteNicki, @tanmoyio, @vumichien
If we forgot someone due to not matching commit email with GitHub account, let us know :]
This discussion was created from the release Faster forward.
Beta Was this translation helpful? Give feedback.
All reactions