Skip to content

Commit

Permalink
Update ToolsAugmentation and OpenAIModel convertMessages
Browse files Browse the repository at this point in the history
  • Loading branch information
Corina Gum committed Aug 16, 2024
1 parent 5a8032e commit c562870
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 25 deletions.
8 changes: 5 additions & 3 deletions js/packages/teams-ai/src/augmentations/ToolsAugmentation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,14 @@ export class ToolsAugmentation implements Augmentation<string> {
parameters = {};
}

commands.push({
const doCommand: PredictedDoCommand = {
type: 'DO',
action: toolCall.function.name,
id: toolCall.id,
actionId: toolCall.id,
parameters: parameters
} as PredictedDoCommand);
};

commands.push(doCommand);
}
return Promise.resolve({ type: 'plan', commands });
}
Expand Down
51 changes: 29 additions & 22 deletions js/packages/teams-ai/src/models/OpenAIModel.ts
Original file line number Diff line number Diff line change
Expand Up @@ -408,9 +408,11 @@ export class OpenAIModel implements PromptCompletionModel {
this._events.emit('responseReceived', context, memory, response);
return response;
} catch (err: unknown) {
console.log(err);
return this.returnError(err, input);
}
}

/**
* Converts the messages to ChatCompletionMessageParam[].
* @param {Message<string>} messages - The messages from result.output.
Expand All @@ -419,23 +421,25 @@ export class OpenAIModel implements PromptCompletionModel {
private convertMessages(messages: Message<string>[]): ChatCompletionMessageParam[] {
const params: ChatCompletionMessageParam[] = [];
// Iterate through the messages and check for action calls
for (const message of messages) {
let param: ChatCompletionMessageParam = {
role: 'user',
content: message.content ?? ''
};

if (message.name) {
param.name = message.name;
}

const toolCallParams: ChatCompletionMessageToolCall[] = [];
let param: ChatCompletionMessageParam = {
role: 'user',
content: ''
};

if (message.role === 'assistant') {
for (const message of messages) {
if (message.role === 'user') {
param.content = message.content ?? '';
} else if (message.role === 'system') {
param = {
role: 'system',
content: message.content ?? ''
};
} else if (message.role === 'assistant') {
param = {
role: 'assistant',
content: message.content ?? ''
};
const toolCallParams: ChatCompletionMessageToolCall[] = [];

if (message.action_calls && message.action_calls.length > 0) {
for (const toolCall of message.action_calls) {
Expand All @@ -445,28 +449,30 @@ export class OpenAIModel implements PromptCompletionModel {
name: toolCall.function.name,
arguments: toolCall.function.arguments
},
type: 'function'
type: toolCall.type
});
}

param.tool_calls = toolCallParams;
}
} else if ((message.role = 'tool')) {
} else if (message.role === 'tool') {
param = {
role: 'tool',
tool_call_id: message.action_call_id ?? '',
content: message.content ?? ''
content: message.content ?? '',
tool_call_id: message.action_call_id ?? ''
};
} else if ((message.role = 'system')) {
} else {
param = {
role: 'system',
content: message.content ?? ''
role: 'function',
content: message.content ?? '',
name: message.name ?? ''
};
}
params.push(param);
}

return params;
}

/**
* @private
* @template TRequest
Expand Down Expand Up @@ -514,12 +520,13 @@ export class OpenAIModel implements PromptCompletionModel {
})
: [];

const parallelToolCalls = isToolsAugmentation ? template.config.completion.parallel_tool_calls : undefined;
const completion = {
...template.config.completion,
tool_choice: isToolsAugmentation ? (template.config.completion.tool_choice ?? 'auto') : undefined,
tools: chatCompletionTools,
// Only include parallel_tool_calls if tools are enabled and the template has it set; otherwise, it will default to true without being added to the API call
parallel_tool_calls: isToolsAugmentation ? template.config.completion.parallel_tool_calls : undefined
...(!!parallelToolCalls && { parallel_tool_calls: parallelToolCalls })
};

const params: ChatCompletionCreateParams = this.copyOptionsToRequest<ChatCompletionCreateParams>(
Expand Down Expand Up @@ -566,7 +573,7 @@ export class OpenAIModel implements PromptCompletionModel {

private getInputMessage(messages: Message<string>[]): Message<string> | undefined {
const last = messages.length - 1;
if (last > 0 && messages[last].role !== 'user') {
if (last > 0 && messages[last].role !== 'assistant') {
return messages[last];
}

Expand Down

0 comments on commit c562870

Please sign in to comment.