Skip to content

Commit

Permalink
Merge pull request #502 from CSUSTers/dev
Browse files Browse the repository at this point in the history
fix: bug cause by concurrently modifying bot context
  • Loading branch information
hugefiver authored Dec 17, 2024
2 parents d6e57e6 + 47006a9 commit 9951765
Show file tree
Hide file tree
Showing 7 changed files with 96 additions and 146 deletions.
2 changes: 0 additions & 2 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,3 @@ jobs:
version: v1.62
args: --issues-exit-code=1
only-new-issues: false
skip-pkg-cache: true
skip-build-cache: true
83 changes: 66 additions & 17 deletions chat/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,33 @@ type chatContext struct {
msg *Message
}

// ChatInfo is input of ChatWith
// nolint: revive
type ChatInfo struct {
Setting

Text string
}

// Setting is settings of ChatWith
type Setting struct {
Stream bool
Reply bool

Placeholder string

Model string
Prompt string
}

// GetPlaceholder get message placeholder
func (s *Setting) GetPlaceholder() string {
if s.Placeholder == "" {
return "正在思考..."
}
return s.Placeholder
}

// InitChat init chat service
func InitChat() {
if config.BotConfig.ChatConfig.Key != "" {
Expand Down Expand Up @@ -61,38 +88,54 @@ func InitChat() {

// GPTChat is handler for chat with GPT
func GPTChat(ctx Context) error {
return chat(ctx, false)
return chatHandler(ctx, false)
}

// GPTChatWithStream is handler for chat with GPT, and use stream api
func GPTChatWithStream(ctx Context) error {
return chat(ctx, true)
return chatHandler(ctx, true)
}

func chat(ctx Context, stream bool) error {
if client == nil {
return nil
}

_, arg, err := entities.CommandTakeArgs(ctx.Message(), 0)
func chatHandler(ctx Context, stream bool) error {
_, text, err := entities.CommandTakeArgs(ctx.Message(), 0)
if err != nil {
log.Error("[ChatGPT] Can't take args", zap.Error(err))
return ctx.Reply("嗦啥呢?")
}

if len(arg) == 0 {
if len(text) == 0 {
return ctx.Reply("您好,有什么问题可以为您解答吗?")
}
if len(arg) > config.BotConfig.ChatConfig.PromptLimit {
if len(text) > config.BotConfig.ChatConfig.PromptLimit {
return ctx.Reply("TLDR")
}
return ChatWith(ctx, &ChatInfo{
Text: text,
Setting: Setting{
Stream: stream,
Reply: true,
},
})
}

// ChatWith chat with GPT
// nolint: revive
func ChatWith(ctx Context, info *ChatInfo) error {
if client == nil {
return nil
}

req, err := generateRequest(ctx, arg, stream)
req, err := generateRequest(ctx, info)
if err != nil {
return err
}

msg, err := util.SendReplyWithError(ctx.Chat(), "正在思考...", ctx.Message())
var msg *Message
if info.Reply {
msg, err = util.SendReplyWithError(ctx.Chat(), info.GetPlaceholder(), ctx.Message())
} else {
msg, err = util.SendMessageWithError(ctx.Chat(), info.GetPlaceholder(), ctx.Message())
}
if err != nil {
return err
}
Expand All @@ -107,24 +150,30 @@ func chat(ctx Context, stream bool) error {
}
}

func generateRequest(ctx Context, arg string, stream bool) (*openai.ChatCompletionRequest, error) {
func generateRequest(ctx Context, info *ChatInfo) (*openai.ChatCompletionRequest, error) {
chatCfg := config.BotConfig.ChatConfig
req := openai.ChatCompletionRequest{
Model: openai.GPT3Dot5Turbo,
MaxTokens: chatCfg.MaxTokens,
Messages: []openai.ChatCompletionMessage{},
Stream: stream,
Stream: info.Stream,
Temperature: chatCfg.Temperature,
}

if chatCfg.Model != "" {
if info.Model != "" {
req.Model = info.Model
} else if chatCfg.Model != "" {
req.Model = chatCfg.Model
}

if len(req.Messages) == 0 && chatCfg.SystemPrompt != "" {
prompt := info.Prompt
if prompt == "" {
prompt = chatCfg.SystemPrompt
}
req.Messages = append(req.Messages, openai.ChatCompletionMessage{
Role: openai.ChatMessageRoleSystem,
Content: chatCfg.SystemPrompt,
Content: prompt,
})
}

Expand All @@ -139,7 +188,7 @@ func generateRequest(ctx Context, arg string, stream bool) (*openai.ChatCompleti
}
}

req.Messages = append(req.Messages, openai.ChatCompletionMessage{Role: openai.ChatMessageRoleUser, Content: arg})
req.Messages = append(req.Messages, openai.ChatCompletionMessage{Role: openai.ChatMessageRoleUser, Content: info.Text})

return &req, nil
}
Expand Down
75 changes: 0 additions & 75 deletions chat/chat_customlized.go

This file was deleted.

16 changes: 9 additions & 7 deletions chat/gacha_reply.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package chat
import (
"csust-got/log"
"csust-got/util/gacha"
"strings"

"go.uber.org/zap"
"gopkg.in/telebot.v3"
)
Expand All @@ -18,7 +20,7 @@ func GachaReplyHandler(ctx telebot.Context) {
} else if len(msg.Caption) > 0 {
text = msg.Caption
}
if len(text) == 0 {
if len(text) == 0 || strings.HasPrefix(text, "/") {
return
}

Expand All @@ -27,17 +29,17 @@ func GachaReplyHandler(ctx telebot.Context) {
log.Error("[GaCha]: perform gacha failed", zap.Error(err))
return
}
ctx.Message().Text = "/chat " + text

switch result {
case 3:
return
case 4:
err = CustomModelChat(ctx)
if err != nil {
log.Error("[ChatGPT]: get a answer failed", zap.Error(err))
}
// TODO: `ChatWith` a different prompt
case 5:
err = GPTChat(ctx)
err = ChatWith(ctx, &ChatInfo{
Text: text,
Setting: Setting{Stream: false, Reply: true},
})
if err != nil {
log.Error("[ChatGPT]: get a answer failed", zap.Error(err))
}
Expand Down
8 changes: 6 additions & 2 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,6 @@ func registerBaseHandler(bot *Bot) {

bot.Handle("/chat", chat.GPTChat, whiteMiddleware)
bot.Handle("/chats", chat.GPTChatWithStream, whiteMiddleware)
bot.Handle("/qiuchat", chat.CustomModelChat, whiteMiddleware)

// meilisearch handler
bot.Handle("/search", meili.SearchHandle)
Expand Down Expand Up @@ -441,7 +440,12 @@ func contentFilterMiddleware(next HandlerFunc) HandlerFunc {
return next(ctx)
}

go chat.GachaReplyHandler(ctx)
// DONE: gacha 会修改ctx.Message.Text,所以放到next之后,等dawu以后重构吧,详见 #501
// 2024-12-17 [dawu]: 已经重构
if m.Text != "" {
go chat.GachaReplyHandler(ctx)
}

return next(ctx)
}
}
Expand Down
20 changes: 15 additions & 5 deletions meili/bot_handle.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import (
"csust-got/log"
"csust-got/util"
"encoding/json"
"errors"
"fmt"
"strconv"

"github.com/meilisearch/meilisearch-go"
Expand Down Expand Up @@ -65,27 +67,30 @@ func executeSearch(ctx Context) string {
log.Debug("[GetChatMember]", zap.String("chatRecipient", ctx.Chat().Recipient()), zap.String("userRecipient", ctx.Sender().Recipient()))
// parse option
searchKeywordIdx := 0
if command.Argc() > 2 {
if command.Argc() >= 2 {
option := command.Arg(0)
if option == "-id" {
// when search by id, index 0 arg is "-id", 1 arg is id, pass rest to query
var err error
chatId, err = strconv.ParseInt(command.Arg(1), 10, 64)
if err != nil {
log.Error("[MeiliSearch]: Parse chat id failed", zap.String("Search args", command.ArgAllInOneFrom(0)), zap.Error(err))
return err.Error()
return "Invalid chat id"
}
searchKeywordIdx = 2
}
}
if searchKeywordIdx > 0 {
// check if user is a member of chat_id group
member, err := util.GetChatMember(ctx.Bot(), chatId, ctx.Sender().Recipient())
member, err := ctx.Bot().ChatMemberOf(ChatID(chatId), ctx.Sender())
if err != nil {
if errors.Is(err, ErrChatNotFound) {
return "Chat not found"
}
log.Error("[MeiliSearch]: Error in GetChatMember", zap.String("Search args", command.ArgAllInOneFrom(0)), zap.Error(err))
return "Error when getting chat member"
return "Not sure if you are a member of the specified group"
}
if member.Result.Status == "left" {
if member.Role == Left || member.Role == Kicked {
log.Error("[MeiliSearch]: Not a member of the specified group", zap.String("Search args", command.ArgAllInOneFrom(0)),
zap.Int64("chatId", chatId), zap.String("user", ctx.Sender().Recipient()))
return "Not a member of the specified group"
Expand All @@ -102,6 +107,11 @@ func executeSearch(ctx Context) string {
SearchRequest: searchRequest,
}
}

if query.Query == "" {
return fmt.Sprintf("search keyword is empty, use `%s <keyword>` to search", ctx.Message().Text)
}

result, err := SearchMeili(query)
if err != nil {
log.Error("[MeiliSearch]: search failed", zap.String("Search args", command.ArgAllInOneFrom(0)), zap.Error(err))
Expand Down
Loading

0 comments on commit 9951765

Please sign in to comment.