Skip to content

Commit 7698516

Browse files
authored
Merge pull request #34 from unisoncomputing/syncv2/causal-negotiation
Naïve dependency negotiation API
2 parents 9399c85 + 1343a7c commit 7698516

File tree

6 files changed

+333
-218
lines changed

6 files changed

+333
-218
lines changed

share-api.cabal

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ library
158158
Share.Web.UCM.SyncV2.API
159159
Share.Web.UCM.SyncV2.Impl
160160
Share.Web.UCM.SyncV2.Queries
161+
Share.Web.UCM.SyncV2.Types
161162
Unison.PrettyPrintEnvDecl.Postgres
162163
Unison.Server.NameSearch.Postgres
163164
Unison.Server.Share.Definitions
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
-- Takes a causal_id and returns a table of ALL hashes which are dependencies of that causal.
2+
CREATE OR REPLACE FUNCTION dependencies_of_causals(the_causal_ids INTEGER[]) RETURNS TABLE (hash TEXT) AS $$
3+
WITH RECURSIVE all_causals(causal_id, causal_hash, causal_namespace_hash_id) AS (
4+
-- Base causal
5+
SELECT DISTINCT causal.id, causal.hash, causal.namespace_hash_id
6+
FROM UNNEST(the_causal_ids) AS causal_id
7+
JOIN causals causal ON causal.id = causal_id
8+
UNION
9+
-- This nested CTE is required because RECURSIVE CTEs can't refer
10+
-- to the recursive table more than once.
11+
-- I don't fully understand why or how this works, but it does
12+
( WITH rec AS (
13+
SELECT tc.causal_id, tc.causal_namespace_hash_id
14+
FROM all_causals tc
15+
)
16+
SELECT ancestor_causal.id, ancestor_causal.hash, ancestor_causal.namespace_hash_id
17+
FROM causal_ancestors ca
18+
JOIN rec tc ON ca.causal_id = tc.causal_id
19+
JOIN causals ancestor_causal ON ca.ancestor_id = ancestor_causal.id
20+
UNION
21+
SELECT child_causal.id, child_causal.hash, child_causal.namespace_hash_id
22+
FROM rec tc
23+
JOIN namespace_children nc ON tc.causal_namespace_hash_id = nc.parent_namespace_hash_id
24+
JOIN causals child_causal ON nc.child_causal_id = child_causal.id
25+
)
26+
), all_namespaces(namespace_hash_id, namespace_hash) AS (
27+
SELECT DISTINCT tc.causal_namespace_hash_id AS namespace_hash_id, bh.base32 as namespace_hash
28+
FROM all_causals tc
29+
JOIN branch_hashes bh ON tc.causal_namespace_hash_id = bh.id
30+
), all_patches(patch_id, patch_hash) AS (
31+
SELECT DISTINCT patch.id, patch.hash
32+
FROM all_namespaces an
33+
JOIN namespace_patches np ON an.namespace_hash_id = np.namespace_hash_id
34+
JOIN patches patch ON np.patch_id = patch.id
35+
),
36+
-- term components to start transitively joining dependencies to
37+
base_term_components(component_hash_id) AS (
38+
SELECT DISTINCT term.component_hash_id
39+
FROM all_namespaces an
40+
JOIN namespace_terms nt ON an.namespace_hash_id = nt.namespace_hash_id
41+
JOIN terms term ON nt.term_id = term.id
42+
UNION
43+
SELECT DISTINCT term.component_hash_id
44+
FROM all_patches ap
45+
JOIN patch_term_mappings ptm ON ap.patch_id = ptm.patch_id
46+
JOIN terms term ON ptm.to_term_id = term.id
47+
UNION
48+
-- term metadata
49+
SELECT DISTINCT term.component_hash_id
50+
FROM all_namespaces an
51+
JOIN namespace_terms nt ON an.namespace_hash_id = nt.namespace_hash_id
52+
JOIN namespace_term_metadata meta ON nt.id = meta.named_term
53+
JOIN terms term ON meta.metadata_term_id = term.id
54+
UNION
55+
-- type metadata
56+
SELECT DISTINCT term.component_hash_id
57+
FROM all_namespaces an
58+
JOIN namespace_types nt ON an.namespace_hash_id = nt.namespace_hash_id
59+
JOIN namespace_type_metadata meta ON nt.id = meta.named_type
60+
JOIN terms term ON meta.metadata_term_id = term.id
61+
),
62+
-- type components to start transitively joining dependencies to
63+
base_type_components(component_hash_id) AS (
64+
SELECT DISTINCT typ.component_hash_id
65+
FROM all_namespaces an
66+
JOIN namespace_types nt ON an.namespace_hash_id = nt.namespace_hash_id
67+
JOIN types typ ON nt.type_id = typ.id
68+
UNION
69+
SELECT DISTINCT typ.component_hash_id
70+
FROM all_namespaces an
71+
JOIN namespace_terms nt ON an.namespace_hash_id = nt.namespace_hash_id
72+
JOIN constructors con ON nt.constructor_id = con.id
73+
JOIN types typ ON con.type_id = typ.id
74+
UNION
75+
SELECT DISTINCT typ.component_hash_id
76+
FROM all_patches ap
77+
JOIN patch_type_mappings ptm ON ap.patch_id = ptm.patch_id
78+
JOIN types typ ON ptm.to_type_id = typ.id
79+
UNION
80+
SELECT DISTINCT typ.component_hash_id
81+
FROM all_patches ap
82+
JOIN patch_constructor_mappings pcm ON ap.patch_id = pcm.patch_id
83+
JOIN constructors con ON pcm.to_constructor_id = con.id
84+
JOIN types typ ON con.type_id = typ.id
85+
),
86+
-- All the dependencies we join in transitively from the known term & type components we depend on.
87+
all_components(component_hash_id) AS (
88+
SELECT DISTINCT btc.component_hash_id
89+
FROM base_term_components btc
90+
UNION
91+
SELECT DISTINCT btc.component_hash_id
92+
FROM base_type_components btc
93+
UNION
94+
( WITH rec AS (
95+
SELECT DISTINCT ac.component_hash_id
96+
FROM all_components ac
97+
)
98+
-- recursively union in term dependencies
99+
SELECT DISTINCT ref.component_hash_id
100+
FROM rec atc
101+
-- This joins in ALL the terms from the component, not just the one that caused the dependency on the
102+
-- component
103+
JOIN terms term ON atc.component_hash_id = term.component_hash_id
104+
JOIN term_local_component_references ref ON term.id = ref.term_id
105+
UNION
106+
-- recursively union in type dependencies
107+
SELECT DISTINCT ref.component_hash_id
108+
FROM rec atc
109+
-- This joins in ALL the types from the component, not just the one that caused the dependency on the
110+
-- component
111+
JOIN types typ ON atc.component_hash_id = typ.component_hash_id
112+
JOIN type_local_component_references ref ON typ.id = ref.type_id
113+
)
114+
)
115+
(SELECT ch.base32 AS hash
116+
FROM all_components ac
117+
JOIN component_hashes ch ON ac.component_hash_id = ch.id
118+
)
119+
UNION ALL
120+
(SELECT ap.patch_hash AS hash
121+
FROM all_patches ap
122+
)
123+
UNION ALL
124+
(SELECT an.namespace_hash AS hash
125+
FROM all_namespaces an
126+
)
127+
UNION ALL
128+
(SELECT ac.causal_hash AS hash
129+
FROM all_causals ac
130+
)
131+
$$ LANGUAGE SQL;

src/Share/Web/UCM/SyncV2/Impl.hs

Lines changed: 103 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,13 @@ import Codec.Serialise qualified as CBOR
77
import Conduit qualified as C
88
import Control.Concurrent.STM qualified as STM
99
import Control.Concurrent.STM.TBMQueue qualified as STM
10-
import Control.Monad.Except (ExceptT (ExceptT))
10+
import Control.Monad.Except (ExceptT (ExceptT), withExceptT)
1111
import Control.Monad.Trans.Except (runExceptT)
1212
import Data.Binary.Builder qualified as Builder
13-
import Data.Vector (Vector)
13+
import Data.Set qualified as Set
14+
import Data.Text.Encoding qualified as Text
1415
import Data.Vector qualified as Vector
16+
import Ki.Unlifted qualified as Ki
1517
import Servant
1618
import Servant.Conduit (ConduitToSourceIO (..))
1719
import Servant.Types.SourceT (SourceT (..))
@@ -33,14 +35,15 @@ import Share.Web.Authorization qualified as AuthZ
3335
import Share.Web.Errors
3436
import Share.Web.UCM.Sync.HashJWT qualified as HashJWT
3537
import Share.Web.UCM.SyncV2.Queries qualified as SSQ
38+
import Share.Web.UCM.SyncV2.Types (IsCausalSpine (..), IsLibRoot (..))
3639
import U.Codebase.Sqlite.Orphans ()
40+
import Unison.Debug qualified as Debug
3741
import Unison.Hash32 (Hash32)
3842
import Unison.Share.API.Hash (HashJWTClaims (..))
3943
import Unison.SyncV2.API qualified as SyncV2
40-
import Unison.SyncV2.Types (DownloadEntitiesChunk (..), EntityChunk (..), ErrorChunk (..), StreamInitInfo (..))
44+
import Unison.SyncV2.Types (CausalDependenciesChunk (..), DependencyType (..), DownloadEntitiesChunk (..), EntityChunk (..), ErrorChunk (..), StreamInitInfo (..))
4145
import Unison.SyncV2.Types qualified as SyncV2
4246
import UnliftIO qualified
43-
import UnliftIO.Async qualified as Async
4447

4548
batchSize :: Int32
4649
batchSize = 1000
@@ -51,7 +54,8 @@ streamSettings rootCausalHash rootBranchRef = StreamInitInfo {version = SyncV2.V
5154
server :: Maybe UserId -> SyncV2.Routes WebAppServer
5255
server mayUserId =
5356
SyncV2.Routes
54-
{ downloadEntitiesStream = downloadEntitiesStreamImpl mayUserId
57+
{ downloadEntitiesStream = downloadEntitiesStreamImpl mayUserId,
58+
causalDependenciesStream = causalDependenciesStreamImpl mayUserId
5559
}
5660

5761
parseBranchRef :: SyncV2.BranchRef -> Either Text (Either ProjectReleaseShortHand ProjectBranchShortHand)
@@ -66,30 +70,16 @@ parseBranchRef (SyncV2.BranchRef branchRef) =
6670
parseRelease = fmap Left . eitherToMaybe $ IDs.fromText @ProjectReleaseShortHand branchRef
6771

6872
downloadEntitiesStreamImpl :: Maybe UserId -> SyncV2.DownloadEntitiesRequest -> WebApp (SourceIO (SyncV2.CBORStream SyncV2.DownloadEntitiesChunk))
69-
downloadEntitiesStreamImpl mayCallerUserId (SyncV2.DownloadEntitiesRequest {causalHash = causalHashJWT, branchRef, knownHashes = _todo}) = do
73+
downloadEntitiesStreamImpl mayCallerUserId (SyncV2.DownloadEntitiesRequest {causalHash = causalHashJWT, branchRef, knownHashes}) = do
7074
either emitErr id <$> runExceptT do
7175
addRequestTag "branch-ref" (SyncV2.unBranchRef branchRef)
7276
HashJWTClaims {hash = causalHash} <- lift (HashJWT.verifyHashJWT mayCallerUserId causalHashJWT >>= either respondError pure)
7377
codebase <-
74-
case parseBranchRef branchRef of
75-
Left err -> throwError (SyncV2.DownloadEntitiesInvalidBranchRef err branchRef)
76-
Right (Left (ProjectReleaseShortHand {userHandle, projectSlug})) -> do
77-
let projectShortHand = ProjectShortHand {userHandle, projectSlug}
78-
(Project {ownerUserId = projectOwnerUserId}, contributorId) <- ExceptT . PG.tryRunTransaction $ do
79-
project <- PGQ.projectByShortHand projectShortHand `whenNothingM` throwError (SyncV2.DownloadEntitiesProjectNotFound $ IDs.toText @ProjectShortHand projectShortHand)
80-
pure (project, Nothing)
81-
authZToken <- lift AuthZ.checkDownloadFromProjectBranchCodebase `whenLeftM` \_err -> throwError (SyncV2.DownloadEntitiesNoReadPermission branchRef)
82-
let codebaseLoc = Codebase.codebaseLocationForProjectBranchCodebase projectOwnerUserId contributorId
83-
pure $ Codebase.codebaseEnv authZToken codebaseLoc
84-
Right (Right (ProjectBranchShortHand {userHandle, projectSlug, contributorHandle})) -> do
85-
let projectShortHand = ProjectShortHand {userHandle, projectSlug}
86-
(Project {ownerUserId = projectOwnerUserId}, contributorId) <- ExceptT . PG.tryRunTransaction $ do
87-
project <- (PGQ.projectByShortHand projectShortHand) `whenNothingM` throwError (SyncV2.DownloadEntitiesProjectNotFound $ IDs.toText @ProjectShortHand projectShortHand)
88-
mayContributorUserId <- for contributorHandle \ch -> fmap user_id $ (PGQ.userByHandle ch) `whenNothingM` throwError (SyncV2.DownloadEntitiesUserNotFound $ IDs.toText @UserHandle ch)
89-
pure (project, mayContributorUserId)
90-
authZToken <- lift AuthZ.checkDownloadFromProjectBranchCodebase `whenLeftM` \_err -> throwError (SyncV2.DownloadEntitiesNoReadPermission branchRef)
91-
let codebaseLoc = Codebase.codebaseLocationForProjectBranchCodebase projectOwnerUserId contributorId
92-
pure $ Codebase.codebaseEnv authZToken codebaseLoc
78+
flip withExceptT (codebaseForBranchRef branchRef) \case
79+
CodebaseLoadingErrorProjectNotFound projectShortHand -> SyncV2.DownloadEntitiesProjectNotFound (IDs.toText projectShortHand)
80+
CodebaseLoadingErrorUserNotFound userHandle -> SyncV2.DownloadEntitiesUserNotFound (IDs.toText userHandle)
81+
CodebaseLoadingErrorNoReadPermission branchRef -> SyncV2.DownloadEntitiesNoReadPermission branchRef
82+
CodebaseLoadingErrorInvalidBranchRef err branchRef -> SyncV2.DownloadEntitiesInvalidBranchRef err branchRef
9383
q <- UnliftIO.atomically $ do
9484
q <- STM.newTBMQueue 10
9585
STM.writeTBMQueue q (Vector.singleton $ InitialC $ streamSettings causalHash (Just branchRef))
@@ -98,39 +88,107 @@ downloadEntitiesStreamImpl mayCallerUserId (SyncV2.DownloadEntitiesRequest {caus
9888
Logging.logInfoText "Starting download entities stream"
9989
Codebase.runCodebaseTransaction codebase $ do
10090
(_bhId, causalId) <- CausalQ.expectCausalIdsOf id (hash32ToCausalHash causalHash)
101-
cursor <- SSQ.allSerializedDependenciesOfCausalCursor causalId
91+
let knownCausalHashes = Set.map hash32ToCausalHash knownHashes
92+
cursor <- SSQ.allSerializedDependenciesOfCausalCursor causalId knownCausalHashes
10293
Cursor.foldBatched cursor batchSize \batch -> do
10394
let entityChunkBatch = batch <&> \(entityCBOR, hash) -> EntityC (EntityChunk {hash, entityCBOR})
10495
PG.transactionUnsafeIO $ STM.atomically $ STM.writeTBMQueue q entityChunkBatch
10596
PG.transactionUnsafeIO $ STM.atomically $ STM.closeTBMQueue q
10697
pure $ sourceIOWithAsync streamResults $ conduitToSourceIO do
107-
stream q
98+
queueToStream q
10899
where
109-
stream :: STM.TBMQueue (Vector DownloadEntitiesChunk) -> C.ConduitT () (SyncV2.CBORStream DownloadEntitiesChunk) IO ()
110-
stream q = do
111-
let loop :: C.ConduitT () (SyncV2.CBORStream DownloadEntitiesChunk) IO ()
112-
loop = do
113-
liftIO (STM.atomically (STM.readTBMQueue q)) >>= \case
114-
-- The queue is closed.
115-
Nothing -> do
116-
pure ()
117-
Just batches -> do
118-
batches
119-
& foldMap (CBOR.serialiseIncremental)
120-
& (SyncV2.CBORStream . Builder.toLazyByteString)
121-
& C.yield
122-
loop
123-
124-
loop
125-
126100
emitErr :: SyncV2.DownloadEntitiesError -> SourceIO (SyncV2.CBORStream SyncV2.DownloadEntitiesChunk)
127101
emitErr err = SourceT.source [SyncV2.CBORStream . CBOR.serialise $ ErrorC (ErrorChunk err)]
128102

103+
causalDependenciesStreamImpl :: Maybe UserId -> SyncV2.CausalDependenciesRequest -> WebApp (SourceIO (SyncV2.CBORStream SyncV2.CausalDependenciesChunk))
104+
causalDependenciesStreamImpl mayCallerUserId (SyncV2.CausalDependenciesRequest {rootCausal = causalHashJWT, branchRef}) = do
105+
respondExceptT do
106+
addRequestTag "branch-ref" (SyncV2.unBranchRef branchRef)
107+
HashJWTClaims {hash = causalHash} <- lift (HashJWT.verifyHashJWT mayCallerUserId causalHashJWT >>= either respondError pure)
108+
addRequestTag "root-causal" (tShow causalHash)
109+
codebase <- codebaseForBranchRef branchRef
110+
q <- UnliftIO.atomically $ STM.newTBMQueue 10
111+
streamResults <- lift $ UnliftIO.toIO do
112+
Logging.logInfoText "Starting causal dependencies stream"
113+
Codebase.runCodebaseTransaction codebase $ do
114+
(_bhId, causalId) <- CausalQ.expectCausalIdsOf id (hash32ToCausalHash causalHash)
115+
Debug.debugLogM Debug.Temp "Getting cursor"
116+
cursor <- SSQ.spineAndLibDependenciesOfCausalCursor causalId
117+
Debug.debugLogM Debug.Temp "Folding cursor"
118+
Cursor.foldBatched cursor batchSize \batch -> do
119+
Debug.debugLogM Debug.Temp "Got batch"
120+
let depBatch =
121+
batch <&> \(causalHash, isCausalSpine, isLibRoot) ->
122+
let dependencyType = case (isCausalSpine, isLibRoot) of
123+
(IsCausalSpine, _) -> CausalSpineDependency
124+
(_, IsLibRoot) -> LibDependency
125+
_ -> error $ "Causal dependency which is neither spine nor lib root: " <> show causalHash
126+
in CausalHashDepC {causalHash, dependencyType}
127+
PG.transactionUnsafeIO $ STM.atomically $ STM.writeTBMQueue q depBatch
128+
PG.transactionUnsafeIO $ STM.atomically $ STM.closeTBMQueue q
129+
pure $ sourceIOWithAsync streamResults $ conduitToSourceIO do
130+
queueToStream q
131+
132+
queueToStream :: forall a f. (CBOR.Serialise a, Foldable f) => STM.TBMQueue (f a) -> C.ConduitT () (SyncV2.CBORStream a) IO ()
133+
queueToStream q = do
134+
let loop :: C.ConduitT () (SyncV2.CBORStream a) IO ()
135+
loop = do
136+
liftIO (STM.atomically (STM.readTBMQueue q)) >>= \case
137+
-- The queue is closed.
138+
Nothing -> do
139+
pure ()
140+
Just batches -> do
141+
batches
142+
& foldMap (CBOR.serialiseIncremental)
143+
& (SyncV2.CBORStream . Builder.toLazyByteString)
144+
& C.yield
145+
loop
146+
loop
147+
148+
data CodebaseLoadingError
149+
= CodebaseLoadingErrorProjectNotFound ProjectShortHand
150+
| CodebaseLoadingErrorUserNotFound UserHandle
151+
| CodebaseLoadingErrorNoReadPermission SyncV2.BranchRef
152+
| CodebaseLoadingErrorInvalidBranchRef Text SyncV2.BranchRef
153+
deriving stock (Show)
154+
deriving (Logging.Loggable) via Logging.ShowLoggable Logging.UserFault CodebaseLoadingError
155+
156+
instance ToServerError CodebaseLoadingError where
157+
toServerError = \case
158+
CodebaseLoadingErrorProjectNotFound projectShortHand -> (ErrorID "codebase-loading:project-not-found", Servant.err404 {errBody = from . Text.encodeUtf8 $ "Project not found: " <> (IDs.toText projectShortHand)})
159+
CodebaseLoadingErrorUserNotFound userHandle -> (ErrorID "codebase-loading:user-not-found", Servant.err404 {errBody = from . Text.encodeUtf8 $ "User not found: " <> (IDs.toText userHandle)})
160+
CodebaseLoadingErrorNoReadPermission branchRef -> (ErrorID "codebase-loading:no-read-permission", Servant.err403 {errBody = from . Text.encodeUtf8 $ "No read permission for branch ref: " <> (SyncV2.unBranchRef branchRef)})
161+
CodebaseLoadingErrorInvalidBranchRef err branchRef -> (ErrorID "codebase-loading:invalid-branch-ref", Servant.err400 {errBody = from . Text.encodeUtf8 $ "Invalid branch ref: " <> err <> " " <> (SyncV2.unBranchRef branchRef)})
162+
163+
codebaseForBranchRef :: SyncV2.BranchRef -> (ExceptT CodebaseLoadingError WebApp Codebase.CodebaseEnv)
164+
codebaseForBranchRef branchRef = do
165+
case parseBranchRef branchRef of
166+
Left err -> throwError (CodebaseLoadingErrorInvalidBranchRef err branchRef)
167+
Right (Left (ProjectReleaseShortHand {userHandle, projectSlug})) -> do
168+
let projectShortHand = ProjectShortHand {userHandle, projectSlug}
169+
(Project {ownerUserId = projectOwnerUserId}, contributorId) <- ExceptT . PG.tryRunTransaction $ do
170+
project <- PGQ.projectByShortHand projectShortHand `whenNothingM` throwError (CodebaseLoadingErrorProjectNotFound $ projectShortHand)
171+
pure (project, Nothing)
172+
authZToken <- lift AuthZ.checkDownloadFromProjectBranchCodebase `whenLeftM` \_err -> throwError (CodebaseLoadingErrorNoReadPermission branchRef)
173+
let codebaseLoc = Codebase.codebaseLocationForProjectBranchCodebase projectOwnerUserId contributorId
174+
pure $ Codebase.codebaseEnv authZToken codebaseLoc
175+
Right (Right (ProjectBranchShortHand {userHandle, projectSlug, contributorHandle})) -> do
176+
let projectShortHand = ProjectShortHand {userHandle, projectSlug}
177+
(Project {ownerUserId = projectOwnerUserId}, contributorId) <- ExceptT . PG.tryRunTransaction $ do
178+
project <- (PGQ.projectByShortHand projectShortHand) `whenNothingM` throwError (CodebaseLoadingErrorProjectNotFound projectShortHand)
179+
mayContributorUserId <- for contributorHandle \ch -> fmap user_id $ (PGQ.userByHandle ch) `whenNothingM` throwError (CodebaseLoadingErrorUserNotFound ch)
180+
pure (project, mayContributorUserId)
181+
authZToken <- lift AuthZ.checkDownloadFromProjectBranchCodebase `whenLeftM` \_err -> throwError (CodebaseLoadingErrorNoReadPermission branchRef)
182+
let codebaseLoc = Codebase.codebaseLocationForProjectBranchCodebase projectOwnerUserId contributorId
183+
pure $ Codebase.codebaseEnv authZToken codebaseLoc
184+
129185
-- | Run an IO action in the background while streaming the results.
130186
--
131187
-- Servant doesn't provide any easier way to do bracketing like this, all the IO must be
132188
-- inside the SourceIO somehow.
133189
sourceIOWithAsync :: IO a -> SourceIO r -> SourceIO r
134190
sourceIOWithAsync action (SourceT k) =
135191
SourceT \k' ->
136-
Async.withAsync action \_ -> k k'
192+
Ki.scoped \scope -> do
193+
_ <- Ki.fork scope action
194+
k k'

0 commit comments

Comments
 (0)