diff --git a/spago.yaml b/spago.yaml index 1d42891..54252d6 100644 --- a/spago.yaml +++ b/spago.yaml @@ -7,6 +7,7 @@ package: test: main: Test.Main dependencies: + - random - spec workspace: package_set: diff --git a/src/Data.Async.Mutex.purs b/src/Data.Async.Mutex.purs index 2817770..3f616fe 100644 --- a/src/Data.Async.Mutex.purs +++ b/src/Data.Async.Mutex.purs @@ -3,14 +3,12 @@ module Data.Async.Mutex (Mutex, MutexGuard) where import Prelude import Control.Monad.Error.Class (liftMaybe, throwError) -import Data.Async.Class (class AsyncState, class AsyncStateLock, class AsyncStateReadable, class AsyncStateWritable, read) +import Data.Async.Class (class AsyncState, class AsyncStateLock, class AsyncStateReadable, class AsyncStateWritable) import Data.Maybe (isNothing) -import Data.Traversable (for, for_, traverse) +import Data.Traversable (for, for_) import Effect.Aff.AVar (AVar) import Effect.Aff.AVar as AVar import Effect.Aff.Class (liftAff) -import Effect.Class (liftEffect) -import Effect.Console as Console import Effect.Exception (error) -- | A lock guaranteeing exclusive access to @@ -36,6 +34,7 @@ instance AsyncStateWritable Mutex MutexGuard where instance AsyncStateLock Mutex MutexGuard where unlock (Mutex stateCell) (MutexGuard localStateCell) = liftAff do state <- AVar.tryTake localStateCell + void $ AVar.tryTake localStateCell when (isNothing state) $ throwError $ error "MutexGuard unlocked already!" for_ state (flip AVar.put stateCell) lock (Mutex stateCell) = liftAff do diff --git a/src/Data.Async.RwLock.purs b/src/Data.Async.RwLock.purs index bad1aa8..0d1ebd2 100644 --- a/src/Data.Async.RwLock.purs +++ b/src/Data.Async.RwLock.purs @@ -80,7 +80,7 @@ instance AsyncStateReadable RwLock WriteGuard where read _ (WriteGuard stateCell) = liftAff $ liftMaybe (error "WriteGuard used after `unlock` invoked!") - =<< AVar.tryRead stateCell + =<< AVar.tryTake stateCell instance AsyncStateReadable RwLock ReadGuard where read (RwLock { readers: readersCell }) (ReadGuard id a) = liftAff do diff --git a/test/Test.Main.purs b/test/Test.Main.purs index 1de1440..9e6d266 100644 --- a/test/Test.Main.purs +++ b/test/Test.Main.purs @@ -2,18 +2,27 @@ module Test.Main where import Prelude -import Control.Monad.Error.Class (try) -import Control.Monad.State.Async (AsyncStateT, asyncModify, asyncPut, asyncRead, runAsyncState) +import Control.Monad.Error.Class (throwError, try) +import Control.Monad.Rec.Class (Step(..), forever, tailRecM, untilJust) +import Control.Monad.State.Async (AsyncStateT, asyncModify, asyncPut, asyncRead, asyncWrite, runAsyncState, runMutexState) import Control.Monad.Trans.Class (lift) -import Control.Parallel (parallel, sequential) +import Control.Parallel (parOneOf, parSequence_, parallel, sequential) +import Data.Array as Array import Data.Async.Class (class AsyncState) import Data.Async.Mutex (Mutex) import Data.Async.RwLock (RwLock) import Data.Either (isLeft) import Data.Identity (Identity) +import Data.Maybe (Maybe(..)) import Data.Newtype (wrap) +import Data.Traversable (for_) +import Data.Tuple.Nested ((/\)) import Effect (Effect) import Effect.Aff (Aff, delay, launchAff_) +import Effect.Aff.Class (liftAff) +import Effect.Class (liftEffect) +import Effect.Exception (error) +import Effect.Random (randomBool) import Test.Spec (SpecT, describe, it) import Test.Spec.Assertions (shouldEqual, shouldSatisfy) import Test.Spec.Reporter (consoleReporter) @@ -50,6 +59,58 @@ common = do pure @(AsyncStateT w String Aff) unit sequential (pure (\_ _ -> unit) <*> t1 <*> t2) asyncRead (shouldEqual "hello, john!") + it "supports concurrent state manipulation 2" do + let + t = do + _ <- asyncRead pure + asyncWrite (pure <<< (unit /\ _) <<< (_ <> "a")) + pure unit + + runAsyncState "" do + pure @(AsyncStateT w String Aff) unit + parSequence_ $ Array.replicate 10 t + asyncRead (shouldEqual "aaaaaaaaaa") + it "supports parallel delay in monadrec" do + let + done = liftAff $ delay $ wrap 5000.0 + go 0 = pure $ Done unit + go n = do + liftAff $ delay $ wrap 2.0 + _ <- asyncRead pure + _ <- asyncModify pure + pure $ Loop (n - 1) + + runAsyncState "" do + pure @(AsyncStateT w String Aff) unit + parOneOf + [ done + , tailRecM go 100 + , tailRecM go 100 + , tailRecM go 100 + , tailRecM go 100 + ] + it "setting the state to a value unblocks MonadRec" do + let + delayThenDone = do + liftAff $ delay $ wrap 100.0 + asyncPut true + wait = untilJust do + done <- asyncRead pure + pure $ if done then Just unit else Nothing + runMutexState false $ parSequence_ [ delayThenDone, wait ] + it "throwing with lock does not block other threads" do + let + t = do + readThrows <- liftEffect randomBool + writeThrows <- liftEffect randomBool + void $ try $ asyncRead (const $ if readThrows then throwError $ error "fail" else pure unit) + void $ try $ asyncModify (\s -> if writeThrows then throwError $ error "fail" else pure s) + pure unit + + for_ (Array.replicate 100 unit) \_ -> do + runAsyncState "" do + pure @(AsyncStateT w String Aff) unit + parSequence_ $ Array.replicate 100 t main :: Effect Unit main = launchAff_ $ runSpec [ consoleReporter ] do