Skip to content

Commit b0a0818

Browse files
committed
add MasterAPITestSuite
1 parent 4ca4d58 commit b0a0818

File tree

4 files changed

+428
-407
lines changed

4 files changed

+428
-407
lines changed

common/mock/openai.go

+30-15
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
package mock
1616

1717
import (
18+
"bytes"
19+
"encoding/json"
1820
"fmt"
1921
"github.com/emicklei/go-restful/v3"
2022
"github.com/sashabaranov/go-openai"
@@ -28,16 +30,15 @@ type OpenAIServer struct {
2830
authToken string
2931
ready chan struct{}
3032

31-
mockChatCompletion string
32-
mockEmbeddings []float32
33+
mockEmbeddings []float32
3334
}
3435

3536
func NewOpenAIServer() *OpenAIServer {
3637
s := &OpenAIServer{}
3738
ws := new(restful.WebService)
3839
ws.Path("/v1").
39-
Consumes(restful.MIME_XML, restful.MIME_JSON).
40-
Produces(restful.MIME_JSON, restful.MIME_XML)
40+
Consumes(restful.MIME_JSON).
41+
Produces(restful.MIME_JSON, "text/event-stream")
4142
ws.Route(ws.POST("chat/completions").
4243
Reads(openai.ChatCompletionRequest{}).
4344
Writes(openai.ChatCompletionResponse{}).
@@ -80,10 +81,6 @@ func (s *OpenAIServer) Close() error {
8081
return s.httpServer.Close()
8182
}
8283

83-
func (s *OpenAIServer) ChatCompletion(mock string) {
84-
s.mockChatCompletion = mock
85-
}
86-
8784
func (s *OpenAIServer) Embeddings(embeddings []float32) {
8885
s.mockEmbeddings = embeddings
8986
}
@@ -95,13 +92,31 @@ func (s *OpenAIServer) chatCompletion(req *restful.Request, resp *restful.Respon
9592
_ = resp.WriteError(http.StatusBadRequest, err)
9693
return
9794
}
98-
_ = resp.WriteEntity(openai.ChatCompletionResponse{
99-
Choices: []openai.ChatCompletionChoice{{
100-
Message: openai.ChatCompletionMessage{
101-
Content: s.mockChatCompletion,
102-
},
103-
}},
104-
})
95+
if r.Stream {
96+
content := r.Messages[0].Content
97+
for i := 0; i < len(content); i += 8 {
98+
buf := bytes.NewBuffer(nil)
99+
buf.WriteString("data: ")
100+
encoder := json.NewEncoder(buf)
101+
_ = encoder.Encode(openai.ChatCompletionStreamResponse{
102+
Choices: []openai.ChatCompletionStreamChoice{{
103+
Delta: openai.ChatCompletionStreamChoiceDelta{
104+
Content: content[i:min(i+8, len(content))],
105+
},
106+
}},
107+
})
108+
buf.WriteString("\n")
109+
_, _ = resp.Write(buf.Bytes())
110+
}
111+
} else {
112+
_ = resp.WriteEntity(openai.ChatCompletionResponse{
113+
Choices: []openai.ChatCompletionChoice{{
114+
Message: openai.ChatCompletionMessage{
115+
Content: r.Messages[0].Content,
116+
},
117+
}},
118+
})
119+
}
105120
}
106121

107122
func (s *OpenAIServer) embeddings(req *restful.Request, resp *restful.Response) {

common/mock/openai_test.go

+36-2
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,11 @@ package mock
1616

1717
import (
1818
"context"
19+
"github.com/juju/errors"
1920
"github.com/sashabaranov/go-openai"
2021
"github.com/stretchr/testify/suite"
22+
"io"
23+
"strings"
2124
"testing"
2225
)
2326

@@ -45,7 +48,6 @@ func (suite *OpenAITestSuite) TearDownSuite() {
4548
}
4649

4750
func (suite *OpenAITestSuite) TestChatCompletion() {
48-
suite.server.ChatCompletion("World")
4951
resp, err := suite.client.CreateChatCompletion(
5052
context.Background(),
5153
openai.ChatCompletionRequest{
@@ -59,7 +61,39 @@ func (suite *OpenAITestSuite) TestChatCompletion() {
5961
},
6062
)
6163
suite.NoError(err)
62-
suite.Equal("World", resp.Choices[0].Message.Content)
64+
suite.Equal("Hello", resp.Choices[0].Message.Content)
65+
}
66+
67+
func (suite *OpenAITestSuite) TestChatCompletionStream() {
68+
content := "In my younger and more vulnerable years my father gave me some advice that I've been turning over in" +
69+
" my mind ever since. Whenever you feel like criticizing anyone, he told me, just remember that all the " +
70+
"people in this world haven't had the advantages that you've had."
71+
stream, err := suite.client.CreateChatCompletionStream(
72+
context.Background(),
73+
openai.ChatCompletionRequest{
74+
Model: "qwen2.5",
75+
Messages: []openai.ChatCompletionMessage{
76+
{
77+
Role: openai.ChatMessageRoleUser,
78+
Content: content,
79+
},
80+
},
81+
Stream: true,
82+
},
83+
)
84+
suite.NoError(err)
85+
defer stream.Close()
86+
var buffer strings.Builder
87+
for {
88+
var resp openai.ChatCompletionStreamResponse
89+
resp, err = stream.Recv()
90+
if errors.Is(err, io.EOF) {
91+
suite.Equal(content, buffer.String())
92+
return
93+
}
94+
suite.NoError(err)
95+
buffer.WriteString(resp.Choices[0].Delta.Content)
96+
}
6397
}
6498

6599
func (suite *OpenAITestSuite) TestEmbeddings() {

master/master.go

+9
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ import (
2929
"github.com/emicklei/go-restful/v3"
3030
"github.com/jellydator/ttlcache/v3"
3131
"github.com/juju/errors"
32+
"github.com/sashabaranov/go-openai"
3233
"github.com/zhenghaoz/gorse/base"
3334
"github.com/zhenghaoz/gorse/base/encoding"
3435
"github.com/zhenghaoz/gorse/base/log"
@@ -71,6 +72,7 @@ type Master struct {
7172
jobsScheduler *task.JobsScheduler
7273
cacheFile string
7374
managedMode bool
75+
openAIClient *openai.Client
7476

7577
// cluster meta cache
7678
metaStore meta.Database
@@ -116,6 +118,7 @@ type Master struct {
116118
// NewMaster creates a master node.
117119
func NewMaster(cfg *config.Config, cacheFile string, managedMode bool) *Master {
118120
rand.Seed(time.Now().UnixNano())
121+
119122
// setup trace provider
120123
tp, err := cfg.Tracing.NewTracerProvider()
121124
if err != nil {
@@ -124,12 +127,18 @@ func NewMaster(cfg *config.Config, cacheFile string, managedMode bool) *Master {
124127
otel.SetTracerProvider(tp)
125128
otel.SetErrorHandler(log.GetErrorHandler())
126129
otel.SetTextMapPropagator(propagation.NewCompositeTextMapPropagator(propagation.TraceContext{}, propagation.Baggage{}))
130+
131+
// setup OpenAI client
132+
clientConfig := openai.DefaultConfig(cfg.OpenAI.AuthToken)
133+
clientConfig.BaseURL = cfg.OpenAI.BaseURL
134+
127135
m := &Master{
128136
// create task monitor
129137
cacheFile: cacheFile,
130138
managedMode: managedMode,
131139
jobsScheduler: task.NewJobsScheduler(cfg.Master.NumJobs),
132140
tracer: progress.NewTracer("master"),
141+
openAIClient: openai.NewClientWithConfig(clientConfig),
133142
// default ranking model
134143
rankingModelName: "bpr",
135144
rankingModelSearcher: ranking.NewModelSearcher(

0 commit comments

Comments
 (0)