Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions constant/context_key.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ const (
ContextKeyChannelIsMultiKey ContextKey = "channel_is_multi_key"
ContextKeyChannelMultiKeyIndex ContextKey = "channel_multi_key_index"
ContextKeyChannelKey ContextKey = "channel_key"
ContextKeyChannelRateLimit ContextKey = "channel_rate_limit"

/* user related keys */
ContextKeyUserId ContextKey = "id"
Expand Down
1 change: 1 addition & 0 deletions middleware/distributor.go
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,7 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
common.SetContextKey(c, constant.ContextKeyChannelAutoBan, channel.GetAutoBan())
common.SetContextKey(c, constant.ContextKeyChannelModelMapping, channel.GetModelMapping())
common.SetContextKey(c, constant.ContextKeyChannelStatusCodeMapping, channel.GetStatusCodeMapping())
common.SetContextKey(c, constant.ContextKeyChannelRateLimit, channel.GetRateLimit())

key, index, newAPIError := channel.GetNextEnabledKey()
if newAPIError != nil {
Expand Down
144 changes: 127 additions & 17 deletions model/ability.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"one-api/common"
"strings"
"sync"
"time"

"github.com/samber/lo"
"gorm.io/gorm"
Expand Down Expand Up @@ -102,6 +103,81 @@ func getChannelQuery(group string, model string, retry int) (*gorm.DB, error) {
return channelQuery, nil
}

var channelRateLimitStatus sync.Map // 存储每个 Channel 的频率限制状态
var rateLimitMutex sync.Mutex

type ChannelRateLimit struct {
Count int64 // 使用次数
ResetTime time.Time // 上次重置时间
}

type ChannelModelKey struct {
ChannelID int
}

func isRateLimited(channel Channel, channelId int) bool {
if (channel.RateLimit != nil && *channel.RateLimit > 0) {
if _, ok := checkRateLimit(&channel, channelId); !ok {
return true
}
updateRateLimitStatus(channelId)
}
return false
}
Comment on lines +118 to +126
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

⚠️ Potential issue

Double-count bug: remove updateRateLimitStatus after successful check

checkRateLimit already increments Count. Calling updateRateLimitStatus increments again, reducing the effective limit.

 func isRateLimited(channel Channel, channelId int) bool {
 	if (channel.RateLimit != nil && *channel.RateLimit > 0) {
 		if _, ok := checkRateLimit(&channel, channelId); !ok {
 			return true
 		}
-		updateRateLimitStatus(channelId)
 	}
 	return false
 }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
func isRateLimited(channel Channel, channelId int) bool {
if (channel.RateLimit != nil && *channel.RateLimit > 0) {
if _, ok := checkRateLimit(&channel, channelId); !ok {
return true
}
updateRateLimitStatus(channelId)
}
return false
}
func isRateLimited(channel Channel, channelId int) bool {
if (channel.RateLimit != nil && *channel.RateLimit > 0) {
if _, ok := checkRateLimit(&channel, channelId); !ok {
return true
}
}
return false
}
🤖 Prompt for AI Agents
In model/ability.go around lines 118 to 126, the function is double-incrementing
the rate counter because checkRateLimit already increments Count; remove the
post-check call to updateRateLimitStatus(channelId) so that on a successful
check the count is not incremented again, leaving updateRateLimitStatus only
where increments are actually needed (e.g., on explicit failures or outside this
successful path).



func checkRateLimit(channel *Channel, channelId int) (*ChannelRateLimit, bool) {
now := time.Now()
key := ChannelModelKey{ChannelID: channelId}

rateLimitMutex.Lock()
defer rateLimitMutex.Unlock()

value, exists := channelRateLimitStatus.Load(key)
if !exists {
value = &ChannelRateLimit{
Count: 1,
ResetTime: now.Add(time.Minute),
}
channelRateLimitStatus.Store(key, value)
return value.(*ChannelRateLimit), true
}
rateLimit := value.(*ChannelRateLimit)
if now.After(rateLimit.ResetTime) {
rateLimit.Count = 1
rateLimit.ResetTime = now.Add(time.Minute)
return rateLimit, true
} else if int64(*channel.RateLimit) > rateLimit.Count {
rateLimit.Count++
return rateLimit, true
}

return rateLimit, false
}

func updateRateLimitStatus(channelId int) {
now := time.Now()
key := ChannelModelKey{ChannelID: channelId}

rateLimitMutex.Lock()
defer rateLimitMutex.Unlock()

val, _ := channelRateLimitStatus.Load(key)
if val == nil {
return
}

rl := val.(*ChannelRateLimit)
if now.After(rl.ResetTime) {
rl.Count = 1
rl.ResetTime = now.Add(time.Minute)
} else {
rl.Count++
}

channelRateLimitStatus.Store(key, rl)
}

func GetRandomSatisfiedChannel(group string, model string, retry int) (*Channel, error) {
var abilities []Ability

Expand All @@ -118,28 +194,62 @@ func GetRandomSatisfiedChannel(group string, model string, retry int) (*Channel,
if err != nil {
return nil, err
}
if len(abilities) <= 0 {
return nil, errors.New("channel not found");
}

channel := Channel{}
if len(abilities) > 0 {
// Randomly choose one
weightSum := uint(0)
for _, ability_ := range abilities {
weightSum += ability_.Weight + 10
for len(abilities) > 0 {
selectedIndex, err := getRandomWeightedIndex(abilities)
if err != nil {
return nil, err
}
// Randomly choose one
weight := common.GetRandomInt(int(weightSum))
for _, ability_ := range abilities {
weight -= int(ability_.Weight) + 10
//log.Printf("weight: %d, ability weight: %d", weight, *ability_.Weight)
if weight <= 0 {
channel.Id = ability_.ChannelId
break

selectedAbility := abilities[selectedIndex]
channelPtr, err := GetChannelById(selectedAbility.ChannelId, true)
if err != nil {
if err.Error() != "channel not found" {
return nil, err
}
abilities = removeAbility(abilities, selectedIndex)
continue
}
Comment on lines +209 to 216
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

⚠️ Potential issue

Use errors.Is for gorm.ErrRecordNotFound

String compare on error is brittle; treat only “record not found” as skippable.

-		channelPtr, err := GetChannelById(selectedAbility.ChannelId, true)
-		if err != nil {
-			if err.Error() != "channel not found" {
-				return nil, err
-			}
-			abilities = removeAbility(abilities, selectedIndex)
-			continue
-		}
+		channelPtr, err := GetChannelById(selectedAbility.ChannelId, true)
+		if err != nil {
+			if !errors.Is(err, gorm.ErrRecordNotFound) {
+				return nil, err
+			}
+			abilities = removeAbility(abilities, selectedIndex)
+			continue
+		}
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
channelPtr, err := GetChannelById(selectedAbility.ChannelId, true)
if err != nil {
if err.Error() != "channel not found" {
return nil, err
}
abilities = removeAbility(abilities, selectedIndex)
continue
}
channelPtr, err := GetChannelById(selectedAbility.ChannelId, true)
if err != nil {
if !errors.Is(err, gorm.ErrRecordNotFound) {
return nil, err
}
abilities = removeAbility(abilities, selectedIndex)
continue
}
🤖 Prompt for AI Agents
In model/ability.go around lines 209 to 216, the code compares err.Error() to a
string to detect a missing channel which is brittle; replace that string
comparison with errors.Is(err, gorm.ErrRecordNotFound) to correctly detect a
not-found error even if wrapped, return the error for other cases, and when the
record is not found remove the ability and continue; also ensure the file
imports the "errors" package and "gorm.io/gorm" (or references the project’s
gorm error variable) so the errors.Is check compiles.

} else {
return nil, errors.New("channel not found")

channel = *channelPtr
if isRateLimited(channel, channel.Id) {
abilities = removeAbility(abilities, selectedIndex)
continue
}

return channelPtr, nil
}

return nil, errors.New("channel not found")
}

func getRandomWeightedIndex(abilities []Ability) (int, error) {
weightSum := uint(0)
for _, ability := range abilities {
weightSum += ability.Weight
}

if weightSum == 0 {
return common.GetRandomInt(len(abilities)), nil
}
err = DB.First(&channel, "id = ?", channel.Id).Error
return &channel, err

randomWeight := common.GetRandomInt(int(weightSum))
for i, ability := range abilities {
randomWeight -= int(ability.Weight)
if randomWeight <= 0 {
return i, nil
}
}

return -1, errors.New("unable to select a random weighted index")
}

func removeAbility(abilities []Ability, index int) []Ability {
return append(abilities[:index], abilities[index+1:]...)
}

func (channel *Channel) AddAbilities() error {
Expand Down
8 changes: 8 additions & 0 deletions model/channel.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ type Channel struct {
Tag *string `json:"tag" gorm:"index"`
Setting *string `json:"setting" gorm:"type:text"` // 渠道额外设置
ParamOverride *string `json:"param_override" gorm:"type:text"`
RateLimit *int `json:"rate_limit" gorm:"default:0"`
// add after v0.8.5
ChannelInfo ChannelInfo `json:"channel_info" gorm:"type:json"`
}
Comment on lines +47 to 50
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Add DB migration and constraints for rate_limit

Schema change requires migration. Backfill nulls to 0 and add a non-negative constraint to prevent invalid values.

Apply one of these migrations (adjust table name if needed):

  • MySQL:
ALTER TABLE channels ADD COLUMN rate_limit INT NOT NULL DEFAULT 0;
ALTER TABLE channels ADD CONSTRAINT chk_rate_limit_nonneg CHECK (rate_limit >= 0);
  • PostgreSQL:
ALTER TABLE channels ADD COLUMN rate_limit INT NOT NULL DEFAULT 0;
ALTER TABLE channels ADD CONSTRAINT chk_rate_limit_nonneg CHECK (rate_limit >= 0);
  • SQLite:
ALTER TABLE channels ADD COLUMN rate_limit INTEGER DEFAULT 0;
UPDATE channels SET rate_limit = 0 WHERE rate_limit IS NULL;
-- (SQLite lacks ADD CONSTRAINT; enforce non-negativity in app code)

Also consider returning 0 instead of null in read APIs to simplify clients.

🤖 Prompt for AI Agents
In model/channel.go around lines 47 to 50, the new RateLimit field introduces a
schema change that needs a DB migration: add the rate_limit column with NOT NULL
DEFAULT 0, backfill existing NULLs to 0, and enforce non-negativity with a CHECK
constraint where supported (MySQL/Postgres); for SQLite add the column, run an
UPDATE to set NULLs to 0 and enforce non-negativity in application code. Create
and run the appropriate migration for your database (adjust table name if
needed), and update read APIs to return 0 instead of null to simplify clients.

Expand Down Expand Up @@ -397,6 +398,13 @@ func (channel *Channel) GetStatusCodeMapping() string {
return *channel.StatusCodeMapping
}

func (channel *Channel) GetRateLimit() int {
if channel.RateLimit == nil || *channel.RateLimit <= 0 {
return 0
}
return *channel.RateLimit
}

func (channel *Channel) Insert() error {
var err error
err = DB.Create(channel).Error
Expand Down
89 changes: 69 additions & 20 deletions model/channel_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,13 @@ import (
"one-api/common"
"one-api/setting"
"sort"
"strconv"
"strings"
"sync"
"time"

"github.com/gin-gonic/gin"
"github.com/go-redis/redis/v8"
)

var group2model2channels map[string]map[string][]int // enabled channel
Expand Down Expand Up @@ -130,26 +132,30 @@ func getRandomSatisfiedChannel(group string, model string, retry int) (*Channel,

channelSyncLock.RLock()
defer channelSyncLock.RUnlock()
channels := group2model2channels[group][model]
channelIds := group2model2channels[group][model]

if len(channels) == 0 {
validChannels := make([]*Channel, 0)
for _, channelId := range channelIds {
if channel, ok := channelsIDM[channelId]; ok {
if !isRedisLimited(*channel, channelId) {
validChannels = append(validChannels, channel)
}
} else {
return nil, fmt.Errorf("数据库一致性错误,渠道# %d 不存在,请联系管理员修复", channelId)
}
}

if len(validChannels) == 0 {
return nil, errors.New("channel not found")
}

if len(channels) == 1 {
if channel, ok := channelsIDM[channels[0]]; ok {
return channel, nil
}
return nil, fmt.Errorf("数据库一致性错误,渠道# %d 不存在,请联系管理员修复", channels[0])
if len(validChannels) == 1 {
return validChannels[0], nil
}

uniquePriorities := make(map[int]bool)
for _, channelId := range channels {
if channel, ok := channelsIDM[channelId]; ok {
uniquePriorities[int(channel.GetPriority())] = true
} else {
return nil, fmt.Errorf("数据库一致性错误,渠道# %d 不存在,请联系管理员修复", channelId)
}
for _, channel := range validChannels {
uniquePriorities[int(channel.GetPriority())] = true
}
var sortedUniquePriorities []int
for priority := range uniquePriorities {
Expand All @@ -164,13 +170,9 @@ func getRandomSatisfiedChannel(group string, model string, retry int) (*Channel,

// get the priority for the given retry number
var targetChannels []*Channel
for _, channelId := range channels {
if channel, ok := channelsIDM[channelId]; ok {
if channel.GetPriority() == targetPriority {
targetChannels = append(targetChannels, channel)
}
} else {
return nil, fmt.Errorf("数据库一致性错误,渠道# %d 不存在,请联系管理员修复", channelId)
for _, channel := range validChannels {
if channel.GetPriority() == targetPriority {
targetChannels = append(targetChannels, channel)
}
}

Expand All @@ -195,6 +197,53 @@ func getRandomSatisfiedChannel(group string, model string, retry int) (*Channel,
return nil, errors.New("channel not found")
}

func isRedisLimited(channel Channel, channelId int) bool {
if channel.RateLimit != nil && *channel.RateLimit > 0 {
if !checkRedisLimit(channel, channelId) {
return true
}
}
return false
}

func checkRedisLimit(channel Channel, channelId int) bool {
key := fmt.Sprintf("rate_limit:%d", channelId)

countStr, err := common.RedisGet(key)
if err == redis.Nil {
// Key doesn't exist, set it with expiration
err = common.RedisSet(key, "1", time.Minute)
if err != nil {
common.SysLog(fmt.Sprintf("Error setting rate limit: %v", err))
return false
}
return true
} else if err != nil {
common.SysLog(fmt.Sprintf("Error checking rate limit: %v", err))
return false
}

count, err := strconv.ParseInt(countStr, 10, 64)
if err != nil {
common.SysLog(fmt.Sprintf("Error parsing rate limit count: %v", err))
return false
}

if count > int64(*channel.RateLimit) {
return false
}

// 增加计数
newCount := strconv.FormatInt(count+1, 10)
err = common.RedisSet(key, newCount, time.Minute)
if err != nil {
common.SysLog(fmt.Sprintf("Error incrementing rate limit: %v", err))
return false
}

return true
}
Comment on lines +200 to +245
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

⚠️ Potential issue

Fix Redis rate limiter: use atomic INCR + one-time TTL set; correct boundary check

Current GET/SET is racy and allows limit+1 per window; also resets TTL on every hit (sliding), diverging from in-memory (fixed). Use INCR and set EXPIRE only when creating the key; compare with <= limit.

-func checkRedisLimit(channel Channel, channelId int) bool {
-	key := fmt.Sprintf("rate_limit:%d", channelId)
-
-	countStr, err := common.RedisGet(key)
-	if err == redis.Nil {
-		// Key doesn't exist, set it with expiration
-		err = common.RedisSet(key, "1", time.Minute)
-		if err != nil {
-			common.SysLog(fmt.Sprintf("Error setting rate limit: %v", err))
-			return false
-		}
-		return true
-	} else if err != nil {
-		common.SysLog(fmt.Sprintf("Error checking rate limit: %v", err))
-		return false
-	}
-
-	count, err := strconv.ParseInt(countStr, 10, 64)
-	if err != nil {
-		common.SysLog(fmt.Sprintf("Error parsing rate limit count: %v", err))
-		return false
-	}
-
-	if count > int64(*channel.RateLimit) {
-		return false
-	}
-
-	// 增加计数
-	newCount := strconv.FormatInt(count+1, 10)
-	err = common.RedisSet(key, newCount, time.Minute)
-	if err != nil {
-		common.SysLog(fmt.Sprintf("Error incrementing rate limit: %v", err))
-		return false
-	}
-
-	return true
-}
+func checkRedisLimit(channel Channel, channelId int) bool {
+	key := fmt.Sprintf("rate_limit:%d", channelId)
+	ctx := context.Background()
+
+	// Atomic increment
+	count, err := common.RDB.Incr(ctx, key).Result()
+	if err != nil {
+		common.SysLog(fmt.Sprintf("Error incrementing rate limit: %v", err))
+		return false
+	}
+	// First hit: set TTL once to create a fixed 1-minute window
+	if count == 1 {
+		if err := common.RDB.Expire(ctx, key, time.Minute).Err(); err != nil {
+			common.SysLog(fmt.Sprintf("Error setting rate limit TTL: %v", err))
+			return false
+		}
+	}
+	// Allow if within limit
+	return count <= int64(*channel.RateLimit)
+}


func CacheGetChannel(id int) (*Channel, error) {
if !common.MemoryCacheEnabled {
return GetChannelById(id, true)
Expand Down
3 changes: 3 additions & 0 deletions relay/common/relay_info.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ type ResponsesUsageInfo struct {
type RelayInfo struct {
ChannelType int
ChannelId int
ChannelRateLimit int
TokenId int
TokenKey string
UserId int
Expand Down Expand Up @@ -215,6 +216,7 @@ func GenRelayInfoImage(c *gin.Context) *RelayInfo {
func GenRelayInfo(c *gin.Context) *RelayInfo {
channelType := common.GetContextKeyInt(c, constant.ContextKeyChannelType)
channelId := common.GetContextKeyInt(c, constant.ContextKeyChannelId)
channelRateLimit := common.GetContextKeyInt(c, constant.ContextKeyChannelRateLimit)
paramOverride := common.GetContextKeyStringMap(c, constant.ContextKeyChannelParamOverride)

tokenId := common.GetContextKeyInt(c, constant.ContextKeyTokenId)
Expand All @@ -235,6 +237,7 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
RequestURLPath: c.Request.URL.String(),
ChannelType: channelType,
ChannelId: channelId,
ChannelRateLimit: channelRateLimit,
TokenId: tokenId,
TokenKey: tokenKey,
UserId: userId,
Expand Down
11 changes: 11 additions & 0 deletions web/src/pages/Channel/EditChannel.js
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ const EditChannel = (props) => {
groups: ['default'],
priority: 0,
weight: 0,
rate_limit: 0,
tag: '',
multi_key_mode: 'random',
};
Expand Down Expand Up @@ -1371,6 +1372,16 @@ const EditChannel = (props) => {
style={{ width: '100%' }}
/>
</Col>
<Col span={12}>
<Form.InputNumber
field='rate_limit'
label={t('每分钟最大请求数')}
placeholder={t('每分钟最大请求数')}
min={0}
onNumberChange={(value) => handleInputChange('rate_limit', value)}
style={{ width: '100%' }}
/>
</Col>
</Row>

<Form.Switch
Expand Down