Skip to content

Commit

Permalink
[C#] port: missing properties on ActionPlanner (#1303)
Browse files Browse the repository at this point in the history
## Linked issues

closes: #1274 

## Details
- port missing relevant public gets
  • Loading branch information
nkrama-99 authored Feb 26, 2024
1 parent f6b15df commit 1863151
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,78 @@ public async void Test_BeginTaskAsync()
Assert.Equal(planMock.Object, result);
}

[Fact]
public void Test_Get_Model()
{
// Arrange
var modelMock = new Mock<IPromptCompletionModel>();
var prompts = new PromptManager();
var promptTemplate = new PromptTemplate(
"prompt",
new(new() { })
);
var options = new ActionPlannerOptions<TurnState>(
modelMock.Object,
prompts,
(context, state, planner) => Task.FromResult(promptTemplate)
);
var planner = new ActionPlanner<TurnState>(options, new TestLoggerFactory());

// Act
var result = planner.Model;

// Assert
Assert.Equal(options.Model, result);
}

[Fact]
public void Test_Get_Prompts()
{
// Arrange
var modelMock = new Mock<IPromptCompletionModel>();
var prompts = new PromptManager();
var promptTemplate = new PromptTemplate(
"prompt",
new(new() { })
);
var options = new ActionPlannerOptions<TurnState>(
modelMock.Object,
prompts,
(context, state, planner) => Task.FromResult(promptTemplate)
);
var planner = new ActionPlanner<TurnState>(options, new TestLoggerFactory());

// Act
var result = planner.Prompts;

// Assert
Assert.Equal(options.Prompts, result);
}

[Fact]
public void Test_Get_DefaultPrompt()
{
// Arrange
var modelMock = new Mock<IPromptCompletionModel>();
var prompts = new PromptManager();
var promptTemplate = new PromptTemplate(
"prompt",
new(new() { })
);
var options = new ActionPlannerOptions<TurnState>(
modelMock.Object,
prompts,
(context, state, planner) => Task.FromResult(promptTemplate)
);
var planner = new ActionPlanner<TurnState>(options, new TestLoggerFactory());

// Act
var result = planner.DefaultPrompt;

// Assert
Assert.Equal(options.DefaultPrompt, result);
}

private sealed class TestMemory : IMemory
{
public Dictionary<string, object> Values { get; set; } = new Dictionary<string, object>();
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using Microsoft.Bot.Builder;
using Microsoft.Extensions.Logging;
using Microsoft.Teams.AI.AI.Clients;
using Microsoft.Teams.AI.AI.Models;
using Microsoft.Teams.AI.AI.Prompts;
using Microsoft.Teams.AI.AI.Validators;
using Microsoft.Teams.AI.State;
Expand Down Expand Up @@ -48,6 +49,21 @@ public ActionPlanner(ActionPlannerOptions<TState> options, ILoggerFactory? logge
this._logger = loggerFactory;
}

/// <summary>
/// Gets the prompt completion model in use
/// </summary>
public IPromptCompletionModel Model { get => Options.Model; }

/// <summary>
/// Get the prompt manager in use
/// </summary>
public PromptManager Prompts { get => Options.Prompts; }

/// <summary>
/// Get the default prompt manager in use
/// </summary>
public ActionPlannerOptions<TState>.ActionPlannerPromptFactory DefaultPrompt { get => Options.DefaultPrompt; }

/// <summary>
/// Starts a new task.
/// </summary>
Expand Down Expand Up @@ -132,26 +148,26 @@ public async Task<PromptResponse> CompletePromptAsync(
CancellationToken cancellationToken = default
)
{
if (!this.Options.Prompts.HasPrompt(template.Name))
if (!this.Prompts.HasPrompt(template.Name))
{
this.Options.Prompts.AddPrompt(template.Name, template);
this.Prompts.AddPrompt(template.Name, template);
}

string historyVariable = template.Configuration.Completion.IncludeHistory ?
$"conversation.{template.Name}_history" :
$"temp.{template.Name}_history";

LLMClient<object> client = new(new(this.Options.Model, template)
LLMClient<object> client = new(new(this.Model, template)
{
HistoryVariable = historyVariable,
Validator = validator ?? new DefaultResponseValidator(),
Tokenizer = this.Options.Tokenizer,
MaxHistoryMessages = this.Options.Prompts.Options.MaxHistoryMessages,
MaxHistoryMessages = this.Prompts.Options.MaxHistoryMessages,
MaxRepairAttempts = this.Options.MaxRepairAttempts,
LogRepairs = this.Options.LogRepairs
}, this._logger);

return await client.CompletePromptAsync(context, memory, this.Options.Prompts, null, cancellationToken);
return await client.CompletePromptAsync(context, memory, this.Prompts, null, cancellationToken);
}
}
}

0 comments on commit 1863151

Please sign in to comment.