Skip to content

Commit

Permalink
Pipeline all the things
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisPenner committed Jul 31, 2024
1 parent bb807ae commit fc80be6
Show file tree
Hide file tree
Showing 5 changed files with 114 additions and 75 deletions.
66 changes: 50 additions & 16 deletions src/Share/Postgres.hs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ module Share.Postgres
Only (..),
QueryA (..),
QueryM (..),
unrecoverableError,
throwErr,
decodeField,
(:.) (..),

Expand Down Expand Up @@ -120,7 +122,7 @@ instance MonadError e (Transaction e) where
Right a -> pure (Right a)

-- | Applicative pipelining transaction
newtype Pipeline e a = Pipeline {unPipeline :: Hasql.Pipeline.Pipeline (Either (TransactionError e) a)}
newtype Pipeline e a = Pipeline {_unPipeline :: Hasql.Pipeline.Pipeline (Either (TransactionError e) a)}
deriving (Functor, Applicative) via (Compose Hasql.Pipeline.Pipeline (Either (TransactionError e)))

pFor :: (Traversable f) => f a -> (a -> Pipeline e b) -> Transaction e (f b)
Expand Down Expand Up @@ -320,9 +322,9 @@ class (Applicative m) => QueryA m e | m -> e where
statement :: q -> Hasql.Statement q r -> m r

-- | Fail the transaction and whole request with an unrecoverable server error.
unrecoverableError :: (HasCallStack, ToServerError x, Loggable x, Show x) => x -> m a
unrecoverableErrorA :: (HasCallStack, ToServerError x, Loggable x, Show x) => m (Either x a) -> m a

throwErr :: (ToServerError e, Loggable e, Show e) => e -> m a
throwErrA :: (ToServerError e, Loggable e, Show e) => m (Either e a) -> m a

pipelined :: Pipeline e a -> m a

Expand All @@ -335,11 +337,14 @@ instance QueryA (Transaction e) e where
statement q s = do
transactionStatement q s

throwErr = throwError
throwErrA m = m >>= either throwError pure

pipelined p = Transaction (Hasql.pipeline (unPipeline p))
pipelined (Pipeline p) = Transaction (Hasql.pipeline p)

unrecoverableError e = Transaction (pure (Left (Unrecoverable (someServerError e))))
unrecoverableErrorA me =
me >>= \case
Right a -> pure a
Left e -> Transaction (pure (Left (Unrecoverable (someServerError e))))

instance QueryM (Transaction e) e where
transactionUnsafeIO io = Transaction (Right <$> liftIO io)
Expand All @@ -348,49 +353,78 @@ instance QueryA (Session e) e where
statement q s = do
lift $ Session.statement q s

throwErr = throwError . Err
throwErrA m = m >>= either (throwError . Err) pure

pipelined p = do
ExceptT $ Hasql.pipeline (unPipeline p)
pipelined (Pipeline p) = do
ExceptT $ Hasql.pipeline p

unrecoverableError e = throwError (Unrecoverable (someServerError e))
unrecoverableErrorA me =
me >>= \case
Right a -> pure a
Left e -> throwError (Unrecoverable (someServerError e))

instance QueryM (Session e) e where
transactionUnsafeIO io = lift $ liftIO io

instance QueryA (Pipeline e) e where
statement q s = Pipeline (Right <$> Hasql.Pipeline.statement q s)

throwErr = Pipeline . pure . Left . Err
throwErrA (Pipeline me) =
-- Flatten error into pipeline
Pipeline $
me <&> \case
Left e -> Left e
Right (Left e) -> Left (Err e)
Right (Right a) -> Right a

pipelined p = p

unrecoverableError e = Pipeline $ pure (Left (Unrecoverable (someServerError e)))
unrecoverableErrorA (Pipeline me) =
Pipeline
( me <&> \case
Right (Left e) -> Left . Unrecoverable . someServerError $ e
Right (Right a) -> Right a
Left e -> Left e
)

-- Pipeline $ pure (Left (Unrecoverable (someServerError e)))

instance (QueryM m e) => QueryA (ReaderT r m) e where
statement q s = lift $ statement q s

throwErr = lift . throwErr
throwErrA m = mapReaderT throwErrA m

pipelined p = lift $ pipelined p

unrecoverableError e = lift $ unrecoverableError e
unrecoverableErrorA me = mapReaderT unrecoverableErrorA me

instance (QueryM m e) => QueryM (ReaderT r m) e where
transactionUnsafeIO io = lift $ transactionUnsafeIO io

instance (QueryM m e) => QueryA (MaybeT m) e where
statement q s = lift $ statement q s

throwErr = lift . throwErr
throwErrA m =
m >>= \case
Left e -> lift $ throwErr e
Right a -> pure a

pipelined p = lift $ pipelined p

unrecoverableError e = lift $ unrecoverableError e
unrecoverableErrorA m =
m >>= \case
Left e -> lift $ unrecoverableError e
Right a -> pure a

instance (QueryM m e) => QueryM (MaybeT m) e where
transactionUnsafeIO io = lift $ transactionUnsafeIO io

unrecoverableError :: (QueryA m e) => (ToServerError x, Loggable x, Show x) => x -> m a
unrecoverableError e = unrecoverableErrorA (pure $ Left e)

throwErr :: (QueryA m e, ToServerError e, Loggable e, Show e) => e -> m a
throwErr e = throwErrA (pure $ Left e)

prepareStatements :: Bool
prepareStatements = True

Expand Down
29 changes: 15 additions & 14 deletions src/Share/Postgres/Causal/Queries.hs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ import Unison.NameSegment.Internal as NameSegment
import Unison.Reference qualified as Reference
import Unison.Util.Map qualified as Map

expectCausalNamespace :: (HasCallStack, QueryM m e) => CausalId -> m (CausalNamespace m)
expectCausalNamespace :: (QueryM m e) => CausalId -> m (CausalNamespace m)
expectCausalNamespace causalId =
loadCausalNamespace causalId
`whenNothingM` unrecoverableError (MissingExpectedEntity $ "Expected causal branch for hash:" <> tShow causalId)
Expand Down Expand Up @@ -101,26 +101,27 @@ expectPgCausalNamespace causalId =

loadCausalNamespace :: forall m e. (QueryM m e) => CausalId -> m (Maybe (CausalNamespace m))
loadCausalNamespace causalId = runMaybeT $ do
causalHash <- HashQ.expectCausalHashesByIdsOf id causalId
branchHashId <- HashQ.expectNamespaceIdsByCausalIdsOf id causalId
namespaceHash <- HashQ.expectNamespaceHashesByNamespaceHashIdsOf id branchHashId
let namespace = expectNamespace branchHashId
ancestors <- lift $ ancestorsByCausalId causalId
pure $
Causal
{ causalHash = causalHash,
valueHash = namespaceHash,
parents = ancestors,
value = namespace
}
pipelined $ do
causalHash <- HashQ.expectCausalHashesByIdsOf id causalId
namespaceHash <- HashQ.expectNamespaceHashesByNamespaceHashIdsOf id branchHashId
let namespace = expectNamespace branchHashId
ancestors <- ancestorsByCausalId causalId
pure $
Causal
{ causalHash = causalHash,
valueHash = namespaceHash,
parents = ancestors,
value = namespace
}
where
ancestorsByCausalId :: CausalId -> m ((Map CausalHash (m (CausalNamespace m))))
ancestorsByCausalId :: CausalId -> Pipeline e ((Map CausalHash (m (CausalNamespace m))))
ancestorsByCausalId causalId = do
getAncestors
<&> fmap (\(ancestorId, ancestorHash) -> (ancestorHash, expectCausalNamespace ancestorId))
<&> Map.fromList
where
getAncestors :: m [(CausalId, CausalHash)]
getAncestors :: Pipeline e [(CausalId, CausalHash)]
getAncestors = do
queryListRows
[sql| SELECT ancestor_id, ancestor.hash
Expand Down
18 changes: 10 additions & 8 deletions src/Share/Postgres/Definitions/Queries.hs
Original file line number Diff line number Diff line change
Expand Up @@ -221,10 +221,11 @@ expectShareTermComponent componentHashId = do
)
`whenNothingM` do
lift . unrecoverableError $ InternalServerError "expected-term-component" (ExpectedTermComponentNotFound (That componentHashId))
second (Hash32.fromHash . unComponentHash) . Share.TermComponent . toList <$> for componentElements \(termId, LocalTermBytes bytes) -> do
textLookup <- lift $ termLocalTextReferences termId
defnLookup <- lift $ termLocalComponentReferences termId
results <- pipelined $ for componentElements \(termId, LocalTermBytes bytes) -> do
textLookup <- termLocalTextReferences termId
defnLookup <- termLocalComponentReferences termId
pure (Share.LocalIds {texts = textLookup, hashes = defnLookup}, bytes)
pure (second (Hash32.fromHash . unComponentHash) . Share.TermComponent . toList $ results)
where
checkElements :: [(TermId, Maybe LocalTermBytes)] -> Maybe (NonEmpty (TermId, LocalTermBytes))
checkElements rows =
Expand All @@ -251,10 +252,11 @@ expectShareTypeComponent componentHashId = do
)
`whenNothingM` do
lift . unrecoverableError $ InternalServerError "expected-type-component" (ExpectedTypeComponentNotFound (That componentHashId))
second (Hash32.fromHash . unComponentHash) . Share.DeclComponent . toList <$> for componentElements \(typeId, LocalTypeBytes bytes) -> do
textLookup <- lift $ typeLocalTextReferences typeId
defnLookup <- lift $ typeLocalComponentReferences typeId
results <- pipelined $ for componentElements \(typeId, LocalTypeBytes bytes) -> do
textLookup <- typeLocalTextReferences typeId
defnLookup <- typeLocalComponentReferences typeId
pure (Share.LocalIds {texts = Vector.toList textLookup, hashes = Vector.toList defnLookup}, bytes)
pure (second (Hash32.fromHash . unComponentHash) . Share.DeclComponent . toList $ results)
where
checkElements :: [(TypeId, Maybe LocalTypeBytes)] -> Maybe (NonEmpty (TypeId, LocalTypeBytes))
checkElements rows =
Expand Down Expand Up @@ -407,7 +409,7 @@ loadDecl codebaseUser (Reference.Id compHash (pgComponentIndex -> compIndex)) =
localIds = LocalIds.LocalIds {textLookup, defnLookup}
pure $ s2cDecl localIds decl

typeLocalTextReferences :: TypeId -> Transaction e (Vector Text)
typeLocalTextReferences :: (QueryA m e) => TypeId -> m (Vector Text)
typeLocalTextReferences typeId =
Vector.fromList
<$> queryListCol
Expand All @@ -419,7 +421,7 @@ typeLocalTextReferences typeId =
ORDER BY local_index ASC
|]

typeLocalComponentReferences :: TypeId -> Transaction e (Vector ComponentHash)
typeLocalComponentReferences :: (QueryA m e) => TypeId -> m (Vector ComponentHash)
typeLocalComponentReferences typeId =
Vector.fromList
<$> queryListCol
Expand Down
45 changes: 24 additions & 21 deletions src/Share/Postgres/Hashes/Queries.hs
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ expectPatchHashesOf trav = do
then error "expectPatchHashesOf: Missing expected patch hash"
else pure results

expectPatchIdsOf :: (HasCallStack) => Traversal s t PatchHash PatchId -> s -> CodebaseM e t
expectPatchIdsOf :: Traversal s t PatchHash PatchId -> s -> CodebaseM e t
expectPatchIdsOf trav = do
unsafePartsOf trav %%~ \hashes -> do
codebaseOwner <- asks Codebase.codebaseOwner
Expand Down Expand Up @@ -183,7 +183,7 @@ loadBranchHashId branchHash = do
)
|]

expectBranchHashId :: (HasCallStack) => BranchHash -> CodebaseM e BranchHashId
expectBranchHashId :: BranchHash -> CodebaseM e BranchHashId
expectBranchHashId branchHash = do
loadBranchHashId branchHash >>= \case
Just hashId -> pure hashId
Expand Down Expand Up @@ -235,12 +235,12 @@ addKnownCausalHashMismatch providedHash actualHash = do
|]

-- | Generic helper which fetches both branch hashes and causal hashes
expectCausalHashesOfG :: (HasCallStack, QueryM m e) => ((BranchHash, CausalHash) -> h) -> Traversal s t CausalId h -> s -> m t
expectCausalHashesOfG :: (HasCallStack, QueryA m e) => ((BranchHash, CausalHash) -> h) -> Traversal s t CausalId h -> s -> m t
expectCausalHashesOfG project trav = do
unsafePartsOf trav %%~ \hashIds -> do
let numberedHashIds = zip [0 :: Int32 ..] hashIds
results :: [(BranchHash, CausalHash)] <-
queryListRows
unrecoverableErrorA $
queryListRows @(BranchHash, CausalHash)
[sql|
WITH causal_ids(ord, id) AS (
SELECT * FROM ^{toTable numberedHashIds}
Expand All @@ -251,17 +251,18 @@ expectCausalHashesOfG project trav = do
JOIN branch_hashes bh ON causal.namespace_hash_id = bh.id
ORDER BY causal_ids.ord ASC
|]
if length results /= length hashIds
then error "expectCausalHashesOf: Missing expected causal hash"
else pure (project <$> results)
<&> \results ->
if length results /= length hashIds
then Left . MissingExpectedEntity $ "expectCausalHashesOfG: Expected to get the same number of results as causal ids."
else pure (project <$> results)

expectCausalAndBranchHashesOf :: (HasCallStack, QueryM m e) => Traversal s t CausalId (BranchHash, CausalHash) -> s -> m t
expectCausalAndBranchHashesOf = expectCausalHashesOfG id

expectCausalHashesByIdsOf :: (HasCallStack, QueryM m e) => Traversal s t CausalId CausalHash -> s -> m t
expectCausalHashesByIdsOf :: (HasCallStack, QueryA m e) => Traversal s t CausalId CausalHash -> s -> m t
expectCausalHashesByIdsOf = expectCausalHashesOfG snd

expectCausalIdsOf :: (HasCallStack) => Traversal s t CausalHash (BranchHashId, CausalId) -> s -> CodebaseM e t
expectCausalIdsOf :: Traversal s t CausalHash (BranchHashId, CausalId) -> s -> CodebaseM e t
expectCausalIdsOf trav = do
unsafePartsOf trav %%~ \hashes -> do
codebaseOwnerId <- asks Codebase.codebaseOwner
Expand All @@ -287,12 +288,12 @@ expectCausalIdsOf trav = do
then unrecoverableError $ EntityMissing "missing-expected-causal" $ "Missing one of these causals: " <> Text.intercalate ", " (into @Text <$> hashes)
else pure results

expectNamespaceIdsByCausalIdsOf :: (QueryM m e) => Traversal s t CausalId BranchHashId -> s -> m t
expectNamespaceIdsByCausalIdsOf :: (QueryA m e) => Traversal s t CausalId BranchHashId -> s -> m t
expectNamespaceIdsByCausalIdsOf trav s = do
s
& unsafePartsOf trav %%~ \causalIds -> do
let causalIdsTable = ordered causalIds
results <-
unrecoverableErrorA $
queryListCol @(BranchHashId)
[sql| WITH causal_ids(ord, causal_id) AS (
SELECT ord, causal_id FROM ^{toTable causalIdsTable} as t(ord, causal_id)
Expand All @@ -302,16 +303,17 @@ expectNamespaceIdsByCausalIdsOf trav s = do
JOIN causals c ON cid.causal_id = c.id
ORDER BY cid.ord
|]
if length results /= length causalIds
then unrecoverableError . MissingExpectedEntity $ "expectNamespaceIdsByCausalIdsOf: Expected to get the same number of results as causal ids. " <> tShow causalIds
else pure results
<&> \results ->
if length results /= length causalIds
then Left . MissingExpectedEntity $ "expectNamespaceIdsByCausalIdsOf: Expected to get the same number of results as causal ids. " <> tShow causalIds
else Right results

expectNamespaceHashesByNamespaceHashIdsOf :: (HasCallStack, QueryM m e) => Traversal s t BranchHashId BranchHash -> s -> m t
expectNamespaceHashesByNamespaceHashIdsOf :: (QueryA m e) => Traversal s t BranchHashId BranchHash -> s -> m t
expectNamespaceHashesByNamespaceHashIdsOf trav s = do
s
& unsafePartsOf trav %%~ \namespaceHashIds -> do
let namespaceHashIdsTable = ordered namespaceHashIds
results <-
unrecoverableErrorA $
queryListCol @(BranchHash)
[sql| WITH namespace_hash_ids(ord, namespace_hash_id) AS (
SELECT ord, namespace_hash_id FROM ^{toTable namespaceHashIdsTable} as t(ord, namespace_hash_id)
Expand All @@ -321,9 +323,10 @@ expectNamespaceHashesByNamespaceHashIdsOf trav s = do
JOIN branch_hashes bh ON nhi.namespace_hash_id = bh.id
ORDER BY nhi.ord
|]
if length results /= length namespaceHashIds
then unrecoverableError . MissingExpectedEntity $ "expectNamespaceHashesByNamespaceHashIdsOf: Expected to get the same number of results as namespace hash ids. " <> tShow namespaceHashIds
else pure results
<&> \results ->
if length results /= length namespaceHashIds
then Left . MissingExpectedEntity $ "expectNamespaceHashesByNamespaceHashIdsOf: Expected to get the same number of results as namespace hash ids. " <> tShow namespaceHashIds
else Right results

loadCausalIdByHash :: CausalHash -> Codebase.CodebaseM e (Maybe CausalId)
loadCausalIdByHash causalHash = do
Expand All @@ -334,7 +337,7 @@ loadCausalIdByHash causalHash = do
AND EXISTS (SELECT FROM causal_ownership o WHERE o.causal_id = causals.id AND o.user_id = #{codebaseOwner})
|]

expectCausalIdByHash :: (HasCallStack) => CausalHash -> Codebase.CodebaseM e CausalId
expectCausalIdByHash :: CausalHash -> Codebase.CodebaseM e CausalId
expectCausalIdByHash causalHash = do
loadCausalIdByHash causalHash
`whenNothingM` unrecoverableError (MissingExpectedEntity $ "Expected causal id for hash: " <> tShow causalHash)
Loading

0 comments on commit fc80be6

Please sign in to comment.