diff --git a/Guide/auto-refresh.markdown b/Guide/auto-refresh.markdown index 4877f175f..1fd10c573 100644 --- a/Guide/auto-refresh.markdown +++ b/Guide/auto-refresh.markdown @@ -98,6 +98,15 @@ action MyAction = do -- <-- We don't enable auto refresh at the action start in render MyView { expensiveModels, cheap } ``` +### Smart Filtering + +`autoRefresh` automatically tracks both the row IDs and the WHERE conditions of queries fetched during your action. This is used to skip unnecessary re-renders: + +- **UPDATE / DELETE**: When a notification arrives for a row whose ID is not in the tracked set, the re-render is skipped. +- **INSERT**: When a new row is inserted, IHP evaluates your query's WHERE conditions against the inserted row. If the new row doesn't match your filters, the re-render is skipped. For example, if your action fetches `query @Task |> filterWhere (#projectId, myProjectId) |> fetch`, inserting a task with a different `projectId` will not trigger a re-render. + +This happens transparently — no configuration needed. For tables accessed via raw SQL or `fetchCount` (where individual row IDs aren't available), auto refresh falls back to refreshing on every change. Conditions that can't be evaluated at notification time (e.g. `LIKE`, `LOWER()`, range operators) also fall back to refreshing. + ### Custom SQL Queries with Auto Refresh Auto Refresh automatically tracks all tables your action is using by hooking itself into the Query Builder and `fetch` functions. diff --git a/ihp/IHP/AutoRefresh.hs b/ihp/IHP/AutoRefresh.hs index dd93e8116..588c062be 100644 --- a/ihp/IHP/AutoRefresh.hs +++ b/ihp/IHP/AutoRefresh.hs @@ -3,17 +3,41 @@ Module: IHP.AutoRefresh Description: Provides automatically diff-based refreshing views after page load Copyright: (c) digitally induced GmbH, 2020 -} -module IHP.AutoRefresh where +module IHP.AutoRefresh +( autoRefresh +, registerNotificationTrigger +, shouldRefreshForPayload +, matchesInsertPayload +, matchesInsertPayloadDynamic +, extractRowId +, lookupColumn +, jsonValueMatchesText +, getAvailableSessions +, getSessionById +, updateSession +, gcSessions +, channelName +, notificationTriggerStatements +, resolveAutoRefreshPayload +, autoRefreshStateVaultKey +, globalAutoRefreshServerVar +, AutoRefreshWSApp (..) +) where import IHP.Prelude import IHP.AutoRefresh.Types import IHP.ControllerSupport +import qualified Data.Aeson as Aeson +import qualified Data.Aeson.KeyMap as AesonKeyMap +import qualified Data.Aeson.Key as AesonKey import qualified Data.UUID.V4 as UUID import qualified Data.UUID as UUID import IHP.Controller.Session import qualified Network.Wai.Internal as Wai import qualified Data.Binary.Builder as ByteString import qualified Data.Set as Set +import qualified Data.Map.Strict as Map +import qualified Database.PostgreSQL.Simple.Types as PG import IHP.ModelSupport import qualified Control.Exception as Exception import qualified Control.Concurrent.MVar as MVar @@ -32,6 +56,8 @@ import qualified Data.TMap as TypeMap import IHP.RequestVault (pgListenerVaultKey) import IHP.FrameworkConfig.Types (FrameworkConfig(..)) import IHP.Environment (Environment(..)) +import Data.Dynamic (Dynamic, fromDynamic) +import IHP.QueryBuilder.Types (Condition(..), ConditionValue(..), FilterOperator(..), getParamPrinterText) {-# NOINLINE globalAutoRefreshServerVar #-} globalAutoRefreshServerVar :: MVar.MVar (Maybe (IORef AutoRefreshServer)) @@ -107,6 +133,8 @@ autoRefresh runAction = do let handleResponse exception@(ResponseException response) = case response of Wai.ResponseBuilder status headers builder -> do tables <- readIORef ?touchedTables + trackedIds <- readIORef ?trackedIds + trackedConditions <- readIORef ?trackedConditions lastPing <- getCurrentTime -- It's important that we evaluate the response to HNF here @@ -124,14 +152,14 @@ autoRefresh runAction = do lastResponse <- Exception.evaluate (ByteString.toLazyByteString builder) event <- MVar.newEmptyMVar - let session = AutoRefreshSession { id, renderView, event, tables, lastResponse, lastPing } + let session = AutoRefreshSession { id, renderView, event, tables, lastResponse, lastPing, trackedIds, trackedConditions } modifyIORef' autoRefreshServer (\s -> s { sessions = session:s.sessions } ) async (gcSessions autoRefreshServer) registerNotificationTrigger ?touchedTables autoRefreshServer throw exception - _ -> error "Unimplemented WAI response type." + _ -> error "Unimplemented WAI response type." runAction `Exception.catch` handleResponse @@ -161,14 +189,13 @@ instance WSApp AutoRefreshWSApp where let handleOtherException :: SomeException -> IO () handleOtherException ex = Log.error ("AutoRefresh: Failed to re-render view: " <> tshow ex) + let currentRequest = ?request + let dummyRespond _ = error "AutoRefresh: respond should not be called directly" + let onRender = (renderView currentRequest dummyRespond) `catch` handleResponseException + async $ forever do MVar.takeMVar event - let currentRequest = ?request - -- Create a dummy respond function that does nothing, since actual response - -- is handled by the handleResponseException handler - let dummyRespond _ = error "AutoRefresh: respond should not be called directly" - ((renderView currentRequest dummyRespond) `catch` handleResponseException) `catch` handleOtherException - pure () + onRender `catch` handleOtherException pure () @@ -188,13 +215,18 @@ instance WSApp AutoRefreshWSApp where modifyIORef' autoRefreshServer (\server -> server { sessions = filter (\AutoRefreshSession { id } -> id /= sessionId) server.sessions }) AwaitingSessionID -> pure () - +-- | Registers row-level triggers with smart ID-based filtering. +-- +-- Uses row-level PostgreSQL triggers that include the changed row data in the notification payload. +-- For UPDATE\/DELETE: extracts the row ID from the payload and checks if it's in the tracked set. +-- For INSERT: always refreshes (we can't know if the new row matches without re-querying). +-- For tables without ID tracking (raw SQL, fetchCount): always refreshes. registerNotificationTrigger :: (?modelContext :: ModelContext, ?context :: ControllerContext) => IORef (Set Text) -> IORef AutoRefreshServer -> IO () registerNotificationTrigger touchedTablesVar autoRefreshServer = do touchedTables <- Set.toList <$> readIORef touchedTablesVar subscribedTables <- (.subscribedTables) <$> (autoRefreshServer |> readIORef) - let subscriptionRequired = touchedTables |> filter (\table -> subscribedTables |> Set.notMember table) + let subscriptionRequired = touchedTables |> filter (\table -> table `Set.notMember` subscribedTables) -- In development, always re-run trigger SQL for all touched tables because -- `make db` drops and recreates the database, destroying triggers that were @@ -206,30 +238,171 @@ registerNotificationTrigger touchedTablesVar autoRefreshServer = do pgListener <- (.pgListener) <$> readIORef autoRefreshServer subscriptions <- subscriptionRequired |> mapM (\table -> do - -- We need to add the trigger from the main IHP database role other we will get this error: - -- ERROR: permission denied for schema public withRowLevelSecurityDisabled do let pool = ?modelContext.hasqlPool - runSessionHasql pool (HasqlSession.script (notificationTriggerSQL table)) + runSessionHasql pool (mapM_ HasqlSession.script (notificationTriggerStatements table)) - pgListener |> PGListener.subscribe (channelName table) \notification -> do - sessions <- (.sessions) <$> readIORef autoRefreshServer - sessions - |> filter (\session -> table `Set.member` session.tables) - |> map (\session -> session.event) - |> mapM (\event -> MVar.tryPutMVar event ()) - pure ()) + pgListener |> PGListener.subscribeJSON (channelName table) (\payload -> do + resolvedPayload <- resolveAutoRefreshPayload payload + sessions <- (.sessions) <$> readIORef autoRefreshServer + sessions |> mapM_ (handleSmartRowChange table resolvedPayload) + pure ())) -- Re-run trigger SQL for already-subscribed tables in dev mode when isDevelopment do - let alreadySubscribed = touchedTables |> filter (\table -> subscribedTables |> Set.member table) + let alreadySubscribed = touchedTables |> filter (`Set.member` subscribedTables) forM_ alreadySubscribed \table -> do withRowLevelSecurityDisabled do let pool = ?modelContext.hasqlPool - runSessionHasql pool (HasqlSession.script (notificationTriggerSQL table)) + runSessionHasql pool (mapM_ HasqlSession.script (notificationTriggerStatements table)) modifyIORef' autoRefreshServer (\s -> s { subscriptions = s.subscriptions <> subscriptions }) pure () + where + handleSmartRowChange table resolvedPayload session@AutoRefreshSession { tables, event, trackedIds, trackedConditions } + | table `Set.member` tables = do + let conditions = Map.lookup table trackedConditions + let shouldRefreshNow = case Map.lookup table trackedIds of + Nothing -> True -- table not tracked with IDs (raw SQL, fetchCount, etc.) + Just ids | Set.null ids -> True -- empty ID set means can't filter + Just ids -> case resolvedPayload of + Nothing -> True -- payload resolution failed, refresh to be safe + Just payload -> shouldRefreshForPayload ids conditions payload + when shouldRefreshNow $ + MVar.tryPutMVar event () >> pure () + | otherwise = pure () + +-- | Determines whether a notification payload should trigger a refresh based on tracked IDs +-- and WHERE conditions. +-- +-- For INSERT: evaluates the INSERT payload against tracked WHERE conditions. If ANY condition +-- set matches the inserted row, we refresh. If no conditions are tracked, we refresh (safe fallback). +-- For UPDATE\/DELETE: only refresh if the row's ID is in our tracked set. +-- +-- Note: For UPDATE, this means that if an UPDATE causes a row to newly match a WHERE filter +-- (e.g. a status change), the refresh will be skipped if that row wasn't already tracked. +-- This is an acceptable tradeoff: the page will catch up on the next INSERT or on the next +-- full page load. +shouldRefreshForPayload :: Set Text -> Maybe [Maybe Dynamic] -> AutoRefreshRowChangePayload -> Bool +shouldRefreshForPayload trackedIds maybeConditions payload = + case payload.payloadOperation of + AutoRefreshInsert -> case maybeConditions of + Nothing -> True -- no condition tracking for this table + Just conditions -> any (matchesInsertPayloadDynamic newRow) conditions + where + newRow = case payload.payloadNewRow of + Just (Aeson.Object obj) -> obj + _ -> AesonKeyMap.empty + _ -> case extractRowId payload of + Nothing -> True -- can't extract ID, refresh to be safe + Just rowId -> rowId `Set.member` trackedIds + +-- | Extracts the row ID from a notification payload. +-- +-- Looks for an "id" field in either the new or old row JSON. +extractRowId :: AutoRefreshRowChangePayload -> Maybe Text +extractRowId payload = + let row = payload.payloadNewRow <|> payload.payloadOldRow + in row >>= \case + Aeson.Object obj -> case AesonKeyMap.lookup "id" obj of + Just (Aeson.String s) -> Just s + Just (Aeson.Number n) -> Just (tshow (round n :: Integer)) + _ -> Nothing + _ -> Nothing + +-- | Unwraps a 'Dynamic'-wrapped 'Maybe Condition' and evaluates it against an INSERT payload. +-- +-- Returns 'True' (refresh) when: +-- * The 'Dynamic' cannot be cast to 'Condition' (unexpected type, safe fallback) +-- * The condition is 'Nothing' (unfiltered query) +-- * The condition matches the payload +matchesInsertPayloadDynamic :: AesonKeyMap.KeyMap Aeson.Value -> Maybe Dynamic -> Bool +matchesInsertPayloadDynamic _ Nothing = True -- no condition (unfiltered query) +matchesInsertPayloadDynamic newRow (Just condDynamic) = + case fromDynamic condDynamic of + Nothing -> True -- can't cast, safe fallback + Just condition -> matchesInsertPayload condition newRow + +-- | Evaluates a 'Condition' tree against a JSON object (the inserted row). +-- +-- For each 'ColumnCondition': +-- * Strips the table prefix from the column name (e.g. @\"tasks.project_id\"@ → @\"project_id\"@) +-- * Extracts text values from the 'ConditionValue' via 'getParamPrinterText' +-- * Conditions with 'applyLeft' or 'applyRight' (e.g. @LOWER(col)@) cannot be evaluated → 'True' +-- * Unsupported operators (LIKE, regex, range comparisons, etc.) → 'True' (safe fallback) +-- +-- For compound conditions: +-- * 'AndCondition': both sub-conditions must match +-- * 'OrCondition': at least one sub-condition must match +matchesInsertPayload :: Condition -> AesonKeyMap.KeyMap Aeson.Value -> Bool +matchesInsertPayload (AndCondition a b) row = matchesInsertPayload a row && matchesInsertPayload b row +matchesInsertPayload (OrCondition a b) row = matchesInsertPayload a row || matchesInsertPayload b row +matchesInsertPayload (ColumnCondition col op value applyLeft applyRight) row + -- Can't evaluate conditions with SQL transforms (LOWER, etc.) + | isJust applyLeft || isJust applyRight = True + | otherwise = case op of + EqOp -> matchEq col value row + IsOp -> matchIs col value row + -- For all other operators (InOp, NotEq, Like, regex, range, etc.), + -- we can't reliably evaluate → safe fallback to refresh. + -- Note: InOp encodes the entire list as a single PostgreSQL array + -- parameter, so getParamPrinterText returns one string like + -- @{\"abc\",\"def\"}@ rather than individual values. We cannot + -- reliably decompose this, so we fall back to always refreshing. + _ -> True + +-- | Match a column = value condition against a JSON row. +matchEq :: Text -> ConditionValue -> AesonKeyMap.KeyMap Aeson.Value -> Bool +matchEq col (Param params) row = + case getParamPrinterText params of + [filterText] -> jsonValueMatchesText (lookupColumn col row) filterText + _ -> True -- multi-value or empty, can't evaluate +matchEq col (Literal text) row = + jsonValueMatchesText (lookupColumn col row) text + +-- | Match a column IS value condition (typically IS NULL). +matchIs :: Text -> ConditionValue -> AesonKeyMap.KeyMap Aeson.Value -> Bool +matchIs col (Literal text) row + | Text.toLower text == "null" = + case lookupColumn col row of + Nothing -> True -- column not in payload, safe fallback + Just Aeson.Null -> True -- matches IS NULL + Just _ -> False -- non-null, doesn't match IS NULL + | otherwise = True -- IS NOT NULL or other, safe fallback +matchIs _ (Param _) _ = True -- unexpected, safe fallback + +-- | Look up a column in the JSON row, stripping any table prefix. +-- +-- E.g. @\"tasks.project_id\"@ looks up key @\"project_id\"@. +lookupColumn :: Text -> AesonKeyMap.KeyMap Aeson.Value -> Maybe Aeson.Value +lookupColumn col row = + let colName = case Text.breakOnEnd "." col of + ("", c) -> c + (_, c) -> c + in AesonKeyMap.lookup (AesonKey.fromText colName) row + +-- | Compare a JSON value against a text representation from the encoder printer. +-- +-- The hasql printer quotes text values (e.g. @"\"abc\""@) but leaves UUIDs and +-- numbers unquoted. 'unquote' strips surrounding double-quotes so that string +-- comparisons work correctly. +jsonValueMatchesText :: Maybe Aeson.Value -> Text -> Bool +jsonValueMatchesText Nothing _ = True -- column not found, can't evaluate → refresh +jsonValueMatchesText (Just jsonVal) filterText = case jsonVal of + Aeson.String s -> s == unquote filterText + Aeson.Number n -> tshow (round n :: Integer) == filterText || tshow n == filterText + Aeson.Bool b -> (if b then "true" else "false") == Text.toLower filterText + || (if b then "t" else "f") == Text.toLower filterText + Aeson.Null -> Text.toLower filterText == "null" + _ -> True -- arrays, objects — can't compare, safe fallback + where + -- | Strip surrounding double-quotes added by the hasql printer for text values. + unquote t + | Text.length t >= 2 + , Text.head t == '"' + , Text.last t == '"' + = Text.init (Text.tail t) + | otherwise = t -- | Returns the ids of all sessions available to the client based on what sessions are found in the session cookie getAvailableSessions :: (?request :: Request) => IORef AutoRefreshServer -> IO [UUID] @@ -276,38 +449,85 @@ isSessionExpired now AutoRefreshSession { lastPing } = (now `diffUTCTime` lastPi -- | Returns the event name of the event that the pg notify trigger dispatches channelName :: Text -> ByteString -channelName tableName = "ar_did_change_" <> cs tableName +channelName tableName = "ar_did_change_row_" <> cs tableName --- | Returns a SQL script to set up database notification triggers. +-- | Returns a list of SQL statements to set up row-level database notification triggers. -- --- Wrapped in a DO $$ block with EXCEPTION handler because concurrent requests --- can race to CREATE OR REPLACE the same function, causing PostgreSQL to throw --- 'tuple concurrently updated' (SQLSTATE XX000). This is safe to ignore: the --- other connection's CREATE OR REPLACE will have succeeded. -notificationTriggerSQL :: Text -> Text -notificationTriggerSQL tableName = - "DO $$\n" - <> "BEGIN\n" - <> " CREATE OR REPLACE FUNCTION " <> functionName <> "() RETURNS TRIGGER AS $BODY$" +-- These triggers send a JSON payload with the operation type and old\/new row data via pg_notify. +-- For payloads exceeding ~8KB, the full JSON is stored in @large_pg_notifications@ and only +-- a reference ID is sent in the notification. +notificationTriggerStatements :: Text -> [Text] +notificationTriggerStatements tableName = + [ "BEGIN" + , "CREATE UNLOGGED TABLE IF NOT EXISTS public.large_pg_notifications (" + <> "id UUID DEFAULT uuid_generate_v4() PRIMARY KEY NOT NULL, " + <> "payload TEXT DEFAULT NULL, " + <> "created_at TIMESTAMP WITH TIME ZONE DEFAULT now() NOT NULL" + <> ")" + , "CREATE INDEX IF NOT EXISTS large_pg_notifications_created_at_index ON public.large_pg_notifications (created_at)" + , "CREATE OR REPLACE FUNCTION " <> functionName <> "() RETURNS TRIGGER AS $$" + <> "DECLARE\n" + <> " payload TEXT;\n" + <> " large_pg_notification_id UUID;\n" <> "BEGIN\n" - <> " PERFORM pg_notify('" <> cs (channelName tableName) <> "', '');\n" - <> " RETURN new;\n" + <> " IF (TG_OP = 'DELETE') THEN\n" + <> " payload := jsonb_build_object('op', lower(TG_OP), 'old', to_jsonb(OLD))::text;\n" + <> " IF octet_length(payload) > 7800 THEN\n" + <> " INSERT INTO public.large_pg_notifications (payload) VALUES (payload) RETURNING id INTO large_pg_notification_id;\n" + <> " payload := jsonb_build_object('op', lower(TG_OP), 'payloadId', large_pg_notification_id::text)::text;\n" + <> " DELETE FROM public.large_pg_notifications WHERE created_at < CURRENT_TIMESTAMP - interval '30s';\n" + <> " END IF;\n" + <> " PERFORM pg_notify('" <> cs (channelName tableName) <> "', payload);\n" + <> " RETURN OLD;\n" + <> " ELSE\n" + <> " IF (TG_OP = 'UPDATE') THEN\n" + <> " payload := jsonb_build_object('op', lower(TG_OP), 'old', to_jsonb(OLD), 'new', to_jsonb(NEW))::text;\n" + <> " ELSE\n" + <> " payload := jsonb_build_object('op', lower(TG_OP), 'new', to_jsonb(NEW))::text;\n" + <> " END IF;\n" + <> " IF octet_length(payload) > 7800 THEN\n" + <> " INSERT INTO public.large_pg_notifications (payload) VALUES (payload) RETURNING id INTO large_pg_notification_id;\n" + <> " payload := jsonb_build_object('op', lower(TG_OP), 'payloadId', large_pg_notification_id::text)::text;\n" + <> " DELETE FROM public.large_pg_notifications WHERE created_at < CURRENT_TIMESTAMP - interval '30s';\n" + <> " END IF;\n" + <> " PERFORM pg_notify('" <> cs (channelName tableName) <> "', payload);\n" + <> " RETURN NEW;\n" + <> " END IF;\n" <> "END;\n" - <> "$BODY$ language plpgsql;\n" - <> " DROP TRIGGER IF EXISTS " <> insertTriggerName <> " ON " <> tableName <> ";\n" - <> " CREATE TRIGGER " <> insertTriggerName <> " AFTER INSERT ON \"" <> tableName <> "\" FOR EACH STATEMENT EXECUTE PROCEDURE " <> functionName <> "();\n" - <> " DROP TRIGGER IF EXISTS " <> updateTriggerName <> " ON " <> tableName <> ";\n" - <> " CREATE TRIGGER " <> updateTriggerName <> " AFTER UPDATE ON \"" <> tableName <> "\" FOR EACH STATEMENT EXECUTE PROCEDURE " <> functionName <> "();\n" - <> " DROP TRIGGER IF EXISTS " <> deleteTriggerName <> " ON " <> tableName <> ";\n" - <> " CREATE TRIGGER " <> deleteTriggerName <> " AFTER DELETE ON \"" <> tableName <> "\" FOR EACH STATEMENT EXECUTE PROCEDURE " <> functionName <> "();\n" - <> "EXCEPTION\n" - <> " WHEN SQLSTATE 'XX000' THEN null; -- 'tuple concurrently updated': another connection installed it first\n" - <> "END; $$" + <> "$$ language plpgsql" + , "DROP TRIGGER IF EXISTS " <> insertTriggerName <> " ON " <> tableName + , "CREATE TRIGGER " <> insertTriggerName <> " AFTER INSERT ON \"" <> tableName <> "\" FOR EACH ROW EXECUTE PROCEDURE " <> functionName <> "()" + , "DROP TRIGGER IF EXISTS " <> updateTriggerName <> " ON " <> tableName + , "CREATE TRIGGER " <> updateTriggerName <> " AFTER UPDATE ON \"" <> tableName <> "\" FOR EACH ROW EXECUTE PROCEDURE " <> functionName <> "()" + , "DROP TRIGGER IF EXISTS " <> deleteTriggerName <> " ON " <> tableName + , "CREATE TRIGGER " <> deleteTriggerName <> " AFTER DELETE ON \"" <> tableName <> "\" FOR EACH ROW EXECUTE PROCEDURE " <> functionName <> "()" + , "COMMIT" + ] where - functionName = "ar_notify_did_change_" <> tableName - insertTriggerName = "ar_did_insert_" <> tableName - updateTriggerName = "ar_did_update_" <> tableName - deleteTriggerName = "ar_did_delete_" <> tableName + functionName = "ar_notify_row_change_" <> tableName + insertTriggerName = "ar_did_insert_row_" <> tableName + updateTriggerName = "ar_did_update_row_" <> tableName + deleteTriggerName = "ar_did_delete_row_" <> tableName + +-- | Internal: When the PostgreSQL trigger had to store the full JSON payload in @large_pg_notifications@ +-- (because @pg_notify@ payloads are limited to ~8KB), this loads the full row json so smart filtering +-- receives the full row data for ID extraction. +-- +-- Returns 'Nothing' when the payload cannot be loaded or decoded. In that case auto refresh will force +-- a refresh. +resolveAutoRefreshPayload :: (?modelContext :: ModelContext) => AutoRefreshRowChangePayload -> IO (Maybe AutoRefreshRowChangePayload) +resolveAutoRefreshPayload payload = case payload.payloadLargePayloadId of + Nothing -> pure (Just payload) + Just payloadId -> fetchAutoRefreshPayload payloadId + +fetchAutoRefreshPayload :: (?modelContext :: ModelContext) => UUID.UUID -> IO (Maybe AutoRefreshRowChangePayload) +fetchAutoRefreshPayload payloadId = do + payloadResult <- Exception.try (sqlQueryScalar "SELECT payload FROM public.large_pg_notifications WHERE id = ? LIMIT 1" (PG.Only payloadId) :: IO ByteString) + case payloadResult of + Left (_ :: Exception.SomeException) -> pure Nothing + Right payload -> case Aeson.eitherDecodeStrict' payload of + Left _ -> pure Nothing + Right result -> pure (Just result) autoRefreshStateVaultKey :: Vault.Key AutoRefreshState autoRefreshStateVaultKey = unsafePerformIO Vault.newKey diff --git a/ihp/IHP/AutoRefresh/Types.hs b/ihp/IHP/AutoRefresh/Types.hs index 4597a648e..85a99477b 100644 --- a/ihp/IHP/AutoRefresh/Types.hs +++ b/ihp/IHP/AutoRefresh/Types.hs @@ -5,13 +5,64 @@ Copyright: (c) digitally induced GmbH, 2020 -} module IHP.AutoRefresh.Types where -import IHP.Prelude -import Wai.Request.Params.Middleware (Respond) import Control.Concurrent.MVar (MVar) +import qualified Data.Aeson as Aeson +import qualified Data.Aeson.Types as AesonTypes +import qualified Data.UUID as UUID import qualified IHP.PGListener as PGListener +import IHP.Prelude +import qualified Data.Map.Strict as Map +import Data.Dynamic (Dynamic) import Network.Wai (Request) +import Wai.Request.Params.Middleware (Respond) + +-- | A database operation that can trigger an auto refresh re-render. +data AutoRefreshOperation + = AutoRefreshInsert + | AutoRefreshUpdate + | AutoRefreshDelete + deriving (Eq, Show) + +instance Aeson.FromJSON AutoRefreshOperation where + parseJSON = Aeson.withText "AutoRefreshOperation" \operation -> + case toLower operation of + "insert" -> pure AutoRefreshInsert + "update" -> pure AutoRefreshUpdate + "delete" -> pure AutoRefreshDelete + _ -> fail ("Unknown operation: " <> cs operation) + +-- | Internal: raw payload sent by the PostgreSQL trigger. +-- +-- For oversized payloads the trigger stores the full JSON in @large_pg_notifications@ and sends only a @payloadId@. +-- The auto refresh server resolves these @payloadId@s via a database lookup before building the change notification, +-- so smart filtering receives the full row json in @old@/@new@ (if payload resolution fails, auto refresh falls back +-- to forcing a refresh). +data AutoRefreshRowChangePayload = AutoRefreshRowChangePayload + { payloadOperation :: !AutoRefreshOperation + , payloadOldRow :: !(Maybe Aeson.Value) + , payloadNewRow :: !(Maybe Aeson.Value) + , payloadLargePayloadId :: !(Maybe UUID.UUID) + } deriving (Eq, Show) + +instance Aeson.FromJSON AutoRefreshRowChangePayload where + parseJSON = Aeson.withObject "AutoRefreshRowChangePayload" \object -> + AutoRefreshRowChangePayload + <$> object Aeson..: "op" + <*> object Aeson..:? "old" + <*> object Aeson..:? "new" + <*> do + payloadId <- object Aeson..:? "payloadId" + case payloadId of + Nothing -> pure Nothing + Just value -> Just <$> parseUUID value + where + parseUUID :: Text -> AesonTypes.Parser UUID.UUID + parseUUID value = case UUID.fromText value of + Just uuid -> pure uuid + Nothing -> fail "Invalid UUID for payloadId" data AutoRefreshState = AutoRefreshEnabled { sessionId :: !UUID } + data AutoRefreshSession = AutoRefreshSession { id :: !UUID -- | A callback to rerun an action within the given request and respond @@ -24,6 +75,14 @@ data AutoRefreshSession = AutoRefreshSession , lastResponse :: !LByteString -- | Keep track of the last ping to this session to close it after too much time has passed without anything happening , lastPing :: !UTCTime + -- | Tracked row IDs per table. 'Nothing' for a table means we can't filter (raw SQL / fetchCount). + -- 'Just ids' means only these IDs are relevant. Used by smart auto refresh to skip unrelated notifications. + , trackedIds :: !(Map.Map Text (Set Text)) + -- | Tracked WHERE conditions per table. Each fetch appends its condition + -- (wrapped in 'Dynamic' to avoid a circular module dependency on 'Condition'). + -- 'Nothing' means no condition (unfiltered query) — always refresh on INSERT. + -- Used by smart auto refresh to evaluate INSERT payloads against query filters. + , trackedConditions :: !(Map.Map Text [Maybe Dynamic]) } data AutoRefreshServer = AutoRefreshServer diff --git a/ihp/IHP/ControllerPrelude.hs b/ihp/IHP/ControllerPrelude.hs index 826a65610..1b171a0dc 100644 --- a/ihp/IHP/ControllerPrelude.hs +++ b/ihp/IHP/ControllerPrelude.hs @@ -59,6 +59,7 @@ import IHP.FetchPipelined import IHP.FetchRelated import Data.Aeson hiding (Success) import Network.Wai.Parse (FileInfo(..)) +import qualified Network.Wai import IHP.RouterSupport hiding (get, post) import IHP.Controller.Redirect import Database.PostgreSQL.Simple.Types (Only (..)) @@ -91,5 +92,5 @@ import IHP.HSX.ToHtml () -- -- > setModal MyModalView { .. } -- -setModal :: (?context :: ControllerContext, ?request :: Request, View view) => view -> IO () +setModal :: (?context :: ControllerContext, ?request :: Network.Wai.Request, View view) => view -> IO () setModal view = let ?view = view in Modal.setModal (ViewSupport.html view) diff --git a/ihp/IHP/Fetch.hs b/ihp/IHP/Fetch.hs index 1584d6985..14db4f4d2 100644 --- a/ihp/IHP/Fetch.hs +++ b/ihp/IHP/Fetch.hs @@ -29,12 +29,14 @@ where import IHP.Prelude import IHP.ModelSupport import IHP.QueryBuilder +import IHP.QueryBuilder.Types (SQLQuery(..)) import IHP.Hasql.FromRow (FromRowHasql(..), HasqlDecodeColumn(..)) import IHP.QueryBuilder.HasqlCompiler (buildStatement) import qualified Hasql.Decoders as Decoders import Hasql.Implicits.Encoders (DefaultParamEncoder) import qualified Hasql.Statement as Hasql -import IHP.Fetch.Statement (fetchByIdOneOrNothingStatement, fetchByIdListStatement, buildQueryListStatement, buildQueryMaybeStatement, buildCountStatement, buildExistsStatement) +import IHP.Fetch.Statement (fetchByIdOneOrNothingStatement, fetchByIdListStatement, buildCountStatement, buildExistsStatement) +import Data.Dynamic (toDyn) class Fetchable fetchable model | fetchable -> model where type FetchResult fetchable model @@ -43,7 +45,7 @@ class Fetchable fetchable model | fetchable -> model where fetchOne :: (Table model, FromRowHasql model, ?modelContext :: ModelContext) => fetchable -> IO model -- The instance declaration had to be split up because a type variable ranging over HasQueryBuilder instances is not allowed in the declaration of the associated type. The common*-functions reduce the redundancy to the necessary minimum. -instance (model ~ GetModelByTableName table, KnownSymbol table) => Fetchable (QueryBuilder table) model where +instance (model ~ GetModelByTableName table, KnownSymbol table, HasField "id" model (Id' table), Show (PrimaryKey table)) => Fetchable (QueryBuilder table) model where type instance FetchResult (QueryBuilder table) model = [model] {-# INLINE fetch #-} fetch :: (Table model, FromRowHasql model, ?modelContext :: ModelContext) => QueryBuilder table -> IO [model] @@ -57,7 +59,7 @@ instance (model ~ GetModelByTableName table, KnownSymbol table) => Fetchable (Qu fetchOne :: (?modelContext :: ModelContext) => (Table model, FromRowHasql model) => QueryBuilder table -> IO model fetchOne = commonFetchOne -instance (model ~ GetModelByTableName table, KnownSymbol table) => Fetchable (JoinQueryBuilderWrapper r table) model where +instance (model ~ GetModelByTableName table, KnownSymbol table, HasField "id" model (Id' table), Show (PrimaryKey table)) => Fetchable (JoinQueryBuilderWrapper r table) model where type instance FetchResult (JoinQueryBuilderWrapper r table) model = [model] {-# INLINE fetch #-} fetch :: (Table model, FromRowHasql model, ?modelContext :: ModelContext) => JoinQueryBuilderWrapper r table -> IO [model] @@ -71,7 +73,7 @@ instance (model ~ GetModelByTableName table, KnownSymbol table) => Fetchable (Jo fetchOne :: (?modelContext :: ModelContext) => (Table model, FromRowHasql model) => JoinQueryBuilderWrapper r table -> IO model fetchOne = commonFetchOne -instance (model ~ GetModelByTableName table, KnownSymbol table) => Fetchable (NoJoinQueryBuilderWrapper table) model where +instance (model ~ GetModelByTableName table, KnownSymbol table, HasField "id" model (Id' table), Show (PrimaryKey table)) => Fetchable (NoJoinQueryBuilderWrapper table) model where type instance FetchResult (NoJoinQueryBuilderWrapper table) model = [model] {-# INLINE fetch #-} fetch :: (Table model, FromRowHasql model, ?modelContext :: ModelContext) => NoJoinQueryBuilderWrapper table -> IO [model] @@ -85,15 +87,18 @@ instance (model ~ GetModelByTableName table, KnownSymbol table) => Fetchable (No fetchOne :: (?modelContext :: ModelContext) => (Table model, FromRowHasql model) => NoJoinQueryBuilderWrapper table -> IO model fetchOne = commonFetchOne -instance (model ~ GetModelByTableName table, KnownSymbol table, HasqlDecodeColumn value, KnownSymbol foreignTable, foreignModel ~ GetModelByTableName foreignTable, KnownSymbol columnName, HasField columnName foreignModel value, HasQueryBuilder (LabeledQueryBuilderWrapper foreignTable columnName value) NoJoins) => Fetchable (LabeledQueryBuilderWrapper foreignTable columnName value table) model where +instance (model ~ GetModelByTableName table, KnownSymbol table, HasqlDecodeColumn value, KnownSymbol foreignTable, foreignModel ~ GetModelByTableName foreignTable, KnownSymbol columnName, HasField columnName foreignModel value, HasQueryBuilder (LabeledQueryBuilderWrapper foreignTable columnName value) NoJoins, HasField "id" model (Id' table), Show (PrimaryKey table)) => Fetchable (LabeledQueryBuilderWrapper foreignTable columnName value table) model where type instance FetchResult (LabeledQueryBuilderWrapper foreignTable columnName value table) model = [LabeledData value model] {-# INLINE fetch #-} fetch :: (Table model, FromRowHasql model, ?modelContext :: ModelContext) => LabeledQueryBuilderWrapper foreignTable columnName value table -> IO [LabeledData value model] fetch !queryBuilderProvider = do - trackTableRead (tableName @model) let pool = ?modelContext.hasqlPool - let statement = buildStatement (buildQuery queryBuilderProvider) (Decoders.rowList (hasqlRowDecoder @(LabeledData value model))) - sqlStatementHasql pool () statement + let !sqlQuery' = buildQuery queryBuilderProvider + let statement = buildStatement sqlQuery' (Decoders.rowList (hasqlRowDecoder @(LabeledData value model))) + results <- sqlStatementHasql pool () statement + trackTableReadWithIds (tableName @model) (map (\m -> tshow (get #id m.contentValue)) results) + trackTableCondition (tableName @model) (toDyn <$> sqlQuery'.whereCondition) + pure results {-# INLINE fetchOneOrNothing #-} fetchOneOrNothing :: (?modelContext :: ModelContext) => (Table model, FromRowHasql model) => LabeledQueryBuilderWrapper foreignTable columnName value table -> IO (Maybe model) @@ -106,21 +111,29 @@ instance (model ~ GetModelByTableName table, KnownSymbol table, HasqlDecodeColum {-# INLINE commonFetch #-} -commonFetch :: forall model table queryBuilderProvider joinRegister. (Table model, HasQueryBuilder queryBuilderProvider joinRegister, model ~ GetModelByTableName table, KnownSymbol table, FromRowHasql model, ?modelContext :: ModelContext) => queryBuilderProvider table -> IO [model] +commonFetch :: forall model table queryBuilderProvider joinRegister. (Table model, HasQueryBuilder queryBuilderProvider joinRegister, model ~ GetModelByTableName table, KnownSymbol table, FromRowHasql model, ?modelContext :: ModelContext, HasField "id" model (Id' table), Show (PrimaryKey table)) => queryBuilderProvider table -> IO [model] commonFetch !queryBuilder = do - trackTableRead (tableName @model) + let !sqlQuery' = buildQuery queryBuilder let pool = ?modelContext.hasqlPool - sqlStatementHasql pool () (buildQueryListStatement queryBuilder) + results <- sqlStatementHasql pool () (buildStatement sqlQuery' (Decoders.rowList (hasqlRowDecoder @model))) + trackTableReadWithIds (tableName @model) (map (\m -> tshow (get #id m)) results) + trackTableCondition (tableName @model) (toDyn <$> sqlQuery'.whereCondition) + pure results {-# INLINE commonFetchOneOrNothing #-} -commonFetchOneOrNothing :: forall model table queryBuilderProvider joinRegister. (?modelContext :: ModelContext) => (Table model, KnownSymbol table, HasQueryBuilder queryBuilderProvider joinRegister, FromRowHasql model, model ~ GetModelByTableName table) => queryBuilderProvider table -> IO (Maybe model) +commonFetchOneOrNothing :: forall model table queryBuilderProvider joinRegister. (?modelContext :: ModelContext) => (Table model, KnownSymbol table, HasQueryBuilder queryBuilderProvider joinRegister, FromRowHasql model, HasField "id" model (Id' table), Show (PrimaryKey table)) => queryBuilderProvider table -> IO (Maybe model) commonFetchOneOrNothing !queryBuilder = do - trackTableRead (tableName @model) + let !sqlQuery' = (buildQuery queryBuilder) { limitClause = Just 1 } let pool = ?modelContext.hasqlPool - sqlStatementHasql pool () (buildQueryMaybeStatement queryBuilder) + result <- sqlStatementHasql pool () (buildStatement sqlQuery' (Decoders.rowMaybe (hasqlRowDecoder @model))) + case result of + Just m -> trackTableReadWithIds (tableName @model) [tshow (get #id m)] + Nothing -> trackTableReadWithIds (tableName @model) [] + trackTableCondition (tableName @model) (toDyn <$> sqlQuery'.whereCondition) + pure result {-# INLINE commonFetchOne #-} -commonFetchOne :: forall model table queryBuilderProvider joinRegister. (?modelContext :: ModelContext) => (Table model, KnownSymbol table, Fetchable (queryBuilderProvider table) model, HasQueryBuilder queryBuilderProvider joinRegister, FromRowHasql model) => queryBuilderProvider table -> IO model +commonFetchOne :: forall model table queryBuilderProvider joinRegister. (?modelContext :: ModelContext) => (Table model, KnownSymbol table, Fetchable (queryBuilderProvider table) model, HasQueryBuilder queryBuilderProvider joinRegister, FromRowHasql model, HasField "id" model (Id' table), Show (PrimaryKey table)) => queryBuilderProvider table -> IO model commonFetchOne !queryBuilder = do maybeModel <- fetchOneOrNothing queryBuilder case maybeModel of @@ -166,36 +179,41 @@ fetchExists !queryBuilder = do {-# INLINE fetchExists #-} {-# INLINE genericFetchId #-} -genericFetchId :: forall table model. (Table model, KnownSymbol table, FromRowHasql model, ?modelContext :: ModelContext, model ~ GetModelByTableName table, GetTableName model ~ table, DefaultParamEncoder (Id' table)) => Id' table -> IO [model] +genericFetchId :: forall table model. (Table model, KnownSymbol table, FromRowHasql model, ?modelContext :: ModelContext, model ~ GetModelByTableName table, GetTableName model ~ table, DefaultParamEncoder (Id' table), HasField "id" model (Id' table), Show (PrimaryKey table)) => Id' table -> IO [model] genericFetchId !id = do - trackTableRead (tableName @model) - sqlStatementHasql ?modelContext.hasqlPool id fetchByIdListStatement + results <- sqlStatementHasql ?modelContext.hasqlPool id fetchByIdListStatement + trackTableReadWithIds (tableName @model) (map (\m -> tshow (get #id m)) results) + pure results {-# INLINE genericfetchIdOneOrNothing #-} -genericfetchIdOneOrNothing :: forall table model. (Table model, KnownSymbol table, FromRowHasql model, ?modelContext :: ModelContext, model ~ GetModelByTableName table, GetTableName model ~ table, DefaultParamEncoder (Id' table)) => Id' table -> IO (Maybe model) +genericfetchIdOneOrNothing :: forall table model. (Table model, KnownSymbol table, FromRowHasql model, ?modelContext :: ModelContext, model ~ GetModelByTableName table, GetTableName model ~ table, DefaultParamEncoder (Id' table), HasField "id" model (Id' table), Show (PrimaryKey table)) => Id' table -> IO (Maybe model) genericfetchIdOneOrNothing !id = do - trackTableRead (tableName @model) - sqlStatementHasql ?modelContext.hasqlPool id fetchByIdOneOrNothingStatement + result <- sqlStatementHasql ?modelContext.hasqlPool id fetchByIdOneOrNothingStatement + case result of + Just m -> trackTableReadWithIds (tableName @model) [tshow (get #id m)] + Nothing -> trackTableReadWithIds (tableName @model) [] + pure result {-# INLINE genericFetchIdOne #-} -genericFetchIdOne :: forall table model. (Table model, KnownSymbol table, FromRowHasql model, ?modelContext :: ModelContext, model ~ GetModelByTableName table, GetTableName model ~ table, DefaultParamEncoder (Id' table)) => Id' table -> IO model +genericFetchIdOne :: forall table model. (Table model, KnownSymbol table, FromRowHasql model, ?modelContext :: ModelContext, model ~ GetModelByTableName table, GetTableName model ~ table, DefaultParamEncoder (Id' table), HasField "id" model (Id' table), Show (PrimaryKey table)) => Id' table -> IO model genericFetchIdOne !id = do - trackTableRead (tableName @model) result <- sqlStatementHasql ?modelContext.hasqlPool id fetchByIdOneOrNothingStatement case result of - Just model -> pure model + Just model -> do + trackTableReadWithIds (tableName @model) [tshow (get #id model)] + pure model Nothing -> throwIO RecordNotFoundException { queryAndParams = cs (Hasql.toSql (fetchByIdOneOrNothingStatement @table @model)) } {-# INLINE genericFetchIds #-} -genericFetchIds :: forall table model. (Table model, KnownSymbol table, FromRowHasql model, ?modelContext :: ModelContext, model ~ GetModelByTableName table, GetTableName model ~ table, DefaultParamEncoder [PrimaryKey (GetTableName model)]) => [Id model] -> IO [model] +genericFetchIds :: forall table model. (Table model, KnownSymbol table, FromRowHasql model, ?modelContext :: ModelContext, model ~ GetModelByTableName table, GetTableName model ~ table, DefaultParamEncoder [PrimaryKey (GetTableName model)], HasField "id" model (Id' table), Show (PrimaryKey table)) => [Id model] -> IO [model] genericFetchIds !ids = query @model |> filterWhereIdIn ids |> fetch {-# INLINE genericfetchIdsOneOrNothing #-} -genericfetchIdsOneOrNothing :: forall table model. (Table model, KnownSymbol table, FromRowHasql model, ?modelContext :: ModelContext, model ~ GetModelByTableName table, GetTableName model ~ table, DefaultParamEncoder [PrimaryKey (GetTableName model)]) => [Id model] -> IO (Maybe model) +genericfetchIdsOneOrNothing :: forall table model. (Table model, KnownSymbol table, FromRowHasql model, ?modelContext :: ModelContext, model ~ GetModelByTableName table, GetTableName model ~ table, DefaultParamEncoder [PrimaryKey (GetTableName model)], HasField "id" model (Id' table), Show (PrimaryKey table)) => [Id model] -> IO (Maybe model) genericfetchIdsOneOrNothing !ids = query @model |> filterWhereIdIn ids |> fetchOneOrNothing {-# INLINE genericFetchIdsOne #-} -genericFetchIdsOne :: forall table model. (Table model, KnownSymbol table, FromRowHasql model, ?modelContext :: ModelContext, model ~ GetModelByTableName table, GetTableName model ~ table, DefaultParamEncoder [PrimaryKey (GetTableName model)]) => [Id model] -> IO model +genericFetchIdsOne :: forall table model. (Table model, KnownSymbol table, FromRowHasql model, ?modelContext :: ModelContext, model ~ GetModelByTableName table, GetTableName model ~ table, DefaultParamEncoder [PrimaryKey (GetTableName model)], HasField "id" model (Id' table), Show (PrimaryKey table)) => [Id model] -> IO model genericFetchIdsOne !ids = query @model |> filterWhereIdIn ids |> fetchOne {-# INLINE findBy #-} @@ -209,7 +227,7 @@ findMaybeBy !field !value !queryBuilder = queryBuilder |> filterWhere (field, va findManyBy !field !value !queryBuilder = queryBuilder |> filterWhere (field, value) |> fetch -- Step.findOneByWorkflowId id == queryBuilder |> findBy #templateId id -instance (model ~ GetModelById (Id' table), GetTableName model ~ table, FilterPrimaryKey table, DefaultParamEncoder (Id' table)) => Fetchable (Id' table) model where +instance (model ~ GetModelById (Id' table), GetTableName model ~ table, FilterPrimaryKey table, DefaultParamEncoder (Id' table), HasField "id" model (Id' table), Show (PrimaryKey table)) => Fetchable (Id' table) model where type FetchResult (Id' table) model = model {-# INLINE fetch #-} fetch = genericFetchIdOne @@ -218,7 +236,7 @@ instance (model ~ GetModelById (Id' table), GetTableName model ~ table, FilterPr {-# INLINE fetchOne #-} fetchOne = genericFetchIdOne -instance (model ~ GetModelById (Id' table), GetTableName model ~ table, FilterPrimaryKey table, DefaultParamEncoder (Id' table)) => Fetchable (Maybe (Id' table)) model where +instance (model ~ GetModelById (Id' table), GetTableName model ~ table, FilterPrimaryKey table, DefaultParamEncoder (Id' table), HasField "id" model (Id' table), Show (PrimaryKey table)) => Fetchable (Maybe (Id' table)) model where type FetchResult (Maybe (Id' table)) model = [model] {-# INLINE fetch #-} fetch (Just a) = genericFetchId a @@ -230,7 +248,7 @@ instance (model ~ GetModelById (Id' table), GetTableName model ~ table, FilterPr fetchOne (Just a) = genericFetchIdOne a fetchOne Nothing = error "Fetchable (Maybe Id): Failed to fetch because given id is 'Nothing', 'Just id' was expected" -instance (model ~ GetModelById (Id' table), GetModelByTableName table ~ model, GetTableName model ~ table, DefaultParamEncoder [PrimaryKey table]) => Fetchable [Id' table] model where +instance (model ~ GetModelById (Id' table), GetModelByTableName table ~ model, GetTableName model ~ table, DefaultParamEncoder [PrimaryKey table], HasField "id" model (Id' table), Show (PrimaryKey table)) => Fetchable [Id' table] model where type FetchResult [Id' table] model = [model] {-# INLINE fetch #-} fetch = genericFetchIds @@ -264,6 +282,8 @@ fetchLatest :: forall table queryBuilderProvider joinRegister model. , Fetchable (queryBuilderProvider table) model , Table model , FromRowHasql model + , HasField "id" model (Id' table) + , Show (PrimaryKey table) ) => queryBuilderProvider table -> IO (Maybe model) fetchLatest queryBuilder = queryBuilder |> fetchLatestBy #createdAt @@ -295,6 +315,8 @@ fetchLatestBy :: forall table createdAt queryBuilderProvider joinRegister model. , Fetchable (queryBuilderProvider table) model , Table model , FromRowHasql model + , HasField "id" model (Id' table) + , Show (PrimaryKey table) ) => Proxy createdAt -> queryBuilderProvider table -> IO (Maybe model) fetchLatestBy field queryBuilder = queryBuilder diff --git a/ihp/IHP/FetchRelated.hs b/ihp/IHP/FetchRelated.hs index 5a38fbcf5..f68b2ad94 100644 --- a/ihp/IHP/FetchRelated.hs +++ b/ihp/IHP/FetchRelated.hs @@ -119,6 +119,7 @@ instance ( -- > SELECT * FROM companies WHERE id IN (?) instance ( Eq (PrimaryKey tableName) + , Show (PrimaryKey tableName) , HasField "id" relatedModel (Id' tableName) , relatedModel ~ GetModelByTableName (GetTableName relatedModel) , GetTableName relatedModel ~ tableName diff --git a/ihp/IHP/LoginSupport/Middleware.hs b/ihp/IHP/LoginSupport/Middleware.hs index ee05aaff8..cef9ec936 100644 --- a/ihp/IHP/LoginSupport/Middleware.hs +++ b/ihp/IHP/LoginSupport/Middleware.hs @@ -25,6 +25,8 @@ initAuthentication :: forall user normalizedModel. , GetTableName normalizedModel ~ GetTableName user , FilterPrimaryKey (GetTableName normalizedModel) , KnownSymbol (GetModelName user) + , HasField "id" normalizedModel (Id' (GetTableName user)) + , Show (PrimaryKey (GetTableName normalizedModel)) ) => IO () initAuthentication = do user <- getSession @(Id user) (sessionKey @user) diff --git a/ihp/IHP/ModelSupport.hs b/ihp/IHP/ModelSupport.hs index aeecda524..62a606ea6 100644 --- a/ihp/IHP/ModelSupport.hs +++ b/ihp/IHP/ModelSupport.hs @@ -42,6 +42,7 @@ import Data.Data import Data.Aeson (ToJSON (..), FromJSON (..)) import qualified Data.Aeson as Aeson import qualified Data.Set as Set +import qualified Data.Map.Strict as Map import qualified Text.Read as Read import qualified Hasql.Pool as HasqlPool import qualified Hasql.Pool.Config as HasqlPoolConfig @@ -81,6 +82,7 @@ notConnectedModelContext logger = ModelContext , transactionRunner = Nothing , logger = logger , trackTableReadCallback = Nothing + , trackTableConditionCallback = Nothing , rowLevelSecurity = Nothing } @@ -100,6 +102,7 @@ createModelContext databaseUrl logger = do hasqlPool <- HasqlPool.acquire hasqlPoolConfig let trackTableReadCallback = Nothing + let trackTableConditionCallback = Nothing let transactionRunner = Nothing let rowLevelSecurity = Nothing pure ModelContext { .. } @@ -1030,14 +1033,38 @@ instance Default Aeson.Value where -- trackTableRead :: (?modelContext :: ModelContext) => Text -> IO () trackTableRead tableName = case ?modelContext.trackTableReadCallback of - Just callback -> callback tableName + Just callback -> callback tableName [] Nothing -> pure () {-# INLINABLE trackTableRead #-} +-- | Like 'trackTableRead' but also records the IDs of fetched rows. +-- +-- This is used internally by 'IHP.Fetch.commonFetch' so that AutoRefresh can skip +-- notifications for rows not in the current view. +trackTableReadWithIds :: (?modelContext :: ModelContext) => Text -> [Text] -> IO () +trackTableReadWithIds tableName ids = case ?modelContext.trackTableReadCallback of + Just callback -> callback tableName ids + Nothing -> pure () +{-# INLINABLE trackTableReadWithIds #-} + +-- | Records the WHERE condition for a table read (as a 'Dynamic'-wrapped 'Condition'). +-- +-- Used internally by 'IHP.Fetch.commonFetch' so that AutoRefresh can evaluate INSERT +-- payloads against query filters without re-executing the query. +trackTableCondition :: (?modelContext :: ModelContext) => Text -> Maybe Dynamic -> IO () +trackTableCondition tableName condition = case ?modelContext.trackTableConditionCallback of + Just callback -> callback tableName condition + Nothing -> pure () +{-# INLINABLE trackTableCondition #-} + -- | Track all tables in SELECT queries executed within the given IO action. -- -- You can read the touched tables by this function by accessing the variable @?touchedTables@ inside your given IO action. -- +-- Also tracks fetched row IDs per table via @?trackedIds@. When a table is read with IDs, the IDs are accumulated. +-- When a table is read without IDs (e.g. raw SQL, fetchCount), the ID set for that table is removed +-- to indicate that filtering is not possible. +-- -- __Example:__ -- -- > withTableReadTracker do @@ -1047,13 +1074,23 @@ trackTableRead tableName = case ?modelContext.trackTableReadCallback of -- > tables <- readIORef ?touchedTables -- > -- tables = Set.fromList ["projects", "users"] -- > -withTableReadTracker :: (?modelContext :: ModelContext) => ((?modelContext :: ModelContext, ?touchedTables :: IORef (Set.Set Text)) => IO ()) -> IO () +withTableReadTracker :: (?modelContext :: ModelContext) => ((?modelContext :: ModelContext, ?touchedTables :: IORef (Set.Set Text), ?trackedIds :: IORef (Map.Map Text (Set.Set Text)), ?trackedConditions :: IORef (Map.Map Text [Maybe Dynamic])) => IO ()) -> IO () withTableReadTracker trackedSection = do touchedTablesVar <- newIORef Set.empty - let trackTableReadCallback = Just \tableName -> modifyIORef' touchedTablesVar (Set.insert tableName) + trackedIdsVar <- newIORef Map.empty + trackedConditionsVar <- newIORef Map.empty + let trackTableReadCallback = Just \tableName ids -> do + modifyIORef' touchedTablesVar (Set.insert tableName) + case ids of + [] -> modifyIORef' trackedIdsVar (Map.delete tableName) + _ -> modifyIORef' trackedIdsVar (Map.insertWith Set.union tableName (Set.fromList ids)) + let trackTableConditionCallback = Just \tableName condition -> + modifyIORef' trackedConditionsVar (Map.insertWith (<>) tableName [condition]) let oldModelContext = ?modelContext - let ?modelContext = oldModelContext { trackTableReadCallback } + let ?modelContext = oldModelContext { trackTableReadCallback, trackTableConditionCallback } let ?touchedTables = touchedTablesVar + let ?trackedIds = trackedIdsVar + let ?trackedConditions = trackedConditionsVar trackedSection diff --git a/ihp/IHP/ModelSupport/Types.hs b/ihp/IHP/ModelSupport/Types.hs index 69e0ffe50..50428e7cf 100644 --- a/ihp/IHP/ModelSupport/Types.hs +++ b/ihp/IHP/ModelSupport/Types.hs @@ -83,8 +83,12 @@ data ModelContext = ModelContext , transactionRunner :: Maybe TransactionRunner -- ^ When set, queries are sent through this runner instead of 'HasqlPool.use' directly -- | Logs all queries to this logger at log level info , logger :: Logger - -- | A callback that is called whenever a specific table is accessed using a SELECT query - , trackTableReadCallback :: Maybe (Text -> IO ()) + -- | A callback that is called whenever a specific table is accessed using a SELECT query. + -- The second argument is a list of fetched row IDs (empty list means IDs are unknown, e.g. raw SQL / fetchCount). + , trackTableReadCallback :: Maybe (Text -> [Text] -> IO ()) + -- | A callback that records the WHERE condition (as a 'Dynamic'-wrapped 'Condition') for a table read. + -- Used by AutoRefresh to evaluate INSERT payloads against query filters. + , trackTableConditionCallback :: Maybe (Text -> Maybe Dynamic -> IO ()) -- | Is set to a value if row level security was enabled at runtime , rowLevelSecurity :: Maybe RowLevelSecurityContext } diff --git a/ihp/IHP/QueryBuilder/Types.hs b/ihp/IHP/QueryBuilder/Types.hs index b3e85ab04..9347ae6fa 100644 --- a/ihp/IHP/QueryBuilder/Types.hs +++ b/ihp/IHP/QueryBuilder/Types.hs @@ -1,4 +1,4 @@ -{-# LANGUAGE BangPatterns, TypeFamilies, DataKinds, PolyKinds, TypeApplications, ScopedTypeVariables, ConstraintKinds, TypeOperators, GADTs, UndecidableInstances, StandaloneDeriving, FunctionalDependencies, FlexibleContexts, InstanceSigs, AllowAmbiguousTypes, DeriveAnyClass #-} +{-# LANGUAGE BangPatterns, TypeFamilies, DataKinds, PolyKinds, TypeApplications, ScopedTypeVariables, ConstraintKinds, TypeOperators, GADTs, UndecidableInstances, StandaloneDeriving, FunctionalDependencies, FlexibleContexts, InstanceSigs, AllowAmbiguousTypes, DeriveAnyClass, MagicHash #-} {-| Module: IHP.QueryBuilder.Types Description: Core data types for the QueryBuilder @@ -33,6 +33,8 @@ module IHP.QueryBuilder.Types , DefaultScope (..) , EqOrIsOperator (..) , FilterPrimaryKey (..) + -- * Param value extraction +, getParamPrinterText ) where import IHP.Prelude @@ -42,6 +44,9 @@ import qualified Control.DeepSeq as DeepSeq import qualified GHC.Generics import qualified Hasql.Encoders as Encoders import qualified Prelude +import Unsafe.Coerce (unsafeCoerce) +import qualified Data.DList as DList +import GHC.Exts (Any, Int#) import qualified Data.Text as Text -- | Represents whether string matching should be case-sensitive or not @@ -165,6 +170,26 @@ instance Eq ConditionValue where Literal a == Literal b = a == b _ == _ = False -- Params cannot be compared for equality +-- | Mirror of hasql's internal @Params@ record (5 fields in hasql 1.10.x). +-- Only the 5th field (printer) is accessed. +-- +-- Fields use @Int#@ and @Any@ to match compiled layouts where +-- GHCi would otherwise ignore @{-# UNPACK #-}@ pragmas. +data ParamsMirror a = ParamsMirror + Int# -- size (unboxed, matching compiled hasql) + Any -- unknownTypes + Any -- columnsMetadata + Any -- serializer + (a -> DList.DList Text) -- printer + +-- | Extract the text representation of parameter values from an 'Encoders.Params'. +-- +-- Uses 'unsafeCoerce' to access the internal @printer@ field of the @Params@ +-- record to get the human-readable text form of encoded values. +getParamPrinterText :: Encoders.Params () -> [Text] +getParamPrinterText p = DList.toList (printer ()) + where ParamsMirror _ _ _ _ printer = unsafeCoerce p + -- | Represents a WHERE condition data Condition = ColumnCondition !Text !FilterOperator !ConditionValue !(Maybe Text) !(Maybe Text) diff --git a/ihp/IHP/ValidationSupport/ValidateIsUnique.hs b/ihp/IHP/ValidationSupport/ValidateIsUnique.hs index b71b8723f..6d3bbef0f 100644 --- a/ihp/IHP/ValidationSupport/ValidateIsUnique.hs +++ b/ihp/IHP/ValidationSupport/ValidateIsUnique.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE AllowAmbiguousTypes #-} module IHP.ValidationSupport.ValidateIsUnique ( validateIsUnique , validateIsUniqueCaseInsensitive @@ -29,7 +30,7 @@ import Hasql.Implicits.Encoders (DefaultParamEncoder) -- > Right user -> do -- > createRecord user -- > redirectTo UsersAction -validateIsUnique :: forall field model savedModel fieldValue modelId savedModelId. ( +validateIsUnique :: forall field model savedModel fieldValue. ( savedModel ~ NormalizeModel model , ?modelContext :: ModelContext , FromRowHasql savedModel @@ -41,12 +42,12 @@ validateIsUnique :: forall field model savedModel fieldValue modelId savedModelI , EqOrIsOperator fieldValue , HasField "meta" model MetaBag , SetField "meta" model MetaBag - , HasField "id" savedModel savedModelId - , HasField "id" model modelId - , savedModelId ~ modelId - , Eq modelId + , HasField "id" savedModel (Id' (GetTableName savedModel)) + , HasField "id" model (Id' (GetTableName savedModel)) + , Eq (Id' (GetTableName savedModel)) , GetModelByTableName (GetTableName savedModel) ~ savedModel , Table savedModel + , Show (PrimaryKey (GetTableName savedModel)) ) => Proxy field -> model -> IO model validateIsUnique fieldProxy model = validateIsUniqueCaseAware fieldProxy model True {-# INLINE validateIsUnique #-} @@ -70,7 +71,7 @@ validateIsUnique fieldProxy model = validateIsUniqueCaseAware fieldProxy model T -- > Right user -> do -- > createRecord user -- > redirectTo UsersAction -validateIsUniqueCaseInsensitive :: forall field model savedModel fieldValue modelId savedModelId. ( +validateIsUniqueCaseInsensitive :: forall field model savedModel fieldValue. ( savedModel ~ NormalizeModel model , ?modelContext :: ModelContext , FromRowHasql savedModel @@ -82,18 +83,18 @@ validateIsUniqueCaseInsensitive :: forall field model savedModel fieldValue mode , EqOrIsOperator fieldValue , HasField "meta" model MetaBag , SetField "meta" model MetaBag - , HasField "id" savedModel savedModelId - , HasField "id" model modelId - , savedModelId ~ modelId - , Eq modelId + , HasField "id" savedModel (Id' (GetTableName savedModel)) + , HasField "id" model (Id' (GetTableName savedModel)) + , Eq (Id' (GetTableName savedModel)) , GetModelByTableName (GetTableName savedModel) ~ savedModel , Table savedModel + , Show (PrimaryKey (GetTableName savedModel)) ) => Proxy field -> model -> IO model validateIsUniqueCaseInsensitive fieldProxy model = validateIsUniqueCaseAware fieldProxy model False {-# INLINE validateIsUniqueCaseInsensitive #-} -- | Internal helper for 'validateIsUnique' and 'validateIsUniqueCaseInsensitive' -validateIsUniqueCaseAware :: forall field model savedModel fieldValue modelId savedModelId. ( +validateIsUniqueCaseAware :: forall field model savedModel fieldValue. ( savedModel ~ NormalizeModel model , ?modelContext :: ModelContext , FromRowHasql savedModel @@ -105,12 +106,12 @@ validateIsUniqueCaseAware :: forall field model savedModel fieldValue modelId sa , EqOrIsOperator fieldValue , HasField "meta" model MetaBag , SetField "meta" model MetaBag - , HasField "id" savedModel savedModelId - , HasField "id" model modelId - , savedModelId ~ modelId - , Eq modelId + , HasField "id" savedModel (Id' (GetTableName savedModel)) + , HasField "id" model (Id' (GetTableName savedModel)) + , Eq (Id' (GetTableName savedModel)) , GetModelByTableName (GetTableName savedModel) ~ savedModel , Table savedModel + , Show (PrimaryKey (GetTableName savedModel)) ) => Proxy field -> model -> Bool -> IO model validateIsUniqueCaseAware fieldProxy model caseSensitive = do let value = getField @field model @@ -139,7 +140,7 @@ validateIsUniqueCaseAware fieldProxy model caseSensitive = do -- > Right user -> do -- > createRecord user -- > redirectTo UsersAction -withCustomErrorMessageIO :: forall field model savedModel fieldValue modelId savedModelId. ( +withCustomErrorMessageIO :: forall field model savedModel fieldValue. ( savedModel ~ NormalizeModel model , ?modelContext :: ModelContext , KnownSymbol field @@ -150,10 +151,9 @@ withCustomErrorMessageIO :: forall field model savedModel fieldValue modelId sav , EqOrIsOperator fieldValue , HasField "meta" model MetaBag , SetField "meta" model MetaBag - , HasField "id" savedModel savedModelId - , HasField "id" model modelId - , savedModelId ~ modelId - , Eq modelId + , HasField "id" savedModel (Id' (GetTableName savedModel)) + , HasField "id" model (Id' (GetTableName savedModel)) + , Eq (Id' (GetTableName savedModel)) , GetModelByTableName (GetTableName savedModel) ~ savedModel ) => Text -> (Proxy field -> model -> IO model) -> Proxy field -> model -> IO model withCustomErrorMessageIO message validator field model = do diff --git a/ihp/Test/Test/AutoRefreshSpec.hs b/ihp/Test/Test/AutoRefreshSpec.hs index 150a6b9cd..5d749e247 100644 --- a/ihp/Test/Test/AutoRefreshSpec.hs +++ b/ihp/Test/Test/AutoRefreshSpec.hs @@ -11,11 +11,19 @@ import IHP.Prelude import IHP.Environment import IHP.FrameworkConfig import IHP.ControllerPrelude hiding (get, request) +import IHP.AutoRefresh.View import Network.Wai import Network.Wai.Internal (ResponseReceived(..)) import Network.HTTP.Types -import IHP.AutoRefresh (globalAutoRefreshServerVar) +import IHP.AutoRefresh (globalAutoRefreshServerVar, autoRefreshStateVaultKey, matchesInsertPayload, shouldRefreshForPayload) import IHP.AutoRefresh.Types +import qualified Hasql.Encoders as Encoders +import Data.Functor.Contravariant (contramap) +import qualified Data.Aeson as Aeson +import qualified Data.Aeson.KeyMap as AesonKeyMap +import qualified Data.Aeson.Key as AesonKey +import Data.Dynamic (toDyn) +import qualified Data.Set as Set import qualified Control.Concurrent.MVar as MVar import IHP.Controller.Response (ResponseException(..)) import qualified Control.Exception as Exception @@ -24,6 +32,9 @@ import IHP.Log.Types (Logger(..), LogLevel(..)) import IHP.Server (initMiddlewareStack) import IHP.Test.Mocking import qualified Network.Wai as Wai +import qualified Data.Vault.Lazy as Vault +import qualified Text.Blaze.Html.Renderer.Text as BlazeHtml +import qualified Data.UUID as UUID data WebApplication = WebApplication deriving (Eq, Show, Data) @@ -67,29 +78,20 @@ callActionWithQueryParams -> IO Response callActionWithQueryParams pgListener controller queryParams = do let MockContext { frameworkConfig, modelContext } = ?mocking - - -- Build request with query params (GET-style, not POST body) let baseRequest = ?request { Wai.queryString = map (\(k,v) -> (k, Just v)) queryParams , Wai.rawQueryString = renderSimpleQuery True queryParams } - - -- Capture the response responseRef <- newIORef Nothing let captureRespond response = do writeIORef responseRef (Just response) pure ResponseReceived - - -- Create the controller app let controllerApp req respond = do let ?request = req let ?respond = respond runActionWithNewContext controller - - -- Run through middleware stack with PGListener enabled middlewareStack <- initMiddlewareStack frameworkConfig modelContext (Just pgListener) _ <- middlewareStack controllerApp baseRequest captureRespond - readIORef responseRef >>= \case Just response -> pure response Nothing -> error "callActionWithQueryParams: No response was returned by the controller" @@ -103,58 +105,186 @@ testLogger = Logger , cleanup = pure () } +renderMeta :: (?context :: ControllerContext) => Text +renderMeta = cs (BlazeHtml.renderHtml autoRefreshMeta) + +withFreshContextWithRequest :: Request -> (ControllerContext -> IO a) -> IO a +withFreshContextWithRequest request block = do + let ?request = request + context <- newControllerContext + block context + +withFreshContext :: (ControllerContext -> IO a) -> IO a +withFreshContext = withFreshContextWithRequest Wai.defaultRequest + tests :: Spec -tests = beforeAll (mockContextNoDatabase WebApplication config) do - describe "AutoRefresh" do - describe "renderView" do - it "should preserve query parameters when re-rendering with a websocket request" $ withContext do - -- Clean up any leftover global state from previous tests - MVar.modifyMVar_ globalAutoRefreshServerVar (\_ -> pure Nothing) - - PGListener.withPGListener "" testLogger \pgListener -> do - -- 1. Call the action with query params — this triggers autoRefresh - -- which stores a session with renderView - response <- callActionWithQueryParams pgListener ShowItemAction [("marketId", "abc-123")] +tests = do + beforeAll (mockContextNoDatabase WebApplication config) do + describe "AutoRefresh" do + describe "renderView" do + it "should preserve query parameters when re-rendering with a websocket request" $ withContext do + -- Clean up any leftover global state from previous tests + MVar.modifyMVar_ globalAutoRefreshServerVar (\_ -> pure Nothing) + + PGListener.withPGListener "" testLogger \pgListener -> do + -- 1. Call the action with query params — this triggers autoRefresh + -- which stores a session with renderView + response <- callActionWithQueryParams pgListener ShowItemAction [("marketId", "abc-123")] + body <- responseBody response + cs body `shouldBe` ("abc-123" :: Text) + + -- 2. Extract the stored renderView from the AutoRefreshSession + maybeServerRef <- MVar.readMVar globalAutoRefreshServerVar + serverRef <- case maybeServerRef of + Just ref -> pure ref + Nothing -> error "AutoRefreshServer was not created" + + server <- readIORef serverRef + session <- case server.sessions of + (s:_) -> pure s + [] -> error "No AutoRefresh sessions found" + + -- 3. Call renderView with a bare request (simulating WebSocket re-render) + -- The WebSocket request has NO query params — this is the bug scenario + let bareRequest = defaultRequest + result <- Exception.try $ session.renderView bareRequest (\_ -> error "respond should not be called") + case result of + Left (ResponseException reResponse) -> do + reBody <- responseBody reResponse + -- If query params are NOT preserved, this would throw ParamNotFoundException + -- instead of reaching here with the correct value + cs reBody `shouldBe` ("abc-123" :: Text) + Right _ -> + expectationFailure "renderView should have thrown ResponseException" + + -- Cleanup + MVar.modifyMVar_ globalAutoRefreshServerVar (\_ -> pure Nothing) + + describe "graceful degradation without PGListener" do + it "should run the action without crashing when PGListener is not available" $ withContext do + MVar.modifyMVar_ globalAutoRefreshServerVar (\_ -> pure Nothing) + + response <- callActionWithParams ShowItemAction [("marketId", "degraded-ok")] body <- responseBody response - cs body `shouldBe` ("abc-123" :: Text) + cs body `shouldBe` ("degraded-ok" :: Text) - -- 2. Extract the stored renderView from the AutoRefreshSession + -- Verify autoRefresh skipped subscription machinery entirely maybeServerRef <- MVar.readMVar globalAutoRefreshServerVar - serverRef <- case maybeServerRef of - Just ref -> pure ref - Nothing -> error "AutoRefreshServer was not created" - - server <- readIORef serverRef - session <- case server.sessions of - (s:_) -> pure s - [] -> error "No AutoRefresh sessions found" - - -- 3. Call renderView with a bare request (simulating WebSocket re-render) - -- The WebSocket request has NO query params — this is the bug scenario - let bareRequest = defaultRequest - result <- Exception.try $ session.renderView bareRequest (\_ -> error "respond should not be called") - case result of - Left (ResponseException reResponse) -> do - reBody <- responseBody reResponse - -- If query params are NOT preserved, this would throw ParamNotFoundException - -- instead of reaching here with the correct value - cs reBody `shouldBe` ("abc-123" :: Text) - Right _ -> - expectationFailure "renderView should have thrown ResponseException" - - -- Cleanup - MVar.modifyMVar_ globalAutoRefreshServerVar (\_ -> pure Nothing) + case maybeServerRef of + Nothing -> pure () + Just _ -> expectationFailure "Expected globalAutoRefreshServerVar to be Nothing" + + describe "AutoRefresh meta tag" do + it "renders nothing when disabled" do + withFreshContext \context -> do + frozen <- freeze context + let ?context = frozen + renderMeta `shouldBe` "" + + it "includes the session id when enabled" do + let requestWithAutoRefresh = Wai.defaultRequest + { Wai.vault = Vault.insert autoRefreshStateVaultKey (AutoRefreshEnabled UUID.nil) Wai.defaultRequest.vault + } + withFreshContextWithRequest requestWithAutoRefresh \context -> do + frozen <- freeze context + let ?context = frozen + (cs renderMeta :: String) `shouldContain` "ihp-auto-refresh-id" + + describe "matchesInsertPayload" do + let mkRow pairs = AesonKeyMap.fromList [(AesonKey.fromText k, v) | (k, v) <- pairs] + let textParam val = Param (contramap (const val) (Encoders.param (Encoders.nonNullable Encoders.text))) + let uuidParam val = Param (contramap (const val) (Encoders.param (Encoders.nonNullable Encoders.uuid))) + + it "returns True for a matching EqOp condition" do + let row = mkRow [("project_id", Aeson.String "abc-123")] + let condition = ColumnCondition "tasks.project_id" EqOp (textParam "abc-123") Nothing Nothing + matchesInsertPayload condition row `shouldBe` True + + it "returns False for a non-matching EqOp condition" do + let row = mkRow [("project_id", Aeson.String "other-id")] + let condition = ColumnCondition "tasks.project_id" EqOp (textParam "abc-123") Nothing Nothing + matchesInsertPayload condition row `shouldBe` False + + it "handles UUID values" do + let uuid = "a7a37bca-417b-21d5-38fc-7f9000efe79c" :: UUID.UUID + let row = mkRow [("project_id", Aeson.String "a7a37bca-417b-21d5-38fc-7f9000efe79c")] + let condition = ColumnCondition "tasks.project_id" EqOp (uuidParam uuid) Nothing Nothing + matchesInsertPayload condition row `shouldBe` True + + it "returns False for AndCondition where one doesn't match" do + let row = mkRow [("project_id", Aeson.String "abc"), ("status", Aeson.String "active")] + let cond1 = ColumnCondition "tasks.project_id" EqOp (textParam "abc") Nothing Nothing + let cond2 = ColumnCondition "tasks.status" EqOp (textParam "inactive") Nothing Nothing + matchesInsertPayload (AndCondition cond1 cond2) row `shouldBe` False + + it "returns True for AndCondition where both match" do + let row = mkRow [("project_id", Aeson.String "abc"), ("status", Aeson.String "active")] + let cond1 = ColumnCondition "tasks.project_id" EqOp (textParam "abc") Nothing Nothing + let cond2 = ColumnCondition "tasks.status" EqOp (textParam "active") Nothing Nothing + matchesInsertPayload (AndCondition cond1 cond2) row `shouldBe` True + + it "returns True for OrCondition where one matches" do + let row = mkRow [("status", Aeson.String "active")] + let cond1 = ColumnCondition "tasks.status" EqOp (textParam "active") Nothing Nothing + let cond2 = ColumnCondition "tasks.status" EqOp (textParam "inactive") Nothing Nothing + matchesInsertPayload (OrCondition cond1 cond2) row `shouldBe` True + + it "returns True for unsupported operator (safe fallback)" do + let row = mkRow [("name", Aeson.String "hello")] + let condition = ColumnCondition "tasks.name" (LikeOp CaseSensitive) (textParam "%hello%") Nothing Nothing + matchesInsertPayload condition row `shouldBe` True + + it "returns True when condition has applyLeft (e.g. LOWER)" do + let row = mkRow [("name", Aeson.String "Hello")] + let condition = ColumnCondition "tasks.name" EqOp (textParam "hello") (Just "LOWER") Nothing + matchesInsertPayload condition row `shouldBe` True + + it "handles IS NULL condition" do + let row = mkRow [("deleted_at", Aeson.Null)] + let condition = ColumnCondition "tasks.deleted_at" IsOp (Literal "NULL") Nothing Nothing + matchesInsertPayload condition row `shouldBe` True + + it "rejects IS NULL when value is not null" do + let row = mkRow [("deleted_at", Aeson.String "2024-01-01")] + let condition = ColumnCondition "tasks.deleted_at" IsOp (Literal "NULL") Nothing Nothing + matchesInsertPayload condition row `shouldBe` False + + it "returns True for InOp (safe fallback, cannot decompose array params)" do + let row = mkRow [("status", Aeson.String "active")] + let condition = ColumnCondition "tasks.status" InOp (textParam "active") Nothing Nothing + matchesInsertPayload condition row `shouldBe` True + + describe "shouldRefreshForPayload" do + let mkInsertPayload row = AutoRefreshRowChangePayload AutoRefreshInsert Nothing (Just (Aeson.Object row)) Nothing + let mkUpdatePayload rowId = AutoRefreshRowChangePayload AutoRefreshUpdate Nothing (Just (Aeson.Object (AesonKeyMap.fromList [(AesonKey.fromText "id", Aeson.String rowId)]))) Nothing + let mkRow pairs = AesonKeyMap.fromList [(AesonKey.fromText k, v) | (k, v) <- pairs] + let textParam val = Param (contramap (const val) (Encoders.param (Encoders.nonNullable Encoders.text))) + + it "INSERT with no conditions tracked → refreshes" do + let payload = mkInsertPayload (mkRow [("id", Aeson.String "1")]) + shouldRefreshForPayload (Set.fromList ["1"]) Nothing payload `shouldBe` True + + it "INSERT with matching condition → refreshes" do + let row = mkRow [("project_id", Aeson.String "abc")] + let condition = ColumnCondition "tasks.project_id" EqOp (textParam "abc") Nothing Nothing + let payload = mkInsertPayload row + shouldRefreshForPayload (Set.fromList []) (Just [Just (toDyn condition)]) payload `shouldBe` True + + it "INSERT with non-matching condition → does NOT refresh" do + let row = mkRow [("project_id", Aeson.String "other")] + let condition = ColumnCondition "tasks.project_id" EqOp (textParam "abc") Nothing Nothing + let payload = mkInsertPayload row + shouldRefreshForPayload (Set.fromList []) (Just [Just (toDyn condition)]) payload `shouldBe` False - describe "graceful degradation without PGListener" do - it "should run the action without crashing when PGListener is not available" $ withContext do - MVar.modifyMVar_ globalAutoRefreshServerVar (\_ -> pure Nothing) + it "INSERT with unfiltered query (Nothing condition) → refreshes" do + let row = mkRow [("project_id", Aeson.String "any")] + let payload = mkInsertPayload row + shouldRefreshForPayload (Set.fromList []) (Just [Nothing]) payload `shouldBe` True - response <- callActionWithParams ShowItemAction [("marketId", "degraded-ok")] - body <- responseBody response - cs body `shouldBe` ("degraded-ok" :: Text) + it "UPDATE with tracked ID → refreshes" do + let payload = mkUpdatePayload "abc-123" + shouldRefreshForPayload (Set.fromList ["abc-123"]) Nothing payload `shouldBe` True - -- Verify autoRefresh skipped subscription machinery entirely - maybeServerRef <- MVar.readMVar globalAutoRefreshServerVar - case maybeServerRef of - Nothing -> pure () - Just _ -> expectationFailure "Expected globalAutoRefreshServerVar to be Nothing" + it "UPDATE with untracked ID → does NOT refresh" do + let payload = mkUpdatePayload "other-id" + shouldRefreshForPayload (Set.fromList ["abc-123"]) Nothing payload `shouldBe` False diff --git a/ihp/ihp.cabal b/ihp/ihp.cabal index 8d6be37ad..c5a3b2df3 100644 --- a/ihp/ihp.cabal +++ b/ihp/ihp.cabal @@ -78,6 +78,7 @@ common shared-properties , interpolate , split , containers + , dlist , http-media , cookie , process