Skip to content

Commit

Permalink
[core][combinators] workaround issue with nested boundaries (#850)
Browse files Browse the repository at this point in the history
This PR applies the workaround identified by @hearnadam with an `inline`
indirection. I think we should still keep #804 open to explore a more
proper solution. Users might need to define methods with boundaries and
the workaround can be confusing.
  • Loading branch information
fwbrasil authored Nov 19, 2024
1 parent 35cdf72 commit 0f3701d
Show file tree
Hide file tree
Showing 5 changed files with 184 additions and 107 deletions.
58 changes: 42 additions & 16 deletions kyo-combinators/shared/src/main/scala/kyo/Combinators.scala
Original file line number Diff line number Diff line change
Expand Up @@ -637,14 +637,12 @@ extension [A, E, Ctx](effect: A < (Abort[E] & Async & Ctx))
* @return
* A computation that produces the result of this computation with Async effect
*/
def fork(
inline def fork(
using
flat: Flat[A],
boundary: Boundary[Ctx, IO & Abort[E]],
reduce: Reducible[Abort[E]],
frame: Frame
): Fiber[E, A] < (IO & Ctx) =
Async.run(effect)(using flat, boundary)
Async.run(effect)

/** Forks this computation using the Async effect and returns its result as a `Fiber[E, A]`, managed by the Resource effect. Unlike
* `fork`, which creates an unmanaged fiber, `forkScoped` ensures that the fiber is properly cleaned up when the enclosing scope is
Expand All @@ -653,11 +651,9 @@ extension [A, E, Ctx](effect: A < (Abort[E] & Async & Ctx))
* @return
* A computation that produces the result of this computation with Async and Resource effects
*/
def forkScoped(
inline def forkScoped(
using
flat: Flat[A],
boundary: Boundary[Ctx, IO & Abort[E]],
reduce: Reducible[Abort[E]],
frame: Frame
): Fiber[E, A] < (IO & Ctx & Resource) =
Kyo.acquireRelease(Async.run(effect))(_.interrupt.unit)
Expand Down Expand Up @@ -693,7 +689,17 @@ extension [A, E, Ctx](effect: A < (Abort[E] & Async & Ctx))
* A computation that produces the result of `next`
*/
@targetName("zipRightPar")
def &>[A1, E1, Ctx1](next: A1 < (Abort[E1] & Async & Ctx1))(
inline def &>[A1, E1, Ctx1](next: A1 < (Abort[E1] & Async & Ctx1))(
using
f: Flat[A],
f1: Flat[A1],
r: Reducible[Abort[E]],
r1: Reducible[Abort[E1]],
fr: Frame
): A1 < (r.SReduced & r1.SReduced & Async & Ctx & Ctx1) =
_zipRightPar(next)

private def _zipRightPar[A1, E1, Ctx1](next: A1 < (Abort[E1] & Async & Ctx1))(
using
f: Flat[A],
f1: Flat[A1],
Expand All @@ -704,8 +710,8 @@ extension [A, E, Ctx](effect: A < (Abort[E] & Async & Ctx))
fr: Frame
): A1 < (r.SReduced & r1.SReduced & Async & Ctx & Ctx1) =
for
fiberA <- effect.fork
fiberA1 <- next.fork
fiberA <- Async._run(effect)
fiberA1 <- Async._run(next)
_ <- fiberA.awaitCompletion
a1 <- fiberA1.join
yield a1
Expand All @@ -718,7 +724,17 @@ extension [A, E, Ctx](effect: A < (Abort[E] & Async & Ctx))
* A computation that produces the result of this computation
*/
@targetName("zipLeftPar")
def <&[A1, E1, Ctx1](next: A1 < (Abort[E1] & Async & Ctx1))(
inline def <&[A1, E1, Ctx1](next: A1 < (Abort[E1] & Async & Ctx1))(
using
f: Flat[A],
f1: Flat[A1],
r: Reducible[Abort[E]],
r1: Reducible[Abort[E1]],
fr: Frame
): A < (r.SReduced & r1.SReduced & Async & Ctx & Ctx1) =
_zipLeftPar(next)

private def _zipLeftPar[A1, E1, Ctx1](next: A1 < (Abort[E1] & Async & Ctx1))(
using
f: Flat[A],
f1: Flat[A1],
Expand All @@ -729,8 +745,8 @@ extension [A, E, Ctx](effect: A < (Abort[E] & Async & Ctx))
fr: Frame
): A < (r.SReduced & r1.SReduced & Async & Ctx & Ctx1) =
for
fiberA <- effect.fork
fiberA1 <- next.fork
fiberA <- Async._run(effect)
fiberA1 <- Async._run(next)
a <- fiberA.join
_ <- fiberA1.awaitCompletion
yield a
Expand All @@ -743,7 +759,17 @@ extension [A, E, Ctx](effect: A < (Abort[E] & Async & Ctx))
* A computation that produces a tuple of both results
*/
@targetName("zipPar")
def <&>[A1, E1, Ctx1](next: A1 < (Abort[E1] & Async & Ctx1))(
inline def <&>[A1, E1, Ctx1](next: A1 < (Abort[E1] & Async & Ctx1))(
using
f: Flat[A],
f1: Flat[A1],
r: Reducible[Abort[E]],
r1: Reducible[Abort[E1]],
fr: Frame
): (A, A1) < (r.SReduced & r1.SReduced & Async & Ctx & Ctx1) =
_zipPar(next)

private def _zipPar[A1, E1, Ctx1](next: A1 < (Abort[E1] & Async & Ctx1))(
using
f: Flat[A],
f1: Flat[A1],
Expand All @@ -754,8 +780,8 @@ extension [A, E, Ctx](effect: A < (Abort[E] & Async & Ctx))
fr: Frame
): (A, A1) < (r.SReduced & r1.SReduced & Async & Ctx & Ctx1) =
for
fiberA <- effect.fork
fiberA1 <- next.fork
fiberA <- Async._run(effect)
fiberA1 <- Async._run(next)
a <- fiberA.join
a1 <- fiberA1.join
yield (a, a1)
Expand Down
7 changes: 2 additions & 5 deletions kyo-combinators/shared/src/main/scala/kyo/Constructors.scala
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package kyo

import java.io.IOException
import kyo.kernel.Boundary
import kyo.kernel.Reducible
import scala.annotation.tailrec
import scala.concurrent.Future
Expand Down Expand Up @@ -109,10 +108,9 @@ extension (kyoObject: Kyo.type)
* @return
* A new sequence with elements collected using the function
*/
def foreachPar[E, A, S, A1, Ctx](sequence: Seq[A])(useElement: A => A1 < (Abort[E] & Async & Ctx))(
inline def foreachPar[E, A, S, A1, Ctx](sequence: Seq[A])(useElement: A => A1 < (Abort[E] & Async & Ctx))(
using
flat: Flat[A1],
boundary: Boundary[Ctx, Async & Abort[E]],
frame: Frame
): Seq[A1] < (Abort[E] & Async & Ctx) =
Async.parallelUnbounded[E, A1, Ctx](sequence.map(useElement))
Expand All @@ -126,10 +124,9 @@ extension (kyoObject: Kyo.type)
* @return
* Discards the results of the function application and returns Unit
*/
def foreachParDiscard[E, A, S, A1, Ctx](sequence: Seq[A])(useElement: A => A1 < (Abort[E] & Async & Ctx))(
inline def foreachParDiscard[E, A, S, A1, Ctx](sequence: Seq[A])(useElement: A => A1 < (Abort[E] & Async & Ctx))(
using
flat: Flat[A1],
boundary: Boundary[Ctx, Async & Abort[E]],
frame: Frame
): Unit < (Abort[E] & Async & Ctx) =
foreachPar(sequence)(useElement).unit
Expand Down
85 changes: 55 additions & 30 deletions kyo-core/shared/src/main/scala/kyo/Async.scala
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@ object Async:
* A Fiber representing the running computation
*/
inline def run[E, A: Flat, Ctx](inline v: => A < (Abort[E] & Async & Ctx))(
using frame: Frame
): Fiber[E, A] < (IO & Ctx) =
_run(v)

private[kyo] inline def _run[E, A: Flat, Ctx](inline v: => A < (Abort[E] & Async & Ctx))(
using
boundary: Boundary[Ctx, IO & Abort[E]],
frame: Frame
Expand All @@ -54,12 +59,17 @@ object Async:
* @return
* The result of the computation, or a Timeout error
*/
def runAndBlock[E, A: Flat, Ctx](timeout: Duration)(v: => A < (Abort[E] & Async & Ctx))(
inline def runAndBlock[E, A: Flat, Ctx](timeout: Duration)(v: => A < (Abort[E] & Async & Ctx))(
using frame: Frame
): A < (Abort[E | Timeout] & IO & Ctx) =
_runAndBlock(timeout)(v)

private def _runAndBlock[E, A: Flat, Ctx](timeout: Duration)(v: => A < (Abort[E] & Async & Ctx))(
using
boundary: Boundary[Ctx, IO & Abort[E | Timeout]],
frame: Frame
): A < (Abort[E | Timeout] & IO & Ctx) =
run(v).map { fiber =>
_run(v).map { fiber =>
fiber.block(timeout).map(Abort.get(_))
}

Expand All @@ -74,12 +84,17 @@ object Async:
* @return
* The result of the computation, which can still be interrupted
*/
def mask[E, A: Flat, Ctx](v: => A < (Abort[E] & Async & Ctx))(
inline def mask[E, A: Flat, Ctx](v: => A < (Abort[E] & Async & Ctx))(
using frame: Frame
): A < (Abort[E] & Async & Ctx) =
_mask(v)

private def _mask[E, A: Flat, Ctx](v: => A < (Abort[E] & Async & Ctx))(
using
boundary: Boundary[Ctx, Async & Abort[E]],
frame: Frame
): A < (Abort[E] & Async & Ctx) =
Async.run(v).map(_.mask.map(_.get))
_run(v).map(_.mask.map(_.get))

/** Delays execution of a computation by a specified duration.
*
Expand Down Expand Up @@ -111,7 +126,12 @@ object Async:
* @return
* The result of the computation, or a Timeout error
*/
def timeout[E, A: Flat, Ctx](after: Duration)(v: => A < (Abort[E] & Async & Ctx))(
inline def timeout[E, A: Flat, Ctx](after: Duration)(v: => A < (Abort[E] & Async & Ctx))(
using frame: Frame
): A < (Abort[E | Timeout] & Async & Ctx) =
_timeout(after)(v)

private def _timeout[E, A: Flat, Ctx](after: Duration)(v: => A < (Abort[E] & Async & Ctx))(
using
boundary: Boundary[Ctx, Async & Abort[E | Timeout]],
frame: Frame
Expand All @@ -129,7 +149,6 @@ object Async:
}
}
}
end timeout

/** Races multiple computations and returns the result of the first to complete. When one computation completes, all other computations
* are interrupted.
Expand All @@ -139,13 +158,18 @@ object Async:
* @return
* The result of the first computation to complete
*/
def race[E, A: Flat, Ctx](seq: Seq[A < (Abort[E] & Async & Ctx)])(
inline def race[E, A: Flat, Ctx](seq: Seq[A < (Abort[E] & Async & Ctx)])(
using frame: Frame
): A < (Abort[E] & Async & Ctx) =
_race(seq)

private def _race[E, A: Flat, Ctx](seq: Seq[A < (Abort[E] & Async & Ctx)])(
using
boundary: Boundary[Ctx, Async & Abort[E]],
frame: Frame
): A < (Abort[E] & Async & Ctx) =
if seq.isEmpty then seq(0)
else Fiber.race(seq).map(_.get)
else Fiber._race(seq).map(_.get)

/** Races two or more computations and returns the result of the first to complete.
*
Expand All @@ -156,13 +180,11 @@ object Async:
* @return
* The result of the first computation to complete
*/
def race[E, A: Flat, Ctx](
inline def race[E, A: Flat, Ctx](
first: A < (Abort[E] & Async & Ctx),
rest: A < (Abort[E] & Async & Ctx)*
)(
using
boundary: Boundary[Ctx, Async & Abort[E]],
frame: Frame
using frame: Frame
): A < (Abort[E] & Async & Ctx) =
race[E, A, Ctx](first +: rest)

Expand All @@ -179,17 +201,22 @@ object Async:
* @return
* A sequence containing the results of all computations in their original order
*/
def parallelUnbounded[E, A: Flat, Ctx](seq: Seq[A < (Abort[E] & Async & Ctx)])(
inline def parallelUnbounded[E, A: Flat, Ctx](seq: Seq[A < (Abort[E] & Async & Ctx)])(
using frame: Frame
): Seq[A] < (Abort[E] & Async & Ctx) =
_parallelUnbounded(seq)

private def _parallelUnbounded[E, A: Flat, Ctx](seq: Seq[A < (Abort[E] & Async & Ctx)])(
using
boundary: Boundary[Ctx, Async & Abort[E]],
frame: Frame
): Seq[A] < (Abort[E] & Async & Ctx) =
seq.size match
case 0 => Seq.empty
case 1 => seq(0).map(Seq(_))
case _ => Fiber.parallelUnbounded(seq).map(_.get)
case _ => Fiber._parallelUnbounded(seq).map(_.get)
end match
end parallelUnbounded
end _parallelUnbounded

/** Runs multiple computations in parallel with a specified level of parallelism and returns their results.
*
Expand All @@ -209,16 +236,20 @@ object Async:
* @return
* A sequence containing the results of all computations in their original order
*/
def parallel[E, A: Flat, Ctx](parallelism: Int)(seq: Seq[A < (Abort[E] & Async & Ctx)])(
inline def parallel[E, A: Flat, Ctx](parallelism: Int)(seq: Seq[A < (Abort[E] & Async & Ctx)])(
using frame: Frame
): Seq[A] < (Abort[E] & Async & Ctx) =
_parallel(parallelism)(seq)

private def _parallel[E, A: Flat, Ctx](parallelism: Int)(seq: Seq[A < (Abort[E] & Async & Ctx)])(
using
boundary: Boundary[Ctx, Async & Abort[E]],
frame: Frame
): Seq[A] < (Abort[E] & Async & Ctx) =
seq.size match
case 0 => Seq.empty
case 1 => seq(0).map(Seq(_))
case n => Fiber.parallel(parallelism)(seq).map(_.get)
end parallel
case n => Fiber._parallel(parallelism)(seq).map(_.get)

/** Runs two computations in parallel and returns their results as a tuple.
*
Expand All @@ -229,13 +260,11 @@ object Async:
* @return
* A tuple containing the results of both computations
*/
def parallel[E, A1: Flat, A2: Flat, Ctx](
inline def parallel[E, A1: Flat, A2: Flat, Ctx](
v1: A1 < (Abort[E] & Async & Ctx),
v2: A2 < (Abort[E] & Async & Ctx)
)(
using
boundary: Boundary[Ctx, Async & Abort[E]],
frame: Frame
using frame: Frame
): (A1, A2) < (Abort[E] & Async & Ctx) =
parallelUnbounded(Seq(v1, v2))(using Flat.unsafe.bypass).map { s =>
(s(0).asInstanceOf[A1], s(1).asInstanceOf[A2])
Expand All @@ -252,14 +281,12 @@ object Async:
* @return
* A tuple containing the results of all three computations
*/
def parallel[E, A1: Flat, A2: Flat, A3: Flat, Ctx](
inline def parallel[E, A1: Flat, A2: Flat, A3: Flat, Ctx](
v1: A1 < (Abort[E] & Async & Ctx),
v2: A2 < (Abort[E] & Async & Ctx),
v3: A3 < (Abort[E] & Async & Ctx)
)(
using
boundary: Boundary[Ctx, Async & Abort[E]],
frame: Frame
using frame: Frame
): (A1, A2, A3) < (Abort[E] & Async & Ctx) =
parallelUnbounded(Seq(v1, v2, v3))(using Flat.unsafe.bypass).map { s =>
(s(0).asInstanceOf[A1], s(1).asInstanceOf[A2], s(2).asInstanceOf[A3])
Expand All @@ -278,15 +305,13 @@ object Async:
* @return
* A tuple containing the results of all four computations
*/
def parallel[E, A1: Flat, A2: Flat, A3: Flat, A4: Flat, Ctx](
inline def parallel[E, A1: Flat, A2: Flat, A3: Flat, A4: Flat, Ctx](
v1: A1 < (Abort[E] & Async & Ctx),
v2: A2 < (Abort[E] & Async & Ctx),
v3: A3 < (Abort[E] & Async & Ctx),
v4: A4 < (Abort[E] & Async & Ctx)
)(
using
boundary: Boundary[Ctx, Async & Abort[E]],
frame: Frame
using frame: Frame
): (A1, A2, A3, A4) < (Abort[E] & Async & Ctx) =
parallelUnbounded(Seq(v1, v2, v3, v4))(using Flat.unsafe.bypass).map { s =>
(s(0).asInstanceOf[A1], s(1).asInstanceOf[A2], s(2).asInstanceOf[A3], s(3).asInstanceOf[A4])
Expand Down
Loading

0 comments on commit 0f3701d

Please sign in to comment.