Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions conversation/anthropic/anthropic.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,19 +41,15 @@ func NewAnthropic(logger logger.Logger) conversation.Conversation {
return a
}

const defaultModel = "claude-3-5-sonnet-20240620"

func (a *Anthropic) Init(ctx context.Context, meta conversation.Metadata) error {
m := conversation.LangchainMetadata{}
err := kmeta.DecodeMetadata(meta.Properties, &m)
if err != nil {
return err
}

model := defaultModel
if m.Model != "" {
model = m.Model
}
// Resolve model via central helper (uses metadata, then env var, then default)
model := conversation.GetAnthropicModel(m.Model)

llm, err := anthropic.New(
anthropic.WithModel(model),
Expand Down
6 changes: 3 additions & 3 deletions conversation/anthropic/metadata.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ metadata:
- name: model
required: false
description: |
The Anthropic LLM to use.
The Anthropic LLM to use. Configurable via ANTHROPIC_MODEL environment variable.
type: string
example: 'claude-3-5-sonnet-20240620'
default: 'claude-3-5-sonnet-20240620'
example: 'claude-sonnet-4-20250514'
default: 'claude-sonnet-4-20250514'
- name: cacheTTL
required: false
description: |
Expand Down
8 changes: 2 additions & 6 deletions conversation/googleai/googleai.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,19 +41,15 @@ func NewGoogleAI(logger logger.Logger) conversation.Conversation {
return g
}

const defaultModel = "gemini-2.5-flash"

func (g *GoogleAI) Init(ctx context.Context, meta conversation.Metadata) error {
md := conversation.LangchainMetadata{}
err := kmeta.DecodeMetadata(meta.Properties, &md)
if err != nil {
return err
}

model := defaultModel
if md.Model != "" {
model = md.Model
}
// Resolve model via central helper (uses metadata, then env var, then default)
model := conversation.GetGoogleAIModel(md.Model)

opts := []openai.Option{
openai.WithModel(model),
Expand Down
6 changes: 3 additions & 3 deletions conversation/googleai/metadata.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ metadata:
- name: model
required: false
description: |
The GoogleAI LLM to use.
The GoogleAI LLM to use. Configurable via GOOGLEAI_MODEL environment variable.
type: string
example: 'gemini-2.5-flash'
default: 'gemini-2.5-flash'
example: 'gemini-2.5-flash-lite'
default: 'gemini-2.5-flash-lite'
- name: cacheTTL
required: false
description: |
Expand Down
9 changes: 2 additions & 7 deletions conversation/huggingface/huggingface.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,6 @@ func NewHuggingface(logger logger.Logger) conversation.Conversation {
return h
}

// Default model - using a popular and reliable model
const defaultModel = "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B"

// Default HuggingFace OpenAI-compatible endpoint
const defaultEndpoint = "https://router.huggingface.co/hf-inference/models/{{model}}/v1"

Expand All @@ -55,10 +52,8 @@ func (h *Huggingface) Init(ctx context.Context, meta conversation.Metadata) erro
return err
}

model := defaultModel
if m.Model != "" {
model = m.Model
}
// Resolve model via central helper (uses metadata, then env var, then default)
model := conversation.GetHuggingFaceModel(m.Model)

endpoint := strings.Replace(defaultEndpoint, "{{model}}", model, 1)
if m.Endpoint != "" {
Expand Down
2 changes: 1 addition & 1 deletion conversation/huggingface/metadata.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ metadata:
- name: model
required: false
description: |
The Huggingface model to use. Uses OpenAI-compatible API.
The Huggingface model to use. Uses OpenAI-compatible API. Configurable via HUGGINGFACE_MODEL environment variable.
type: string
example: 'deepseek-ai/DeepSeek-R1-Distill-Qwen-32B'
default: 'deepseek-ai/DeepSeek-R1-Distill-Qwen-32B'
Expand Down
2 changes: 1 addition & 1 deletion conversation/metadata_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ func TestLangchainMetadata(t *testing.T) {
t.Run("json marshaling with endpoint", func(t *testing.T) {
metadata := LangchainMetadata{
Key: "test-key",
Model: "gpt-4",
Model: DefaultOpenAIModel,
CacheTTL: "10m",
Endpoint: "https://custom-endpoint.example.com",
}
Expand Down
2 changes: 1 addition & 1 deletion conversation/mistral/metadata.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ metadata:
- name: model
required: false
description: |
The Mistral LLM to use.
The Mistral LLM to use. Configurable via MISTRAL_MODEL environment variable.
type: string
example: 'open-mistral-7b'
default: 'open-mistral-7b'
Expand Down
8 changes: 2 additions & 6 deletions conversation/mistral/mistral.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,19 +43,15 @@ func NewMistral(logger logger.Logger) conversation.Conversation {
return m
}

const defaultModel = "open-mistral-7b"

func (m *Mistral) Init(ctx context.Context, meta conversation.Metadata) error {
md := conversation.LangchainMetadata{}
err := kmeta.DecodeMetadata(meta.Properties, &md)
if err != nil {
return err
}

model := defaultModel
if md.Model != "" {
model = md.Model
}
// Resolve model via central helper (uses metadata, then env var, then default)
model := conversation.GetMistralModel(md.Model)

llm, err := mistral.New(
mistral.WithModel(model),
Expand Down
85 changes: 85 additions & 0 deletions conversation/models.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/*
Copyright 2025 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package conversation

import (
"os"
)

// Default models for conversation components
// These can be overridden via environment variables for runtime configuration
const (
// Environment variable names
envOpenAIModel = "OPENAI_MODEL"
envAzureOpenAIModel = "AZURE_OPENAI_MODEL"
envAnthropicModel = "ANTHROPIC_MODEL"
envGoogleAIModel = "GOOGLEAI_MODEL"
envMistralModel = "MISTRAL_MODEL"
envHuggingFaceModel = "HUGGINGFACE_MODEL"
envOllamaModel = "OLLAMA_MODEL"
)

// Exported default model constants for consumers of the conversation package.
// These are used as fallbacks when env vars and metadata are not set.
const (
DefaultOpenAIModel = "gpt-5-nano" // Enable GPT-5 (Preview) for all clients
DefaultAzureOpenAIModel = "gpt-4.1-nano" // Default Azure OpenAI model
DefaultAnthropicModel = "claude-sonnet-4-20250514"
DefaultGoogleAIModel = "gemini-2.5-flash-lite"
DefaultMistralModel = "open-mistral-7b"
DefaultHuggingFaceModel = "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B"
DefaultOllamaModel = "llama3.2:latest"
)

// getModel returns the value of an environment variable or a default value
func getModel(envVar, defaultValue, metadataValue string) string {
if value := os.Getenv(envVar); value != "" {
return value
}
if metadataValue != "" {
return metadataValue
}
return defaultValue
}

// Example usage for model getters with metadata support:
// Pass metadataValue from your metadata file/struct, or "" if not set.
func GetOpenAIModel(metadataValue string) string {
return getModel(envOpenAIModel, DefaultOpenAIModel, metadataValue)
}

func GetAzureOpenAIModel(metadataValue string) string {
return getModel(envAzureOpenAIModel, DefaultAzureOpenAIModel, metadataValue)
}

func GetAnthropicModel(metadataValue string) string {
return getModel(envAnthropicModel, DefaultAnthropicModel, metadataValue)
}

func GetGoogleAIModel(metadataValue string) string {
return getModel(envGoogleAIModel, DefaultGoogleAIModel, metadataValue)
}

func GetMistralModel(metadataValue string) string {
return getModel(envMistralModel, DefaultMistralModel, metadataValue)
}

func GetHuggingFaceModel(metadataValue string) string {
return getModel(envHuggingFaceModel, DefaultHuggingFaceModel, metadataValue)
}

func GetOllamaModel(metadataValue string) string {
return getModel(envOllamaModel, DefaultOllamaModel, metadataValue)
}
2 changes: 1 addition & 1 deletion conversation/ollama/metadata.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ metadata:
- name: model
required: false
description: |
The Ollama LLM to use.
The Ollama LLM to use. Configurable via OLLAMA_MODEL environment variable.
type: string
example: 'llama3.2:latest'
default: 'llama3.2:latest'
Expand Down
8 changes: 2 additions & 6 deletions conversation/ollama/ollama.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,19 +41,15 @@ func NewOllama(logger logger.Logger) conversation.Conversation {
return o
}

const defaultModel = "llama3.2:latest"

func (o *Ollama) Init(ctx context.Context, meta conversation.Metadata) error {
md := conversation.LangchainMetadata{}
err := kmeta.DecodeMetadata(meta.Properties, &md)
if err != nil {
return err
}

model := defaultModel
if md.Model != "" {
model = md.Model
}
// Resolve model via central helper (uses metadata, then env var, then default)
model := conversation.GetOllamaModel(md.Model)

llm, err := ollama.New(
ollama.WithModel(model),
Expand Down
8 changes: 4 additions & 4 deletions conversation/openai/metadata.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ metadata:
- name: model
required: false
description: |
The OpenAI LLM to use.
The OpenAI LLM to use. Configurable via OPENAI_MODEL environment variable.
type: string
example: 'gpt-4-turbo'
default: 'gpt-4o'
default: 'gpt-5-nano'
example: 'gpt-5-nano'
- name: endpoint
required: false
description: |
Expand All @@ -53,4 +53,4 @@ metadata:
The type of API to use for the OpenAI service. This is required when using Azure OpenAI.
type: string
example: 'azure'
default: ''
default: 'gpt-4.1-nano'
11 changes: 6 additions & 5 deletions conversation/openai/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,18 +42,19 @@ func NewOpenAI(logger logger.Logger) conversation.Conversation {
return o
}

const defaultModel = "gpt-4o"

func (o *OpenAI) Init(ctx context.Context, meta conversation.Metadata) error {
md := OpenAILangchainMetadata{}
err := kmeta.DecodeMetadata(meta.Properties, &md)
if err != nil {
return err
}

model := defaultModel
if md.Model != "" {
model = md.Model
// Resolve model via central helper (uses metadata, then env var, then default)
var model string
if md.APIType == "azure" {
model = conversation.GetAzureOpenAIModel(md.Model)
} else {
model = conversation.GetOpenAIModel(md.Model)
}
// Create options for OpenAI client
options := []openai.Option{
Expand Down
10 changes: 5 additions & 5 deletions conversation/openai/openai_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func TestInit(t *testing.T) {
name: "with default endpoint",
metadata: map[string]string{
"key": "test-key",
"model": "gpt-4",
"model": conversation.DefaultOpenAIModel,
},
testFn: func(t *testing.T, o *OpenAI, err error) {
require.NoError(t, err)
Expand All @@ -45,7 +45,7 @@ func TestInit(t *testing.T) {
name: "with custom endpoint",
metadata: map[string]string{
"key": "test-key",
"model": "gpt-4",
"model": conversation.DefaultOpenAIModel,
"endpoint": "https://api.openai.com/v1",
},
testFn: func(t *testing.T, o *OpenAI, err error) {
Expand All @@ -59,7 +59,7 @@ func TestInit(t *testing.T) {
name: "with apiType azure and missing apiVersion",
metadata: map[string]string{
"key": "test-key",
"model": "gpt-4",
"model": conversation.DefaultOpenAIModel,
"apiType": "azure",
"endpoint": "https://custom-endpoint.openai.azure.com/",
},
Expand All @@ -72,7 +72,7 @@ func TestInit(t *testing.T) {
name: "with apiType azure and custom apiVersion",
metadata: map[string]string{
"key": "test-key",
"model": "gpt-4",
"model": conversation.DefaultOpenAIModel,
"apiType": "azure",
"endpoint": "https://custom-endpoint.openai.azure.com/",
"apiVersion": "2025-01-01-preview",
Expand All @@ -86,7 +86,7 @@ func TestInit(t *testing.T) {
name: "with apiType azure but missing endpoint",
metadata: map[string]string{
"key": "test-key",
"model": "gpt-4",
"model": conversation.DefaultOpenAIModel,
"apiType": "azure",
"apiVersion": "2025-01-01-preview",
},
Expand Down
Loading
Loading