diff --git a/pertpy/tools/_distances/_distances.py b/pertpy/tools/_distances/_distances.py index 70d1096b..0d09ccc2 100644 --- a/pertpy/tools/_distances/_distances.py +++ b/pertpy/tools/_distances/_distances.py @@ -363,7 +363,7 @@ def solve_ot_problem(self, geom: Geometry, **kwargs): solver = Sinkhorn() # Solve OT problem ot = solver(ot_prob, **kwargs) - return ot.reg_ot_cost + return ot.reg_ot_cost.item() class PseudobulkDistance(AbstractDistance):