diff --git a/chronos/apps/http/httpbodyrw.nim b/chronos/apps/http/httpbodyrw.nim index b948fbd3e..90169b4d2 100644 --- a/chronos/apps/http/httpbodyrw.nim +++ b/chronos/apps/http/httpbodyrw.nim @@ -45,8 +45,8 @@ proc closeWait*(bstream: HttpBodyReader) {.async.} = # data from stream at position [1]. for index in countdown((len(bstream.streams) - 1), 0): res.add(bstream.streams[index].closeWait()) - await allFutures(res) - await procCall(closeWait(AsyncStreamReader(bstream))) + res.add(procCall(closeWait(AsyncStreamReader(bstream)))) + await noCancelWait(allFutures(res)) bstream.bstate = HttpState.Closed untrackCounter(HttpBodyReaderTrackerName) diff --git a/chronos/apps/http/httpclient.nim b/chronos/apps/http/httpclient.nim index e764ab058..ff295a4cb 100644 --- a/chronos/apps/http/httpclient.nim +++ b/chronos/apps/http/httpclient.nim @@ -827,26 +827,30 @@ proc sessionWatcher(session: HttpSessionRef) {.async.} = break proc closeWait*(request: HttpClientRequestRef) {.async.} = + var pending: seq[FutureBase] if request.state notin {HttpReqRespState.Closing, HttpReqRespState.Closed}: request.state = HttpReqRespState.Closing if not(isNil(request.writer)): if not(request.writer.closed()): - await request.writer.closeWait() + pending.add(FutureBase(request.writer.closeWait())) request.writer = nil - await request.releaseConnection() + pending.add(FutureBase(request.releaseConnection())) + await noCancelWait(allFutures(pending)) request.session = nil request.error = nil request.state = HttpReqRespState.Closed untrackCounter(HttpClientRequestTrackerName) proc closeWait*(response: HttpClientResponseRef) {.async.} = + var pending: seq[FutureBase] if response.state notin {HttpReqRespState.Closing, HttpReqRespState.Closed}: response.state = HttpReqRespState.Closing if not(isNil(response.reader)): if not(response.reader.closed()): - await response.reader.closeWait() + pending.add(FutureBase(response.reader.closeWait())) response.reader = nil - await response.releaseConnection() + pending.add(FutureBase(response.releaseConnection())) + await noCancelWait(allFutures(pending)) response.session = nil response.error = nil response.state = HttpReqRespState.Closed diff --git a/chronos/asyncfutures2.nim b/chronos/asyncfutures2.nim index 1a5ec828b..be2971a25 100644 --- a/chronos/asyncfutures2.nim +++ b/chronos/asyncfutures2.nim @@ -850,6 +850,33 @@ proc checkedCancelAndWait*(fut: FutureBase): Future[bool] = res = fut.checkedCancel() retFuture +proc noCancelWait*[T](future: Future[T]): Future[T] = + let retFuture = newFuture[T]("chronos.noCancelWait(T)", + {FutureFlag.OwnCancelSchedule}) + template completeFuture() = + if future.completed(): + when T is void: + retFuture.complete() + else: + retFuture.complete(future.value) + elif future.failed(): + retFuture.fail(future.error) + else: + raiseAssert("Unexpected future state [" & $future.state & "]") + + proc continuation(udata: pointer) {.gcsafe.} = + completeFuture() + + proc cancellation(udata: pointer) {.gcsafe.} = + discard + + if future.finished(): + completeFuture() + else: + future.addCallback(continuation) + retFuture.cancelCallback = cancellation + retFuture + proc allFutures*(futs: varargs[FutureBase]): Future[void] = ## Returns a future which will complete only when all futures in ``futs`` ## will be completed, failed or canceled. @@ -886,7 +913,7 @@ proc allFutures*(futs: varargs[FutureBase]): Future[void] = if len(nfuts) == 0 or len(nfuts) == finishedFutures: retFuture.complete() - return retFuture + retFuture proc allFutures*[T](futs: varargs[Future[T]]): Future[void] = ## Returns a future which will complete only when all futures in ``futs`` diff --git a/tests/teststream.nim b/tests/teststream.nim index 13e662496..f76cb971d 100644 --- a/tests/teststream.nim +++ b/tests/teststream.nim @@ -1331,6 +1331,11 @@ suite "Stream Transport test suite": counter = 0 exitLoop = false + # This timer will help to awake events poll in case its going to stuck + # usually happens on MacOS. + + var sleepFut = sleepAsync(1.seconds) + while not(exitLoop): let server = createStreamServer(initTAddress("127.0.0.1:0")) @@ -1340,6 +1345,8 @@ suite "Stream Transport test suite": transpFut = connect(address) acceptFut = server.accept() + echo "AWAITING FOR [", counter, "] STEPS" + if counter > 0: await stepsAsync(counter) @@ -1379,6 +1386,9 @@ suite "Stream Transport test suite": await server.closeWait() echo "SERVER CLOSED" + if not(sleepFut.finished()): + await cancelAndWait(sleepFut) + echo "TEST EXITED" markFD = getCurrentFD()