Skip to content

Commit

Permalink
Implement mock OpenAI server (#935)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhenghaoz authored Jan 28, 2025
1 parent 337fbe9 commit b9086da
Show file tree
Hide file tree
Showing 4 changed files with 202 additions and 0 deletions.
119 changes: 119 additions & 0 deletions common/mock/openai.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
// Copyright 2025 gorse Project Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package mock

import (
"fmt"
"github.com/emicklei/go-restful/v3"
"github.com/sashabaranov/go-openai"
"net"
"net/http"
)

type OpenAIServer struct {
listener net.Listener
httpServer *http.Server
authToken string
ready chan struct{}

mockChatCompletion string
mockEmbeddings []float32
}

func NewOpenAIServer() *OpenAIServer {
s := &OpenAIServer{}
ws := new(restful.WebService)
ws.Path("/v1").
Consumes(restful.MIME_XML, restful.MIME_JSON).
Produces(restful.MIME_JSON, restful.MIME_XML)
ws.Route(ws.POST("chat/completions").
Reads(openai.ChatCompletionRequest{}).
Writes(openai.ChatCompletionResponse{}).
To(s.chatCompletion))
ws.Route(ws.POST("embeddings").
Reads(openai.EmbeddingRequest{}).
Writes(openai.EmbeddingResponse{}).
To(s.embeddings))
container := restful.NewContainer()
container.Add(ws)
s.httpServer = &http.Server{Handler: container}
s.authToken = "ollama"
s.ready = make(chan struct{})
return s
}

func (s *OpenAIServer) Start() error {
var err error
s.listener, err = net.Listen("tcp", "")
if err != nil {
return err
}
close(s.ready)
return s.httpServer.Serve(s.listener)
}

func (s *OpenAIServer) BaseURL() string {
return fmt.Sprintf("http://%s/v1", s.listener.Addr().String())
}

func (s *OpenAIServer) AuthToken() string {
return s.authToken
}

func (s *OpenAIServer) Ready() {
<-s.ready
}

func (s *OpenAIServer) Close() error {
return s.httpServer.Close()
}

func (s *OpenAIServer) ChatCompletion(mock string) {
s.mockChatCompletion = mock
}

func (s *OpenAIServer) Embeddings(embeddings []float32) {
s.mockEmbeddings = embeddings
}

func (s *OpenAIServer) chatCompletion(req *restful.Request, resp *restful.Response) {
var r openai.ChatCompletionRequest
err := req.ReadEntity(&r)
if err != nil {
_ = resp.WriteError(http.StatusBadRequest, err)
return
}
_ = resp.WriteEntity(openai.ChatCompletionResponse{
Choices: []openai.ChatCompletionChoice{{
Message: openai.ChatCompletionMessage{
Content: s.mockChatCompletion,
},
}},
})
}

func (s *OpenAIServer) embeddings(req *restful.Request, resp *restful.Response) {
var r openai.EmbeddingRequest
err := req.ReadEntity(&r)
if err != nil {
_ = resp.WriteError(http.StatusBadRequest, err)
return
}
_ = resp.WriteEntity(openai.EmbeddingResponse{
Data: []openai.Embedding{{
Embedding: s.mockEmbeddings,
}},
})
}
80 changes: 80 additions & 0 deletions common/mock/openai_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
// Copyright 2025 gorse Project Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package mock

import (
"context"
"github.com/sashabaranov/go-openai"
"github.com/stretchr/testify/suite"
"testing"
)

type OpenAITestSuite struct {
suite.Suite
server *OpenAIServer
client *openai.Client
}

func (suite *OpenAITestSuite) SetupSuite() {
// Start mock server
suite.server = NewOpenAIServer()
go func() {
_ = suite.server.Start()
}()
suite.server.Ready()
// Create client
clientConfig := openai.DefaultConfig(suite.server.AuthToken())
clientConfig.BaseURL = suite.server.BaseURL()
suite.client = openai.NewClientWithConfig(clientConfig)
}

func (suite *OpenAITestSuite) TearDownSuite() {
suite.NoError(suite.server.Close())
}

func (suite *OpenAITestSuite) TestChatCompletion() {
suite.server.ChatCompletion("World")
resp, err := suite.client.CreateChatCompletion(
context.Background(),
openai.ChatCompletionRequest{
Model: "qwen2.5",
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleUser,
Content: "Hello",
},
},
},
)
suite.NoError(err)
suite.Equal("World", resp.Choices[0].Message.Content)
}

func (suite *OpenAITestSuite) TestEmbeddings() {
suite.server.Embeddings([]float32{1, 2, 3})
resp, err := suite.client.CreateEmbeddings(
context.Background(),
openai.EmbeddingRequest{
Input: "Hello",
Model: "mxbai-embed-large",
},
)
suite.NoError(err)
suite.Equal([]float32{1, 2, 3}, resp.Data[0].Embedding)
}

func TestOpenAITestSuite(t *testing.T) {
suite.Run(t, new(OpenAITestSuite))
}
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ require (
github.com/redis/go-redis/extra/redisotel/v9 v9.5.3
github.com/redis/go-redis/v9 v9.7.0
github.com/samber/lo v1.38.1
github.com/sashabaranov/go-openai v1.36.1
github.com/schollz/progressbar/v3 v3.17.1
github.com/sclevine/yj v0.0.0-20210612025309-737bdf40a5d1
github.com/spf13/cobra v1.8.1
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -567,6 +567,8 @@ github.com/sagikazarmark/slog-shim v0.1.0 h1:diDBnUNK9N/354PgrxMywXnAwEr1QZcOr6g
github.com/sagikazarmark/slog-shim v0.1.0/go.mod h1:SrcSrq8aKtyuqEI1uvTDTK1arOWRIczQRv+GVI1AkeQ=
github.com/samber/lo v1.38.1 h1:j2XEAqXKb09Am4ebOg31SpvzUTTs6EN3VfgeLUhPdXM=
github.com/samber/lo v1.38.1/go.mod h1:+m/ZKRl6ClXCE2Lgf3MsQlWfh4bn1bz6CXEOxnEXnEA=
github.com/sashabaranov/go-openai v1.36.1 h1:EVfRXwIlW2rUzpx6vR+aeIKCK/xylSrVYAx1TMTSX3g=
github.com/sashabaranov/go-openai v1.36.1/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0=
github.com/schollz/progressbar/v3 v3.17.1 h1:bI1MTaoQO+v5kzklBjYNRQLoVpe0zbyRZNK6DFkVC5U=
github.com/schollz/progressbar/v3 v3.17.1/go.mod h1:RzqpnsPQNjUyIgdglUjRLgD7sVnxN1wpmBMV+UiEbL4=
Expand Down

0 comments on commit b9086da

Please sign in to comment.