From 79a006f3b4d00eafb75b0e6f3c064717f5bbe015 Mon Sep 17 00:00:00 2001 From: Lily Du Date: Thu, 1 Aug 2024 09:25:14 -0700 Subject: [PATCH] [repo] fix: max_tokens bug fix in completions (#1887) ## Linked issues closes: #1835, #1800 ## Details - `max_input_tokens` was incorrectly passed in as openai's `max_token` (for both C# and Python) ## Attestation Checklist - [x] My code follows the style guidelines of this project - I have checked for/fixed spelling, linting, and other errors - I have commented my code for clarity - I have made corresponding changes to the documentation (updating the doc strings in the code is sufficient) - My changes generate no new warnings - I have added tests that validates my changes, and provides sufficient test coverage. I have tested with: - Local testing - E2E testing in Teams - New and existing unit tests pass locally with my changes --- .../Microsoft.TeamsAI/AI/Models/OpenAIModel.cs | 2 +- python/packages/ai/teams/ai/models/openai_model.py | 8 ++++---- python/packages/ai/teams/app.py | 7 ++++--- 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/AI/Models/OpenAIModel.cs b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/AI/Models/OpenAIModel.cs index 23d8de179..3a3a779ec 100644 --- a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/AI/Models/OpenAIModel.cs +++ b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/AI/Models/OpenAIModel.cs @@ -178,7 +178,7 @@ public async Task CompletePromptAsync(ITurnContext turnContext, IEnumerable chatMessages = prompt.Output.Select(chatMessage => chatMessage.ToOpenAIChatMessage()); ChatCompletionOptions? chatCompletionOptions = ModelReaderWriter.Read(BinaryData.FromString($@"{{ - ""max_tokens"": {maxInputTokens}, + ""max_tokens"": {promptTemplate.Configuration.Completion.MaxTokens}, ""temperature"": {(float)promptTemplate.Configuration.Completion.Temperature}, ""top_p"": {(float)promptTemplate.Configuration.Completion.TopP}, ""presence_penalty"": {(float)promptTemplate.Configuration.Completion.PresencePenalty}, diff --git a/python/packages/ai/teams/ai/models/openai_model.py b/python/packages/ai/teams/ai/models/openai_model.py index 8924ac024..3f6894451 100644 --- a/python/packages/ai/teams/ai/models/openai_model.py +++ b/python/packages/ai/teams/ai/models/openai_model.py @@ -123,7 +123,7 @@ async def complete_prompt( tokenizer: Tokenizer, template: PromptTemplate, ) -> PromptResponse[str]: - max_tokens = template.config.completion.max_input_tokens + max_input_tokens = template.config.completion.max_input_tokens model = ( template.config.completion.model if template.config.completion.model is not None @@ -134,7 +134,7 @@ async def complete_prompt( memory=memory, functions=functions, tokenizer=tokenizer, - max_tokens=max_tokens, + max_tokens=max_input_tokens, ) if res.too_long: @@ -142,7 +142,7 @@ async def complete_prompt( status="too_long", error=f""" the generated chat completion prompt had a length of {res.length} tokens - which exceeded the max_input_tokens of {max_tokens} + which exceeded the max_input_tokens of {max_input_tokens} """, ) @@ -194,7 +194,7 @@ async def complete_prompt( frequency_penalty=template.config.completion.frequency_penalty, top_p=template.config.completion.top_p, temperature=template.config.completion.temperature, - max_tokens=max_tokens, + max_tokens=template.config.completion.max_tokens, extra_body=extra_body, ) diff --git a/python/packages/ai/teams/app.py b/python/packages/ai/teams/app.py index d4da12443..f943cd2b1 100644 --- a/python/packages/ai/teams/app.py +++ b/python/packages/ai/teams/app.py @@ -610,7 +610,7 @@ async def __handler__(context: TurnContext, state: StateT): return False feedback = context.activity.value - feedback.reply_to_id=context.activity.reply_to_id + feedback.reply_to_id = context.activity.reply_to_id await func(context, state, feedback) await context.send_activity( @@ -819,8 +819,9 @@ async def _run_ai_chain(self, context: TurnContext, state): return True def _contains_non_text_attachments(self, context): - non_text_attachments = filter(lambda a: not a.content_type.startswith( - "text/html"), context.activity.attachments) + non_text_attachments = filter( + lambda a: not a.content_type.startswith("text/html"), context.activity.attachments + ) return len(list(non_text_attachments)) > 0 async def _run_after_turn_middleware(self, context: TurnContext, state):