Skip to content

Commit

Permalink
Merge pull request #170 from CSUSTers/dev
Browse files Browse the repository at this point in the history
Feat: chat command
  • Loading branch information
Anthony-Hoo authored Mar 5, 2023
2 parents 822d9a1 + 0d1d2ce commit a367b32
Show file tree
Hide file tree
Showing 10 changed files with 325 additions and 6 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/push.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ jobs:
password: ${{ secrets.DOCKERHUB_TOKEN }}

- name: Build and push
uses: docker/build-push-action@v2
uses: docker/build-push-action@v4
with:
context: .
push: true
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ hugencoder - <text> huge编码
hugedecoder - <text> huge解码
getvoice - 角色=<character> 性别=<sex> 主题=<topic> 类型=<type> <text> 通过前述五个参数查询(可选填),获取一段来自游戏《原神》的角色语音(Chinese Olny),数据来源于游戏解包
getvoice_old - getvoice的旧版入口,没有查询功能,数据来源于mys爬虫
chat - <text> 聊会天呗
```

## attachment
Expand Down
206 changes: 206 additions & 0 deletions chat/chat.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
package chat

import (
"context"
"csust-got/config"
"csust-got/entities"
"csust-got/log"
"csust-got/orm"
"csust-got/util"
"errors"
"fmt"
"io"
"strings"
"sync"
"time"

gogpt "github.com/sashabaranov/go-gpt3"
"go.uber.org/zap"
. "gopkg.in/telebot.v3"
)

var (
client *gogpt.Client
chatChan = make(chan *chatContext, 16)
)

type chatContext struct {
Context
req *gogpt.ChatCompletionRequest
msg *Message
}

// InitChat init chat service
func InitChat() {
if config.BotConfig.ChatConfig.Key != "" {
client = gogpt.NewClient(config.BotConfig.ChatConfig.Key)
go chatService()
}
}

// GPTChat is handler for chat with GPT
func GPTChat(ctx Context) error {
if client == nil {
return nil
}

_, arg, err := entities.CommandTakeArgs(ctx.Message(), 0)
if err != nil {
log.Error("[ChatGPT] Can't take args", zap.Error(err))
return ctx.Reply("嗦啥呢?")
}
if len(arg) == 0 {
return ctx.Reply("您好,有什么问题可以为您解答吗?")
}
if len(arg) > config.BotConfig.ChatConfig.PromptLimit {
return ctx.Reply("TLDR")
}

req, err := generateRequest(ctx, arg)
if err != nil {
return err
}

msg, err := util.SendReplyWithError(ctx.Chat(), "正在思考...", ctx.Message())
if err != nil {
return err
}

payload := &chatContext{Context: ctx, req: req, msg: msg}

select {
case chatChan <- payload:
return nil
default:
return ctx.Reply("要处理的对话太多了,要不您稍后再试试?")
}
}

func generateRequest(ctx Context, arg string) (*gogpt.ChatCompletionRequest, error) {
req := gogpt.ChatCompletionRequest{
Model: gogpt.GPT3Dot5Turbo,
MaxTokens: config.BotConfig.ChatConfig.MaxTokens,
Messages: []gogpt.ChatCompletionMessage{},
Stream: true,
Temperature: config.BotConfig.ChatConfig.Temperature,
}

if len(req.Messages) == 0 && config.BotConfig.ChatConfig.SystemPrompt != "" {
req.Messages = append(req.Messages, gogpt.ChatCompletionMessage{Role: "system", Content: config.BotConfig.ChatConfig.SystemPrompt})
}

keepContext := config.BotConfig.ChatConfig.KeepContext
if keepContext > 0 && ctx.Message().ReplyTo != nil {
chatContext, err := orm.GetChatContext(ctx.Chat().ID, ctx.Message().ReplyTo.ID)
if err == nil {
if len(chatContext) > 2*keepContext {
chatContext = chatContext[len(chatContext)-2*keepContext:]
}
req.Messages = append(req.Messages, chatContext...)
}
}

req.Messages = append(req.Messages, gogpt.ChatCompletionMessage{Role: "user", Content: arg})

return &req, nil
}

func chatService() {
for ctx := range chatChan {
go func(ctx *chatContext) {
start := time.Now()

// resp, err := client.CreateChatCompletion(context.Background(), *ctx.req)
// if err != nil {
// log.Error("[ChatGPT] Can't create completion", zap.Error(err))
// return
// }
// fmt.Printf("%+v", resp)

// content := resp.Choices[0].Message.Content

var replyMsg *Message

stream, err := client.CreateChatCompletionStream(context.Background(), *ctx.req)
if err != nil {
_, err = util.EditMessageWithError(ctx.msg,
"An error occurred. If this issue persists please contact us through our help center at help.openai.com.")
if err != nil {
log.Error("[ChatGPT] Can't edit message", zap.Error(err))
}
return
}
defer stream.Close()

content := ""
contentLock := sync.Mutex{}
done := make(chan struct{})
go func() {
for {
response, err := stream.Recv()
if errors.Is(err, io.EOF) {
ctx.req.Messages = append(ctx.req.Messages, gogpt.ChatCompletionMessage{Role: "assistant", Content: content})
done <- struct{}{}
break
}

if err != nil {
contentLock.Lock()
content += "\n\n...寄了"
contentLock.Unlock()
log.Error("[ChatGPT] Stream error", zap.Error(err))
break
}

contentLock.Lock()
content += response.Choices[0].Delta.Content
contentLock.Unlock()
}
}()

ticker := time.NewTicker(2 * time.Second) // 编辑过快会被tg限流
defer ticker.Stop()
lastContent := "" // 记录上次编辑的内容,内容相同则不再编辑,避免tg的api返回400
out:
for range ticker.C {
contentLock.Lock()
contentCopy := content
contentLock.Unlock()
if len(strings.TrimSpace(contentCopy)) > 0 && strings.TrimSpace(contentCopy) != strings.TrimSpace(lastContent) {
replyMsg, err = util.EditMessageWithError(ctx.msg, contentCopy)
if err != nil {
log.Error("[ChatGPT] Can't edit message", zap.Error(err))
} else {
lastContent = contentCopy
}
}
select {
case <-done:
break out
default:
}
}

contentLock.Lock()
if strings.TrimSpace(content) == "" {
content += "\n...嗦不粗话"
}
if config.BotConfig.DebugMode {
// content += fmt.Sprintf("\n\nusage: %d + %d = %d\n", resp.Usage.PromptTokens, resp.Usage.CompletionTokens, resp.Usage.TotalTokens)
content += fmt.Sprintf("\n\ntime cost: %v\n", time.Since(start))
replyMsg, err = util.EditMessageWithError(ctx.msg, content)
if err != nil {
log.Error("[ChatGPT] Can't edit message", zap.Error(err))
}
}
contentLock.Unlock()

if replyMsg != nil {
err = orm.SetChatContext(ctx.Context.Chat().ID, replyMsg.ID, ctx.req.Messages)
if err != nil {
log.Error("[ChatGPT] Can't set chat context", zap.Error(err))
}
}
}(ctx)
}
}
10 changes: 9 additions & 1 deletion config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,4 +47,12 @@ prometheus:
# 原神语音api
genshin_voice:
api_server: "https://api.csu.st"
err_audio_addr: "https://api.csu.st/file/VO_inGame/VO_NPC/NPC_DQ/vo_npc_dq_f_katheryne_01.ogg"
err_audio_addr: "https://api.csu.st/file/VO_inGame/VO_NPC/NPC_DQ/vo_npc_dq_f_katheryne_01.ogg"

chatgpt:
key: ""
max_tokens: 1000
temperature: 1
prompt_limit: 500
system_prompt: ""
keep_context: 0
36 changes: 36 additions & 0 deletions config/chat.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package config

import "github.com/spf13/viper"

type chatConfig struct {
Key string
MaxTokens int
Temperature float32
PromptLimit int
SystemPrompt string
KeepContext int
}

func (c *chatConfig) readConfig() {
c.Key = viper.GetString("chatgpt.key")
c.MaxTokens = viper.GetInt("chatgpt.max_tokens")
c.Temperature = float32(viper.GetFloat64("chatgpt.temperature"))
c.PromptLimit = viper.GetInt("chatgpt.prompt_limit")
c.SystemPrompt = viper.GetString("chatgpt.system_prompt")
c.KeepContext = viper.GetInt("chatgpt.keep_context")
}

func (c *chatConfig) checkConfig() {
if c.MaxTokens <= 0 {
c.MaxTokens = 10
}
if c.Temperature < 0 || c.Temperature > 1 {
c.Temperature = 0.7
}
if c.PromptLimit <= 0 {
c.PromptLimit = 10
}
if c.KeepContext < 0 {
c.KeepContext = 0
}
}
4 changes: 4 additions & 0 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ func NewBotConfig() *Config {
config.WhiteListConfig.SetName("white_list")
config.BlockListConfig.SetName("black_list")
config.GenShinConfig = new(genShinConfig)
config.ChatConfig = new(chatConfig)
return config
}

Expand All @@ -67,6 +68,7 @@ type Config struct {
WhiteListConfig *specialListConfig
PromConfig *promConfig
GenShinConfig *genShinConfig
ChatConfig *chatConfig
}

// GetBot returns Bot.
Expand Down Expand Up @@ -107,6 +109,7 @@ func readConfig() {
BotConfig.WhiteListConfig.readConfig()
BotConfig.BlockListConfig.readConfig()
BotConfig.PromConfig.readConfig()
BotConfig.ChatConfig.readConfig()

// genshin voice
BotConfig.GenShinConfig.readConfig()
Expand All @@ -132,4 +135,5 @@ func checkConfig() {
BotConfig.WhiteListConfig.checkConfig()
BotConfig.PromConfig.checkConfig()
BotConfig.GenShinConfig.checkConfig()
BotConfig.ChatConfig.checkConfig()
}
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ require (
github.com/quic-go/qpack v0.4.0 // indirect
github.com/quic-go/qtls-go1-19 v0.2.1 // indirect
github.com/quic-go/qtls-go1-20 v0.1.1 // indirect
github.com/sashabaranov/go-gpt3 v1.3.3 // indirect
github.com/spf13/afero v1.9.3 // indirect
github.com/spf13/cast v1.5.0 // indirect
github.com/spf13/jwalterweatherman v1.1.0 // indirect
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,8 @@ github.com/rogpeppe/go-internal v1.6.1 h1:/FiVV8dS/e+YqF2JvO3yXRFbBLTIuSDkuC7aBO
github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc=
github.com/ryanuber/columnize v0.0.0-20160712163229-9b3edd62028f/go.mod h1:sm1tb6uqfes/u+d4ooFouqFdy9/2g9QGwK3SQygK0Ts=
github.com/sagikazarmark/crypt v0.6.0/go.mod h1:U8+INwJo3nBv1m6A/8OBXAq7Jnpspk5AxSgDyEQcea8=
github.com/sashabaranov/go-gpt3 v1.3.3 h1:S8Zd4YybnBaNMK+w+XGGWgsjQY1R+6QE2n9SLzVna9k=
github.com/sashabaranov/go-gpt3 v1.3.3/go.mod h1:BIZdbwdzxZbCrcKGMGH6u2eyGe1xFuX9Anmh3tCP8lQ=
github.com/sean-/seed v0.0.0-20170313163322-e2103e2c3529/go.mod h1:DxrIzT+xaE7yg65j358z/aeFdxmN0P9QXhEzd20vsDc=
github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo=
github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE=
Expand Down
22 changes: 18 additions & 4 deletions main.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package main

import (
"csust-got/chat"
"csust-got/sd"
"net/http"
"net/url"
Expand Down Expand Up @@ -45,6 +46,8 @@ func main() {

go sd.Process()

go chat.InitChat()

base.Init()

bot.Start()
Expand Down Expand Up @@ -124,6 +127,8 @@ func registerBaseHandler(bot *Bot) {

bot.Handle("/getvoice_old", base.GetVoice)
bot.Handle("/getvoice", base.GetVoiceV2)

bot.Handle("/chat", chat.GPTChat, whiteMiddleware)
}

func registerRestrictHandler(bot *Bot) {
Expand Down Expand Up @@ -206,14 +211,10 @@ func fakeBanMiddleware(next HandlerFunc) HandlerFunc {
}

func rateMiddleware(next HandlerFunc) HandlerFunc {
whiteListConfig := config.BotConfig.WhiteListConfig
return func(ctx Context) error {
if !isChatMessageHasSender(ctx) || ctx.Chat().Type == ChatPrivate {
return next(ctx)
}
if !whiteListConfig.Enabled || !whiteListConfig.Check(ctx.Chat().ID) {
return next(ctx)
}
if !restrict.CheckLimit(ctx.Message()) {
log.Info("message deleted by rate limit", zap.String("chat", ctx.Chat().Title),
zap.String("user", ctx.Sender().Username))
Expand All @@ -223,6 +224,19 @@ func rateMiddleware(next HandlerFunc) HandlerFunc {
}
}

func whiteMiddleware(next HandlerFunc) HandlerFunc {
return func(ctx Context) error {
if !config.BotConfig.WhiteListConfig.Enabled {
return next(ctx)
}
if ctx.Chat() != nil && !config.BotConfig.WhiteListConfig.Check(ctx.Chat().ID) {
log.Info("chat ignore by white list", zap.String("chat", ctx.Chat().Title))
return nil
}
return next(ctx)
}
}

func noStickerMiddleware(next HandlerFunc) HandlerFunc {
return func(ctx Context) error {
if !isChatMessageHasSender(ctx) || ctx.Message().Sticker == nil {
Expand Down
Loading

0 comments on commit a367b32

Please sign in to comment.