Skip to content

Commit

Permalink
Remove ToolsAugmentationConstants & update ToolsAugmentation to not u…
Browse files Browse the repository at this point in the history
…se them
  • Loading branch information
Corina Gum committed Aug 5, 2024
1 parent 52d9c47 commit 6b4e0c2
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 156 deletions.
157 changes: 26 additions & 131 deletions js/packages/teams-ai/src/augmentations/ToolsAugmentation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,11 @@ import { ChatCompletionAction } from '../models';
import { Plan, PredictedCommand, PredictedDoCommand, PredictedSayCommand } from '../planners';
import { PromptSection } from '../prompts';
import { Tokenizer } from '../tokenizers';
import { ActionCall, PromptResponse, ToolsAugmentationConstants } from '../types';
import { ActionCall, PromptResponse } from '../types';
import { Validation } from '../validators';

import { Augmentation } from './Augmentation';

const { SUBMIT_TOOL_OUTPUTS_VARIABLE, SUBMIT_TOOL_OUTPUTS_MAP, SUBMIT_TOOL_OUTPUTS_MESSAGES } =
ToolsAugmentationConstants;

/**
* The 'tools' augmentation is for enabling server-side action/tools calling.
* In the Teams AI Library, the equivalent to OpenAI's 'tools' functionality is called an 'action'.
Expand Down Expand Up @@ -59,104 +56,9 @@ export class ToolsAugmentation implements Augmentation<string | ActionCall[]> {
response: PromptResponse<string>,
remaining_attempts: number
): Promise<Validation> {
const validActionHandlers: ActionCall[] = [];

if (
this._actions &&
response.message &&
response.message.action_tool_calls &&
memory.getValue(SUBMIT_TOOL_OUTPUTS_VARIABLE) === true
) {
const actionCall: ActionCall[] = response.message.action_tool_calls!;
const actions = this._actions;
const toolChoice = memory.getValue('temp.toolChoice') || 'auto';

let currentCall: ActionCall | undefined;
let currentTool: ChatCompletionAction | undefined;
let functionName: string = '';

// Validate a single tool where tool_choice is a single action definition
if (toolChoice instanceof Map) {
functionName = toolChoice.get('function').get('name');
currentCall = actionCall[0];

for (const tool of actions) {
if (tool.name === functionName) {
currentTool = tool;
break;
}
}
} else {
// Validate multiple tools
for (const call of actionCall) {
functionName = call.function.name;

for (const tool of actions) {
if (tool.name === functionName) {
currentTool = tool;
currentCall = call;
break;
}
}
// Validate function name
if (!currentTool) {
continue;
}
}
}

if (!currentTool) {
return Promise.resolve({
type: 'Validation',
valid: false,
feedback: `ToolsAugmentation: The invoked action ${functionName} does not exist.`
});
}

if (currentTool && currentCall) {
// Validate required function arguments
const requiredArgs: string[] =
currentTool.parameters &&
currentTool.parameters.required &&
Array.isArray(currentTool.parameters.required)
? currentTool.parameters.required
: [];

let currentArgs = {};
try {
currentArgs = JSON.parse(currentCall.function.arguments);
} catch (error) {
return Promise.resolve({
type: 'Validation',
valid: false,
feedback: `ToolsAugmentation: Error parsing tool arguments: ${error}`
});
}

// Validate that required arguments are included in current arguments
if (
requiredArgs &&
currentArgs &&
requiredArgs.every((arg) => Object.keys(currentArgs).includes(arg))
) {
validActionHandlers.push(currentCall);
} else {
// There are no required arguments that need validation
validActionHandlers.push(currentCall);
}
}
// No tools were valid; reset ToolsAugmentation constants
if (validActionHandlers.length === 0) {
memory.setValue(SUBMIT_TOOL_OUTPUTS_VARIABLE, false);
memory.setValue(SUBMIT_TOOL_OUTPUTS_MAP, {});
memory.setValue(SUBMIT_TOOL_OUTPUTS_MESSAGES, []);
}
}

return Promise.resolve({
type: 'Validation',
valid: true,
value: validActionHandlers.length > 0 ? validActionHandlers : undefined
valid: true
});
}

Expand All @@ -175,43 +77,36 @@ export class ToolsAugmentation implements Augmentation<string | ActionCall[]> {
const toolsMap = new Map<string, string>();
const commands: PredictedCommand[] = [];

if (response.message && response.message.content) {
if (memory.getValue(SUBMIT_TOOL_OUTPUTS_VARIABLE) === true && Array.isArray(response.message.content)) {
const actionToolCalls: ActionCall[] = response.message.content;
for (const actionToolCall of actionToolCalls) {
toolsMap.set(actionToolCall.function.name, actionToolCall.id);
let parameters;

try {
parameters = JSON.parse(actionToolCall.function.arguments);
} catch (err) {
console.error('ToolsAugmentation createPlanFromResponse: Error parsing tool arguments: ', err);
parameters = {};
}
if (response.message && response.message.action_tool_calls) {
const actionToolCalls: ActionCall[] = response.message.action_tool_calls ?? [];
for (const toolCall of actionToolCalls) {
toolsMap.set(toolCall.function.name, toolCall.id);
let parameters;

commands.push({
type: 'DO',
action: actionToolCall.function.name,
parameters: parameters
} as PredictedDoCommand);
try {
parameters = JSON.parse(toolCall.function.arguments);
} catch (err) {
console.error('ToolsAugmentation createPlanFromResponse: Error parsing tool arguments: ', err);
parameters = {};
}
memory.setValue(SUBMIT_TOOL_OUTPUTS_MAP, toolsMap);
return Promise.resolve({ type: 'plan', commands });
}

return Promise.resolve({
type: 'plan',
commands: [
{
type: 'SAY',
response: response.message
} as PredictedSayCommand
]
});
commands.push({
type: 'DO',
action: toolCall.function.name,
parameters: parameters
} as PredictedDoCommand);
}
return Promise.resolve({ type: 'plan', commands });
}

return Promise.resolve({
type: 'plan',
commands: []
commands: [
{
type: 'SAY',
response: response.message
} as PredictedSayCommand
]
});
}
}
25 changes: 0 additions & 25 deletions js/packages/teams-ai/src/types/ToolsAugmentationConstants.ts

This file was deleted.

0 comments on commit 6b4e0c2

Please sign in to comment.