Skip to content

Commit c8eae2a

Browse files
authored
enhance: add support for azure deployment name mapping (#78)
1 parent 1db7708 commit c8eae2a

File tree

5 files changed

+33
-13
lines changed

5 files changed

+33
-13
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@
44
**/node_modules/
55
**/package-lock.json
66
**/__pycache__
7+
/docs/yarn.lock

README.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,15 @@ Download and install the archive for your platform and architecture from the [re
7070
export OPENAI_API_KEY="your-api-key"
7171
```
7272

73+
Alternatively Azure OpenAI can be utilized
74+
75+
```shell
76+
export OPENAI_API_KEY="your-api-key"
77+
export OPENAI_BASE_URL="your-endpiont"
78+
export OPENAI_API_TYPE="AZURE"
79+
export OPENAI_AZURE_DEPLOYMENT="your-deployment-name"
80+
```
81+
7382
#### Windows
7483

7584
```powershell

go.mod

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@ module github.com/gptscript-ai/gptscript
22

33
go 1.22.0
44

5-
replace github.com/sashabaranov/go-openai => github.com/gptscript-ai/go-openai v0.0.0-20240227161457-daa30caa3185
6-
75
require (
86
github.com/BurntSushi/locker v0.0.0-20171006230638-a6e239ea1c69
97
github.com/acorn-io/broadcaster v0.0.0-20240105011354-bfadd4a7b45d
@@ -14,7 +12,7 @@ require (
1412
github.com/jaytaylor/html2text v0.0.0-20230321000545-74c2419ad056
1513
github.com/olahol/melody v1.1.4
1614
github.com/rs/cors v1.10.1
17-
github.com/sashabaranov/go-openai v1.18.3
15+
github.com/sashabaranov/go-openai v1.20.1
1816
github.com/sirupsen/logrus v1.9.3
1917
github.com/spf13/cobra v1.8.0
2018
github.com/stretchr/testify v1.8.4

go.sum

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,6 @@ github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 h1:El6M4kTTCOh6aBiKaU
4040
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510/go.mod h1:pupxD2MaaD3pAXIBCelhxNneeOaAeabZDe5s4K6zSpQ=
4141
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
4242
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
43-
github.com/gptscript-ai/go-openai v0.0.0-20240227161457-daa30caa3185 h1:+TfC9DYtWuexdL7x1lIdD1HP61IStb3ZTj/byBdiWs0=
44-
github.com/gptscript-ai/go-openai v0.0.0-20240227161457-daa30caa3185/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
4543
github.com/hexops/autogold v0.8.1/go.mod h1:97HLDXyG23akzAoRYJh/2OBs3kd80eHyKPvZw0S5ZBY=
4644
github.com/hexops/autogold v1.3.1 h1:YgxF9OHWbEIUjhDbpnLhgVsjUDsiHDTyDfy2lrfdlzo=
4745
github.com/hexops/autogold v1.3.1/go.mod h1:sQO+mQUCVfxOKPht+ipDSkJ2SCJ7BNJVHZexsXqWMx4=
@@ -97,6 +95,8 @@ github.com/samber/lo v1.38.1 h1:j2XEAqXKb09Am4ebOg31SpvzUTTs6EN3VfgeLUhPdXM=
9795
github.com/samber/lo v1.38.1/go.mod h1:+m/ZKRl6ClXCE2Lgf3MsQlWfh4bn1bz6CXEOxnEXnEA=
9896
github.com/samber/slog-logrus v1.0.0 h1:SsrN0p9akjCEaYd42Q5GtisMdHm0q11UD4fp4XCZi04=
9997
github.com/samber/slog-logrus v1.0.0/go.mod h1:ZTdPCmVWljwlfjz6XflKNvW4TcmYlexz4HMUOO/42bI=
98+
github.com/sashabaranov/go-openai v1.20.1 h1:cFnTixAtc0I0cCBFr8gkvEbGCm6Rjf2JyoVWCjXwy9g=
99+
github.com/sashabaranov/go-openai v1.20.1/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
100100
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
101101
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
102102
github.com/spf13/cobra v1.8.0 h1:7aJaZx1B85qltLMc546zn58BxxfZdR/W22ej9CFoEf0=

pkg/openai/client.go

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ const (
2727
var (
2828
key = os.Getenv("OPENAI_API_KEY")
2929
url = os.Getenv("OPENAI_URL")
30+
azureModel = os.Getenv("OPENAI_AZURE_DEPLOYMENT")
3031
completionID int64
3132
)
3233

@@ -80,6 +81,15 @@ func complete(opts ...Options) (result Options, err error) {
8081
return result, err
8182
}
8283

84+
func AzureMapperFunction(model string) string {
85+
if azureModel == "" {
86+
return model
87+
}
88+
return map[string]string{
89+
openai.GPT4TurboPreview: azureModel,
90+
}[model]
91+
}
92+
8393
func NewClient(opts ...Options) (*Client, error) {
8494
opt, err := complete(opts...)
8595
if err != nil {
@@ -89,6 +99,7 @@ func NewClient(opts ...Options) (*Client, error) {
8999
cfg := openai.DefaultConfig(opt.APIKey)
90100
if strings.Contains(string(opt.APIType), "AZURE") {
91101
cfg = openai.DefaultAzureConfig(key, url)
102+
cfg.AzureModelMapperFunc = AzureMapperFunction
92103
}
93104

94105
cfg.BaseURL = types.FirstSet(opt.BaseURL, cfg.BaseURL)
@@ -236,15 +247,16 @@ func (c *Client) Call(ctx context.Context, messageRequest types.CompletionReques
236247
}
237248

238249
request := openai.ChatCompletionRequest{
239-
Model: messageRequest.Model,
240-
Messages: msgs,
241-
MaxTokens: messageRequest.MaxTokens,
242-
Temperature: messageRequest.Temperature,
243-
Grammar: messageRequest.Grammar,
250+
Model: messageRequest.Model,
251+
Messages: msgs,
252+
MaxTokens: messageRequest.MaxTokens,
244253
}
245254

246-
if request.Temperature == nil {
247-
request.Temperature = new(float32)
255+
if messageRequest.Temperature == nil {
256+
// this is a hack because the field is marked as omitempty, so we need it to be set to a non-zero value but arbitrarily small
257+
request.Temperature = 1e-08
258+
} else {
259+
request.Temperature = *messageRequest.Temperature
248260
}
249261

250262
if messageRequest.JSONResponse {
@@ -260,7 +272,7 @@ func (c *Client) Call(ctx context.Context, messageRequest types.CompletionReques
260272
}
261273
request.Tools = append(request.Tools, openai.Tool{
262274
Type: openai.ToolTypeFunction,
263-
Function: openai.FunctionDefinition{
275+
Function: &openai.FunctionDefinition{
264276
Name: tool.Function.Name,
265277
Description: tool.Function.Description,
266278
Parameters: params,

0 commit comments

Comments
 (0)