Skip to content

Refactors using language models to support OpenAI-compatible clients #1139

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ namespace DevProxy.Abstractions.LanguageModel;

public interface ILanguageModelClient
{
Task<ILanguageModelCompletionResponse?> GenerateChatCompletionAsync(ILanguageModelChatCompletionMessage[] messages);
Task<ILanguageModelCompletionResponse?> GenerateChatCompletionAsync(ILanguageModelChatCompletionMessage[] messages, CompletionOptions? options = null);
Task<ILanguageModelCompletionResponse?> GenerateCompletionAsync(string prompt, CompletionOptions? options = null);
Task<bool> IsEnabledAsync();
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ public static ILanguageModelClient Create(LanguageModelConfiguration? config, IL
{
return config?.Client switch
{
LanguageModelClient.LMStudio => new LMStudioLanguageModelClient(config, logger),
LanguageModelClient.Ollama => new OllamaLanguageModelClient(config, logger),
_ => new OllamaLanguageModelClient(config, logger)
LanguageModelClient.OpenAI => new OpenAILanguageModelClient(config, logger),
_ => new OpenAILanguageModelClient(config, logger)
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,16 @@ namespace DevProxy.Abstractions.LanguageModel;

public enum LanguageModelClient
{
LMStudio,
Ollama
Ollama,
OpenAI
}

public class LanguageModelConfiguration
{
public bool CacheResponses { get; set; } = true;
public LanguageModelClient Client { get; set; } = LanguageModelClient.OpenAI;
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still uses Ollama by default, but uses its OpenAI-compatible APIs rather than the Ollama-specific ones.

public bool Enabled { get; set; } = false;
public LanguageModelClient Client { get; set; } = LanguageModelClient.Ollama;
public string Model { get; set; } = "llama3.2";
// default Ollama URL
public string? Model { get; set; } = "llama3.2";
public string? Url { get; set; } = "http://localhost:11434";
public string? Url { get; set; } = "http://localhost:11434/v1/";
}
Original file line number Diff line number Diff line change
Expand Up @@ -50,16 +50,6 @@ private async Task<bool> IsEnabledInternalAsync()

try
{
// check if lm is on
using var client = new HttpClient();
var response = await client.GetAsync(_configuration.Url);
_logger.LogDebug("Response: {response}", response.StatusCode);

if (!response.IsSuccessStatusCode)
{
return false;
}

var testCompletion = await GenerateCompletionInternalAsync("Are you there? Reply with a yes or no.");
if (testCompletion?.Error is not null)
{
Expand Down Expand Up @@ -160,7 +150,7 @@ private async Task<bool> IsEnabledInternalAsync()
}
}

public async Task<ILanguageModelCompletionResponse?> GenerateChatCompletionAsync(ILanguageModelChatCompletionMessage[] messages)
public async Task<ILanguageModelCompletionResponse?> GenerateChatCompletionAsync(ILanguageModelChatCompletionMessage[] messages, CompletionOptions? options = null)
{
using var scope = _logger.BeginScope(nameof(OllamaLanguageModelClient));

Expand All @@ -186,7 +176,7 @@ private async Task<bool> IsEnabledInternalAsync()
return cachedResponse;
}

var response = await GenerateChatCompletionInternalAsync(messages);
var response = await GenerateChatCompletionInternalAsync(messages, options);
if (response == null)
{
return null;
Expand All @@ -207,7 +197,7 @@ private async Task<bool> IsEnabledInternalAsync()
}
}

private async Task<OllamaLanguageModelChatCompletionResponse?> GenerateChatCompletionInternalAsync(ILanguageModelChatCompletionMessage[] messages)
private async Task<OllamaLanguageModelChatCompletionResponse?> GenerateChatCompletionInternalAsync(ILanguageModelChatCompletionMessage[] messages, CompletionOptions? options = null)
{
Debug.Assert(_configuration != null, "Configuration is null");

Expand All @@ -222,7 +212,8 @@ private async Task<bool> IsEnabledInternalAsync()
{
messages,
model = _configuration.Model,
stream = false
stream = false,
options
}
);
_logger.LogDebug("Response: {response}", response.StatusCode);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,11 @@

namespace DevProxy.Abstractions.LanguageModel;

public class LMStudioLanguageModelClient(LanguageModelConfiguration? configuration, ILogger logger) : ILanguageModelClient
public class OpenAILanguageModelClient(LanguageModelConfiguration? configuration, ILogger logger) : ILanguageModelClient
{
private readonly LanguageModelConfiguration? _configuration = configuration;
private readonly ILogger _logger = logger;
private bool? _lmAvailable;
private readonly Dictionary<string, OpenAICompletionResponse> _cacheCompletion = [];
private readonly Dictionary<ILanguageModelChatCompletionMessage[], OpenAIChatCompletionResponse> _cacheChatCompletion = [];

public async Task<bool> IsEnabledAsync()
Expand All @@ -29,6 +28,8 @@ public async Task<bool> IsEnabledAsync()

private async Task<bool> IsEnabledInternalAsync()
{
using var scope = _logger.BeginScope(nameof(OpenAILanguageModelClient));

if (_configuration is null || !_configuration.Enabled)
{
return false;
Expand All @@ -50,20 +51,14 @@ private async Task<bool> IsEnabledInternalAsync()

try
{
// check if lm is on
using var client = new HttpClient();
var response = await client.GetAsync($"{_configuration.Url}/v1/models");
_logger.LogDebug("Response: {response}", response.StatusCode);

if (!response.IsSuccessStatusCode)
var testCompletion = await GenerateChatCompletionInternalAsync([new()
{
return false;
}

var testCompletion = await GenerateCompletionInternalAsync("Are you there? Reply with a yes or no.");
if (testCompletion?.Error is not null)
Content = "Are you there? Reply with a yes or no.",
Role = "user"
}]);
if (testCompletion?.ErrorMessage is not null)
{
_logger.LogError("Error: {error}. Param: {param}", testCompletion.Error.Message, testCompletion.Error.Param);
_logger.LogError("Error: {error}", testCompletion.ErrorMessage);
return false;
}

Expand All @@ -78,90 +73,41 @@ private async Task<bool> IsEnabledInternalAsync()

public async Task<ILanguageModelCompletionResponse?> GenerateCompletionAsync(string prompt, CompletionOptions? options = null)
{
using var scope = _logger.BeginScope(nameof(LMStudioLanguageModelClient));

if (_configuration is null)
{
return null;
}

if (!_lmAvailable.HasValue)
{
_logger.LogError("Language model availability is not checked. Call {isEnabled} first.", nameof(IsEnabledAsync));
return null;
}

if (!_lmAvailable.Value)
{
return null;
}

if (_configuration.CacheResponses && _cacheCompletion.TryGetValue(prompt, out var cachedResponse))
{
_logger.LogDebug("Returning cached response for prompt: {prompt}", prompt);
return cachedResponse;
}

var response = await GenerateCompletionInternalAsync(prompt, options);
var response = await GenerateChatCompletionAsync([new OpenAIChatCompletionMessage() { Content = prompt, Role = "user" }], options);
if (response == null)
{
return null;
}
if (response.Error is not null)
if (response.ErrorMessage is not null)
{
_logger.LogError("Error: {error}. Param: {param}", response.Error.Message, response.Error.Param);
_logger.LogError("Error: {error}", response.ErrorMessage);
return null;
}
else
{
if (_configuration.CacheResponses && response.Response is not null)
{
_cacheCompletion[prompt] = response;
}
var openAIResponse = (OpenAIChatCompletionResponse)response;

return response;
}
}

private async Task<OpenAICompletionResponse?> GenerateCompletionInternalAsync(string prompt, CompletionOptions? options = null)
{
Debug.Assert(_configuration != null, "Configuration is null");

try
return new OpenAICompletionResponse
{
using var client = new HttpClient();
var url = $"{_configuration.Url}/v1/completions";
_logger.LogDebug("Requesting completion. Prompt: {prompt}", prompt);

var response = await client.PostAsJsonAsync(url,
new
{
prompt,
model = _configuration.Model,
stream = false,
temperature = options?.Temperature ?? 0.8,
}
);
_logger.LogDebug("Response: {response}", response.StatusCode);

var res = await response.Content.ReadFromJsonAsync<OpenAICompletionResponse>();
if (res is null)
Choices = openAIResponse.Choices?.Select(c => new OpenAICompletionResponseChoice
{
return res;
}
res.RequestUrl = url;
return res;
}
catch (Exception ex)
{
_logger.LogError(ex, "Failed to generate completion");
return null;
}
ContentFilterResults = c.ContentFilterResults,
FinishReason = c.FinishReason,
Index = c.Index,
LogProbabilities = c.LogProbabilities,
Text = c.Message.Content
}).ToArray(),
Created = openAIResponse.Created,
Error = openAIResponse.Error,
Id = openAIResponse.Id,
Model = openAIResponse.Model,
Object = openAIResponse.Object,
PromptFilterResults = openAIResponse.PromptFilterResults,
Usage = openAIResponse.Usage,
};
}

public async Task<ILanguageModelCompletionResponse?> GenerateChatCompletionAsync(ILanguageModelChatCompletionMessage[] messages)
public async Task<ILanguageModelCompletionResponse?> GenerateChatCompletionAsync(ILanguageModelChatCompletionMessage[] messages, CompletionOptions? options = null)
{
using var scope = _logger.BeginScope(nameof(LMStudioLanguageModelClient));
using var scope = _logger.BeginScope(nameof(OpenAILanguageModelClient));

if (_configuration is null)
{
Expand All @@ -185,14 +131,14 @@ private async Task<bool> IsEnabledInternalAsync()
return cachedResponse;
}

var response = await GenerateChatCompletionInternalAsync(messages);
var response = await GenerateChatCompletionInternalAsync([.. messages.Select(m => (OpenAIChatCompletionMessage)m)], options);
if (response == null)
{
return null;
}
if (response.Error is not null)
{
_logger.LogError("Error: {error}. Param: {param}", response.Error.Message, response.Error.Param);
_logger.LogError("Error: {error}. Code: {code}", response.Error.Message, response.Error.Code);
return null;
}
else
Expand All @@ -206,24 +152,25 @@ private async Task<bool> IsEnabledInternalAsync()
}
}

private async Task<OpenAIChatCompletionResponse?> GenerateChatCompletionInternalAsync(ILanguageModelChatCompletionMessage[] messages)
private async Task<OpenAIChatCompletionResponse?> GenerateChatCompletionInternalAsync(OpenAIChatCompletionMessage[] messages, CompletionOptions? options = null)
{
Debug.Assert(_configuration != null, "Configuration is null");

try
{
using var client = new HttpClient();
var url = $"{_configuration.Url}/v1/chat/completions";
var url = $"{_configuration.Url}/chat/completions";
_logger.LogDebug("Requesting chat completion. Message: {lastMessage}", messages.Last().Content);

var response = await client.PostAsJsonAsync(url,
new
{
messages,
model = _configuration.Model,
stream = false
}
);
var payload = new OpenAIChatCompletionRequest
{
Messages = messages,
Model = _configuration.Model,
Stream = false,
Temperature = options?.Temperature
};

var response = await client.PostAsJsonAsync(url, payload);
_logger.LogDebug("Response: {response}", response.StatusCode);

var res = await response.Content.ReadFromJsonAsync<OpenAIChatCompletionResponse>();
Expand All @@ -243,7 +190,7 @@ private async Task<bool> IsEnabledInternalAsync()
}
}

internal static class CacheChatCompletionExtensions
internal static class OpenAICacheChatCompletionExtensions
{
public static OpenAIChatCompletionMessage[]? GetKey(
this Dictionary<OpenAIChatCompletionMessage[], OpenAIChatCompletionResponse> cache,
Expand Down
23 changes: 11 additions & 12 deletions dev-proxy-abstractions/LanguageModel/OpenAIModels.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,17 @@ namespace DevProxy.Abstractions.LanguageModel;
public abstract class OpenAIRequest
{
[JsonPropertyName("frequency_penalty")]
public long FrequencyPenalty { get; set; }
public long? FrequencyPenalty { get; set; }
[JsonPropertyName("max_tokens")]
public long MaxTokens { get; set; }
public long? MaxTokens { get; set; }
public string Model { get; set; } = string.Empty;
[JsonPropertyName("presence_penalty")]
public long PresencePenalty { get; set; }
public long? PresencePenalty { get; set; }
public object? Stop { get; set; }
public bool Stream { get; set; }
public long Temperature { get; set; }
public bool? Stream { get; set; }
public double? Temperature { get; set; }
[JsonPropertyName("top_p")]
public double TopP { get; set; }
public double? TopP { get; set; }
}

public class OpenAICompletionRequest : OpenAIRequest
Expand All @@ -33,10 +34,8 @@ public class OpenAIChatCompletionRequest : OpenAIRequest

public class OpenAIError
{
public string? Message { get; set; }
public string? Type { get; set; }
public string? Code { get; set; }
public string? Param { get; set; }
public string? Message { get; set; }
}

public abstract class OpenAIResponse: ILanguageModelCompletionResponse
Expand Down Expand Up @@ -77,10 +76,10 @@ public abstract class OpenAIResponseChoice
[JsonPropertyName("content_filter_results")]
public Dictionary<string, OpenAIResponseContentFilterResult> ContentFilterResults { get; set; } = new();
[JsonPropertyName("finish_reason")]
public string FinishReason { get; set; } = "length";
public string FinishReason { get; set; } = "stop";
public long Index { get; set; }
[JsonIgnore(Condition = JsonIgnoreCondition.Never)]
public object? Logprobs { get; set; }
[JsonPropertyName("logprobs")]
public int? LogProbabilities { get; set; }
}

public class OpenAIResponsePromptFilterResult
Expand Down
2 changes: 2 additions & 0 deletions dev-proxy/CommandHandlers/ProxyCommandHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,8 @@ private async Task CheckForNewVersionAsync()
configObject.UrlsToWatch = urlsSection.Get<List<string>>() ?? [];
}

configObject.LanguageModel?.Url?.TrimEnd('/');

return configObject;
});
}