Files
Beyhan Oğur 880f412e2c first commit
2026-04-26 21:52:23 +03:00

503 lines
18 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package integrations
import (
"bytes"
"fmt"
"net/url"
"reflect"
"strconv"
"strings"
"github.com/bytedance/sonic"
bifrost "github.com/maximhq/bifrost/core"
"github.com/maximhq/bifrost/core/providers/gemini"
"github.com/maximhq/bifrost/core/schemas"
"github.com/maximhq/bifrost/framework/kvstore"
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
"github.com/valyala/fasthttp"
)
var bifrostContextKeyProvider = schemas.BifrostContextKey("provider")
var availableIntegrations = []string{
"openai",
"anthropic",
"genai",
"litellm",
"langchain",
"bedrock",
"pydantic",
"cohere",
}
// newBifrostErrorWithCode is like newBifrostError but sets an explicit HTTP status code.
func newBifrostErrorWithCode(err error, message string, statusCode int) *schemas.BifrostError {
e := newBifrostError(err, message)
e.StatusCode = &statusCode
return e
}
// newBifrostError wraps a standard error into a BifrostError with IsBifrostError set to false.
// This helper function reduces code duplication when handling non-Bifrost errors.
func newBifrostError(err error, message string) *schemas.BifrostError {
if err == nil {
return &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: message,
},
}
}
return &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: message,
Error: err,
},
}
}
// safeGetRequestType safely obtains the request type from a BifrostStreamChunk chunk.
// It checks multiple sources in order of preference:
// 1. Response ExtraFields if any response is available
// 2. BifrostError ExtraFields if error is available and not nil
// 3. Falls back to "unknown" if no source is available
func safeGetRequestType(chunk *schemas.BifrostStreamChunk) string {
if chunk == nil {
return "unknown"
}
// Try to get RequestType from response ExtraFields (preferred source)
switch {
case chunk.BifrostTextCompletionResponse != nil:
return string(chunk.BifrostTextCompletionResponse.ExtraFields.RequestType)
case chunk.BifrostChatResponse != nil:
return string(chunk.BifrostChatResponse.ExtraFields.RequestType)
case chunk.BifrostResponsesStreamResponse != nil:
return string(chunk.BifrostResponsesStreamResponse.ExtraFields.RequestType)
case chunk.BifrostSpeechStreamResponse != nil:
return string(chunk.BifrostSpeechStreamResponse.ExtraFields.RequestType)
case chunk.BifrostTranscriptionStreamResponse != nil:
return string(chunk.BifrostTranscriptionStreamResponse.ExtraFields.RequestType)
}
// Try to get RequestType from error ExtraFields (fallback)
if chunk.BifrostError != nil && chunk.BifrostError.ExtraFields.RequestType != "" {
return string(chunk.BifrostError.ExtraFields.RequestType)
}
// Final fallback
return "unknown"
}
// extractHeadersFromRequest extracts headers from the request and returns them as a map.
// It uses the fasthttp.RequestCtx.Header.All() method to iterate over all headers.
func extractHeadersFromRequest(ctx *fasthttp.RequestCtx) map[string][]string {
headers := make(map[string][]string)
for key, value := range ctx.Request.Header.All() {
keyStr := string(key)
headers[keyStr] = append(headers[keyStr], string(value))
}
return headers
}
// extractExactPath returns the request path *after* the integration prefix,
// preserving the original query string exactly as sent by the client.
//
// Example:
//
// /openai/v1/chat/completions?model=gpt-4o -> v1/chat/completions?model=gpt-4o
func extractExactPath(ctx *fasthttp.RequestCtx) string {
// ctx.Path() returns only the path (no query) as a []byte backed by fasthttps internal buffers.
// Treat it as read-only; dont append to it directly.
path := ctx.Path() // e.g. "/openai/v1/chat/completions"
// Strip the integration prefix only if its at the start.
for _, integration := range availableIntegrations {
if bytes.HasPrefix(path, []byte("/"+integration+"/")) {
path = path[len("/"+integration+"/"):]
break
}
}
// Raw query string as sent by client (unparsed, preserves ordering/duplicates/encoding).
q := ctx.URI().QueryString() // e.g. "model=gpt-4o&stream=true"
if len(q) == 0 {
// No query → just return the (possibly trimmed) path.
return string(path)
}
// --- Build "<path>?<query>" efficiently and safely ---
//
// Why not do: return string(path) + "?" + string(q) ?
// - That allocates multiple temporary strings and may copy data more than necessary.
//
// Why not append into 'path' directly?
// - 'path' may alias fasthttps internal buffers; mutating/expanding it could corrupt request state.
//
// We instead allocate a new buffer with exact capacity and copy into it,
// staying in []byte until the final string conversion (1 allocation for the new slice).
out := make([]byte, 0, len(path)+1+len(q)) // pre-size: path + "?" + query
out = append(out, path...) // copy path bytes
out = append(out, '?') // separator
out = append(out, q...) // copy raw query bytes
return string(out)
}
// sendStreamError sends an error response for a streaming request that failed before streaming started.
// It propagates the provider's HTTP status code and returns a JSON error body (not SSE format),
// since no streaming has begun and clients should receive a standard error response.
func (g *GenericRouter) sendStreamError(ctx *fasthttp.RequestCtx, bifrostCtx *schemas.BifrostContext, config RouteConfig, bifrostErr *schemas.BifrostError) {
// Forward provider response headers from context so streaming error responses include them
if bifrostCtx != nil {
if headers, ok := bifrostCtx.Value(schemas.BifrostContextKeyProviderResponseHeaders).(map[string]string); ok {
for key, value := range headers {
ctx.Response.Header.Set(key, value)
}
}
}
// Set the HTTP status code from the provider error
if bifrostErr.StatusCode != nil {
ctx.SetStatusCode(*bifrostErr.StatusCode)
} else {
ctx.SetStatusCode(fasthttp.StatusInternalServerError)
}
ctx.SetContentType("application/json")
// Always use the route-level ErrorConverter (not StreamConfig.ErrorConverter) because
// sendStreamError returns JSON, not SSE. StreamConfig.ErrorConverter is designed for
// in-stream SSE errors (e.g., Anthropic's returns a raw SSE string that would be
// double-escaped by JSON marshaling).
errorResponse := config.ErrorConverter(bifrostCtx, bifrostErr)
errorJSON, err := sonic.Marshal(errorResponse)
if err != nil {
g.logger.Error("failed to marshal error response", "err", err, "path", extractExactPath(ctx))
ctx.SetStatusCode(fasthttp.StatusInternalServerError)
ctx.SetContentType("text/plain; charset=utf-8")
ctx.SetBodyString(fmt.Sprintf("failed to encode error response: %v", err))
return
}
ctx.SetBody(errorJSON)
}
// sendError sends an error response with the appropriate status code and JSON body.
// It handles different error types (string, error interface, or arbitrary objects).
func (g *GenericRouter) sendError(ctx *fasthttp.RequestCtx, bifrostCtx *schemas.BifrostContext, errorConverter ErrorConverter, bifrostErr *schemas.BifrostError) {
// Forward provider response headers from context so error responses include them
if bifrostCtx != nil {
if headers, ok := bifrostCtx.Value(schemas.BifrostContextKeyProviderResponseHeaders).(map[string]string); ok {
for key, value := range headers {
ctx.Response.Header.Set(key, value)
}
}
}
if bifrostErr.StatusCode != nil {
ctx.SetStatusCode(*bifrostErr.StatusCode)
} else {
ctx.SetStatusCode(fasthttp.StatusInternalServerError)
}
ctx.SetContentType("application/json")
// Marshal the error for response and log the error for diagnostics
responseObj := errorConverter(bifrostCtx, bifrostErr)
errorBody, err := sonic.Marshal(responseObj)
if err != nil {
// Log the marshal failure and return a plain text error
g.logger.Error("failed to marshal error response", "err", err, "path", extractExactPath(ctx))
ctx.SetStatusCode(fasthttp.StatusInternalServerError)
ctx.SetContentType("text/plain; charset=utf-8")
ctx.SetBodyString(fmt.Sprintf("failed to encode error response: %v", err))
return
}
ctx.SetBody(errorBody)
}
// sendSuccess sends a successful response with HTTP 200 status and JSON body.
func (g *GenericRouter) sendSuccess(ctx *fasthttp.RequestCtx, bifrostCtx *schemas.BifrostContext, errorConverter ErrorConverter, response interface{}, extraHeaders map[string]string) {
ctx.SetStatusCode(fasthttp.StatusOK)
ctx.SetContentType("application/json")
if extraHeaders != nil {
for key, value := range extraHeaders {
ctx.Response.Header.Set(key, value)
}
}
responseBody, err := sonic.Marshal(response)
if err != nil {
g.sendError(ctx, bifrostCtx, errorConverter, newBifrostError(err, "failed to encode response"))
return
}
ctx.SetBody(responseBody)
}
// tryStreamLargeResponse checks if large response mode was activated by the provider,
// sets the transport marker, and streams the response directly to the client.
// Returns true if the response was handled (caller should return).
func (g *GenericRouter) tryStreamLargeResponse(ctx *fasthttp.RequestCtx, bifrostCtx *schemas.BifrostContext) bool {
isLargeResponse, ok := bifrostCtx.Value(schemas.BifrostContextKeyLargeResponseMode).(bool)
if !ok || !isLargeResponse {
return false
}
// Forward provider response headers before streaming — providers store them in
// context via BifrostContextKeyProviderResponseHeaders, but some early-return
// branches in the router skip the common footer that normally forwards them.
if headers, ok := bifrostCtx.Value(schemas.BifrostContextKeyProviderResponseHeaders).(map[string]string); ok {
for key, value := range headers {
ctx.Response.Header.Set(key, value)
}
}
if g.streamLargeResponse(ctx, bifrostCtx) {
ctx.SetUserValue(lib.FastHTTPUserValueLargeResponseMode, true)
}
return true
}
// streamLargeResponse streams the large response body directly from the upstream provider to the client.
// This bypasses the normal serialize → set body path, piping the response bytes unchanged.
func (g *GenericRouter) streamLargeResponse(ctx *fasthttp.RequestCtx, bifrostCtx *schemas.BifrostContext) bool {
// Enterprise hook: wrap the reader with Phase B scanning (e.g., usage extraction
// from the full response stream) before streaming to client.
if g.largeResponseHook != nil {
g.largeResponseHook(ctx, bifrostCtx)
}
if !lib.StreamLargeResponseBody(ctx, bifrostCtx) {
ctx.SetStatusCode(fasthttp.StatusInternalServerError)
ctx.SetBodyString("large response reader not available")
return false
}
return true
}
// extractAndParseFallbacks extracts fallbacks from the integration request and adds them to the BifrostRequest
func (g *GenericRouter) extractAndParseFallbacks(req interface{}, bifrostReq *schemas.BifrostRequest) error {
// Check if the request has a fallbacks field ([]string)
fallbacks, err := g.extractFallbacksFromRequest(req)
if err != nil {
return fmt.Errorf("failed to extract fallbacks: %w", err)
}
if len(fallbacks) == 0 {
return nil // No fallbacks to process
}
provider, _, _ := bifrostReq.GetRequestFields()
// Parse fallbacks from strings to Fallback structs
parsedFallbacks := make([]schemas.Fallback, 0, len(fallbacks))
for _, fallbackStr := range fallbacks {
if fallbackStr == "" {
continue // Skip empty strings
}
// Use ParseModelString to extract provider and model
provider, model := schemas.ParseModelString(fallbackStr, provider)
parsedFallback := schemas.Fallback{
Provider: provider,
Model: model,
}
parsedFallbacks = append(parsedFallbacks, parsedFallback)
}
if len(parsedFallbacks) == 0 {
return nil // No valid fallbacks found
}
// Add fallbacks to the main BifrostRequest
bifrostReq.SetFallbacks(parsedFallbacks)
// Also add fallbacks to the specific request type if it exists
switch bifrostReq.RequestType {
case schemas.TextCompletionRequest, schemas.TextCompletionStreamRequest:
if bifrostReq.TextCompletionRequest != nil {
bifrostReq.TextCompletionRequest.Fallbacks = parsedFallbacks
}
case schemas.ChatCompletionRequest, schemas.ChatCompletionStreamRequest:
if bifrostReq.ChatRequest != nil {
bifrostReq.ChatRequest.Fallbacks = parsedFallbacks
}
case schemas.ResponsesRequest, schemas.ResponsesStreamRequest:
if bifrostReq.ResponsesRequest != nil {
bifrostReq.ResponsesRequest.Fallbacks = parsedFallbacks
}
case schemas.EmbeddingRequest:
if bifrostReq.EmbeddingRequest != nil {
bifrostReq.EmbeddingRequest.Fallbacks = parsedFallbacks
}
case schemas.RerankRequest:
if bifrostReq.RerankRequest != nil {
bifrostReq.RerankRequest.Fallbacks = parsedFallbacks
}
case schemas.SpeechRequest, schemas.SpeechStreamRequest:
if bifrostReq.SpeechRequest != nil {
bifrostReq.SpeechRequest.Fallbacks = parsedFallbacks
}
case schemas.TranscriptionRequest, schemas.TranscriptionStreamRequest:
if bifrostReq.TranscriptionRequest != nil {
bifrostReq.TranscriptionRequest.Fallbacks = parsedFallbacks
}
case schemas.ImageGenerationRequest, schemas.ImageGenerationStreamRequest:
if bifrostReq.ImageGenerationRequest != nil {
bifrostReq.ImageGenerationRequest.Fallbacks = parsedFallbacks
}
}
return nil
}
// extractFallbacksFromRequest uses reflection to extract fallbacks field from any request type
func (g *GenericRouter) extractFallbacksFromRequest(req interface{}) ([]string, error) {
if req == nil {
return nil, nil
}
// Try to use reflection to find a "fallbacks" field
reqValue := reflect.ValueOf(req)
if reqValue.Kind() == reflect.Ptr {
reqValue = reqValue.Elem()
}
if reqValue.Kind() != reflect.Struct {
return nil, nil // Not a struct, no fallbacks
}
// Look for the "fallbacks" field
fallbacksField := reqValue.FieldByName("fallbacks")
if !fallbacksField.IsValid() {
return nil, nil // No fallbacks field found
}
// Handle different types of fallbacks field
switch fallbacksField.Kind() {
case reflect.Slice:
if fallbacksField.Type().Elem().Kind() == reflect.String {
// []string case
fallbacks := make([]string, fallbacksField.Len())
for i := 0; i < fallbacksField.Len(); i++ {
fallbacks[i] = fallbacksField.Index(i).String()
}
return fallbacks, nil
}
case reflect.String:
// Single string case - treat as one fallback
return []string{fallbacksField.String()}, nil
}
return nil, nil
}
// getVirtualKeyFromBifrostContext extracts the virtual key value from bifrost context.
// Returns nil if no VK is present (e.g., direct key mode or no governance).
func getVirtualKeyFromBifrostContext(ctx *schemas.BifrostContext) *string {
vkValue := bifrost.GetStringFromContext(ctx, schemas.BifrostContextKeyVirtualKey)
if vkValue == "" {
return nil
}
return &vkValue
}
// getResultTTLFromHeaderWithDefault extracts the result TTL from the x-bf-async-job-result-ttl header.
// Returns the default TTL if the header is not present or invalid.
func getResultTTLFromHeaderWithDefault(ctx *fasthttp.RequestCtx, defaultTTL int) int {
resultTTL := string(ctx.Request.Header.Peek(schemas.AsyncHeaderResultTTL))
if resultTTL == "" {
return defaultTTL
}
resultTTLInt, err := strconv.Atoi(resultTTL)
if err != nil || resultTTLInt < 0 {
return defaultTTL
}
return resultTTLInt
}
// isAnthropicAPIKeyAuth checks if the request uses standard API key authentication.
// Returns true for API key auth (x-api-key header), false for OAuth (Bearer sk-ant-oat*).
// This is required for Claude Code specifically, which may use OAuth authentication.
// Default behavior is to assume API mode when neither x-api-key nor OAuth token is present.
func isAnthropicAPIKeyAuth(ctx *fasthttp.RequestCtx) bool {
// If x-api-key header is present - this is definitely API mode
if apiKey := string(ctx.Request.Header.Peek("x-api-key")); apiKey != "" {
return true
}
// Check for OAuth token in Authorization header
if authHeader := string(ctx.Request.Header.Peek("Authorization")); authHeader != "" {
if strings.HasPrefix(strings.ToLower(authHeader), "bearer sk-ant-oat") {
return false // OAuth mode, NOT API
}
}
// Default to API mode
return true
}
// resolveLargePayloadMetadata returns metadata from the sync context key,
// falling back to a non-blocking read from the deferred channel.
// If deferred metadata is resolved, it is cached in the sync key for later readers.
func resolveLargePayloadMetadata(bifrostCtx *schemas.BifrostContext) *schemas.LargePayloadMetadata {
if bifrostCtx == nil {
return nil
}
if metadata, ok := bifrostCtx.Value(schemas.BifrostContextKeyLargePayloadMetadata).(*schemas.LargePayloadMetadata); ok && metadata != nil {
return metadata
}
ch, ok := bifrostCtx.Value(schemas.BifrostContextKeyDeferredLargePayloadMetadata).(<-chan *schemas.LargePayloadMetadata)
if !ok || ch == nil {
return nil
}
select {
case metadata := <-ch:
if metadata != nil {
bifrostCtx.SetValue(schemas.BifrostContextKeyLargePayloadMetadata, metadata)
}
return metadata
default:
return nil
}
}
// ParseProviderScopedVideoID parses a provider-scoped video ID in the form "id:provider".
// The ID portion is automatically URL-decoded to restore the original ID.
func ParseProviderScopedVideoID(videoID string) (schemas.ModelProvider, string, error) {
parts := strings.SplitN(videoID, ":", 2)
if len(parts) != 2 || parts[0] == "" || parts[1] == "" {
return "", "", fmt.Errorf("video_id must be in id:provider format")
}
provider := schemas.ModelProvider(parts[1])
rawID := parts[0]
// URL decode the ID to restore original characters (e.g., %2F -> /)
// This handles IDs from all providers that may contain special characters
if decoded, err := url.PathUnescape(rawID); err == nil {
rawID = decoded
}
return provider, rawID, nil
}
func getProviderFromHeader(ctx *fasthttp.RequestCtx, defaultProvider schemas.ModelProvider) schemas.ModelProvider {
providerHeader := string(ctx.Request.Header.Peek("x-model-provider"))
if providerHeader == "" {
return defaultProvider
}
return schemas.ModelProvider(providerHeader)
}
func RegisterKVDecoders(store *kvstore.Store) {
store.RegisterDecoder("genai_upload_session:", func(data []byte) (any, error) {
var v gemini.GeminiResumableUploadSession
return &v, sonic.Unmarshal(data, &v)
})
}