Skip to content

Commit 13b2e88

Browse files
committed
Implement mock embedding
1 parent f799379 commit 13b2e88

File tree

2 files changed

+36
-0
lines changed

2 files changed

+36
-0
lines changed

common/mock/openai.go

+23
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ type OpenAIServer struct {
2929
ready chan struct{}
3030

3131
mockChatCompletion string
32+
mockEmbeddings []float32
3233
}
3334

3435
func NewOpenAIServer() *OpenAIServer {
@@ -41,6 +42,10 @@ func NewOpenAIServer() *OpenAIServer {
4142
Reads(openai.ChatCompletionRequest{}).
4243
Writes(openai.ChatCompletionResponse{}).
4344
To(s.chatCompletion))
45+
ws.Route(ws.POST("embeddings").
46+
Reads(openai.EmbeddingRequest{}).
47+
Writes(openai.EmbeddingResponse{}).
48+
To(s.embeddings))
4449
container := restful.NewContainer()
4550
container.Add(ws)
4651
s.httpServer = &http.Server{Handler: container}
@@ -79,6 +84,10 @@ func (s *OpenAIServer) ChatCompletion(mock string) {
7984
s.mockChatCompletion = mock
8085
}
8186

87+
func (s *OpenAIServer) Embeddings(embeddings []float32) {
88+
s.mockEmbeddings = embeddings
89+
}
90+
8291
func (s *OpenAIServer) chatCompletion(req *restful.Request, resp *restful.Response) {
8392
var r openai.ChatCompletionRequest
8493
err := req.ReadEntity(&r)
@@ -94,3 +103,17 @@ func (s *OpenAIServer) chatCompletion(req *restful.Request, resp *restful.Respon
94103
}},
95104
})
96105
}
106+
107+
func (s *OpenAIServer) embeddings(req *restful.Request, resp *restful.Response) {
108+
var r openai.EmbeddingRequest
109+
err := req.ReadEntity(&r)
110+
if err != nil {
111+
_ = resp.WriteError(http.StatusBadRequest, err)
112+
return
113+
}
114+
_ = resp.WriteEntity(openai.EmbeddingResponse{
115+
Data: []openai.Embedding{{
116+
Embedding: s.mockEmbeddings,
117+
}},
118+
})
119+
}

common/mock/openai_test.go

+13
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,19 @@ func (suite *OpenAITestSuite) TestChatCompletion() {
6262
suite.Equal("World", resp.Choices[0].Message.Content)
6363
}
6464

65+
func (suite *OpenAITestSuite) TestEmbeddings() {
66+
suite.server.Embeddings([]float32{1, 2, 3})
67+
resp, err := suite.client.CreateEmbeddings(
68+
context.Background(),
69+
openai.EmbeddingRequest{
70+
Input: "Hello",
71+
Model: "mxbai-embed-large",
72+
},
73+
)
74+
suite.NoError(err)
75+
suite.Equal([]float32{1, 2, 3}, resp.Data[0].Embedding)
76+
}
77+
6578
func TestOpenAITestSuite(t *testing.T) {
6679
suite.Run(t, new(OpenAITestSuite))
6780
}

0 commit comments

Comments
 (0)