Skip to content

Commit 7ac2070

Browse files
committed
feat: 为渠道添加上下文限制功能
1 parent c826d06 commit 7ac2070

File tree

7 files changed

+38
-0
lines changed

7 files changed

+38
-0
lines changed

constant/context_key.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ const (
3131
ContextKeyChannelIsMultiKey ContextKey = "channel_is_multi_key"
3232
ContextKeyChannelMultiKeyIndex ContextKey = "channel_multi_key_index"
3333
ContextKeyChannelKey ContextKey = "channel_key"
34+
ContextKeyChannelTokenLimit ContextKey = "channel_token_limit"
3435

3536
/* user related keys */
3637
ContextKeyUserId ContextKey = "id"

middleware/distributor.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,7 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
267267
common.SetContextKey(c, constant.ContextKeyChannelAutoBan, channel.GetAutoBan())
268268
common.SetContextKey(c, constant.ContextKeyChannelModelMapping, channel.GetModelMapping())
269269
common.SetContextKey(c, constant.ContextKeyChannelStatusCodeMapping, channel.GetStatusCodeMapping())
270+
common.SetContextKey(c, constant.ContextKeyChannelTokenLimit, channel.GetTokenLimit())
270271

271272
key, index, newAPIError := channel.GetNextEnabledKey()
272273
if newAPIError != nil {

model/channel.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ type Channel struct {
4444
Tag *string `json:"tag" gorm:"index"`
4545
Setting *string `json:"setting" gorm:"type:text"` // 渠道额外设置
4646
ParamOverride *string `json:"param_override" gorm:"type:text"`
47+
TokenLimit *int `json:"token_limit" gorm:"default:0"`
4748
// add after v0.8.5
4849
ChannelInfo ChannelInfo `json:"channel_info" gorm:"type:json"`
4950
}
@@ -397,6 +398,14 @@ func (channel *Channel) GetStatusCodeMapping() string {
397398
return *channel.StatusCodeMapping
398399
}
399400

401+
func (channel *Channel) GetTokenLimit() int {
402+
if channel.TokenLimit == nil || *channel.TokenLimit <= 0 {
403+
return 0
404+
}
405+
return *channel.TokenLimit
406+
}
407+
408+
400409
func (channel *Channel) Insert() error {
401410
var err error
402411
err = DB.Create(channel).Error

relay/common/relay_info.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ type ResponsesUsageInfo struct {
6262
type RelayInfo struct {
6363
ChannelType int
6464
ChannelId int
65+
ChannelTokenLimit int
6566
TokenId int
6667
TokenKey string
6768
UserId int
@@ -215,6 +216,7 @@ func GenRelayInfoImage(c *gin.Context) *RelayInfo {
215216
func GenRelayInfo(c *gin.Context) *RelayInfo {
216217
channelType := common.GetContextKeyInt(c, constant.ContextKeyChannelType)
217218
channelId := common.GetContextKeyInt(c, constant.ContextKeyChannelId)
219+
channelTokenLimit := common.GetContextKeyInt(c, constant.ContextKeyChannelTokenLimit)
218220
paramOverride := common.GetContextKeyStringMap(c, constant.ContextKeyChannelParamOverride)
219221

220222
tokenId := common.GetContextKeyInt(c, constant.ContextKeyTokenId)
@@ -235,6 +237,7 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
235237
RequestURLPath: c.Request.URL.String(),
236238
ChannelType: channelType,
237239
ChannelId: channelId,
240+
ChannelTokenLimit: channelTokenLimit,
238241
TokenId: tokenId,
239242
TokenKey: tokenKey,
240243
UserId: userId,

relay/relay-text.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,11 @@ func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
127127
c.Set("prompt_tokens", promptTokens)
128128
}
129129

130+
err = checkPromptTokensInBotChannel(promptTokens, relayInfo)
131+
if err != nil {
132+
return types.NewError(err, types.ErrorCodePromptTokensTooLarge)
133+
}
134+
130135
priceData, err := helper.ModelPriceHelper(c, relayInfo, promptTokens, int(math.Max(float64(textRequest.MaxTokens), float64(textRequest.MaxCompletionTokens))))
131136
if err != nil {
132137
return types.NewError(err, types.ErrorCodeModelPriceError)
@@ -261,6 +266,13 @@ func getPromptTokens(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.Re
261266
return promptTokens, err
262267
}
263268

269+
func checkPromptTokensInBotChannel(promptTokens int, info *relaycommon.RelayInfo) error {
270+
if info.ChannelTokenLimit > 0 && promptTokens > info.ChannelTokenLimit {
271+
return fmt.Errorf("prompt tokens (%d) is greater than channel token limit (%d)", promptTokens, info.ChannelTokenLimit)
272+
}
273+
return nil
274+
}
275+
264276
func checkRequestSensitive(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) ([]string, error) {
265277
var err error
266278
var words []string

types/error.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ const (
3838

3939
// new api error
4040
ErrorCodeCountTokenFailed ErrorCode = "count_token_failed"
41+
ErrorCodePromptTokensTooLarge ErrorCode = "prompt_tokens_too_large"
4142
ErrorCodeModelPriceError ErrorCode = "model_price_error"
4243
ErrorCodeInvalidApiType ErrorCode = "invalid_api_type"
4344
ErrorCodeJsonMarshalFailed ErrorCode = "json_marshal_failed"

web/src/pages/Channel/EditChannel.js

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ const EditChannel = (props) => {
102102
groups: ['default'],
103103
priority: 0,
104104
weight: 0,
105+
token_limit: 0,
105106
tag: '',
106107
multi_key_mode: 'random',
107108
};
@@ -1371,6 +1372,16 @@ const EditChannel = (props) => {
13711372
style={{ width: '100%' }}
13721373
/>
13731374
</Col>
1375+
<Col span={12}>
1376+
<Form.InputNumber
1377+
field='token_limit'
1378+
label={t('最大上下文')}
1379+
placeholder={t('最大上下文')}
1380+
min={0}
1381+
onNumberChange={(value) => handleInputChange('token_limit', value)}
1382+
style={{ width: '100%' }}
1383+
/>
1384+
</Col>
13741385
</Row>
13751386

13761387
<Form.Switch

0 commit comments

Comments
 (0)