Skip to content
This repository was archived by the owner on Sep 29, 2025. It is now read-only.

Commit 94d3b71

Browse files
mongodbenBen Perlmutterhschawe
authored
(EAI-1117): Guardrail fix (#781)
* thx claude * Remove unused guardrail fxn --------- Co-authored-by: Ben Perlmutter <[email protected]> Co-authored-by: Helen Schawe <[email protected]>
1 parent 42469b3 commit 94d3b71

File tree

5 files changed

+74
-240
lines changed

5 files changed

+74
-240
lines changed

packages/chatbot-server-mongodb-public/src/processors/generateResponseWithSearchTool.test.ts

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,34 @@ describe("generateResponseWithSearchTool", () => {
272272
) as UserMessage;
273273
expect(userMessage.customData).toMatchObject(searchToolMockArgs);
274274
});
275+
it("should not generate until guardrail has resolved (reject)", async () => {
276+
const generateResponse = makeGenerateResponseWithSearchTool({
277+
...makeMakeGenerateResponseWithSearchToolArgs(),
278+
inputGuardrail: async () => {
279+
// sleep for 2 seconds
280+
await new Promise((resolve) => setTimeout(resolve, 2000));
281+
return mockGuardrailRejectResult;
282+
},
283+
});
275284

285+
const result = await generateResponse(generateResponseBaseArgs);
286+
287+
expectGuardrailRejectResult(result);
288+
});
289+
it("should not generate until guardrail has resolved (pass)", async () => {
290+
const generateResponse = makeGenerateResponseWithSearchTool({
291+
...makeMakeGenerateResponseWithSearchToolArgs(),
292+
inputGuardrail: async () => {
293+
// sleep for 2 seconds
294+
await new Promise((resolve) => setTimeout(resolve, 2000));
295+
return mockGuardrailPassResult;
296+
},
297+
});
298+
299+
const result = await generateResponse(generateResponseBaseArgs);
300+
301+
expectSuccessfulResult(result);
302+
});
276303
describe("non-streaming", () => {
277304
test("should handle successful generation non-streaming", async () => {
278305
const generateResponse = makeGenerateResponseWithSearchTool(

packages/chatbot-server-mongodb-public/src/processors/generateResponseWithSearchTool.ts

Lines changed: 46 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ import {
2626
MakeReferenceLinksFunc,
2727
makeDefaultReferenceLinks,
2828
GenerateResponse,
29-
withAbortControllerGuardrail,
3029
GenerateResponseReturnValue,
3130
InputGuardrailResult,
3231
} from "mongodb-chatbot-server";
@@ -128,14 +127,26 @@ export function makeGenerateResponseWithSearchTool({
128127

129128
const references: References = [];
130129
let userMessageCustomData: Partial<MongoDbSearchToolArgs> = {};
131-
const { result, guardrailResult } = await withAbortControllerGuardrail(
132-
async (controller) => {
133-
// Pass the tools as a separate parameter
130+
131+
// Create an AbortController for the generation
132+
const generationController = new AbortController();
133+
let guardrailRejected = false;
134+
135+
// Start guardrail check immediately and monitor it
136+
const guardrailMonitor = inputGuardrailPromise?.then((result) => {
137+
if (result?.rejected) {
138+
guardrailRejected = true;
139+
generationController.abort();
140+
}
141+
return result;
142+
});
143+
144+
// Start generation immediately (in parallel with guardrail)
145+
const generationPromise = (async () => {
146+
try {
134147
const result = streamText({
135148
...generationArgs,
136-
// Abort the stream if the guardrail AbortController is triggered
137-
abortSignal: controller.signal,
138-
// Add the search tool results to the references
149+
abortSignal: generationController.signal,
139150
onStepFinish: async ({ toolResults, toolCalls }) => {
140151
toolCalls?.forEach((toolCall) => {
141152
if (toolCall.toolName === SEARCH_TOOL_NAME) {
@@ -159,10 +170,13 @@ export function makeGenerateResponseWithSearchTool({
159170
},
160171
});
161172

173+
// Process the stream
162174
for await (const chunk of result.fullStream) {
163-
if (controller.signal.aborted) {
175+
// Check if we should abort due to guardrail rejection
176+
if (generationController.signal.aborted) {
164177
break;
165178
}
179+
166180
switch (chunk.type) {
167181
case "text-delta":
168182
if (shouldStream) {
@@ -185,22 +199,32 @@ export function makeGenerateResponseWithSearchTool({
185199
break;
186200
}
187201
}
188-
try {
189-
if (references.length > 0) {
190-
if (shouldStream) {
191-
dataStreamer?.streamData({
192-
data: references,
193-
type: "references",
194-
});
195-
}
202+
203+
// Stream references if we have any and weren't aborted
204+
if (references.length > 0 && !generationController.signal.aborted) {
205+
if (shouldStream) {
206+
dataStreamer?.streamData({
207+
data: references,
208+
type: "references",
209+
});
196210
}
197-
return result;
198-
} catch (error: unknown) {
199-
throw new Error(typeof error === "string" ? error : String(error));
200211
}
201-
},
202-
inputGuardrailPromise
203-
);
212+
213+
return result;
214+
} catch (error: unknown) {
215+
// If aborted due to guardrail, return null
216+
if (generationController.signal.aborted && guardrailRejected) {
217+
return null;
218+
}
219+
throw new Error(typeof error === "string" ? error : String(error));
220+
}
221+
})();
222+
223+
// Wait for both to complete
224+
const [guardrailResult, result] = await Promise.all([
225+
guardrailMonitor ?? Promise.resolve(undefined),
226+
generationPromise,
227+
]);
204228

205229
// If the guardrail rejected the query,
206230
// return the LLM refusal message

packages/mongodb-chatbot-server/src/processors/InputGuardrail.test.ts

Lines changed: 0 additions & 178 deletions
This file was deleted.

packages/mongodb-chatbot-server/src/processors/InputGuardrail.ts

Lines changed: 0 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -19,40 +19,3 @@ export type InputGuardrail<
1919
> = (
2020
generateResponseParams: GenerateResponseParams
2121
) => Promise<InputGuardrailResult<Metadata>>;
22-
23-
export function withAbortControllerGuardrail<T>(
24-
fn: (abortController: AbortController) => Promise<T>,
25-
guardrailPromise?: Promise<InputGuardrailResult>
26-
): Promise<{
27-
result: T | null;
28-
guardrailResult: InputGuardrailResult | undefined;
29-
}> {
30-
const abortController = new AbortController();
31-
return (async () => {
32-
try {
33-
// Run both the main function and guardrail function in parallel
34-
const [result, guardrailResult] = await Promise.all([
35-
fn(abortController),
36-
guardrailPromise
37-
?.then((guardrailResult) => {
38-
if (guardrailResult.rejected) {
39-
abortController.abort();
40-
}
41-
return guardrailResult;
42-
})
43-
.catch((error) => {
44-
abortController.abort();
45-
return { ...guardrailFailedResult, metadata: { error } };
46-
}),
47-
]);
48-
49-
return { result, guardrailResult };
50-
} catch (error) {
51-
// If an unexpected error occurs, abort any ongoing operations
52-
if (!abortController.signal.aborted) {
53-
abortController.abort();
54-
}
55-
throw error;
56-
}
57-
})();
58-
}

packages/mongodb-rag-core/src/conversations/ConversationsService.ts

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,5 @@ export const defaultConversationConstants: ConversationConstants = {
288288
289289
Please try to rephrase your message. Adding more details can help me respond with a relevant answer.`,
290290
LLM_NOT_WORKING: `Unfortunately, my chat functionality is not working at the moment,
291-
so I cannot respond to your message. Please try again later.
292-
293-
However, here are some links that might provide some helpful information for your message:`,
291+
so I cannot respond to your message. Please try again later.`,
294292
};

0 commit comments

Comments
 (0)