Skip to content

Commit

Permalink
feat: implement apiToken failover mechanism (#1256)
Browse files Browse the repository at this point in the history
  • Loading branch information
cr7258 authored Nov 16, 2024
1 parent f2a5df3 commit d24123a
Show file tree
Hide file tree
Showing 33 changed files with 1,552 additions and 1,225 deletions.
29 changes: 20 additions & 9 deletions plugins/wasm-go/extensions/ai-proxy/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,16 @@ description: AI 代理插件配置参考

`provider`的配置字段说明如下:

| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
| -------------- | --------------- | -------- | ------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `type` | string | 必填 | - | AI 服务提供商名称 |
| `apiTokens` | array of string | 非必填 | - | 用于在访问 AI 服务时进行认证的令牌。如果配置了多个 token,插件会在请求时随机进行选择。部分服务提供商只支持配置一个 token。 |
| `timeout` | number | 非必填 | - | 访问 AI 服务的超时时间。单位为毫秒。默认值为 120000,即 2 分钟 |
| `modelMapping` | map of string | 非必填 | - | AI 模型映射表,用于将请求中的模型名称映射为服务提供商支持模型名称。<br/>1. 支持前缀匹配。例如用 "gpt-3-*" 匹配所有名称以“gpt-3-”开头的模型;<br/>2. 支持使用 "*" 为键来配置通用兜底映射关系;<br/>3. 如果映射的目标名称为空字符串 "",则表示保留原模型名称。 |
| `protocol` | string | 非必填 | - | 插件对外提供的 API 接口契约。目前支持以下取值:openai(默认值,使用 OpenAI 的接口契约)、original(使用目标服务提供商的原始接口契约) |
| `context` | object | 非必填 | - | 配置 AI 对话上下文信息 |
| `customSettings` | array of customSetting | 非必填 | - | 为AI请求指定覆盖或者填充参数 |
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|------------------| --------------- | -------- | ------ |-----------------------------------------------------------------------------------------------------------------------------------------------------------|
| `type` | string | 必填 | - | AI 服务提供商名称 |
| `apiTokens` | array of string | 非必填 | - | 用于在访问 AI 服务时进行认证的令牌。如果配置了多个 token,插件会在请求时随机进行选择。部分服务提供商只支持配置一个 token。 |
| `timeout` | number | 非必填 | - | 访问 AI 服务的超时时间。单位为毫秒。默认值为 120000,即 2 分钟 |
| `modelMapping` | map of string | 非必填 | - | AI 模型映射表,用于将请求中的模型名称映射为服务提供商支持模型名称。<br/>1. 支持前缀匹配。例如用 "gpt-3-*" 匹配所有名称以“gpt-3-”开头的模型;<br/>2. 支持使用 "*" 为键来配置通用兜底映射关系;<br/>3. 如果映射的目标名称为空字符串 "",则表示保留原模型名称。 |
| `protocol` | string | 非必填 | - | 插件对外提供的 API 接口契约。目前支持以下取值:openai(默认值,使用 OpenAI 的接口契约)、original(使用目标服务提供商的原始接口契约) |
| `context` | object | 非必填 | - | 配置 AI 对话上下文信息 |
| `customSettings` | array of customSetting | 非必填 | - | 为AI请求指定覆盖或者填充参数 |
| `failover` | object | 非必填 | - | 配置 apiToken 的 failover 策略,当 apiToken 不可用时,将其移出 apiToken 列表,待健康检测通过后重新添加回 apiToken 列表 |

`context`的配置字段说明如下:

Expand Down Expand Up @@ -75,6 +76,16 @@ custom-setting会遵循如下表格,根据`name`和协议来替换对应的字
如果启用了raw模式,custom-setting会直接用输入的`name``value`去更改请求中的json内容,而不对参数名称做任何限制和修改。
对于大多数协议,custom-setting都会在json内容的根路径修改或者填充参数。对于`qwen`协议,ai-proxy会在json的`parameters`子路径下做配置。对于`gemini`协议,则会在`generation_config`子路径下做配置。

`failover` 的配置字段说明如下:

| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|------------------|--------|------|-------|-----------------------------|
| enabled | bool | 非必填 | false | 是否启用 apiToken 的 failover 机制 |
| failureThreshold | int | 非必填 | 3 | 触发 failover 连续请求失败的阈值(次数) |
| successThreshold | int | 非必填 | 1 | 健康检测的成功阈值(次数) |
| healthCheckInterval | int | 非必填 | 5000 | 健康检测的间隔时间,单位毫秒 |
| healthCheckTimeout | int | 非必填 | 5000 | 健康检测的超时时间,单位毫秒 |
| healthCheckModel | string | 必填 | | 健康检测使用的模型 |

### 提供商特有配置

Expand Down
10 changes: 7 additions & 3 deletions plugins/wasm-go/extensions/ai-proxy/config/config.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
package config

import (
"github.com/tidwall/gjson"

"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/provider"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/tidwall/gjson"
)

// @Name ai-proxy
Expand Down Expand Up @@ -75,13 +75,17 @@ func (c *PluginConfig) Validate() error {
return nil
}

func (c *PluginConfig) Complete() error {
func (c *PluginConfig) Complete(log wrapper.Log) error {
if c.activeProviderConfig == nil {
c.activeProvider = nil
return nil
}
var err error
c.activeProvider, err = provider.CreateProvider(*c.activeProviderConfig)

providerConfig := c.GetProviderConfig()
err = providerConfig.SetApiTokensFailover(log, c.activeProvider)

return err
}

Expand Down
49 changes: 35 additions & 14 deletions plugins/wasm-go/extensions/ai-proxy/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,10 @@ func parseGlobalConfig(json gjson.Result, pluginConfig *config.PluginConfig, log
if err := pluginConfig.Validate(); err != nil {
return err
}
if err := pluginConfig.Complete(); err != nil {
if err := pluginConfig.Complete(log); err != nil {
return err
}

return nil
}

Expand All @@ -59,9 +60,10 @@ func parseOverrideRuleConfig(json gjson.Result, global config.PluginConfig, plug
if err := pluginConfig.Validate(); err != nil {
return err
}
if err := pluginConfig.Complete(); err != nil {
if err := pluginConfig.Complete(log); err != nil {
return err
}

return nil
}

Expand All @@ -80,7 +82,13 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf
path, _ := url.Parse(rawPath)
apiName := getOpenAiApiName(path.Path)
providerConfig := pluginConfig.GetProviderConfig()
if apiName == "" && !providerConfig.IsOriginal() {
if providerConfig.IsOriginal() {
if handler, ok := activeProvider.(provider.ApiNameHandler); ok {
apiName = handler.GetApiName(path.Path)
}
}

if apiName == "" {
log.Debugf("[onHttpRequestHeader] unsupported path: %s", path.Path)
// _ = util.SendResponse(404, "ai-proxy.unknown_api", util.MimeTypeTextPlain, "API not found: "+path.Path)
log.Debugf("[onHttpRequestHeader] no send response")
Expand All @@ -89,8 +97,11 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf
ctx.SetContext(ctxKeyApiName, apiName)

if handler, ok := activeProvider.(provider.RequestHeadersHandler); ok {
// Disable the route re-calculation since the plugin may modify some headers related to the chosen route.
// Disable the route re-calculation since the plugin may modify some headers related to the chosen route.
ctx.DisableReroute()
// Set the apiToken for the current request.
providerConfig.SetApiTokenInUse(ctx, log)

hasRequestBody := wrapper.HasRequestBody()
action, err := handler.OnRequestHeaders(ctx, apiName, log)
if err == nil {
Expand All @@ -102,6 +113,7 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf
}
return action
}

_ = util.SendResponse(500, "ai-proxy.proc_req_headers_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to process request headers: %v", err))
return types.ActionContinue
}
Expand Down Expand Up @@ -156,15 +168,24 @@ func onHttpResponseHeaders(ctx wrapper.HttpContext, pluginConfig config.PluginCo

log.Debugf("[onHttpResponseHeaders] provider=%s", activeProvider.GetProviderType())

providerConfig := pluginConfig.GetProviderConfig()
apiTokenInUse := providerConfig.GetApiTokenInUse(ctx)

status, err := proxywasm.GetHttpResponseHeader(":status")
if err != nil || status != "200" {
if err != nil {
log.Errorf("unable to load :status header from response: %v", err)
}
ctx.DontReadResponseBody()
providerConfig.OnRequestFailed(ctx, apiTokenInUse, log)

return types.ActionContinue
}

// Reset ctxApiTokenRequestFailureCount if the request is successful,
// the apiToken is removed only when the number of consecutive request failures exceeds the threshold.
providerConfig.ResetApiTokenRequestFailureCount(apiTokenInUse, log)

if handler, ok := activeProvider.(provider.ResponseHeadersHandler); ok {
apiName, _ := ctx.GetContext(ctxKeyApiName).(provider.ApiName)
action, err := handler.OnResponseHeaders(ctx, apiName, log)
Expand Down Expand Up @@ -233,16 +254,6 @@ func onHttpResponseBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfi
return types.ActionContinue
}

func getOpenAiApiName(path string) provider.ApiName {
if strings.HasSuffix(path, "/v1/chat/completions") {
return provider.ApiNameChatCompletion
}
if strings.HasSuffix(path, "/v1/embeddings") {
return provider.ApiNameEmbeddings
}
return ""
}

func checkStream(ctx *wrapper.HttpContext, log *wrapper.Log) {
contentType, err := proxywasm.GetHttpResponseHeader("Content-Type")
if err != nil || !strings.HasPrefix(contentType, "text/event-stream") {
Expand All @@ -252,3 +263,13 @@ func checkStream(ctx *wrapper.HttpContext, log *wrapper.Log) {
(*ctx).BufferResponseBody()
}
}

func getOpenAiApiName(path string) provider.ApiName {
if strings.HasSuffix(path, "/v1/chat/completions") {
return provider.ApiNameChatCompletion
}
if strings.HasSuffix(path, "/v1/embeddings") {
return provider.ApiNameEmbeddings
}
return ""
}
58 changes: 9 additions & 49 deletions plugins/wasm-go/extensions/ai-proxy/provider/ai360.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
package provider

import (
"encoding/json"
"errors"
"fmt"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"net/http"

"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
)

// ai360Provider is the provider for 360 OpenAI service.
Expand Down Expand Up @@ -46,10 +44,7 @@ func (m *ai360Provider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNam
if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings {
return types.ActionContinue, errUnsupportedApiName
}
_ = util.OverwriteRequestHost(ai360Domain)
_ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding")
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
_ = proxywasm.ReplaceHttpRequestHeader("Authorization", m.config.GetRandomToken())
m.config.handleRequestHeaders(m, ctx, apiName, log)
// Delay the header processing to allow changing streaming mode in OnRequestBody
return types.HeaderStopIteration, nil
}
Expand All @@ -58,47 +53,12 @@ func (m *ai360Provider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName,
if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings {
return types.ActionContinue, errUnsupportedApiName
}
if apiName == ApiNameChatCompletion {
return m.onChatCompletionRequestBody(ctx, body, log)
}
if apiName == ApiNameEmbeddings {
return m.onEmbeddingsRequestBody(ctx, body, log)
}
return types.ActionContinue, errUnsupportedApiName
}

func (m *ai360Provider) onChatCompletionRequestBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) (types.Action, error) {
request := &chatCompletionRequest{}
if err := decodeChatCompletionRequest(body, request); err != nil {
return types.ActionContinue, err
}
if request.Model == "" {
return types.ActionContinue, errors.New("missing model in chat completion request")
}
// 映射模型
mappedModel := getMappedModel(request.Model, m.config.modelMapping, log)
if mappedModel == "" {
return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping")
}
ctx.SetContext(ctxKeyFinalRequestModel, mappedModel)
request.Model = mappedModel
return types.ActionContinue, replaceJsonRequestBody(request, log)
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
}

func (m *ai360Provider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) (types.Action, error) {
request := &embeddingsRequest{}
if err := json.Unmarshal(body, request); err != nil {
return types.ActionContinue, fmt.Errorf("unable to unmarshal request: %v", err)
}
if request.Model == "" {
return types.ActionContinue, errors.New("missing model in embeddings request")
}
// 映射模型
mappedModel := getMappedModel(request.Model, m.config.modelMapping, log)
if mappedModel == "" {
return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping")
}
ctx.SetContext(ctxKeyFinalRequestModel, mappedModel)
request.Model = mappedModel
return types.ActionContinue, replaceJsonRequestBody(request, log)
func (m *ai360Provider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
util.OverwriteRequestHostHeader(headers, ai360Domain)
util.OverwriteRequestAuthorizationHeader(headers, "Authorization "+m.config.GetApiTokenInUse(ctx))
headers.Del("Accept-Encoding")
headers.Del("Content-Length")
}
51 changes: 13 additions & 38 deletions plugins/wasm-go/extensions/ai-proxy/provider/azure.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,15 @@ package provider
import (
"errors"
"fmt"
"net/http"
"net/url"

"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
)

// azureProvider is the provider for Azure OpenAI service.

type azureProviderInitializer struct {
}

Expand Down Expand Up @@ -55,47 +54,23 @@ func (m *azureProvider) GetProviderType() string {
}

func (m *azureProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
_ = util.OverwriteRequestPath(m.serviceUrl.RequestURI())
_ = util.OverwriteRequestHost(m.serviceUrl.Host)
_ = proxywasm.ReplaceHttpRequestHeader("api-key", m.config.apiTokens[0])
if apiName == ApiNameChatCompletion {
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
} else {
ctx.DontReadRequestBody()
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
m.config.handleRequestHeaders(m, ctx, apiName, log)
return types.ActionContinue, nil
}

func (m *azureProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if apiName != ApiNameChatCompletion {
// We don't need to process the request body for other APIs.
return types.ActionContinue, nil
}
request := &chatCompletionRequest{}
if err := decodeChatCompletionRequest(body, request); err != nil {
return types.ActionContinue, err
return types.ActionContinue, errUnsupportedApiName
}
if m.contextCache == nil {
if err := replaceJsonRequestBody(request, log); err != nil {
_ = util.SendResponse(500, "ai-proxy.openai.set_include_usage_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
}
return types.ActionContinue, nil
}
err := m.contextCache.GetContent(func(content string, err error) {
defer func() {
_ = proxywasm.ResumeHttpRequest()
}()
if err != nil {
log.Errorf("failed to load context file: %v", err)
_ = util.SendResponse(500, "ai-proxy.azure.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
}
insertContextMessage(request, content)
if err := replaceJsonRequestBody(request, log); err != nil {
_ = util.SendResponse(500, "ai-proxy.azure.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
}
}, log)
if err == nil {
return types.ActionPause, nil
}
return types.ActionContinue, err
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
}

func (m *azureProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
util.OverwriteRequestPathHeader(headers, m.serviceUrl.RequestURI())
util.OverwriteRequestHostHeader(headers, m.serviceUrl.Host)
util.OverwriteRequestAuthorizationHeader(headers, "api-key "+m.config.GetApiTokenInUse(ctx))
headers.Del("Content-Length")
}
Loading

0 comments on commit d24123a

Please sign in to comment.