diff --git a/spago.yaml b/spago.yaml index 90f0bd1..057e339 100644 --- a/spago.yaml +++ b/spago.yaml @@ -15,6 +15,7 @@ package: - maybe - mmorph - newtype + - ordered-collections - parallel - postgresql-client - prelude diff --git a/src/Control.Monad.Postgres.purs b/src/Control.Monad.Postgres.purs index ab8da46..961ed3d 100644 --- a/src/Control.Monad.Postgres.purs +++ b/src/Control.Monad.Postgres.purs @@ -11,7 +11,7 @@ import Control.Monad.Reader (class MonadAsk, class MonadReader, ReaderT, ask) import Control.Monad.Rec.Class (class MonadRec) import Control.Monad.State (StateT(..)) import Control.Monad.Trans.Class (class MonadTrans, lift) -import Control.Monad.Writer (class MonadTell, class MonadWriter, WriterT(..)) +import Control.Monad.Writer (class MonadTell, class MonadWriter) import Control.MonadPlus (class MonadPlus) import Control.Parallel (class Parallel, parallel, sequential) import Data.Bifunctor (lmap) @@ -26,7 +26,6 @@ import Effect.Aff.Class (class MonadAff, liftAff) import Effect.Aff.Unlift (class MonadUnliftAff, UnliftAff(..), withUnliftAff) import Effect.Class (class MonadEffect) import Effect.Exception (Error, error) -import Foreign (Foreign) newtype HasPostgresT :: (Type -> Type) -> Type -> Type newtype HasPostgresT m a = HasPostgresT (ReaderT Pg.Connection m a) @@ -154,13 +153,13 @@ class (Monad m, MonadThrow Error m) <= MonadPostgres m where instance (MonadUnliftAff m, MonadThrow Error m) => MonadPostgres (PostgresT m) where query' nm q = do conn <- ask - qs /\ ps <- Query.runBuilder $ hoist nm q - res <- liftAff $ Pg.Aff.query conn (Pg.Query qs) ps + qs /\ { params } <- Query.runBuilder $ hoist nm q + res <- liftAff $ Pg.Aff.query conn (Pg.Query qs) params liftEither $ lmap (error <<< show) $ res exec' nm q = do conn <- ask - qs /\ ps <- Query.runBuilder $ hoist nm q - res <- liftAff $ Pg.Aff.execute conn (Pg.Query qs) ps + qs /\ { params } <- Query.runBuilder $ hoist nm q + res <- liftAff $ Pg.Aff.execute conn (Pg.Query qs) params liftEither $ lmap (error <<< show) $ maybe (Right unit) Left $ res transaction pg = do conn :: Pg.Connection <- ask diff --git a/src/Data.Postgres.Query.Builder.purs b/src/Data.Postgres.Query.Builder.purs index 8159ee3..8edd317 100644 --- a/src/Data.Postgres.Query.Builder.purs +++ b/src/Data.Postgres.Query.Builder.purs @@ -2,20 +2,24 @@ module Data.Postgres.Query.Builder where import Prelude -import Control.Monad.State (StateT, get, put, runStateT) +import Control.Monad.State (StateT, modify, runStateT) import Data.Array as Array +import Data.Set (Set) +import Data.Set as Set import Data.Tuple.Nested (type (/\)) import Database.PostgreSQL (class ToSQLValue, toSQLValue) import Foreign (Foreign) -type BuilderT m a = StateT (Array Foreign) m a +type BuilderT m a = StateT { params :: Array Foreign, refs :: Set String } m a -runBuilder :: forall m a. BuilderT m a -> m (a /\ Array Foreign) -runBuilder = flip runStateT [] +runBuilder :: forall m a. BuilderT m a -> m (a /\ { params :: Array Foreign, refs :: Set String }) +runBuilder = flip runStateT { params: [], refs: Set.empty } + +reference :: forall m. Monad m => String -> BuilderT m Unit +reference k = void $ modify (\s@{ refs } -> s { refs = Set.insert k refs }) param :: forall m a. Monad m => ToSQLValue a => a -> BuilderT m String param p = do - ps <- get - put $ ps <> [ toSQLValue p ] - pure $ "$" <> show (Array.length ps + 1) + { params } <- modify (\s@{ params } -> s { params = params <> [ toSQLValue p ] }) + pure $ "$" <> show (Array.length params) diff --git a/test/Spec.Data.Postgres.Query.Builder.purs b/test/Spec.Data.Postgres.Query.Builder.purs index 3208e6b..25d49a9 100644 --- a/test/Spec.Data.Postgres.Query.Builder.purs +++ b/test/Spec.Data.Postgres.Query.Builder.purs @@ -3,7 +3,8 @@ module Spec.Data.Postgres.Query.Builder where import Prelude import Control.Monad.Trans.Class (lift) -import Data.Postgres.Query.Builder (param, runBuilder) +import Data.Postgres.Query.Builder (param, reference, runBuilder) +import Data.Set as Set import Data.Tuple.Nested ((/\)) import Foreign (unsafeFromForeign) import Foreign.Internal.Stringify (unsafeStringify) @@ -15,18 +16,34 @@ spec = describe "Data.Postgres.Query.Builder" do describe "runBuilder" do it "empty" do - _ /\ ps <- runBuilder (pure unit) - map unsafeStringify ps `shouldEqual` [] + _ /\ { params, refs } <- runBuilder (pure unit) + map unsafeStringify params `shouldEqual` [] + refs `shouldEqual` Set.empty + describe "reference" do + it "one" do + _ /\ { refs } <- runBuilder $ reference "foo" + refs `shouldEqual` (Set.singleton "foo") + it "dup" do + _ /\ { refs } <- runBuilder do + reference "foo" + reference "foo" + refs `shouldEqual` (Set.singleton "foo") + it "multiple" do + _ /\ { refs } <- runBuilder do + reference "foo" + reference "bar" + reference "baz" + refs `shouldEqual` (Set.fromFoldable [ "foo", "bar", "baz" ]) describe "param" do it "single" do - p /\ ps <- runBuilder $ param 123 + p /\ { params } <- runBuilder $ param 123 p `shouldEqual` "$1" - map unsafeFromForeign ps `shouldEqual` [ 123 ] + map unsafeFromForeign params `shouldEqual` [ 123 ] it "many" do - _ /\ ps <- runBuilder do + _ /\ { params } <- runBuilder do a <- param 123 b <- param "abc" c <- param [ 123 ] d <- param true lift $ [ a, b, c, d ] `shouldEqual` [ "$1", "$2", "$3", "$4" ] - map unsafeStringify ps `shouldEqual` [ "123", "\"abc\"", "[123]", "true" ] + map unsafeStringify params `shouldEqual` [ "123", "\"abc\"", "[123]", "true" ]