Skip to content

Commit

Permalink
feat: relay retry ignore forbidden channel (#5435)
Browse files Browse the repository at this point in the history
* feat: relay retry ignore forbidden channel

* fix: dont reuse meta
  • Loading branch information
zijiren233 authored Mar 3, 2025
1 parent 8f45b6f commit feb5f9d
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 46 deletions.
88 changes: 61 additions & 27 deletions service/aiproxy/controller/relay.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,8 @@ func RelayHelper(meta *meta.Meta, c *gin.Context, relayController RelayControlle
return err, shouldRetry(c, err.StatusCode)
}

func getChannelWithFallback(cache *dbmodel.ModelCaches, model string, failedChannelIDs ...int) (*dbmodel.Channel, error) {
channel, err := cache.GetRandomSatisfiedChannel(model, failedChannelIDs...)
func getChannelWithFallback(cache *dbmodel.ModelCaches, model string, ignoreChannelIDs ...int) (*dbmodel.Channel, error) {
channel, err := cache.GetRandomSatisfiedChannel(model, ignoreChannelIDs...)
if err == nil {
return channel, nil
}
Expand Down Expand Up @@ -110,17 +110,14 @@ func relay(c *gin.Context, mode int, relayController RelayController) {
if err != nil {
log.Errorf("get %s auto banned channels failed: %+v", requestModel, err)
}

log.Debugf("%s model banned channels: %+v", requestModel, ids)

failedChannelIDs := []int{}
ignoreChannelIDs := make([]int, 0, len(ids))
for _, id := range ids {
failedChannelIDs = append(failedChannelIDs, int(id))
ignoreChannelIDs = append(ignoreChannelIDs, int(id))
}

mc := middleware.GetModelCaches(c)

channel, err := getChannelWithFallback(mc, requestModel, failedChannelIDs...)
channel, err := getChannelWithFallback(mc, requestModel, ignoreChannelIDs...)
if err != nil {
c.JSON(http.StatusServiceUnavailable, gin.H{
"error": &model.Error{
Expand All @@ -137,46 +134,68 @@ func relay(c *gin.Context, mode int, relayController RelayController) {
if bizErr == nil {
return
}
failedChannelIDs = append(failedChannelIDs, channel.ID)
requestID := middleware.GetRequestID(c)
var retryTimes int64
if retry {
retryTimes = config.GetRetryTimes()
if !retry {
bizErr.Error.Message = middleware.MessageWithRequestID(c, bizErr.Error.Message)
c.JSON(bizErr.StatusCode, bizErr)
return
}

var lastCanContinueChannel *dbmodel.Channel

retryTimes := config.GetRetryTimes()
if !channelCanContinue(bizErr.StatusCode) {
ignoreChannelIDs = append(ignoreChannelIDs, channel.ID)
} else {
lastCanContinueChannel = channel
}

for i := retryTimes; i > 0; i-- {
newChannel, err := mc.GetRandomSatisfiedChannel(requestModel, failedChannelIDs...)
newChannel, err := mc.GetRandomSatisfiedChannel(requestModel, ignoreChannelIDs...)
if err != nil {
if errors.Is(err, dbmodel.ErrChannelsNotFound) {
break
}
// use first channel to retry
if !errors.Is(err, dbmodel.ErrChannelsExhausted) {
if !errors.Is(err, dbmodel.ErrChannelsExhausted) ||
lastCanContinueChannel == nil {
break
}
newChannel = channel
// use last can continue channel to retry
newChannel = lastCanContinueChannel
}
log.Warnf("using channel %s(%d) to retry (remain times %d)", newChannel.Name, newChannel.ID, i)
log.Warnf("using channel %s (type: %d, id: %d) to retry (remain times %d)",
newChannel.Name,
newChannel.Type,
newChannel.ID,
i-1,
)

requestBody, err := common.GetRequestBody(c.Request)
if err != nil {
log.Errorf("GetRequestBody failed: %+v", err)
break
}
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
meta.Reset(newChannel)
//nolint:gosec
// random wait 1-2 seconds
time.Sleep(time.Duration(rand.Float64()*float64(time.Second)) + time.Second)

if shouldDelay(bizErr.StatusCode) {
//nolint:gosec
// random wait 1-2 seconds
time.Sleep(time.Duration(rand.Float64()*float64(time.Second)) + time.Second)
}

meta := middleware.NewMetaByContext(c, newChannel, requestModel, mode)
bizErr, retry = RelayHelper(meta, c, relayController)
if bizErr == nil {
return
}
if !retry {
break
}
failedChannelIDs = append(failedChannelIDs, newChannel.ID)
if !channelCanContinue(bizErr.StatusCode) {
ignoreChannelIDs = append(ignoreChannelIDs, newChannel.ID)
} else {
lastCanContinueChannel = newChannel
}
}

if bizErr != nil {
bizErr.Error.Message = middleware.MessageWithRequestID(bizErr.Error.Message, requestID)
bizErr.Error.Message = middleware.MessageWithRequestID(c, bizErr.Error.Message)
c.JSON(bizErr.StatusCode, bizErr)
}
}
Expand All @@ -195,6 +214,21 @@ func shouldRetry(_ *gin.Context, statusCode int) bool {
return ok
}

var channelCanContinueStatusCodesMap = map[int]struct{}{
http.StatusTooManyRequests: {},
http.StatusRequestTimeout: {},
http.StatusGatewayTimeout: {},
}

func channelCanContinue(statusCode int) bool {
_, ok := channelCanContinueStatusCodesMap[statusCode]
return ok
}

func shouldDelay(statusCode int) bool {
return statusCode == http.StatusTooManyRequests
}

// 仅当是channel错误时,才需要记录,用户请求参数错误时,不需要记录
func shouldErrorMonitor(statusCode int) bool {
return statusCode != http.StatusBadRequest
Expand Down
6 changes: 3 additions & 3 deletions service/aiproxy/middleware/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ const (
ErrorTypeAIPROXY = "aiproxy_error"
)

func MessageWithRequestID(message string, id string) string {
return fmt.Sprintf("%s (aiproxy: %s)", message, id)
func MessageWithRequestID(c *gin.Context, message string) string {
return fmt.Sprintf("%s (aiproxy: %s)", message, GetRequestID(c))
}

func abortLogWithMessage(c *gin.Context, statusCode int, message string) {
Expand All @@ -23,7 +23,7 @@ func abortLogWithMessage(c *gin.Context, statusCode int, message string) {
func abortWithMessage(c *gin.Context, statusCode int, message string) {
c.JSON(statusCode, gin.H{
"error": &model.Error{
Message: MessageWithRequestID(message, GetRequestID(c)),
Message: MessageWithRequestID(c, message),
Type: ErrorTypeAIPROXY,
},
})
Expand Down
27 changes: 11 additions & 16 deletions service/aiproxy/relay/meta/meta.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,27 +92,22 @@ func NewMeta(
}

if channel != nil {
meta.Reset(channel)
meta.Channel = &ChannelMeta{
Name: channel.Name,
BaseURL: channel.BaseURL,
Key: channel.Key,
ID: channel.ID,
Type: channel.Type,
}
if channel.Config != nil {
meta.ChannelConfig = *channel.Config
}
meta.ActualModel, _ = GetMappedModelName(modelName, channel.ModelMapping)
}

return &meta
}

func (m *Meta) Reset(channel *model.Channel) {
m.Channel = &ChannelMeta{
Name: channel.Name,
BaseURL: channel.BaseURL,
Key: channel.Key,
ID: channel.ID,
Type: channel.Type,
}
if channel.Config != nil {
m.ChannelConfig = *channel.Config
}
m.ActualModel, _ = GetMappedModelName(m.OriginModel, channel.ModelMapping)
m.ClearValues()
}

func (m *Meta) ClearValues() {
clear(m.values)
}
Expand Down

0 comments on commit feb5f9d

Please sign in to comment.