Skip to content

Commit

Permalink
DataSync: support column names that are not in snake_case
Browse files Browse the repository at this point in the history
  • Loading branch information
mpscholten committed Oct 26, 2024
1 parent 9bbe416 commit 96d41a5
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 24 deletions.
3 changes: 2 additions & 1 deletion IHP/DataSync/Controller.hs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import IHP.DataSync.RowLevelSecurity
import qualified Database.PostgreSQL.Simple.ToField as PG
import qualified IHP.DataSync.ChangeNotifications as ChangeNotifications
import IHP.DataSync.ControllerImpl (runDataSyncController, cleanupAllSubscriptions)
import IHP.DataSync.DynamicQueryCompiler (camelCaseRenamer)

instance (
PG.ToField (PrimaryKey (GetTableName CurrentUserRecord))
Expand All @@ -21,5 +22,5 @@ instance (
run = do
ensureRLSEnabled <- makeCachedEnsureRLSEnabled
installTableChangeTriggers <- ChangeNotifications.makeCachedInstallTableChangeTriggers
runDataSyncController ensureRLSEnabled installTableChangeTriggers (receiveData @ByteString) sendJSON (\_ _ -> pure ())
runDataSyncController ensureRLSEnabled installTableChangeTriggers (receiveData @ByteString) sendJSON (\_ _ -> pure ()) (\_ -> camelCaseRenamer)
onClose = cleanupAllSubscriptions
55 changes: 34 additions & 21 deletions IHP/DataSync/ControllerImpl.hs
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,11 @@ runDataSyncController ::
, Typeable CurrentUserRecord
, HasNewSessionUrl CurrentUserRecord
, Show (PrimaryKey (GetTableName CurrentUserRecord))
) => EnsureRLSEnabledFn -> InstallTableChangeTriggerFn -> IO ByteString -> SendJSONFn -> HandleCustomMessageFn -> IO ()
runDataSyncController ensureRLSEnabled installTableChangeTriggers receiveData sendJSON handleCustomMessage = do
) => EnsureRLSEnabledFn -> InstallTableChangeTriggerFn -> IO ByteString -> SendJSONFn -> HandleCustomMessageFn -> (Text -> Renamer) -> IO ()
runDataSyncController ensureRLSEnabled installTableChangeTriggers receiveData sendJSON handleCustomMessage renamer = do
setState DataSyncReady { subscriptions = HashMap.empty, transactions = HashMap.empty, asyncs = [] }

let handleMessage = buildMessageHandler ensureRLSEnabled installTableChangeTriggers sendJSON handleCustomMessage
let handleMessage :: DataSyncMessage -> IO () = buildMessageHandler ensureRLSEnabled installTableChangeTriggers sendJSON handleCustomMessage renamer

forever do
message <- Aeson.eitherDecodeStrict' <$> receiveData
Expand Down Expand Up @@ -91,17 +91,18 @@ buildMessageHandler ::
, HasNewSessionUrl CurrentUserRecord
, Show (PrimaryKey (GetTableName CurrentUserRecord))
)
=> EnsureRLSEnabledFn -> InstallTableChangeTriggerFn -> SendJSONFn -> HandleCustomMessageFn -> (DataSyncMessage -> IO ())
buildMessageHandler ensureRLSEnabled installTableChangeTriggers sendJSON handleCustomMessage = handleMessage
=> EnsureRLSEnabledFn -> InstallTableChangeTriggerFn -> SendJSONFn -> HandleCustomMessageFn -> (Text -> Renamer) -> (DataSyncMessage -> IO ())
buildMessageHandler ensureRLSEnabled installTableChangeTriggers sendJSON handleCustomMessage renamer = handleMessage
where
pgListener = ?applicationContext.pgListener
handleMessage :: DataSyncMessage -> IO ()
handleMessage DataSyncQuery { query, requestId, transactionId } = do
ensureRLSEnabled (query.table)

let (theQuery, theParams) = compileQuery query
let (theQuery, theParams) = compileQueryWithRenamer (renamer query.table) query

result :: [[Field]] <- sqlQueryWithRLSAndTransactionId transactionId theQuery theParams
rawResult :: [[Field]] <- sqlQueryWithRLSAndTransactionId transactionId theQuery theParams
let result = map (map (renameField (renamer query.table))) rawResult

sendJSON DataSyncResult { result, requestId }

Expand All @@ -118,16 +119,17 @@ buildMessageHandler ensureRLSEnabled installTableChangeTriggers sendJSON handleC
close <- MVar.newEmptyMVar
atomicModifyIORef'' ?state (\state -> state |> modify #subscriptions (HashMap.insert subscriptionId close))

let (theQuery, theParams) = compileQuery query
let (theQuery, theParams) = compileQueryWithRenamer (renamer query.table) query

result :: [[Field]] <- sqlQueryWithRLS theQuery theParams
rawResult :: [[Field]] <- sqlQueryWithRLS theQuery theParams
let result = map (map (renameField (renamer query.table))) rawResult

let tableName = query.table

-- We need to keep track of all the ids of entities we're watching to make
-- sure that we only send update notifications to clients that can actually
-- access the record (e.g. if a RLS policy denies access)
let watchedRecordIds = recordIds result
let watchedRecordIds = recordIds rawResult

-- Store it in IORef as an INSERT requires us to add an id
watchedRecordIdsRef <- newIORef (Set.fromList watchedRecordIds)
Expand All @@ -149,11 +151,12 @@ buildMessageHandler ensureRLSEnabled installTableChangeTriggers sendJSON handleC
newRecord :: [[Field]] <- sqlQueryWithRLS ("SELECT * FROM (" <> theQuery <> ") AS records WHERE records.id = ? LIMIT 1") (theParams <> [PG.toField id])
case headMay newRecord of
Just record -> do
Just rawRecord -> do
-- Add the new record to 'watchedRecordIdsRef'
-- Otherwise the updates and deletes will not be dispatched to the client
modifyIORef' watchedRecordIdsRef (Set.insert id)
let record = map (renameField (renamer tableName)) rawRecord
sendJSON DidInsert { subscriptionId, record }
Nothing -> pure ()
ChangeNotifications.DidUpdate { id, changeSet } -> do
Expand All @@ -167,7 +170,7 @@ buildMessageHandler ensureRLSEnabled installTableChangeTriggers sendJSON handleC
changes <- ChangeNotifications.retrieveChanges changeSet
if isRecordInResultSet
then sendJSON DidUpdate { subscriptionId, id, changeSet = changesToValue changes }
then sendJSON DidUpdate { subscriptionId, id, changeSet = changesToValue (renamer tableName) changes }
else sendJSON DidDelete { subscriptionId, id }
ChangeNotifications.DidDelete { id } -> do
-- Only send the notifcation if the deleted record was part of the initial
Expand Down Expand Up @@ -202,7 +205,7 @@ buildMessageHandler ensureRLSEnabled installTableChangeTriggers sendJSON handleC
let query = "INSERT INTO ? ? VALUES ? RETURNING *"
let columns = record
|> HashMap.keys
|> map fieldNameToColumnName
|> map (renamer table).fieldToColumn
let values = record
|> HashMap.elems
Expand All @@ -213,7 +216,11 @@ buildMessageHandler ensureRLSEnabled installTableChangeTriggers sendJSON handleC
result :: [[Field]] <- sqlQueryWithRLSAndTransactionId transactionId query params
case result of
[record] -> sendJSON DidCreateRecord { requestId, record }
[rawRecord] ->
let
record = map (renameField (renamer table)) rawRecord
in
sendJSON DidCreateRecord { requestId, record }
otherwise -> error "Unexpected result in CreateRecordMessage handler"
pure ()
Expand All @@ -228,7 +235,7 @@ buildMessageHandler ensureRLSEnabled installTableChangeTriggers sendJSON handleC
Just value -> value
Nothing -> error "Atleast one record is required"
|> HashMap.keys
|> map fieldNameToColumnName
|> map (renamer table).fieldToColumn
let values = records
|> map (\object ->
Expand All @@ -240,7 +247,8 @@ buildMessageHandler ensureRLSEnabled installTableChangeTriggers sendJSON handleC
let params = (PG.Identifier table, PG.In (map PG.Identifier columns), PG.Values [] values)
records :: [[Field]] <- sqlQueryWithRLSAndTransactionId transactionId query params
rawRecords :: [[Field]] <- sqlQueryWithRLSAndTransactionId transactionId query params
let records = map (map (renameField (renamer table))) rawRecords
sendJSON DidCreateRecords { requestId, records }
Expand Down Expand Up @@ -272,7 +280,11 @@ buildMessageHandler ensureRLSEnabled installTableChangeTriggers sendJSON handleC
result :: [[Field]] <- sqlQueryWithRLSAndTransactionId transactionId (PG.Query query) params
case result of
[record] -> sendJSON DidUpdateRecord { requestId, record }
[rawRecord] ->
let
record = map (renameField (renamer table)) rawRecord
in
sendJSON DidUpdateRecord { requestId, record }
otherwise -> error "Could not apply the update to the given record. Are you sure the record ID you passed is correct? If the record ID is correct, likely the row level security policy is not making the record visible to the UPDATE operation."
pure ()
Expand Down Expand Up @@ -300,7 +312,8 @@ buildMessageHandler ensureRLSEnabled installTableChangeTriggers sendJSON handleC
<> (join (map (\(key, value) -> [PG.toField key, value]) keyValues))
<> [PG.toField (PG.In ids)]
records <- sqlQueryWithRLSAndTransactionId transactionId (PG.Query query) params
rawRecords <- sqlQueryWithRLSAndTransactionId transactionId (PG.Query query) params
let records = map (map (renameField (renamer table))) rawRecords
sendJSON DidUpdateRecords { requestId, records }
Expand Down Expand Up @@ -380,10 +393,10 @@ cleanupAllSubscriptions = do
DataSyncReady { asyncs } -> forEach asyncs uninterruptibleCancel
_ -> pure ()
changesToValue :: [ChangeNotifications.Change] -> Value
changesToValue changes = object (map changeToPair changes)
changesToValue :: Renamer -> [ChangeNotifications.Change] -> Value
changesToValue renamer changes = object (map changeToPair changes)
where
changeToPair ChangeNotifications.Change { col, new } = (Aeson.fromText $ columnNameToFieldName col) .= new
changeToPair ChangeNotifications.Change { col, new } = (Aeson.fromText $ renamer.columnToField col) .= new
runInModelContextWithTransaction :: (?state :: IORef DataSyncController, ?modelContext :: ModelContext) => ((?modelContext :: ModelContext) => IO result) -> Maybe UUID -> IO result
runInModelContextWithTransaction function (Just transactionId) = do
Expand Down
2 changes: 1 addition & 1 deletion IHP/DataSync/DynamicQuery.hs
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ instance PG.FromField Field where
pure Field { .. }
where
fieldName = (PG.name field)
|> fmap (columnNameToFieldName . cs)
|> fmap cs
|> fromMaybe ""

instance PG.FromField DynamicValue where
Expand Down
34 changes: 33 additions & 1 deletion IHP/DataSync/DynamicQueryCompiler.hs
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,40 @@ import qualified Database.PostgreSQL.Simple.ToField as PG
import qualified Database.PostgreSQL.Simple.Types as PG
import qualified Data.List as List

data Renamer = Renamer
{ fieldToColumn :: Text -> Text
, columnToField :: Text -> Text
}

compileQuery :: DynamicSQLQuery -> (PG.Query, [PG.Action])
compileQuery query = compileQueryMapped (mapColumnNames fieldNameToColumnName query)
compileQuery = compileQueryWithRenamer camelCaseRenamer

compileQueryWithRenamer :: Renamer -> DynamicSQLQuery -> (PG.Query, [PG.Action])
compileQueryWithRenamer renamer query = compileQueryMapped (mapColumnNames renamer.fieldToColumn query)

-- | Default renamer used by DataSync.
--
-- Transforms JS inputs in @camelCase@ to snake_case for the database
-- and DB outputs in @snake_case@ back to @camelCase@
camelCaseRenamer :: Renamer
camelCaseRenamer =
Renamer
{ fieldToColumn = fieldNameToColumnName
, columnToField = columnNameToFieldName
}

-- | Renamer that does not modify the column names
unmodifiedRenamer :: Renamer
unmodifiedRenamer =
Renamer
{ fieldToColumn = id
, columnToField = id
}

-- | When a Field is retrieved from the database, it's all in @snake_case@. This turns it into @camelCase@
renameField :: Renamer -> Field -> Field
renameField renamer field =
field { fieldName = renamer.columnToField field.fieldName }

compileQueryMapped :: DynamicSQLQuery -> (PG.Query, [PG.Action])
compileQueryMapped DynamicSQLQuery { .. } = (sql, args)
Expand Down

0 comments on commit 96d41a5

Please sign in to comment.