Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

.Net: Add unit tests for Text Search AOT enhancements #10143

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// Copyright (c) Microsoft. All rights reserved.

using System.Text.Json.Serialization;
using Microsoft.SemanticKernel.Data;
using SemanticKernel.AotTests.Plugins;

namespace SemanticKernel.AotTests.JsonSerializerContexts;

[JsonSerializable(typeof(CustomResult))]
[JsonSerializable(typeof(int))]
[JsonSerializable(typeof(KernelSearchResults<string>))]
[JsonSerializable(typeof(KernelSearchResults<TextSearchResult>))]
[JsonSerializable(typeof(KernelSearchResults<object>))]
internal sealed partial class CustomResultJsonSerializerContext : JsonSerializerContext
{
}
12 changes: 12 additions & 0 deletions dotnet/src/SemanticKernel.AotTests/Plugins/CustomResult.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
// Copyright (c) Microsoft. All rights reserved.

namespace SemanticKernel.AotTests.Plugins;
internal sealed class CustomResult
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: new line between namespace and the class declaration.

{
public string Value { get; set; }

public CustomResult(string value)
{
this.Value = value;
}
}
5 changes: 5 additions & 0 deletions dotnet/src/SemanticKernel.AotTests/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,11 @@ private static async Task<int> Main(string[] args)

// Tests for text search
VectorStoreTextSearchTests.GetTextSearchResultsAsync,
VectorStoreTextSearchTests.AddVectorStoreTextSearch,

TextSearchExtensionsTests.CreateWithSearch,
TextSearchExtensionsTests.CreateWithGetTextSearchResults,
TextSearchExtensionsTests.CreateWithGetSearchResults,
];

private static async Task<bool> RunUnitTestsAsync(IEnumerable<Func<Task>> functionsToRun)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
<PackageReference Include="Microsoft.Extensions.Configuration" />
<PackageReference Include="Microsoft.Extensions.Configuration.UserSecrets" />
<PackageReference Include="MSTest.TestFramework" />
<PackageReference Include="System.Linq.Async" />
</ItemGroup>

<ItemGroup>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// Copyright (c) Microsoft. All rights reserved.

using Microsoft.SemanticKernel.Data;

namespace SemanticKernel.AotTests.UnitTests.Search;

internal sealed class MockTextSearch : ITextSearch
{
private readonly KernelSearchResults<object>? _objectResults;
private readonly KernelSearchResults<TextSearchResult>? _textSearchResults;
private readonly KernelSearchResults<string>? _stringResults;

public MockTextSearch(KernelSearchResults<object>? objectResults)
{
this._objectResults = objectResults;
}

public MockTextSearch(KernelSearchResults<TextSearchResult>? textSearchResults)
{
this._textSearchResults = textSearchResults;
}

public MockTextSearch(KernelSearchResults<string>? stringResults)
{
this._stringResults = stringResults;
}

public Task<KernelSearchResults<object>> GetSearchResultsAsync(string query, TextSearchOptions? searchOptions = null, CancellationToken cancellationToken = default)
{
return Task.FromResult(this._objectResults!);
}

public Task<KernelSearchResults<TextSearchResult>> GetTextSearchResultsAsync(string query, TextSearchOptions? searchOptions = null, CancellationToken cancellationToken = default)
{
return Task.FromResult(this._textSearchResults!);
}

public Task<KernelSearchResults<string>> SearchAsync(string query, TextSearchOptions? searchOptions = null, CancellationToken cancellationToken = default)
{
return Task.FromResult(this._stringResults!);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
// Copyright (c) Microsoft. All rights reserved.

using System.Text.Json;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.Data;
using Microsoft.VisualStudio.TestTools.UnitTesting;
using SemanticKernel.AotTests.JsonSerializerContexts;
using SemanticKernel.AotTests.Plugins;

namespace SemanticKernel.AotTests.UnitTests.Search;

internal sealed class TextSearchExtensionsTests
{
private static readonly JsonSerializerOptions s_jsonSerializerOptions = new()
{
TypeInfoResolverChain = { CustomResultJsonSerializerContext.Default }
};

public static async Task CreateWithSearch()
{
// Arrange
var testData = new List<string> { "test-value" };
KernelSearchResults<string> results = new(testData.ToAsyncEnumerable());
ITextSearch textSearch = new MockTextSearch(results);

// Act
var plugin = textSearch.CreateWithSearch("SearchPlugin", s_jsonSerializerOptions);

// Assert
await AssertSearchFunctionSchemaAndInvocationResult<string>(plugin["Search"], testData[0]);
}

public static async Task CreateWithGetTextSearchResults()
{
// Arrange
var testData = new List<TextSearchResult> { new("test-value") };
KernelSearchResults<TextSearchResult> results = new(testData.ToAsyncEnumerable());
ITextSearch textSearch = new MockTextSearch(results);

// Act
var plugin = textSearch.CreateWithGetTextSearchResults("SearchPlugin", s_jsonSerializerOptions);

// Assert
await AssertSearchFunctionSchemaAndInvocationResult<TextSearchResult>(plugin["GetTextSearchResults"], testData[0]);
}

public static async Task CreateWithGetSearchResults()
{
// Arrange
var testData = new List<CustomResult> { new("test-value") };
KernelSearchResults<object> results = new(testData.ToAsyncEnumerable());
ITextSearch textSearch = new MockTextSearch(results);

// Act
var plugin = textSearch.CreateWithGetSearchResults("SearchPlugin", s_jsonSerializerOptions);

// Assert
await AssertSearchFunctionSchemaAndInvocationResult<object>(plugin["GetSearchResults"], testData[0]);
}

#region assert
internal static async Task AssertSearchFunctionSchemaAndInvocationResult<T>(KernelFunction function, T expectedResult)
{
// Assert input parameter schema
AssertSearchFunctionMetadata<T>(function.Metadata);

// Assert the function result
FunctionResult functionResult = await function.InvokeAsync(new(), new() { ["query"] = "Mock Query" });

var result = functionResult.GetValue<List<T>>()!;
Assert.AreEqual(1, result.Count);
Assert.AreEqual(expectedResult, result[0]);
}

internal static void AssertSearchFunctionMetadata<T>(KernelFunctionMetadata metadata)
{
// Assert input parameter schema
Assert.AreEqual(3, metadata.Parameters.Count);
Assert.AreEqual("{\"description\":\"What to search for\",\"type\":\"string\"}", metadata.Parameters[0].Schema!.ToString());
Assert.AreEqual("{\"description\":\"Number of results (default value: 2)\",\"type\":\"integer\"}", metadata.Parameters[1].Schema!.ToString());
Assert.AreEqual("{\"description\":\"Number of results to skip (default value: 0)\",\"type\":\"integer\"}", metadata.Parameters[2].Schema!.ToString());

// Assert return type schema
var type = typeof(T).Name;
var expectedSchema = type switch
{
"String" => "{\"type\":\"object\",\"properties\":{\"TotalCount\":{\"type\":[\"integer\",\"null\"],\"default\":null},\"Metadata\":{\"type\":[\"object\",\"null\"],\"default\":null},\"Results\":{\"type\":\"array\",\"items\":{\"type\":\"string\"}}},\"required\":[\"Results\"]}",
"TextSearchResult" => "{\"type\":\"object\",\"properties\":{\"TotalCount\":{\"type\":[\"integer\",\"null\"],\"default\":null},\"Metadata\":{\"type\":[\"object\",\"null\"],\"default\":null},\"Results\":{\"type\":\"array\",\"items\":{\"type\":\"object\",\"properties\":{\"Name\":{\"type\":[\"string\",\"null\"]},\"Link\":{\"type\":[\"string\",\"null\"]},\"Value\":{\"type\":\"string\"}},\"required\":[\"Value\"]}}},\"required\":[\"Results\"]}",
_ => "{\"type\":\"object\",\"properties\":{\"TotalCount\":{\"type\":[\"integer\",\"null\"],\"default\":null},\"Metadata\":{\"type\":[\"object\",\"null\"],\"default\":null},\"Results\":{\"type\":\"array\",\"items\":{\"type\":\"object\",\"properties\":{\"Name\":{\"type\":[\"string\",\"null\"]},\"Link\":{\"type\":[\"string\",\"null\"]},\"Value\":{\"type\":\"string\"}},\"required\":[\"Value\"]}}},\"required\":[\"Results\"]}",
};
Comment on lines +85 to +90
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
var expectedSchema = type switch
{
"String" => "{\"type\":\"object\",\"properties\":{\"TotalCount\":{\"type\":[\"integer\",\"null\"],\"default\":null},\"Metadata\":{\"type\":[\"object\",\"null\"],\"default\":null},\"Results\":{\"type\":\"array\",\"items\":{\"type\":\"string\"}}},\"required\":[\"Results\"]}",
"TextSearchResult" => "{\"type\":\"object\",\"properties\":{\"TotalCount\":{\"type\":[\"integer\",\"null\"],\"default\":null},\"Metadata\":{\"type\":[\"object\",\"null\"],\"default\":null},\"Results\":{\"type\":\"array\",\"items\":{\"type\":\"object\",\"properties\":{\"Name\":{\"type\":[\"string\",\"null\"]},\"Link\":{\"type\":[\"string\",\"null\"]},\"Value\":{\"type\":\"string\"}},\"required\":[\"Value\"]}}},\"required\":[\"Results\"]}",
_ => "{\"type\":\"object\",\"properties\":{\"TotalCount\":{\"type\":[\"integer\",\"null\"],\"default\":null},\"Metadata\":{\"type\":[\"object\",\"null\"],\"default\":null},\"Results\":{\"type\":\"array\",\"items\":{\"type\":\"object\",\"properties\":{\"Name\":{\"type\":[\"string\",\"null\"]},\"Link\":{\"type\":[\"string\",\"null\"]},\"Value\":{\"type\":\"string\"}},\"required\":[\"Value\"]}}},\"required\":[\"Results\"]}",
};
var expectedSchema = type switch
{
"String" => """{"type":"object","properties":{"TotalCount":{"type":["integer","null"],"default":null},"Metadata":{"type":["object","null"],"default":null},"Results":{"type":"array","items":{"type":"string"}}},"required":["Results"]}""",
"TextSearchResult" => """{"type":"object","properties":{"TotalCount":{"type":["integer","null"],"default":null},"Metadata":{"type":["object","null"],"default":null},"Results":{"type":"array","items":{"type":"object","properties":{"Name":{"type":["string","null"]},"Link":{"type":["string","null"]},"Value":{"type":"string"}},"required":["Value"]}}},"required":["Results"]}""",
_ => """{"type":"object","properties":{"TotalCount":{"type":["integer","null"],"default":null},"Metadata":{"type":["object","null"],"default":null},"Results":{"type":"array","items":{"type":"object","properties":{"Name":{"type":["string","null"]},"Link":{"type":["string","null"]},"Value":{"type":"string"}},"required":["Value"]}}},"required":["Results"]}"""
};

Assert.AreEqual(expectedSchema, metadata.ReturnParameter.Schema!.ToString());
}
#endregion
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
// Copyright (c) Microsoft. All rights reserved.

using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.VectorData;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.Data;
using Microsoft.VisualStudio.TestTools.UnitTesting;

Expand Down Expand Up @@ -35,6 +37,39 @@ public static async Task GetTextSearchResultsAsync()
Assert.AreEqual("test-link", results[0].Link);
}

public static async Task AddVectorStoreTextSearch()
{
// Arrange
var testData = new List<VectorSearchResult<DataModel>>
{
new(new DataModel { Key = "test-name", Text = "test-result", Link = "test-link" }, 0.5)
};
var vectorizableTextSearch = new MockVectorizableTextSearch<DataModel>(testData);
var serviceCollection = new ServiceCollection();
serviceCollection.AddSingleton<IVectorizableTextSearch<DataModel>>(vectorizableTextSearch);

// Act
serviceCollection.AddVectorStoreTextSearch<DataModel>();
var textSearch = serviceCollection.BuildServiceProvider().GetService<VectorStoreTextSearch<DataModel>>();
Assert.IsNotNull(textSearch);

// Assert
KernelSearchResults<TextSearchResult> searchResults = await textSearch.GetTextSearchResultsAsync("query");

List<TextSearchResult> results = [];

await foreach (TextSearchResult result in searchResults.Results)
{
results.Add(result);
}

// Assert
Assert.AreEqual(1, results.Count);
Assert.AreEqual("test-name", results[0].Name);
Assert.AreEqual("test-result", results[0].Value);
Assert.AreEqual("test-link", results[0].Link);
}

private sealed class DataModel
{
[TextSearchResultName]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -519,9 +519,9 @@ private static KernelFunctionFromMethodOptions DefaultGetSearchResultsMethodOpti
private static IEnumerable<KernelParameterMetadata> CreateDefaultKernelParameterMetadata(JsonSerializerOptions jsonSerializerOptions)
{
return [
new KernelParameterMetadata("query", jsonSerializerOptions) { Description = "What to search for", IsRequired = true },
new KernelParameterMetadata("count", jsonSerializerOptions) { Description = "Number of results", IsRequired = false, DefaultValue = 2 },
new KernelParameterMetadata("skip", jsonSerializerOptions) { Description = "Number of results to skip", IsRequired = false, DefaultValue = 0 },
new KernelParameterMetadata("query", jsonSerializerOptions) { Description = "What to search for", ParameterType = typeof(string), IsRequired = true },
new KernelParameterMetadata("count", jsonSerializerOptions) { Description = "Number of results", ParameterType = typeof(int), IsRequired = false, DefaultValue = 2 },
new KernelParameterMetadata("skip", jsonSerializerOptions) { Description = "Number of results to skip", ParameterType = typeof(int), IsRequired = false, DefaultValue = 0 },
];
}

Expand All @@ -530,9 +530,9 @@ private static IEnumerable<KernelParameterMetadata> CreateDefaultKernelParameter
private static IEnumerable<KernelParameterMetadata> GetDefaultKernelParameterMetadata()
{
return s_kernelParameterMetadata ??= [
new KernelParameterMetadata("query") { Description = "What to search for", IsRequired = true },
new KernelParameterMetadata("count") { Description = "Number of results", IsRequired = false, DefaultValue = 2 },
new KernelParameterMetadata("skip") { Description = "Number of results to skip", IsRequired = false, DefaultValue = 0 },
new KernelParameterMetadata("query") { Description = "What to search for", ParameterType = typeof(string), IsRequired = true },
new KernelParameterMetadata("count") { Description = "Number of results", ParameterType = typeof(int), IsRequired = false, DefaultValue = 2 },
new KernelParameterMetadata("skip") { Description = "Number of results to skip", ParameterType = typeof(int), IsRequired = false, DefaultValue = 0 },
];
}

Expand Down
Loading