Skip to content

Commit 0a5e1cc

Browse files
committed
proxy chat
1 parent 072eed7 commit 0a5e1cc

File tree

2 files changed

+32
-80
lines changed

2 files changed

+32
-80
lines changed

master/rest.go

+26-64
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
package master
1616

1717
import (
18-
"bytes"
1918
"context"
2019
"encoding/binary"
2120
"encoding/json"
@@ -36,8 +35,6 @@ import (
3635
"github.com/gorilla/securecookie"
3736
_ "github.com/gorse-io/dashboard"
3837
"github.com/juju/errors"
39-
"github.com/nikolalohinski/gonja/v2"
40-
"github.com/nikolalohinski/gonja/v2/exec"
4138
"github.com/rakyll/statik/fs"
4239
"github.com/samber/lo"
4340
"github.com/sashabaranov/go-openai"
@@ -1640,82 +1637,47 @@ func (m *Master) chat(response http.ResponseWriter, request *http.Request) {
16401637
writeError(response, http.StatusUnauthorized, "unauthorized")
16411638
return
16421639
}
1643-
1644-
var (
1645-
itemId = request.URL.Query().Get("item_id")
1646-
userId = request.URL.Query().Get("user_id")
1647-
)
1648-
1649-
// parse prompt template
1650-
b, err := io.ReadAll(request.Body)
1640+
content, err := io.ReadAll(request.Body)
16511641
if err != nil {
16521642
writeError(response, http.StatusInternalServerError, err.Error())
16531643
return
16541644
}
1655-
prompt, err := gonja.FromString(string(b))
1645+
stream, err := m.openAIClient.CreateChatCompletionStream(
1646+
request.Context(),
1647+
openai.ChatCompletionRequest{
1648+
Model: m.Config.OpenAI.ChatCompletionModel,
1649+
Messages: []openai.ChatCompletionMessage{
1650+
{
1651+
Role: openai.ChatMessageRoleUser,
1652+
Content: string(content),
1653+
},
1654+
},
1655+
Stream: true,
1656+
},
1657+
)
16561658
if err != nil {
1657-
writeError(response, http.StatusBadRequest, err.Error())
1659+
writeError(response, http.StatusInternalServerError, err.Error())
16581660
return
16591661
}
1660-
1661-
if itemId != "" {
1662-
// get item
1663-
item, err := m.DataClient.GetItem(request.Context(), itemId)
1664-
if err != nil {
1665-
writeError(response, http.StatusInternalServerError, err.Error())
1662+
// read response
1663+
defer stream.Close()
1664+
for {
1665+
var resp openai.ChatCompletionStreamResponse
1666+
resp, err = stream.Recv()
1667+
if errors.Is(err, io.EOF) {
16661668
return
16671669
}
1668-
// render prompt
1669-
var buf bytes.Buffer
1670-
err = prompt.Execute(&buf, exec.NewContext(map[string]any{
1671-
"item": item,
1672-
}))
16731670
if err != nil {
16741671
writeError(response, http.StatusInternalServerError, err.Error())
16751672
return
16761673
}
1677-
// create chat completion stream
1678-
stream, err := m.openAIClient.CreateChatCompletionStream(
1679-
request.Context(),
1680-
openai.ChatCompletionRequest{
1681-
Model: m.Config.OpenAI.ChatCompletionModel,
1682-
Messages: []openai.ChatCompletionMessage{
1683-
{
1684-
Role: openai.ChatMessageRoleUser,
1685-
Content: buf.String(),
1686-
},
1687-
},
1688-
Stream: true,
1689-
},
1690-
)
1691-
if err != nil {
1692-
writeError(response, http.StatusInternalServerError, err.Error())
1674+
if _, err = response.Write([]byte(resp.Choices[0].Delta.Content)); err != nil {
1675+
log.Logger().Error("failed to write response", zap.Error(err))
16931676
return
16941677
}
1695-
// read response
1696-
defer stream.Close()
1697-
for {
1698-
var resp openai.ChatCompletionStreamResponse
1699-
resp, err = stream.Recv()
1700-
if errors.Is(err, io.EOF) {
1701-
return
1702-
}
1703-
if err != nil {
1704-
writeError(response, http.StatusInternalServerError, err.Error())
1705-
return
1706-
}
1707-
if _, err = response.Write([]byte(resp.Choices[0].Delta.Content)); err != nil {
1708-
log.Logger().Error("failed to write response", zap.Error(err))
1709-
return
1710-
}
1711-
// flush response
1712-
if f, ok := response.(http.Flusher); ok {
1713-
f.Flush()
1714-
}
1678+
// flush response
1679+
if f, ok := response.(http.Flusher); ok {
1680+
f.Flush()
17151681
}
1716-
} else if userId != "" {
1717-
writeError(response, http.StatusNotImplemented, "chat with user is not implemented")
1718-
} else {
1719-
writeError(response, http.StatusBadRequest, "missing item_id or user_id")
17201682
}
17211683
}

master/rest_test.go

+6-16
Original file line numberDiff line numberDiff line change
@@ -958,27 +958,17 @@ func (suite *MasterAPITestSuite) TestExportAndImport() {
958958
}
959959
}
960960

961-
func (suite *MasterAPITestSuite) TestChatItem() {
962-
// insert item
963-
ctx := context.Background()
964-
err := suite.DataClient.BatchInsertItems(ctx, []data.Item{{
965-
ItemId: "0",
966-
Labels: map[string]any{"author": "F. Scott Fitzgerald"},
967-
Comment: "The Great Gatsby",
968-
}})
969-
suite.NoError(err)
970-
971-
// chat item
972-
buf := strings.NewReader("{{ item.Labels.author }}'s {{ item.Comment }}")
961+
func (suite *MasterAPITestSuite) TestChat() {
962+
content := "In my younger and more vulnerable years my father gave me some advice that I've been turning over in" +
963+
" my mind ever since. \"Whenever you feel like criticizing any one,\" he told me, \" just remember that all " +
964+
"the people in this world haven't had the advantages that you've had.\""
965+
buf := strings.NewReader(content)
973966
req := httptest.NewRequest("POST", "https://example.com/", buf)
974-
q := req.URL.Query()
975-
q.Add("item_id", "0")
976-
req.URL.RawQuery = q.Encode()
977967
req.Header.Set("Cookie", suite.cookie)
978968
w := httptest.NewRecorder()
979969
suite.chat(w, req)
980970
suite.Equal(http.StatusOK, w.Code, w.Body.String())
981-
suite.Equal("F. Scott Fitzgerald's The Great Gatsby", w.Body.String())
971+
suite.Equal(content, w.Body.String())
982972
}
983973

984974
func TestMasterAPI(t *testing.T) {

0 commit comments

Comments
 (0)