first commit
This commit is contained in:
162
core/providers/mistral/custom_provider_test.go
Normal file
162
core/providers/mistral/custom_provider_test.go
Normal file
@@ -0,0 +1,162 @@
|
||||
package mistral
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
const customMistralProviderName = schemas.ModelProvider("custom-mistral")
|
||||
|
||||
func TestParseMistralError_UsesExportedConverterMetadata(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
resp := fasthttp.AcquireResponse()
|
||||
defer fasthttp.ReleaseResponse(resp)
|
||||
|
||||
resp.SetStatusCode(http.StatusBadRequest)
|
||||
resp.SetBodyString(`{"message":"invalid request","type":"invalid_request_error","code":"bad_request"}`)
|
||||
|
||||
bifrostErr := ParseMistralError(resp)
|
||||
require.NotNil(t, bifrostErr)
|
||||
require.NotNil(t, bifrostErr.Error)
|
||||
|
||||
assert.Equal(t, "invalid request", bifrostErr.Error.Message)
|
||||
assert.Equal(t, schemas.Ptr("invalid_request_error"), bifrostErr.Error.Type)
|
||||
assert.Equal(t, schemas.Ptr("bad_request"), bifrostErr.Error.Code)
|
||||
// Note: ExtraFields.Provider is populated by bifrost.go's dispatcher via
|
||||
// PopulateExtraFields, not by ParseMistralError called in isolation.
|
||||
}
|
||||
|
||||
func TestMistralProvider_CustomAliasChatStreamUsesBaseCompatibilityAndAliasMetadata(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var capturedRequest map[string]any
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
body, err := io.ReadAll(r.Body)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, sonic.Unmarshal(body, &capturedRequest))
|
||||
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
flusher, ok := w.(http.Flusher)
|
||||
require.True(t, ok)
|
||||
|
||||
_, err = w.Write([]byte("data: {\"id\":\"chatcmpl-1\",\"object\":\"chat.completion.chunk\",\"created\":1,\"model\":\"mistral-small-latest\",\"choices\":[{\"index\":0,\"delta\":{\"role\":\"assistant\",\"content\":\"hello\"}}]}\n\n"))
|
||||
require.NoError(t, err)
|
||||
flusher.Flush()
|
||||
|
||||
_, err = w.Write([]byte("data: [DONE]\n\n"))
|
||||
require.NoError(t, err)
|
||||
flusher.Flush()
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
provider := NewMistralProvider(&schemas.ProviderConfig{
|
||||
NetworkConfig: schemas.NetworkConfig{BaseURL: server.URL},
|
||||
CustomProviderConfig: &schemas.CustomProviderConfig{
|
||||
CustomProviderKey: string(customMistralProviderName),
|
||||
BaseProviderType: schemas.Mistral,
|
||||
},
|
||||
}, &testLogger{})
|
||||
|
||||
ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
|
||||
ctx.SetValue(schemas.BifrostContextKeyIsCustomProvider, true)
|
||||
|
||||
request := &schemas.BifrostChatRequest{
|
||||
Provider: customMistralProviderName,
|
||||
Model: "mistral-small-latest",
|
||||
Input: []schemas.ChatMessage{{
|
||||
Role: schemas.ChatMessageRoleUser,
|
||||
Content: &schemas.ChatMessageContent{
|
||||
ContentStr: schemas.Ptr("hello"),
|
||||
},
|
||||
}},
|
||||
Params: &schemas.ChatParameters{
|
||||
MaxCompletionTokens: schemas.Ptr(32),
|
||||
ToolChoice: &schemas.ChatToolChoice{
|
||||
ChatToolChoiceStruct: &schemas.ChatToolChoiceStruct{
|
||||
Type: schemas.ChatToolChoiceTypeFunction,
|
||||
Function: &schemas.ChatToolChoiceFunction{
|
||||
Name: "lookup",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
postHookRunner := func(_ *schemas.BifrostContext, response *schemas.BifrostResponse, err *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError) {
|
||||
return response, err
|
||||
}
|
||||
|
||||
stream, bifrostErr := provider.ChatCompletionStream(ctx, postHookRunner, nil, schemas.Key{}, request)
|
||||
require.Nil(t, bifrostErr)
|
||||
|
||||
var firstResponse *schemas.BifrostChatResponse
|
||||
for chunk := range stream {
|
||||
if chunk.BifrostError != nil {
|
||||
t.Fatalf("unexpected stream error: %s", chunk.BifrostError.Error.Message)
|
||||
}
|
||||
if chunk.BifrostChatResponse != nil {
|
||||
firstResponse = chunk.BifrostChatResponse
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
require.NotNil(t, firstResponse)
|
||||
// Note: ExtraFields.Provider on stream chunks is populated by bifrost.go's
|
||||
// dispatcher via PopulateExtraFields, not by provider streaming methods
|
||||
// called in isolation.
|
||||
|
||||
require.NotNil(t, capturedRequest)
|
||||
assert.Equal(t, float64(32), capturedRequest["max_tokens"])
|
||||
assert.NotContains(t, capturedRequest, "max_completion_tokens")
|
||||
assert.Equal(t, "any", capturedRequest["tool_choice"])
|
||||
assert.Equal(t, "mistral-small-latest", capturedRequest["model"])
|
||||
assert.Equal(t, true, capturedRequest["stream"])
|
||||
assert.Equal(t, customMistralProviderName, provider.GetProviderKey())
|
||||
assert.Equal(t, customMistralProviderName, request.Provider)
|
||||
}
|
||||
|
||||
func TestMistralProvider_CustomAliasEmbeddingReportsAliasMetadata(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, err := w.Write([]byte(`{"object":"list","data":[{"object":"embedding","embedding":[0.1,0.2],"index":0}],"model":"codestral-embed","usage":{"prompt_tokens":1,"total_tokens":1}}`))
|
||||
require.NoError(t, err)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
provider := NewMistralProvider(&schemas.ProviderConfig{
|
||||
NetworkConfig: schemas.NetworkConfig{BaseURL: server.URL},
|
||||
CustomProviderConfig: &schemas.CustomProviderConfig{
|
||||
CustomProviderKey: string(customMistralProviderName),
|
||||
BaseProviderType: schemas.Mistral,
|
||||
},
|
||||
}, &testLogger{})
|
||||
|
||||
ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
|
||||
request := &schemas.BifrostEmbeddingRequest{
|
||||
Provider: customMistralProviderName,
|
||||
Model: "codestral-embed",
|
||||
Input: &schemas.EmbeddingInput{
|
||||
Texts: []string{"hello"},
|
||||
},
|
||||
}
|
||||
|
||||
response, bifrostErr := provider.Embedding(ctx, schemas.Key{}, request)
|
||||
require.Nil(t, bifrostErr)
|
||||
require.NotNil(t, response)
|
||||
|
||||
// Note: ExtraFields.Provider and ResolvedModelUsed are populated by
|
||||
// bifrost.go's dispatcher via PopulateExtraFields, not by provider
|
||||
// methods called in isolation.
|
||||
}
|
||||
71
core/providers/mistral/errors.go
Normal file
71
core/providers/mistral/errors.go
Normal file
@@ -0,0 +1,71 @@
|
||||
package mistral
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// MistralErrorResponse captures both Mistral's top-level error shape and nested OpenAI-style errors.
|
||||
type MistralErrorResponse struct {
|
||||
Object string `json:"object,omitempty"`
|
||||
Message string `json:"message,omitempty"`
|
||||
Type string `json:"type,omitempty"`
|
||||
Code string `json:"code,omitempty"`
|
||||
Error *schemas.ErrorField `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// ParseMistralError parses Mistral-specific error responses.
|
||||
func ParseMistralError(resp *fasthttp.Response) *schemas.BifrostError {
|
||||
var errorResp MistralErrorResponse
|
||||
bifrostErr := providerUtils.HandleProviderAPIError(resp, &errorResp)
|
||||
if bifrostErr == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if bifrostErr.Error == nil {
|
||||
bifrostErr.Error = &schemas.ErrorField{}
|
||||
}
|
||||
|
||||
if errorResp.Error != nil {
|
||||
if strings.TrimSpace(errorResp.Error.Message) != "" {
|
||||
bifrostErr.Error.Message = errorResp.Error.Message
|
||||
}
|
||||
if errorResp.Error.Type != nil && strings.TrimSpace(*errorResp.Error.Type) != "" {
|
||||
bifrostErr.Error.Type = errorResp.Error.Type
|
||||
bifrostErr.Type = errorResp.Error.Type
|
||||
}
|
||||
if errorResp.Error.Code != nil && strings.TrimSpace(*errorResp.Error.Code) != "" {
|
||||
bifrostErr.Error.Code = errorResp.Error.Code
|
||||
}
|
||||
bifrostErr.Error.Param = errorResp.Error.Param
|
||||
if errorResp.Error.EventID != nil {
|
||||
bifrostErr.Error.EventID = errorResp.Error.EventID
|
||||
}
|
||||
}
|
||||
|
||||
if strings.TrimSpace(errorResp.Message) != "" {
|
||||
bifrostErr.Error.Message = errorResp.Message
|
||||
}
|
||||
if strings.TrimSpace(errorResp.Type) != "" {
|
||||
errorType := schemas.Ptr(errorResp.Type)
|
||||
bifrostErr.Error.Type = errorType
|
||||
bifrostErr.Type = errorType
|
||||
}
|
||||
if strings.TrimSpace(errorResp.Code) != "" {
|
||||
bifrostErr.Error.Code = schemas.Ptr(errorResp.Code)
|
||||
}
|
||||
|
||||
if strings.TrimSpace(bifrostErr.Error.Message) == "" {
|
||||
if bifrostErr.StatusCode != nil {
|
||||
bifrostErr.Error.Message = fmt.Sprintf("provider API error (status %d)", *bifrostErr.StatusCode)
|
||||
} else {
|
||||
bifrostErr.Error.Message = "provider API error"
|
||||
}
|
||||
}
|
||||
|
||||
return bifrostErr
|
||||
}
|
||||
893
core/providers/mistral/mistral.go
Normal file
893
core/providers/mistral/mistral.go
Normal file
@@ -0,0 +1,893 @@
|
||||
// Package mistral implements the Mistral provider.
|
||||
package mistral
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
"github.com/maximhq/bifrost/core/providers/openai"
|
||||
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
|
||||
schemas "github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// MistralProvider implements the Provider interface for Mistral's API.
|
||||
type MistralProvider struct {
|
||||
logger schemas.Logger // Logger for provider operations
|
||||
client *fasthttp.Client // HTTP client for unary API requests (ReadTimeout bounds overall response)
|
||||
streamingClient *fasthttp.Client // HTTP client for streaming API requests (no ReadTimeout; idle governed by NewIdleTimeoutReader)
|
||||
networkConfig schemas.NetworkConfig // Network configuration including extra headers
|
||||
customProviderConfig *schemas.CustomProviderConfig
|
||||
sendBackRawRequest bool // Whether to include raw request in BifrostResponse
|
||||
sendBackRawResponse bool // Whether to include raw response in BifrostResponse
|
||||
}
|
||||
|
||||
// NewMistralProvider creates a new Mistral provider instance.
|
||||
// It initializes the HTTP client with the provided configuration and sets up response pools.
|
||||
// The client is configured with timeouts, concurrency limits, and optional proxy settings.
|
||||
func NewMistralProvider(config *schemas.ProviderConfig, logger schemas.Logger) *MistralProvider {
|
||||
config.CheckAndSetDefaults()
|
||||
|
||||
requestTimeout := time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds)
|
||||
client := &fasthttp.Client{
|
||||
ReadTimeout: requestTimeout,
|
||||
WriteTimeout: requestTimeout,
|
||||
MaxConnsPerHost: config.NetworkConfig.MaxConnsPerHost,
|
||||
MaxIdleConnDuration: 30 * time.Second,
|
||||
MaxConnWaitTimeout: requestTimeout,
|
||||
MaxConnDuration: time.Second * time.Duration(schemas.DefaultMaxConnDurationInSeconds),
|
||||
ConnPoolStrategy: fasthttp.FIFO,
|
||||
}
|
||||
|
||||
// Pre-warm response pools
|
||||
// for range config.ConcurrencyAndBufferSize.Concurrency {
|
||||
// mistralResponsePool.Put(&schemas.BifrostResponse{})
|
||||
// }
|
||||
|
||||
// Configure proxy and retry policy
|
||||
client = providerUtils.ConfigureProxy(client, config.ProxyConfig, logger)
|
||||
client = providerUtils.ConfigureDialer(client)
|
||||
client = providerUtils.ConfigureTLS(client, config.NetworkConfig, logger)
|
||||
streamingClient := providerUtils.BuildStreamingClient(client)
|
||||
// Set default BaseURL if not provided
|
||||
if config.NetworkConfig.BaseURL == "" {
|
||||
config.NetworkConfig.BaseURL = "https://api.mistral.ai"
|
||||
}
|
||||
config.NetworkConfig.BaseURL = strings.TrimRight(config.NetworkConfig.BaseURL, "/")
|
||||
|
||||
return &MistralProvider{
|
||||
logger: logger,
|
||||
client: client,
|
||||
streamingClient: streamingClient,
|
||||
networkConfig: config.NetworkConfig,
|
||||
customProviderConfig: config.CustomProviderConfig,
|
||||
sendBackRawRequest: config.SendBackRawRequest,
|
||||
sendBackRawResponse: config.SendBackRawResponse,
|
||||
}
|
||||
}
|
||||
|
||||
// GetProviderKey returns the provider identifier for Mistral.
|
||||
func (provider *MistralProvider) GetProviderKey() schemas.ModelProvider {
|
||||
return providerUtils.GetProviderName(schemas.Mistral, provider.customProviderConfig)
|
||||
}
|
||||
|
||||
// listModelsByKey performs a list models request for a single key.
|
||||
// Returns the response and latency, or an error if the request fails.
|
||||
func (provider *MistralProvider) listModelsByKey(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) {
|
||||
// Create request
|
||||
req := fasthttp.AcquireRequest()
|
||||
resp := fasthttp.AcquireResponse()
|
||||
defer fasthttp.ReleaseRequest(req)
|
||||
defer fasthttp.ReleaseResponse(resp)
|
||||
|
||||
// Set any extra headers from network config
|
||||
providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil)
|
||||
|
||||
req.SetRequestURI(provider.networkConfig.BaseURL + providerUtils.GetPathFromContext(ctx, "/v1/models"))
|
||||
req.Header.SetMethod(http.MethodGet)
|
||||
req.Header.SetContentType("application/json")
|
||||
if key.Value.GetValue() != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+key.Value.GetValue())
|
||||
}
|
||||
|
||||
// Make request
|
||||
latency, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, provider.client, req, resp)
|
||||
defer wait()
|
||||
if bifrostErr != nil {
|
||||
return nil, bifrostErr
|
||||
}
|
||||
|
||||
// Handle error response
|
||||
if resp.StatusCode() != fasthttp.StatusOK {
|
||||
bifrostErr := ParseMistralError(resp)
|
||||
return nil, bifrostErr
|
||||
}
|
||||
|
||||
// Copy response body before releasing
|
||||
responseBody := append([]byte(nil), resp.Body()...)
|
||||
|
||||
// Parse Mistral's response
|
||||
var mistralResponse MistralListModelsResponse
|
||||
rawRequest, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(responseBody, &mistralResponse, nil, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse))
|
||||
if bifrostErr != nil {
|
||||
return nil, bifrostErr
|
||||
}
|
||||
|
||||
// Create final response
|
||||
response := mistralResponse.ToBifrostListModelsResponse(key.Models, key.BlacklistedModels, key.Aliases, request.Unfiltered)
|
||||
|
||||
response.ExtraFields.Latency = latency.Milliseconds()
|
||||
|
||||
// Set raw request if enabled
|
||||
if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) {
|
||||
response.ExtraFields.RawRequest = rawRequest
|
||||
}
|
||||
|
||||
// Set raw response if enabled
|
||||
if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) {
|
||||
response.ExtraFields.RawResponse = rawResponse
|
||||
}
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
// ListModels performs a list models request to Mistral's API.
|
||||
// Requests are made concurrently for improved performance.
|
||||
func (provider *MistralProvider) ListModels(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) {
|
||||
return providerUtils.HandleMultipleListModelsRequests(
|
||||
ctx,
|
||||
keys,
|
||||
request,
|
||||
provider.listModelsByKey,
|
||||
)
|
||||
}
|
||||
|
||||
// TextCompletion is not supported by the Mistral provider.
|
||||
func (provider *MistralProvider) TextCompletion(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (*schemas.BifrostTextCompletionResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.TextCompletionRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// TextCompletionStream performs a streaming text completion request to Mistral's API.
|
||||
// It formats the request, sends it to Mistral, and processes the response.
|
||||
// Returns a channel of BifrostStreamChunk objects or an error if the request fails.
|
||||
func (provider *MistralProvider) TextCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.TextCompletionStreamRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// normalizeChatRequestForConversion returns the request unchanged for the stock Mistral
|
||||
// provider. For custom aliases (e.g. a provider registered as "custom-mistral" with
|
||||
// BaseProviderType=Mistral), it returns a shallow copy with Provider set to schemas.Mistral
|
||||
// so the shared OpenAI converter applies Mistral-specific compatibility (max_completion_tokens
|
||||
// → max_tokens, tool_choice struct → "any"). The caller's request is never mutated.
|
||||
func (provider *MistralProvider) normalizeChatRequestForConversion(request *schemas.BifrostChatRequest) *schemas.BifrostChatRequest {
|
||||
if request == nil || provider.customProviderConfig == nil || request.Provider == schemas.Mistral {
|
||||
return request
|
||||
}
|
||||
normalized := *request
|
||||
normalized.Provider = schemas.Mistral
|
||||
return &normalized
|
||||
}
|
||||
|
||||
// ChatCompletion performs a chat completion request to the Mistral API.
|
||||
func (provider *MistralProvider) ChatCompletion(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) {
|
||||
return openai.HandleOpenAIChatCompletionRequest(
|
||||
ctx,
|
||||
provider.client,
|
||||
provider.networkConfig.BaseURL+providerUtils.GetPathFromContext(ctx, "/v1/chat/completions"),
|
||||
provider.normalizeChatRequestForConversion(request),
|
||||
key,
|
||||
provider.networkConfig.ExtraHeaders,
|
||||
providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest),
|
||||
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
|
||||
provider.GetProviderKey(),
|
||||
nil,
|
||||
ParseMistralError,
|
||||
provider.logger,
|
||||
)
|
||||
}
|
||||
|
||||
// ChatCompletionStream performs a streaming chat completion request to the Mistral API.
|
||||
// It supports real-time streaming of responses using Server-Sent Events (SSE).
|
||||
// Uses Mistral's OpenAI-compatible streaming format.
|
||||
// Returns a channel containing BifrostStreamChunk objects representing the stream or an error if the request fails.
|
||||
func (provider *MistralProvider) ChatCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
||||
var authHeader map[string]string
|
||||
if key.Value.GetValue() != "" {
|
||||
authHeader = map[string]string{"Authorization": "Bearer " + key.Value.GetValue()}
|
||||
}
|
||||
// Use shared OpenAI-compatible streaming logic
|
||||
return openai.HandleOpenAIChatCompletionStreaming(
|
||||
ctx,
|
||||
provider.streamingClient,
|
||||
provider.networkConfig.BaseURL+"/v1/chat/completions",
|
||||
provider.normalizeChatRequestForConversion(request),
|
||||
authHeader,
|
||||
provider.networkConfig.ExtraHeaders,
|
||||
providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest),
|
||||
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
|
||||
provider.GetProviderKey(),
|
||||
postHookRunner,
|
||||
nil,
|
||||
nil,
|
||||
ParseMistralError,
|
||||
nil,
|
||||
nil,
|
||||
provider.logger,
|
||||
postHookSpanFinalizer,
|
||||
)
|
||||
}
|
||||
|
||||
// Responses performs a responses request to the Mistral API.
|
||||
func (provider *MistralProvider) Responses(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) {
|
||||
chatResponse, err := provider.ChatCompletion(ctx, key, request.ToChatRequest())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
response := chatResponse.ToBifrostResponsesResponse()
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
// ResponsesStream performs a streaming responses request to the Mistral API.
|
||||
func (provider *MistralProvider) ResponsesStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
||||
ctx.SetValue(schemas.BifrostContextKeyIsResponsesToChatCompletionFallback, true)
|
||||
return provider.ChatCompletionStream(
|
||||
ctx,
|
||||
postHookRunner,
|
||||
postHookSpanFinalizer,
|
||||
key,
|
||||
request.ToChatRequest(),
|
||||
)
|
||||
}
|
||||
|
||||
// Embedding generates embeddings for the given input text(s) using the Mistral API.
|
||||
// Supports Mistral's embedding models and returns a BifrostResponse containing the embedding(s).
|
||||
func (provider *MistralProvider) Embedding(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostEmbeddingRequest) (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) {
|
||||
// Use the shared embedding request handler
|
||||
return openai.HandleOpenAIEmbeddingRequest(
|
||||
ctx,
|
||||
provider.client,
|
||||
provider.networkConfig.BaseURL+providerUtils.GetPathFromContext(ctx, "/v1/embeddings"),
|
||||
request,
|
||||
key,
|
||||
provider.networkConfig.ExtraHeaders,
|
||||
provider.GetProviderKey(),
|
||||
providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest),
|
||||
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
|
||||
nil,
|
||||
provider.logger,
|
||||
)
|
||||
}
|
||||
|
||||
// Speech is not supported by the Mistral provider.
|
||||
func (provider *MistralProvider) Speech(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostSpeechRequest) (*schemas.BifrostSpeechResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.SpeechRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// Rerank is not supported by the Mistral provider.
|
||||
func (provider *MistralProvider) Rerank(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostRerankRequest) (*schemas.BifrostRerankResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.RerankRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// OCR performs an OCR request to the Mistral API.
|
||||
// It sends a JSON request to Mistral's OCR endpoint and returns the extracted content.
|
||||
func (provider *MistralProvider) OCR(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostOCRRequest) (*schemas.BifrostOCRResponse, *schemas.BifrostError) {
|
||||
// Convert Bifrost request to Mistral format
|
||||
mistralReq := ToMistralOCRRequest(request)
|
||||
if mistralReq == nil {
|
||||
return nil, providerUtils.NewBifrostOperationError("ocr request input is not provided", nil)
|
||||
}
|
||||
|
||||
// Marshal request body
|
||||
requestBody, err := sonic.Marshal(mistralReq)
|
||||
if err != nil {
|
||||
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err)
|
||||
}
|
||||
|
||||
// Merge extra params into JSON payload
|
||||
if len(mistralReq.ExtraParams) > 0 {
|
||||
requestBody, err = providerUtils.MergeExtraParamsIntoJSON(requestBody, mistralReq.ExtraParams)
|
||||
if err != nil {
|
||||
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Create HTTP request
|
||||
req := fasthttp.AcquireRequest()
|
||||
resp := fasthttp.AcquireResponse()
|
||||
defer fasthttp.ReleaseRequest(req)
|
||||
defer fasthttp.ReleaseResponse(resp)
|
||||
|
||||
// Set extra headers from network config
|
||||
providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil)
|
||||
|
||||
req.SetRequestURI(provider.networkConfig.BaseURL + providerUtils.GetPathFromContext(ctx, "/v1/ocr"))
|
||||
req.Header.SetMethod(http.MethodPost)
|
||||
req.Header.SetContentType("application/json")
|
||||
if key.Value.GetValue() != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+key.Value.GetValue())
|
||||
}
|
||||
|
||||
req.SetBody(requestBody)
|
||||
|
||||
// Set raw request if enabled
|
||||
if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) {
|
||||
request.RawRequestBody = requestBody
|
||||
}
|
||||
|
||||
// Make request
|
||||
latency, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, provider.client, req, resp)
|
||||
defer wait()
|
||||
if bifrostErr != nil {
|
||||
return nil, bifrostErr
|
||||
}
|
||||
|
||||
// Handle error response
|
||||
if resp.StatusCode() != fasthttp.StatusOK {
|
||||
return nil, ParseMistralError(resp)
|
||||
}
|
||||
|
||||
responseBody, err := providerUtils.CheckAndDecodeBody(resp)
|
||||
if err != nil {
|
||||
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err)
|
||||
}
|
||||
|
||||
// Check for empty response
|
||||
trimmed := strings.TrimSpace(string(responseBody))
|
||||
if len(trimmed) == 0 {
|
||||
return nil, &schemas.BifrostError{
|
||||
IsBifrostError: true,
|
||||
Error: &schemas.ErrorField{
|
||||
Message: schemas.ErrProviderResponseEmpty,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
copiedResponseBody := append([]byte(nil), responseBody...)
|
||||
|
||||
// Parse Mistral's OCR response
|
||||
var mistralResponse MistralOCRResponse
|
||||
if err := sonic.Unmarshal(copiedResponseBody, &mistralResponse); err != nil {
|
||||
if providerUtils.IsHTMLResponse(resp, copiedResponseBody) {
|
||||
return nil, &schemas.BifrostError{
|
||||
IsBifrostError: false,
|
||||
Error: &schemas.ErrorField{
|
||||
Message: schemas.ErrProviderResponseHTML,
|
||||
Error: errors.New(string(copiedResponseBody)),
|
||||
},
|
||||
}
|
||||
}
|
||||
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err)
|
||||
}
|
||||
|
||||
// Convert to Bifrost format
|
||||
response := mistralResponse.ToBifrostOCRResponse()
|
||||
if response == nil {
|
||||
return nil, providerUtils.NewBifrostOperationError("failed to convert ocr response", nil)
|
||||
}
|
||||
|
||||
// Set extra fields
|
||||
response.ExtraFields.Latency = latency.Milliseconds()
|
||||
|
||||
// Set raw response if enabled
|
||||
if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) {
|
||||
var rawResponse interface{}
|
||||
if err := sonic.Unmarshal(copiedResponseBody, &rawResponse); err == nil {
|
||||
response.ExtraFields.RawResponse = rawResponse
|
||||
}
|
||||
}
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
// SpeechStream is not supported by the Mistral provider.
|
||||
func (provider *MistralProvider) SpeechStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.SpeechStreamRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// Transcription performs an audio transcription request to the Mistral API.
|
||||
// It creates a multipart form with the audio file and sends it to Mistral's transcription endpoint.
|
||||
// Returns the transcribed text and metadata, or an error if the request fails.
|
||||
func (provider *MistralProvider) Transcription(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (*schemas.BifrostTranscriptionResponse, *schemas.BifrostError) {
|
||||
// Convert Bifrost request to Mistral format
|
||||
mistralReq := ToMistralTranscriptionRequest(request)
|
||||
if mistralReq == nil {
|
||||
return nil, providerUtils.NewBifrostOperationError("transcription input is not provided", nil)
|
||||
}
|
||||
|
||||
// Create multipart form body
|
||||
body, contentType, bifrostErr := createMistralTranscriptionMultipartBody(mistralReq, provider.GetProviderKey())
|
||||
if bifrostErr != nil {
|
||||
return nil, bifrostErr
|
||||
}
|
||||
|
||||
// Create HTTP request
|
||||
req := fasthttp.AcquireRequest()
|
||||
resp := fasthttp.AcquireResponse()
|
||||
defer fasthttp.ReleaseRequest(req)
|
||||
defer fasthttp.ReleaseResponse(resp)
|
||||
|
||||
// Set extra headers from network config
|
||||
providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil)
|
||||
|
||||
req.SetRequestURI(provider.networkConfig.BaseURL + providerUtils.GetPathFromContext(ctx, "/v1/audio/transcriptions"))
|
||||
req.Header.SetMethod(http.MethodPost)
|
||||
req.Header.SetContentType(contentType)
|
||||
if key.Value.GetValue() != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+key.Value.GetValue())
|
||||
}
|
||||
|
||||
req.SetBody(body.Bytes())
|
||||
|
||||
// Make request
|
||||
latency, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, provider.client, req, resp)
|
||||
defer wait()
|
||||
if bifrostErr != nil {
|
||||
return nil, bifrostErr
|
||||
}
|
||||
|
||||
// Handle error response
|
||||
if resp.StatusCode() != fasthttp.StatusOK {
|
||||
return nil, ParseMistralError(resp)
|
||||
}
|
||||
|
||||
responseBody, err := providerUtils.CheckAndDecodeBody(resp)
|
||||
if err != nil {
|
||||
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err)
|
||||
}
|
||||
|
||||
// Check for empty response
|
||||
trimmed := strings.TrimSpace(string(responseBody))
|
||||
if len(trimmed) == 0 {
|
||||
return nil, &schemas.BifrostError{
|
||||
IsBifrostError: true,
|
||||
Error: &schemas.ErrorField{
|
||||
Message: schemas.ErrProviderResponseEmpty,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
copiedResponseBody := append([]byte(nil), responseBody...)
|
||||
|
||||
// Parse Mistral's transcription response
|
||||
var mistralResponse MistralTranscriptionResponse
|
||||
if err := sonic.Unmarshal(copiedResponseBody, &mistralResponse); err != nil {
|
||||
if providerUtils.IsHTMLResponse(resp, copiedResponseBody) {
|
||||
return nil, &schemas.BifrostError{
|
||||
IsBifrostError: false,
|
||||
Error: &schemas.ErrorField{
|
||||
Message: schemas.ErrProviderResponseHTML,
|
||||
Error: errors.New(string(copiedResponseBody)),
|
||||
},
|
||||
}
|
||||
}
|
||||
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err)
|
||||
}
|
||||
|
||||
// Convert to Bifrost format
|
||||
response := mistralResponse.ToBifrostTranscriptionResponse()
|
||||
if response == nil {
|
||||
return nil, providerUtils.NewBifrostOperationError("failed to convert transcription response", nil)
|
||||
}
|
||||
|
||||
// Set extra fields
|
||||
response.ExtraFields.Latency = latency.Milliseconds()
|
||||
|
||||
// Set raw response if enabled
|
||||
if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) {
|
||||
var rawResponse interface{}
|
||||
if err := sonic.Unmarshal(copiedResponseBody, &rawResponse); err == nil {
|
||||
response.ExtraFields.RawResponse = rawResponse
|
||||
}
|
||||
}
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
// TranscriptionStream performs a streaming transcription request to Mistral's API.
|
||||
// It creates a multipart form with the audio file and streams transcription events.
|
||||
// Returns a channel of BifrostStreamChunk objects containing transcription deltas.
|
||||
func (provider *MistralProvider) TranscriptionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
||||
providerName := provider.GetProviderKey()
|
||||
|
||||
// Convert Bifrost request to Mistral format
|
||||
mistralReq := ToMistralTranscriptionRequest(request)
|
||||
if mistralReq == nil {
|
||||
return nil, providerUtils.NewBifrostOperationError("transcription input is not provided", nil)
|
||||
}
|
||||
mistralReq.Stream = schemas.Ptr(true)
|
||||
|
||||
// Create multipart form body with stream=true
|
||||
body, contentType, bifrostErr := createMistralTranscriptionMultipartBody(mistralReq, providerName)
|
||||
if bifrostErr != nil {
|
||||
return nil, bifrostErr
|
||||
}
|
||||
|
||||
// Prepare headers for streaming
|
||||
headers := map[string]string{
|
||||
"Content-Type": contentType,
|
||||
"Accept": "text/event-stream",
|
||||
"Cache-Control": "no-cache",
|
||||
}
|
||||
|
||||
if key.Value.GetValue() != "" {
|
||||
headers["Authorization"] = "Bearer " + key.Value.GetValue()
|
||||
}
|
||||
|
||||
// Create HTTP request for streaming
|
||||
req := fasthttp.AcquireRequest()
|
||||
resp := fasthttp.AcquireResponse()
|
||||
resp.StreamBody = true
|
||||
defer fasthttp.ReleaseRequest(req)
|
||||
|
||||
// Set any extra headers from network config
|
||||
providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil)
|
||||
|
||||
req.Header.SetMethod(http.MethodPost)
|
||||
req.SetRequestURI(provider.networkConfig.BaseURL + providerUtils.GetPathFromContext(ctx, "/v1/audio/transcriptions"))
|
||||
|
||||
// Set headers
|
||||
for headerKey, value := range headers {
|
||||
req.Header.Set(headerKey, value)
|
||||
}
|
||||
|
||||
req.SetBody(body.Bytes())
|
||||
|
||||
// Make the request
|
||||
err := provider.streamingClient.Do(req, resp)
|
||||
if err != nil {
|
||||
defer providerUtils.ReleaseStreamingResponse(resp)
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return nil, &schemas.BifrostError{
|
||||
IsBifrostError: false,
|
||||
Error: &schemas.ErrorField{
|
||||
Type: schemas.Ptr(schemas.RequestCancelled),
|
||||
Message: schemas.ErrRequestCancelled,
|
||||
Error: err,
|
||||
},
|
||||
}
|
||||
}
|
||||
if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) {
|
||||
return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err)
|
||||
}
|
||||
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err)
|
||||
}
|
||||
|
||||
// Store provider response headers in context before status check so error responses also forward them
|
||||
ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerUtils.ExtractProviderResponseHeaders(resp))
|
||||
|
||||
// Check for HTTP errors
|
||||
if resp.StatusCode() != fasthttp.StatusOK {
|
||||
defer providerUtils.ReleaseStreamingResponse(resp)
|
||||
return nil, ParseMistralError(resp)
|
||||
}
|
||||
|
||||
// Large payload streaming passthrough — pipe raw upstream SSE to client
|
||||
if providerUtils.SetupStreamingPassthrough(ctx, resp) {
|
||||
responseChan := make(chan *schemas.BifrostStreamChunk)
|
||||
close(responseChan)
|
||||
return responseChan, nil
|
||||
}
|
||||
|
||||
// Create response channel
|
||||
responseChan := make(chan *schemas.BifrostStreamChunk, schemas.DefaultStreamBufferSize)
|
||||
|
||||
providerUtils.SetStreamIdleTimeoutIfEmpty(ctx, provider.networkConfig.StreamIdleTimeoutInSeconds)
|
||||
|
||||
// Start streaming in a goroutine
|
||||
go func() {
|
||||
defer func() {
|
||||
if ctx.Err() == context.Canceled {
|
||||
providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.logger, postHookSpanFinalizer)
|
||||
} else if ctx.Err() == context.DeadlineExceeded {
|
||||
providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.logger, postHookSpanFinalizer)
|
||||
}
|
||||
close(responseChan)
|
||||
}()
|
||||
defer providerUtils.ReleaseStreamingResponse(resp)
|
||||
// Decompress gzip-encoded streams transparently (no-op for non-gzip)
|
||||
reader, releaseGzip := providerUtils.DecompressStreamBody(resp)
|
||||
defer releaseGzip()
|
||||
|
||||
// Wrap reader with idle timeout to detect stalled streams.
|
||||
reader, stopIdleTimeout := providerUtils.NewIdleTimeoutReader(reader, resp.BodyStream(), providerUtils.GetStreamIdleTimeout(ctx))
|
||||
defer stopIdleTimeout()
|
||||
|
||||
// Setup cancellation handler to close the raw network stream on ctx cancellation,
|
||||
// which immediately unblocks any in-progress read (including reads blocked inside a gzip decompression layer).
|
||||
stopCancellation := providerUtils.SetupStreamCancellation(ctx, resp.BodyStream(), provider.logger)
|
||||
defer stopCancellation()
|
||||
defer providerUtils.EnsureStreamFinalizerCalled(ctx, postHookSpanFinalizer)
|
||||
|
||||
sseReader := providerUtils.GetSSEEventReader(ctx, reader)
|
||||
chunkIndex := -1
|
||||
|
||||
startTime := time.Now()
|
||||
lastChunkTime := startTime
|
||||
|
||||
for {
|
||||
// If context was cancelled/timed out, let defer handle it
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
eventType, eventDataBytes, readErr := sseReader.ReadEvent()
|
||||
if readErr != nil {
|
||||
if readErr != io.EOF {
|
||||
// If context was cancelled/timed out, let defer handle it
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
|
||||
provider.logger.Warn("Error reading stream: %v", readErr)
|
||||
providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, provider.logger, postHookSpanFinalizer)
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
currentEvent := eventType
|
||||
currentData := string(eventDataBytes)
|
||||
if currentEvent == "" || currentData == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
chunkIndex++
|
||||
provider.processTranscriptionStreamEvent(ctx, postHookRunner, currentEvent, currentData, request.Model, providerName, chunkIndex, startTime, &lastChunkTime, responseChan, postHookSpanFinalizer)
|
||||
// Break on terminal stream indicator (covers both done events and error events
|
||||
// that processTranscriptionStreamEvent signals via context).
|
||||
if ended, _ := ctx.Value(schemas.BifrostContextKeyStreamEndIndicator).(bool); ended {
|
||||
break
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return responseChan, nil
|
||||
}
|
||||
|
||||
// processTranscriptionStreamEvent processes a single SSE event and sends it to the response channel.
|
||||
func (provider *MistralProvider) processTranscriptionStreamEvent(
|
||||
ctx *schemas.BifrostContext,
|
||||
postHookRunner schemas.PostHookRunner,
|
||||
eventType string,
|
||||
jsonData string,
|
||||
model string,
|
||||
providerName schemas.ModelProvider,
|
||||
chunkIndex int,
|
||||
startTime time.Time,
|
||||
lastChunkTime *time.Time,
|
||||
responseChan chan *schemas.BifrostStreamChunk,
|
||||
postHookSpanFinalizer func(context.Context),
|
||||
) {
|
||||
// Skip empty data
|
||||
if strings.TrimSpace(jsonData) == "" {
|
||||
return
|
||||
}
|
||||
|
||||
// Quick check for error field (allocation-free using sonic.GetFromString)
|
||||
if errorNode, _ := sonic.GetFromString(jsonData, "error"); errorNode.Exists() {
|
||||
// Only unmarshal when we know there's an error
|
||||
var bifrostErr schemas.BifrostError
|
||||
if err := sonic.UnmarshalString(jsonData, &bifrostErr); err == nil {
|
||||
if bifrostErr.Error != nil && bifrostErr.Error.Message != "" {
|
||||
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
|
||||
providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, &bifrostErr, responseChan, provider.logger, postHookSpanFinalizer)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Parse the event data
|
||||
var eventData MistralTranscriptionStreamData
|
||||
if err := sonic.UnmarshalString(jsonData, &eventData); err != nil {
|
||||
provider.logger.Warn("Failed to parse stream event data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Create the stream event
|
||||
streamEvent := &MistralTranscriptionStreamEvent{
|
||||
Event: eventType,
|
||||
Data: &eventData,
|
||||
}
|
||||
|
||||
// Convert to Bifrost format
|
||||
response := streamEvent.ToBifrostTranscriptionStreamResponse()
|
||||
if response == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Set extra fields
|
||||
response.ExtraFields = schemas.BifrostResponseExtraFields{
|
||||
ChunkIndex: chunkIndex,
|
||||
Latency: time.Since(*lastChunkTime).Milliseconds(),
|
||||
}
|
||||
*lastChunkTime = time.Now()
|
||||
|
||||
if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) {
|
||||
response.ExtraFields.RawResponse = jsonData
|
||||
}
|
||||
|
||||
// Check for done event (handle both "transcription.done" and "transcript.text.done")
|
||||
if MistralTranscriptionStreamEventType(eventType) == MistralTranscriptionStreamEventDone || eventType == "transcript.text.done" {
|
||||
response.ExtraFields.Latency = time.Since(startTime).Milliseconds()
|
||||
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
|
||||
// Ensure response type is set to Done
|
||||
response.Type = schemas.TranscriptionStreamResponseTypeDone
|
||||
}
|
||||
|
||||
providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, nil, nil, response, nil), responseChan, postHookSpanFinalizer)
|
||||
}
|
||||
|
||||
// BatchCreate is not supported by Mistral provider.
|
||||
func (provider *MistralProvider) BatchCreate(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostBatchCreateRequest) (*schemas.BifrostBatchCreateResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchCreateRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// BatchList is not supported by Mistral provider.
|
||||
func (provider *MistralProvider) BatchList(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchListRequest) (*schemas.BifrostBatchListResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchListRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// BatchRetrieve is not supported by Mistral provider.
|
||||
func (provider *MistralProvider) BatchRetrieve(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchRetrieveRequest) (*schemas.BifrostBatchRetrieveResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchRetrieveRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// BatchCancel is not supported by Mistral provider.
|
||||
func (provider *MistralProvider) BatchCancel(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchCancelRequest) (*schemas.BifrostBatchCancelResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchCancelRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// BatchDelete is not supported by Mistral provider.
|
||||
func (provider *MistralProvider) BatchDelete(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchDeleteRequest) (*schemas.BifrostBatchDeleteResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchDeleteRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// BatchResults is not supported by Mistral provider.
|
||||
func (provider *MistralProvider) BatchResults(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchResultsRequest) (*schemas.BifrostBatchResultsResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchResultsRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// FileUpload is not supported by Mistral provider.
|
||||
func (provider *MistralProvider) FileUpload(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostFileUploadRequest) (*schemas.BifrostFileUploadResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.FileUploadRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// FileList is not supported by Mistral provider.
|
||||
func (provider *MistralProvider) FileList(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileListRequest) (*schemas.BifrostFileListResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.FileListRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// FileRetrieve is not supported by Mistral provider.
|
||||
func (provider *MistralProvider) FileRetrieve(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileRetrieveRequest) (*schemas.BifrostFileRetrieveResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.FileRetrieveRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// FileDelete is not supported by Mistral provider.
|
||||
func (provider *MistralProvider) FileDelete(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileDeleteRequest) (*schemas.BifrostFileDeleteResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.FileDeleteRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// FileContent is not supported by Mistral provider.
|
||||
func (provider *MistralProvider) FileContent(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileContentRequest) (*schemas.BifrostFileContentResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.FileContentRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// CountTokens is not supported by the Mistral provider.
|
||||
func (provider *MistralProvider) CountTokens(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostResponsesRequest) (*schemas.BifrostCountTokensResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.CountTokensRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ImageGeneration is not supported by the Mistral provider.
|
||||
func (provider *MistralProvider) ImageGeneration(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostImageGenerationRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageGenerationRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ImageGenerationStream is not supported by the Mistral provider.
|
||||
func (provider *MistralProvider) ImageGenerationStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostImageGenerationRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageGenerationStreamRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ImageEdit is not supported by the Mistral provider.
|
||||
func (provider *MistralProvider) ImageEdit(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostImageEditRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageEditRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ImageEditStream is not supported by the Mistral provider.
|
||||
func (provider *MistralProvider) ImageEditStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostImageEditRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageEditStreamRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ImageVariation is not supported by the Mistral provider.
|
||||
func (provider *MistralProvider) ImageVariation(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostImageVariationRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageVariationRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// VideoGeneration is not supported by the Mistral provider.
|
||||
func (provider *MistralProvider) VideoGeneration(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostVideoGenerationRequest) (*schemas.BifrostVideoGenerationResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.VideoGenerationRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// VideoRetrieve is not supported by the Mistral provider.
|
||||
func (provider *MistralProvider) VideoRetrieve(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostVideoRetrieveRequest) (*schemas.BifrostVideoGenerationResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.VideoRetrieveRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// VideoDownload is not supported by the Mistral provider.
|
||||
func (provider *MistralProvider) VideoDownload(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostVideoDownloadRequest) (*schemas.BifrostVideoDownloadResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.VideoDownloadRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// VideoDelete is not supported by the Mistral provider.
|
||||
func (provider *MistralProvider) VideoDelete(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostVideoDeleteRequest) (*schemas.BifrostVideoDeleteResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.VideoDeleteRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// VideoList is not supported by the Mistral provider.
|
||||
func (provider *MistralProvider) VideoList(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostVideoListRequest) (*schemas.BifrostVideoListResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.VideoListRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// VideoRemix is not supported by the Mistral provider.
|
||||
func (provider *MistralProvider) VideoRemix(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostVideoRemixRequest) (*schemas.BifrostVideoGenerationResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.VideoRemixRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ContainerCreate is not supported by the Mistral provider.
|
||||
func (provider *MistralProvider) ContainerCreate(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostContainerCreateRequest) (*schemas.BifrostContainerCreateResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerCreateRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ContainerList is not supported by the Mistral provider.
|
||||
func (provider *MistralProvider) ContainerList(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerListRequest) (*schemas.BifrostContainerListResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerListRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ContainerRetrieve is not supported by the Mistral provider.
|
||||
func (provider *MistralProvider) ContainerRetrieve(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerRetrieveRequest) (*schemas.BifrostContainerRetrieveResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerRetrieveRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ContainerDelete is not supported by the Mistral provider.
|
||||
func (provider *MistralProvider) ContainerDelete(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerDeleteRequest) (*schemas.BifrostContainerDeleteResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerDeleteRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ContainerFileCreate is not supported by the Mistral provider.
|
||||
func (provider *MistralProvider) ContainerFileCreate(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostContainerFileCreateRequest) (*schemas.BifrostContainerFileCreateResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerFileCreateRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ContainerFileList is not supported by the Mistral provider.
|
||||
func (provider *MistralProvider) ContainerFileList(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerFileListRequest) (*schemas.BifrostContainerFileListResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerFileListRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ContainerFileRetrieve is not supported by the Mistral provider.
|
||||
func (provider *MistralProvider) ContainerFileRetrieve(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerFileRetrieveRequest) (*schemas.BifrostContainerFileRetrieveResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerFileRetrieveRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ContainerFileContent is not supported by the Mistral provider.
|
||||
func (provider *MistralProvider) ContainerFileContent(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerFileContentRequest) (*schemas.BifrostContainerFileContentResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerFileContentRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ContainerFileDelete is not supported by the Mistral provider.
|
||||
func (provider *MistralProvider) ContainerFileDelete(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerFileDeleteRequest) (*schemas.BifrostContainerFileDeleteResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerFileDeleteRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// Passthrough is not supported by the Mistral provider.
|
||||
func (provider *MistralProvider) Passthrough(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostPassthroughRequest) (*schemas.BifrostPassthroughResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.PassthroughRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
func (provider *MistralProvider) PassthroughStream(_ *schemas.BifrostContext, _ schemas.PostHookRunner, _ func(context.Context), _ schemas.Key, _ *schemas.BifrostPassthroughRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.PassthroughStreamRequest, provider.GetProviderKey())
|
||||
}
|
||||
64
core/providers/mistral/mistral_test.go
Normal file
64
core/providers/mistral/mistral_test.go
Normal file
@@ -0,0 +1,64 @@
|
||||
package mistral_test
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/maximhq/bifrost/core/internal/llmtests"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
func TestMistral(t *testing.T) {
|
||||
t.Parallel()
|
||||
if strings.TrimSpace(os.Getenv("MISTRAL_API_KEY")) == "" {
|
||||
t.Skip("Skipping Mistral tests because MISTRAL_API_KEY is not set")
|
||||
}
|
||||
|
||||
client, ctx, cancel, err := llmtests.SetupTest()
|
||||
if err != nil {
|
||||
t.Fatalf("Error initializing test setup: %v", err)
|
||||
}
|
||||
defer cancel()
|
||||
defer client.Shutdown()
|
||||
|
||||
testConfig := llmtests.ComprehensiveTestConfig{
|
||||
Provider: schemas.Mistral,
|
||||
ChatModel: "mistral-medium-2508",
|
||||
Fallbacks: []schemas.Fallback{
|
||||
{Provider: schemas.Mistral, Model: "mistral-small-2503"},
|
||||
},
|
||||
VisionModel: "pixtral-12b-latest",
|
||||
EmbeddingModel: "codestral-embed",
|
||||
TranscriptionModel: "voxtral-mini-latest", // Mistral's audio transcription model
|
||||
ExternalTTSProvider: schemas.OpenAI,
|
||||
ExternalTTSModel: "gpt-4o-mini-tts",
|
||||
Scenarios: llmtests.TestScenarios{
|
||||
TextCompletion: false, // Not supported
|
||||
SimpleChat: true,
|
||||
CompletionStream: true,
|
||||
MultiTurnConversation: true,
|
||||
ToolCalls: true,
|
||||
ToolCallsStreaming: true,
|
||||
MultipleToolCalls: true,
|
||||
End2EndToolCalling: true,
|
||||
AutomaticFunctionCall: true,
|
||||
ImageURL: true,
|
||||
ImageBase64: true,
|
||||
MultipleImages: true,
|
||||
FileBase64: false, // supports documents url
|
||||
FileURL: false, // bifrost limitation: native mistral api converter needed
|
||||
CompleteEnd2End: true,
|
||||
Embedding: true,
|
||||
Transcription: true,
|
||||
TranscriptionStream: true,
|
||||
ListModels: true,
|
||||
Reasoning: false, // Not supported right now because we are not using native mistral converters
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("MistralTests", func(t *testing.T) {
|
||||
llmtests.RunAllComprehensiveTests(t, client, ctx, testConfig)
|
||||
})
|
||||
}
|
||||
55
core/providers/mistral/models.go
Normal file
55
core/providers/mistral/models.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package mistral
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
func (response *MistralListModelsResponse) ToBifrostListModelsResponse(allowedModels schemas.WhiteList, blacklistedModels schemas.BlackList, aliases map[string]string, unfiltered bool) *schemas.BifrostListModelsResponse {
|
||||
if response == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
bifrostResponse := &schemas.BifrostListModelsResponse{
|
||||
Data: make([]schemas.Model, 0, len(response.Data)),
|
||||
}
|
||||
|
||||
pipeline := &providerUtils.ListModelsPipeline{
|
||||
AllowedModels: allowedModels,
|
||||
BlacklistedModels: blacklistedModels,
|
||||
Aliases: aliases,
|
||||
Unfiltered: unfiltered,
|
||||
ProviderKey: schemas.Mistral,
|
||||
MatchFns: providerUtils.DefaultMatchFns(),
|
||||
}
|
||||
if pipeline.ShouldEarlyExit() {
|
||||
return bifrostResponse
|
||||
}
|
||||
|
||||
included := make(map[string]bool)
|
||||
|
||||
for _, model := range response.Data {
|
||||
for _, result := range pipeline.FilterModel(model.ID) {
|
||||
entry := schemas.Model{
|
||||
ID: string(schemas.Mistral) + "/" + result.ResolvedID,
|
||||
Name: schemas.Ptr(model.Name),
|
||||
Description: schemas.Ptr(model.Description),
|
||||
Created: schemas.Ptr(model.Created),
|
||||
ContextLength: schemas.Ptr(int(model.MaxContextLength)),
|
||||
OwnedBy: schemas.Ptr(model.OwnedBy),
|
||||
}
|
||||
if result.AliasValue != "" {
|
||||
entry.Alias = schemas.Ptr(result.AliasValue)
|
||||
}
|
||||
bifrostResponse.Data = append(bifrostResponse.Data, entry)
|
||||
included[strings.ToLower(result.ResolvedID)] = true
|
||||
}
|
||||
}
|
||||
|
||||
bifrostResponse.Data = append(bifrostResponse.Data,
|
||||
pipeline.BackfillModels(included)...)
|
||||
|
||||
return bifrostResponse
|
||||
}
|
||||
104
core/providers/mistral/ocr.go
Normal file
104
core/providers/mistral/ocr.go
Normal file
@@ -0,0 +1,104 @@
|
||||
package mistral
|
||||
|
||||
import (
|
||||
schemas "github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// ToMistralOCRRequest converts a Bifrost OCR request to a Mistral OCR request.
|
||||
func ToMistralOCRRequest(req *schemas.BifrostOCRRequest) *MistralOCRRequest {
|
||||
if req == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
mistralReq := &MistralOCRRequest{
|
||||
Model: req.Model,
|
||||
Document: MistralOCRDocument{
|
||||
Type: string(req.Document.Type),
|
||||
},
|
||||
}
|
||||
|
||||
if req.ID != nil {
|
||||
mistralReq.ID = *req.ID
|
||||
}
|
||||
|
||||
switch req.Document.Type {
|
||||
case schemas.OCRDocumentTypeDocumentURL:
|
||||
if req.Document.DocumentURL != nil {
|
||||
mistralReq.Document.DocumentURL = *req.Document.DocumentURL
|
||||
}
|
||||
case schemas.OCRDocumentTypeImageURL:
|
||||
if req.Document.ImageURL != nil {
|
||||
mistralReq.Document.ImageURL = *req.Document.ImageURL
|
||||
}
|
||||
}
|
||||
|
||||
if req.Params != nil {
|
||||
mistralReq.IncludeImageBase64 = req.Params.IncludeImageBase64
|
||||
mistralReq.Pages = req.Params.Pages
|
||||
mistralReq.ImageLimit = req.Params.ImageLimit
|
||||
mistralReq.ImageMinSize = req.Params.ImageMinSize
|
||||
mistralReq.TableFormat = req.Params.TableFormat
|
||||
mistralReq.ExtractHeader = req.Params.ExtractHeader
|
||||
mistralReq.ExtractFooter = req.Params.ExtractFooter
|
||||
mistralReq.BBoxAnnotationFormat = req.Params.BBoxAnnotationFormat
|
||||
mistralReq.DocumentAnnotationFormat = req.Params.DocumentAnnotationFormat
|
||||
mistralReq.DocumentAnnotationPrompt = req.Params.DocumentAnnotationPrompt
|
||||
mistralReq.ExtraParams = req.Params.ExtraParams
|
||||
}
|
||||
|
||||
return mistralReq
|
||||
}
|
||||
|
||||
// ToBifrostOCRResponse converts a Mistral OCR response to a Bifrost OCR response.
|
||||
func (r *MistralOCRResponse) ToBifrostOCRResponse() *schemas.BifrostOCRResponse {
|
||||
if r == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
resp := &schemas.BifrostOCRResponse{
|
||||
Model: r.Model,
|
||||
DocumentAnnotation: r.DocumentAnnotation,
|
||||
}
|
||||
|
||||
// Convert pages
|
||||
if len(r.Pages) > 0 {
|
||||
resp.Pages = make([]schemas.OCRPage, len(r.Pages))
|
||||
for i, p := range r.Pages {
|
||||
page := schemas.OCRPage{
|
||||
Index: p.Index,
|
||||
Markdown: p.Markdown,
|
||||
}
|
||||
if len(p.Images) > 0 {
|
||||
page.Images = make([]schemas.OCRPageImage, len(p.Images))
|
||||
for j, img := range p.Images {
|
||||
page.Images[j] = schemas.OCRPageImage{
|
||||
ID: img.ID,
|
||||
TopLeftX: img.TopLeftX,
|
||||
TopLeftY: img.TopLeftY,
|
||||
BottomRightX: img.BottomRightX,
|
||||
BottomRightY: img.BottomRightY,
|
||||
ImageBase64: img.ImageBase64,
|
||||
}
|
||||
}
|
||||
}
|
||||
if p.Dimensions != nil {
|
||||
page.Dimensions = &schemas.OCRPageDimensions{
|
||||
DPI: p.Dimensions.DPI,
|
||||
Height: p.Dimensions.Height,
|
||||
Width: p.Dimensions.Width,
|
||||
}
|
||||
}
|
||||
resp.Pages[i] = page
|
||||
}
|
||||
}
|
||||
|
||||
// Convert usage info
|
||||
if r.UsageInfo != nil {
|
||||
resp.UsageInfo = &schemas.OCRUsageInfo{
|
||||
PagesProcessed: r.UsageInfo.PagesProcessed,
|
||||
DocSizeBytes: r.UsageInfo.DocSizeBytes,
|
||||
}
|
||||
}
|
||||
|
||||
return resp
|
||||
}
|
||||
755
core/providers/mistral/ocr_test.go
Normal file
755
core/providers/mistral/ocr_test.go
Normal file
@@ -0,0 +1,755 @@
|
||||
package mistral
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestToMistralOCRRequest tests conversion from Bifrost OCR request to Mistral OCR request.
|
||||
func TestToMistralOCRRequest(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input *schemas.BifrostOCRRequest
|
||||
validate func(t *testing.T, result *MistralOCRRequest)
|
||||
}{
|
||||
{
|
||||
name: "nil request returns nil",
|
||||
input: nil,
|
||||
validate: func(t *testing.T, result *MistralOCRRequest) {
|
||||
assert.Nil(t, result)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "basic document_url request",
|
||||
input: &schemas.BifrostOCRRequest{
|
||||
Model: "mistral-ocr-latest",
|
||||
Document: schemas.OCRDocument{
|
||||
Type: schemas.OCRDocumentTypeDocumentURL,
|
||||
DocumentURL: schemas.Ptr("https://example.com/doc.pdf"),
|
||||
},
|
||||
},
|
||||
validate: func(t *testing.T, result *MistralOCRRequest) {
|
||||
require.NotNil(t, result)
|
||||
assert.Equal(t, "mistral-ocr-latest", result.Model)
|
||||
assert.Equal(t, "document_url", result.Document.Type)
|
||||
assert.Equal(t, "https://example.com/doc.pdf", result.Document.DocumentURL)
|
||||
assert.Empty(t, result.Document.ImageURL)
|
||||
assert.Empty(t, result.ID)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "basic image_url request",
|
||||
input: &schemas.BifrostOCRRequest{
|
||||
Model: "mistral-ocr-latest",
|
||||
Document: schemas.OCRDocument{
|
||||
Type: schemas.OCRDocumentTypeImageURL,
|
||||
ImageURL: schemas.Ptr("https://example.com/image.png"),
|
||||
},
|
||||
},
|
||||
validate: func(t *testing.T, result *MistralOCRRequest) {
|
||||
require.NotNil(t, result)
|
||||
assert.Equal(t, "mistral-ocr-latest", result.Model)
|
||||
assert.Equal(t, "image_url", result.Document.Type)
|
||||
assert.Equal(t, "https://example.com/image.png", result.Document.ImageURL)
|
||||
assert.Empty(t, result.Document.DocumentURL)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "request with ID",
|
||||
input: &schemas.BifrostOCRRequest{
|
||||
Model: "mistral-ocr-latest",
|
||||
ID: schemas.Ptr("req-123"),
|
||||
Document: schemas.OCRDocument{
|
||||
Type: schemas.OCRDocumentTypeDocumentURL,
|
||||
DocumentURL: schemas.Ptr("https://example.com/doc.pdf"),
|
||||
},
|
||||
},
|
||||
validate: func(t *testing.T, result *MistralOCRRequest) {
|
||||
require.NotNil(t, result)
|
||||
assert.Equal(t, "req-123", result.ID)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "request with all parameters",
|
||||
input: &schemas.BifrostOCRRequest{
|
||||
Model: "mistral-ocr-latest",
|
||||
Document: schemas.OCRDocument{
|
||||
Type: schemas.OCRDocumentTypeDocumentURL,
|
||||
DocumentURL: schemas.Ptr("https://example.com/doc.pdf"),
|
||||
},
|
||||
Params: &schemas.OCRParameters{
|
||||
IncludeImageBase64: schemas.Ptr(true),
|
||||
Pages: []int{0, 1, 2},
|
||||
ImageLimit: schemas.Ptr(10),
|
||||
ImageMinSize: schemas.Ptr(100),
|
||||
TableFormat: schemas.Ptr("html"),
|
||||
ExtractHeader: schemas.Ptr(true),
|
||||
ExtractFooter: schemas.Ptr(false),
|
||||
BBoxAnnotationFormat: schemas.Ptr("json"),
|
||||
DocumentAnnotationFormat: schemas.Ptr("markdown"),
|
||||
DocumentAnnotationPrompt: schemas.Ptr("Summarize this document"),
|
||||
},
|
||||
},
|
||||
validate: func(t *testing.T, result *MistralOCRRequest) {
|
||||
require.NotNil(t, result)
|
||||
assert.Equal(t, "mistral-ocr-latest", result.Model)
|
||||
assert.Equal(t, "document_url", result.Document.Type)
|
||||
assert.Equal(t, "https://example.com/doc.pdf", result.Document.DocumentURL)
|
||||
|
||||
require.NotNil(t, result.IncludeImageBase64)
|
||||
assert.True(t, *result.IncludeImageBase64)
|
||||
assert.Equal(t, []int{0, 1, 2}, result.Pages)
|
||||
require.NotNil(t, result.ImageLimit)
|
||||
assert.Equal(t, 10, *result.ImageLimit)
|
||||
require.NotNil(t, result.ImageMinSize)
|
||||
assert.Equal(t, 100, *result.ImageMinSize)
|
||||
require.NotNil(t, result.TableFormat)
|
||||
assert.Equal(t, "html", *result.TableFormat)
|
||||
require.NotNil(t, result.ExtractHeader)
|
||||
assert.True(t, *result.ExtractHeader)
|
||||
require.NotNil(t, result.ExtractFooter)
|
||||
assert.False(t, *result.ExtractFooter)
|
||||
require.NotNil(t, result.BBoxAnnotationFormat)
|
||||
assert.Equal(t, "json", *result.BBoxAnnotationFormat)
|
||||
require.NotNil(t, result.DocumentAnnotationFormat)
|
||||
assert.Equal(t, "markdown", *result.DocumentAnnotationFormat)
|
||||
require.NotNil(t, result.DocumentAnnotationPrompt)
|
||||
assert.Equal(t, "Summarize this document", *result.DocumentAnnotationPrompt)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "request with nil params",
|
||||
input: &schemas.BifrostOCRRequest{
|
||||
Model: "mistral-ocr-latest",
|
||||
Document: schemas.OCRDocument{
|
||||
Type: schemas.OCRDocumentTypeDocumentURL,
|
||||
DocumentURL: schemas.Ptr("https://example.com/doc.pdf"),
|
||||
},
|
||||
Params: nil,
|
||||
},
|
||||
validate: func(t *testing.T, result *MistralOCRRequest) {
|
||||
require.NotNil(t, result)
|
||||
assert.Nil(t, result.IncludeImageBase64)
|
||||
assert.Nil(t, result.Pages)
|
||||
assert.Nil(t, result.ImageLimit)
|
||||
assert.Nil(t, result.ImageMinSize)
|
||||
assert.Nil(t, result.TableFormat)
|
||||
assert.Nil(t, result.ExtractHeader)
|
||||
assert.Nil(t, result.ExtractFooter)
|
||||
assert.Nil(t, result.BBoxAnnotationFormat)
|
||||
assert.Nil(t, result.DocumentAnnotationFormat)
|
||||
assert.Nil(t, result.DocumentAnnotationPrompt)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
result := ToMistralOCRRequest(tt.input)
|
||||
tt.validate(t, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestToBifrostOCRResponse tests conversion from Mistral OCR response to Bifrost OCR response.
|
||||
func TestToBifrostOCRResponse(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input *MistralOCRResponse
|
||||
validate func(t *testing.T, result *schemas.BifrostOCRResponse)
|
||||
}{
|
||||
{
|
||||
name: "nil response returns nil",
|
||||
input: nil,
|
||||
validate: func(t *testing.T, result *schemas.BifrostOCRResponse) {
|
||||
assert.Nil(t, result)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "basic response with single page",
|
||||
input: &MistralOCRResponse{
|
||||
Model: "mistral-ocr-latest",
|
||||
Pages: []MistralOCRPage{
|
||||
{
|
||||
Index: 0,
|
||||
Markdown: "# Hello World\n\nThis is a test document.",
|
||||
},
|
||||
},
|
||||
},
|
||||
validate: func(t *testing.T, result *schemas.BifrostOCRResponse) {
|
||||
require.NotNil(t, result)
|
||||
assert.Equal(t, "mistral-ocr-latest", result.Model)
|
||||
require.Len(t, result.Pages, 1)
|
||||
assert.Equal(t, 0, result.Pages[0].Index)
|
||||
assert.Equal(t, "# Hello World\n\nThis is a test document.", result.Pages[0].Markdown)
|
||||
assert.Nil(t, result.Pages[0].Images)
|
||||
assert.Nil(t, result.Pages[0].Dimensions)
|
||||
assert.Nil(t, result.UsageInfo)
|
||||
assert.Nil(t, result.DocumentAnnotation)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "response with images",
|
||||
input: &MistralOCRResponse{
|
||||
Model: "mistral-ocr-latest",
|
||||
Pages: []MistralOCRPage{
|
||||
{
|
||||
Index: 0,
|
||||
Markdown: "Page with image",
|
||||
Images: []MistralOCRPageImage{
|
||||
{
|
||||
ID: "img-1",
|
||||
TopLeftX: 10.5,
|
||||
TopLeftY: 20.3,
|
||||
BottomRightX: 100.0,
|
||||
BottomRightY: 200.0,
|
||||
ImageBase64: schemas.Ptr("base64encodeddata"),
|
||||
},
|
||||
{
|
||||
ID: "img-2",
|
||||
TopLeftX: 50.0,
|
||||
TopLeftY: 60.0,
|
||||
BottomRightX: 150.0,
|
||||
BottomRightY: 250.0,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
validate: func(t *testing.T, result *schemas.BifrostOCRResponse) {
|
||||
require.NotNil(t, result)
|
||||
require.Len(t, result.Pages, 1)
|
||||
require.Len(t, result.Pages[0].Images, 2)
|
||||
|
||||
img1 := result.Pages[0].Images[0]
|
||||
assert.Equal(t, "img-1", img1.ID)
|
||||
assert.Equal(t, 10.5, img1.TopLeftX)
|
||||
assert.Equal(t, 20.3, img1.TopLeftY)
|
||||
assert.Equal(t, 100.0, img1.BottomRightX)
|
||||
assert.Equal(t, 200.0, img1.BottomRightY)
|
||||
require.NotNil(t, img1.ImageBase64)
|
||||
assert.Equal(t, "base64encodeddata", *img1.ImageBase64)
|
||||
|
||||
img2 := result.Pages[0].Images[1]
|
||||
assert.Equal(t, "img-2", img2.ID)
|
||||
assert.Nil(t, img2.ImageBase64)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "response with dimensions",
|
||||
input: &MistralOCRResponse{
|
||||
Model: "mistral-ocr-latest",
|
||||
Pages: []MistralOCRPage{
|
||||
{
|
||||
Index: 0,
|
||||
Markdown: "Page with dimensions",
|
||||
Dimensions: &MistralOCRPageDimensions{
|
||||
DPI: 300,
|
||||
Height: 2200,
|
||||
Width: 1700,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
validate: func(t *testing.T, result *schemas.BifrostOCRResponse) {
|
||||
require.NotNil(t, result)
|
||||
require.Len(t, result.Pages, 1)
|
||||
require.NotNil(t, result.Pages[0].Dimensions)
|
||||
assert.Equal(t, 300, result.Pages[0].Dimensions.DPI)
|
||||
assert.Equal(t, 2200, result.Pages[0].Dimensions.Height)
|
||||
assert.Equal(t, 1700, result.Pages[0].Dimensions.Width)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "response with usage info",
|
||||
input: &MistralOCRResponse{
|
||||
Model: "mistral-ocr-latest",
|
||||
Pages: []MistralOCRPage{
|
||||
{Index: 0, Markdown: "Page 1"},
|
||||
{Index: 1, Markdown: "Page 2"},
|
||||
},
|
||||
UsageInfo: &MistralOCRUsageInfo{
|
||||
PagesProcessed: 2,
|
||||
DocSizeBytes: 1024000,
|
||||
},
|
||||
},
|
||||
validate: func(t *testing.T, result *schemas.BifrostOCRResponse) {
|
||||
require.NotNil(t, result)
|
||||
require.Len(t, result.Pages, 2)
|
||||
require.NotNil(t, result.UsageInfo)
|
||||
assert.Equal(t, 2, result.UsageInfo.PagesProcessed)
|
||||
assert.Equal(t, 1024000, result.UsageInfo.DocSizeBytes)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "response with document annotation",
|
||||
input: &MistralOCRResponse{
|
||||
Model: "mistral-ocr-latest",
|
||||
Pages: []MistralOCRPage{
|
||||
{Index: 0, Markdown: "Page content"},
|
||||
},
|
||||
DocumentAnnotation: schemas.Ptr("This is a legal contract."),
|
||||
},
|
||||
validate: func(t *testing.T, result *schemas.BifrostOCRResponse) {
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, result.DocumentAnnotation)
|
||||
assert.Equal(t, "This is a legal contract.", *result.DocumentAnnotation)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "response with empty pages",
|
||||
input: &MistralOCRResponse{
|
||||
Model: "mistral-ocr-latest",
|
||||
Pages: []MistralOCRPage{},
|
||||
},
|
||||
validate: func(t *testing.T, result *schemas.BifrostOCRResponse) {
|
||||
require.NotNil(t, result)
|
||||
assert.Empty(t, result.Pages)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "full response with all fields",
|
||||
input: &MistralOCRResponse{
|
||||
Model: "mistral-ocr-latest",
|
||||
Pages: []MistralOCRPage{
|
||||
{
|
||||
Index: 0,
|
||||
Markdown: "# Title\n\nParagraph with **bold** text.",
|
||||
Images: []MistralOCRPageImage{
|
||||
{
|
||||
ID: "img-0-1",
|
||||
TopLeftX: 0,
|
||||
TopLeftY: 0,
|
||||
BottomRightX: 500,
|
||||
BottomRightY: 300,
|
||||
ImageBase64: schemas.Ptr("aW1hZ2VkYXRh"),
|
||||
},
|
||||
},
|
||||
Dimensions: &MistralOCRPageDimensions{
|
||||
DPI: 150,
|
||||
Height: 1100,
|
||||
Width: 850,
|
||||
},
|
||||
},
|
||||
},
|
||||
UsageInfo: &MistralOCRUsageInfo{
|
||||
PagesProcessed: 1,
|
||||
DocSizeBytes: 512000,
|
||||
},
|
||||
DocumentAnnotation: schemas.Ptr("A technical report."),
|
||||
},
|
||||
validate: func(t *testing.T, result *schemas.BifrostOCRResponse) {
|
||||
require.NotNil(t, result)
|
||||
assert.Equal(t, "mistral-ocr-latest", result.Model)
|
||||
require.Len(t, result.Pages, 1)
|
||||
|
||||
page := result.Pages[0]
|
||||
assert.Equal(t, 0, page.Index)
|
||||
assert.Contains(t, page.Markdown, "# Title")
|
||||
require.Len(t, page.Images, 1)
|
||||
assert.Equal(t, "img-0-1", page.Images[0].ID)
|
||||
require.NotNil(t, page.Images[0].ImageBase64)
|
||||
assert.Equal(t, "aW1hZ2VkYXRh", *page.Images[0].ImageBase64)
|
||||
require.NotNil(t, page.Dimensions)
|
||||
assert.Equal(t, 150, page.Dimensions.DPI)
|
||||
|
||||
require.NotNil(t, result.UsageInfo)
|
||||
assert.Equal(t, 1, result.UsageInfo.PagesProcessed)
|
||||
assert.Equal(t, 512000, result.UsageInfo.DocSizeBytes)
|
||||
|
||||
require.NotNil(t, result.DocumentAnnotation)
|
||||
assert.Equal(t, "A technical report.", *result.DocumentAnnotation)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
result := tt.input.ToBifrostOCRResponse()
|
||||
tt.validate(t, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestOCRWithMockServer tests the OCR method with a mock HTTP server.
|
||||
func TestOCRWithMockServer(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
request *schemas.BifrostOCRRequest
|
||||
statusCode int
|
||||
responseBody interface{}
|
||||
expectError bool
|
||||
errorContains string
|
||||
validateError func(t *testing.T, err *schemas.BifrostError)
|
||||
validateResult func(t *testing.T, resp *schemas.BifrostOCRResponse)
|
||||
}{
|
||||
{
|
||||
name: "successful OCR with document_url",
|
||||
request: &schemas.BifrostOCRRequest{
|
||||
Model: "mistral-ocr-latest",
|
||||
Document: schemas.OCRDocument{
|
||||
Type: schemas.OCRDocumentTypeDocumentURL,
|
||||
DocumentURL: schemas.Ptr("https://example.com/doc.pdf"),
|
||||
},
|
||||
},
|
||||
statusCode: http.StatusOK,
|
||||
responseBody: MistralOCRResponse{
|
||||
Model: "mistral-ocr-latest",
|
||||
Pages: []MistralOCRPage{
|
||||
{
|
||||
Index: 0,
|
||||
Markdown: "# Test Document\n\nThis is page 1.",
|
||||
},
|
||||
{
|
||||
Index: 1,
|
||||
Markdown: "## Section 2\n\nThis is page 2.",
|
||||
},
|
||||
},
|
||||
UsageInfo: &MistralOCRUsageInfo{
|
||||
PagesProcessed: 2,
|
||||
DocSizeBytes: 2048,
|
||||
},
|
||||
},
|
||||
expectError: false,
|
||||
validateResult: func(t *testing.T, resp *schemas.BifrostOCRResponse) {
|
||||
assert.Equal(t, "mistral-ocr-latest", resp.Model)
|
||||
require.Len(t, resp.Pages, 2)
|
||||
assert.Equal(t, 0, resp.Pages[0].Index)
|
||||
assert.Contains(t, resp.Pages[0].Markdown, "Test Document")
|
||||
assert.Equal(t, 1, resp.Pages[1].Index)
|
||||
require.NotNil(t, resp.UsageInfo)
|
||||
assert.Equal(t, 2, resp.UsageInfo.PagesProcessed)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "successful OCR with image_url",
|
||||
request: &schemas.BifrostOCRRequest{
|
||||
Model: "mistral-ocr-latest",
|
||||
Document: schemas.OCRDocument{
|
||||
Type: schemas.OCRDocumentTypeImageURL,
|
||||
ImageURL: schemas.Ptr("https://example.com/image.png"),
|
||||
},
|
||||
},
|
||||
statusCode: http.StatusOK,
|
||||
responseBody: MistralOCRResponse{
|
||||
Model: "mistral-ocr-latest",
|
||||
Pages: []MistralOCRPage{
|
||||
{
|
||||
Index: 0,
|
||||
Markdown: "Text extracted from image",
|
||||
Images: []MistralOCRPageImage{
|
||||
{
|
||||
ID: "img-1",
|
||||
TopLeftX: 0,
|
||||
TopLeftY: 0,
|
||||
BottomRightX: 100,
|
||||
BottomRightY: 100,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectError: false,
|
||||
validateResult: func(t *testing.T, resp *schemas.BifrostOCRResponse) {
|
||||
assert.Equal(t, "mistral-ocr-latest", resp.Model)
|
||||
require.Len(t, resp.Pages, 1)
|
||||
require.Len(t, resp.Pages[0].Images, 1)
|
||||
assert.Equal(t, "img-1", resp.Pages[0].Images[0].ID)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "server error 500",
|
||||
request: &schemas.BifrostOCRRequest{
|
||||
Model: "mistral-ocr-latest",
|
||||
Document: schemas.OCRDocument{
|
||||
Type: schemas.OCRDocumentTypeDocumentURL,
|
||||
DocumentURL: schemas.Ptr("https://example.com/doc.pdf"),
|
||||
},
|
||||
},
|
||||
statusCode: http.StatusInternalServerError,
|
||||
responseBody: map[string]interface{}{
|
||||
"message": "Internal server error",
|
||||
"type": "server_error",
|
||||
"code": "internal_error",
|
||||
},
|
||||
expectError: true,
|
||||
errorContains: "Internal server error",
|
||||
validateError: func(t *testing.T, err *schemas.BifrostError) {
|
||||
require.NotNil(t, err)
|
||||
require.NotNil(t, err.Error)
|
||||
require.NotNil(t, err.StatusCode)
|
||||
assert.Equal(t, http.StatusInternalServerError, *err.StatusCode)
|
||||
require.NotNil(t, err.Error.Type)
|
||||
assert.Equal(t, "server_error", *err.Error.Type)
|
||||
require.NotNil(t, err.Error.Code)
|
||||
assert.Equal(t, "internal_error", *err.Error.Code)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "unauthorized 401",
|
||||
request: &schemas.BifrostOCRRequest{
|
||||
Model: "mistral-ocr-latest",
|
||||
Document: schemas.OCRDocument{
|
||||
Type: schemas.OCRDocumentTypeDocumentURL,
|
||||
DocumentURL: schemas.Ptr("https://example.com/doc.pdf"),
|
||||
},
|
||||
},
|
||||
statusCode: http.StatusUnauthorized,
|
||||
responseBody: map[string]interface{}{
|
||||
"message": "Unauthorized",
|
||||
"type": "authentication_error",
|
||||
"code": "invalid_api_key",
|
||||
},
|
||||
expectError: true,
|
||||
errorContains: "Unauthorized",
|
||||
validateError: func(t *testing.T, err *schemas.BifrostError) {
|
||||
require.NotNil(t, err)
|
||||
require.NotNil(t, err.Error)
|
||||
require.NotNil(t, err.StatusCode)
|
||||
assert.Equal(t, http.StatusUnauthorized, *err.StatusCode)
|
||||
require.NotNil(t, err.Error.Type)
|
||||
assert.Equal(t, "authentication_error", *err.Error.Type)
|
||||
require.NotNil(t, err.Error.Code)
|
||||
assert.Equal(t, "invalid_api_key", *err.Error.Code)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty response body",
|
||||
request: &schemas.BifrostOCRRequest{
|
||||
Model: "mistral-ocr-latest",
|
||||
Document: schemas.OCRDocument{
|
||||
Type: schemas.OCRDocumentTypeDocumentURL,
|
||||
DocumentURL: schemas.Ptr("https://example.com/doc.pdf"),
|
||||
},
|
||||
},
|
||||
statusCode: http.StatusOK,
|
||||
responseBody: nil, // will send empty body
|
||||
expectError: true,
|
||||
errorContains: "",
|
||||
},
|
||||
{
|
||||
name: "HTML error response",
|
||||
request: &schemas.BifrostOCRRequest{
|
||||
Model: "mistral-ocr-latest",
|
||||
Document: schemas.OCRDocument{
|
||||
Type: schemas.OCRDocumentTypeDocumentURL,
|
||||
DocumentURL: schemas.Ptr("https://example.com/doc.pdf"),
|
||||
},
|
||||
},
|
||||
statusCode: http.StatusOK,
|
||||
responseBody: "html_error", // sentinel to trigger HTML response
|
||||
expectError: true,
|
||||
errorContains: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, http.MethodPost, r.Method)
|
||||
assert.Equal(t, "/v1/ocr", r.URL.Path)
|
||||
assert.Equal(t, "application/json", r.Header.Get("Content-Type"))
|
||||
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
assert.Contains(t, authHeader, "Bearer")
|
||||
|
||||
switch body := tt.responseBody.(type) {
|
||||
case nil:
|
||||
// Send empty body
|
||||
case string:
|
||||
if body == "html_error" {
|
||||
w.Header().Set("Content-Type", "text/html")
|
||||
}
|
||||
default:
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
}
|
||||
|
||||
w.WriteHeader(tt.statusCode)
|
||||
|
||||
switch body := tt.responseBody.(type) {
|
||||
case nil:
|
||||
// Send empty body
|
||||
case string:
|
||||
if body == "html_error" {
|
||||
w.Write([]byte("<html><body>502 Bad Gateway</body></html>"))
|
||||
}
|
||||
default:
|
||||
responseJSON, err := sonic.Marshal(body)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to marshal response: %v", err)
|
||||
}
|
||||
w.Write(responseJSON)
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
provider := NewMistralProvider(&schemas.ProviderConfig{
|
||||
NetworkConfig: schemas.NetworkConfig{
|
||||
BaseURL: server.URL,
|
||||
DefaultRequestTimeoutInSeconds: 30,
|
||||
},
|
||||
}, &testLogger{})
|
||||
|
||||
ctx, cancel := schemas.NewBifrostContextWithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
resp, err := provider.OCR(ctx, schemas.Key{Value: *schemas.NewEnvVar("test-api-key")}, tt.request)
|
||||
|
||||
if tt.expectError {
|
||||
require.NotNil(t, err)
|
||||
if tt.errorContains != "" {
|
||||
assert.Contains(t, err.Error.Message, tt.errorContains)
|
||||
}
|
||||
if tt.validateError != nil {
|
||||
tt.validateError(t, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
require.Nil(t, err)
|
||||
require.NotNil(t, resp)
|
||||
tt.validateResult(t, resp)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestOCRNilInput tests handling of nil OCR request.
|
||||
func TestOCRNilInput(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
provider := NewMistralProvider(&schemas.ProviderConfig{
|
||||
NetworkConfig: schemas.NetworkConfig{
|
||||
BaseURL: "https://api.mistral.ai",
|
||||
DefaultRequestTimeoutInSeconds: 30,
|
||||
},
|
||||
}, &testLogger{})
|
||||
|
||||
ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
|
||||
|
||||
resp, err := provider.OCR(ctx, schemas.Key{Value: *schemas.NewEnvVar("test-key")}, nil)
|
||||
|
||||
require.NotNil(t, err)
|
||||
assert.Nil(t, resp)
|
||||
assert.Contains(t, err.Error.Message, "ocr request input is not provided")
|
||||
}
|
||||
|
||||
// TestOCRRequestValidation tests that the mock server receives correctly serialized request bodies.
|
||||
func TestOCRRequestValidation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Parse the request body to validate it was serialized correctly
|
||||
var mistralReq MistralOCRRequest
|
||||
err := sonic.ConfigDefault.NewDecoder(r.Body).Decode(&mistralReq)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "mistral-ocr-latest", mistralReq.Model)
|
||||
assert.Equal(t, "document_url", mistralReq.Document.Type)
|
||||
assert.Equal(t, "https://example.com/doc.pdf", mistralReq.Document.DocumentURL)
|
||||
assert.NotNil(t, mistralReq.IncludeImageBase64)
|
||||
assert.True(t, *mistralReq.IncludeImageBase64)
|
||||
assert.Equal(t, []int{0, 1}, mistralReq.Pages)
|
||||
|
||||
// Return a valid response
|
||||
resp := MistralOCRResponse{
|
||||
Model: "mistral-ocr-latest",
|
||||
Pages: []MistralOCRPage{
|
||||
{Index: 0, Markdown: "Page 1"},
|
||||
{Index: 1, Markdown: "Page 2"},
|
||||
},
|
||||
}
|
||||
responseJSON, _ := sonic.Marshal(resp)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write(responseJSON)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
provider := NewMistralProvider(&schemas.ProviderConfig{
|
||||
NetworkConfig: schemas.NetworkConfig{
|
||||
BaseURL: server.URL,
|
||||
DefaultRequestTimeoutInSeconds: 30,
|
||||
},
|
||||
}, &testLogger{})
|
||||
|
||||
ctx, cancel := schemas.NewBifrostContextWithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
request := &schemas.BifrostOCRRequest{
|
||||
Model: "mistral-ocr-latest",
|
||||
Document: schemas.OCRDocument{
|
||||
Type: schemas.OCRDocumentTypeDocumentURL,
|
||||
DocumentURL: schemas.Ptr("https://example.com/doc.pdf"),
|
||||
},
|
||||
Params: &schemas.OCRParameters{
|
||||
IncludeImageBase64: schemas.Ptr(true),
|
||||
Pages: []int{0, 1},
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := provider.OCR(ctx, schemas.Key{Value: *schemas.NewEnvVar("test-api-key")}, request)
|
||||
|
||||
require.Nil(t, err)
|
||||
require.NotNil(t, resp)
|
||||
assert.Equal(t, "mistral-ocr-latest", resp.Model)
|
||||
require.Len(t, resp.Pages, 2)
|
||||
}
|
||||
|
||||
// TestMistralOCRIntegration tests the OCR endpoint with the real Mistral API.
|
||||
// This test requires MISTRAL_API_KEY environment variable to be set.
|
||||
// Run with: MISTRAL_API_KEY=xxx go test -v -run TestMistralOCRIntegration
|
||||
func TestMistralOCRIntegration(t *testing.T) {
|
||||
apiKey := os.Getenv("MISTRAL_API_KEY")
|
||||
if apiKey == "" {
|
||||
t.Skip("Skipping integration test: MISTRAL_API_KEY not set")
|
||||
}
|
||||
|
||||
provider := NewMistralProvider(&schemas.ProviderConfig{
|
||||
NetworkConfig: schemas.NetworkConfig{
|
||||
BaseURL: "https://api.mistral.ai",
|
||||
DefaultRequestTimeoutInSeconds: 60,
|
||||
},
|
||||
}, &testLogger{})
|
||||
|
||||
ctx, cancel := schemas.NewBifrostContextWithTimeout(context.Background(), 60*time.Second)
|
||||
defer cancel()
|
||||
|
||||
request := &schemas.BifrostOCRRequest{
|
||||
Model: "mistral-ocr-latest",
|
||||
Document: schemas.OCRDocument{
|
||||
Type: schemas.OCRDocumentTypeDocumentURL,
|
||||
DocumentURL: schemas.Ptr("https://arxiv.org/pdf/2201.04234"),
|
||||
},
|
||||
Params: &schemas.OCRParameters{
|
||||
Pages: []int{0},
|
||||
},
|
||||
}
|
||||
|
||||
resp, bifrostErr := provider.OCR(ctx, schemas.Key{Value: *schemas.NewEnvVar(apiKey)}, request)
|
||||
|
||||
require.Nil(t, bifrostErr, "OCR request failed: %v", bifrostErr)
|
||||
require.NotNil(t, resp)
|
||||
assert.Equal(t, "mistral-ocr-latest", resp.Model)
|
||||
require.NotEmpty(t, resp.Pages, "Expected at least one page")
|
||||
assert.Equal(t, 0, resp.Pages[0].Index)
|
||||
assert.NotEmpty(t, resp.Pages[0].Markdown, "Expected non-empty markdown for page 0")
|
||||
assert.Greater(t, resp.ExtraFields.Latency, int64(0))
|
||||
}
|
||||
219
core/providers/mistral/transcription.go
Normal file
219
core/providers/mistral/transcription.go
Normal file
@@ -0,0 +1,219 @@
|
||||
// Package mistral implements transcription support for Mistral's audio API.
|
||||
package mistral
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"mime/multipart"
|
||||
"strconv"
|
||||
|
||||
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// ToMistralTranscriptionRequest converts a Bifrost transcription request to Mistral format.
|
||||
func ToMistralTranscriptionRequest(bifrostReq *schemas.BifrostTranscriptionRequest) *MistralTranscriptionRequest {
|
||||
if bifrostReq == nil || bifrostReq.Input == nil || len(bifrostReq.Input.File) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
req := &MistralTranscriptionRequest{
|
||||
Model: bifrostReq.Model,
|
||||
File: bifrostReq.Input.File,
|
||||
Filename: bifrostReq.Input.Filename,
|
||||
}
|
||||
|
||||
if bifrostReq.Params != nil {
|
||||
req.Language = bifrostReq.Params.Language
|
||||
req.Prompt = bifrostReq.Params.Prompt
|
||||
req.ResponseFormat = bifrostReq.Params.ResponseFormat
|
||||
|
||||
// Handle extra params for Mistral-specific fields
|
||||
if bifrostReq.Params.ExtraParams != nil {
|
||||
if temp, ok := schemas.SafeExtractFloat64Pointer(bifrostReq.Params.ExtraParams["temperature"]); ok {
|
||||
req.Temperature = temp
|
||||
}
|
||||
if granularities, ok := bifrostReq.Params.ExtraParams["timestamp_granularities"].([]string); ok {
|
||||
req.TimestampGranularities = granularities
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return req
|
||||
}
|
||||
|
||||
// ToBifrostTranscriptionResponse converts a Mistral transcription response to Bifrost format.
|
||||
func (r *MistralTranscriptionResponse) ToBifrostTranscriptionResponse() *schemas.BifrostTranscriptionResponse {
|
||||
if r == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
response := &schemas.BifrostTranscriptionResponse{
|
||||
Text: r.Text,
|
||||
Duration: r.Duration,
|
||||
Language: r.Language,
|
||||
Task: schemas.Ptr("transcribe"),
|
||||
}
|
||||
|
||||
// Convert segments
|
||||
if len(r.Segments) > 0 {
|
||||
response.Segments = make([]schemas.TranscriptionSegment, len(r.Segments))
|
||||
for i, seg := range r.Segments {
|
||||
response.Segments[i] = schemas.TranscriptionSegment{
|
||||
ID: seg.ID,
|
||||
Seek: seg.Seek,
|
||||
Start: seg.Start,
|
||||
End: seg.End,
|
||||
Text: seg.Text,
|
||||
Tokens: seg.Tokens,
|
||||
Temperature: seg.Temperature,
|
||||
AvgLogProb: seg.AvgLogProb,
|
||||
CompressionRatio: seg.CompressionRatio,
|
||||
NoSpeechProb: seg.NoSpeechProb,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Convert words
|
||||
if len(r.Words) > 0 {
|
||||
response.Words = make([]schemas.TranscriptionWord, len(r.Words))
|
||||
for i, word := range r.Words {
|
||||
response.Words[i] = schemas.TranscriptionWord{
|
||||
Word: word.Word,
|
||||
Start: word.Start,
|
||||
End: word.End,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return response
|
||||
}
|
||||
|
||||
// createMistralTranscriptionMultipartBody creates the multipart form body for a transcription request.
|
||||
func createMistralTranscriptionMultipartBody(req *MistralTranscriptionRequest, providerName schemas.ModelProvider) (*bytes.Buffer, string, *schemas.BifrostError) {
|
||||
var body bytes.Buffer
|
||||
writer := multipart.NewWriter(&body)
|
||||
|
||||
if err := parseTranscriptionFormDataBodyFromRequest(writer, req, providerName); err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
return &body, writer.FormDataContentType(), nil
|
||||
}
|
||||
|
||||
// parseTranscriptionFormDataBodyFromRequest writes the transcription request to a multipart form.
|
||||
func parseTranscriptionFormDataBodyFromRequest(writer *multipart.Writer, req *MistralTranscriptionRequest, providerName schemas.ModelProvider) *schemas.BifrostError {
|
||||
// Add model field (required) before the file so upstreams can route without buffering audio bytes.
|
||||
if err := writer.WriteField("model", req.Model); err != nil {
|
||||
return providerUtils.NewBifrostOperationError("failed to write model field", err)
|
||||
}
|
||||
|
||||
// Add stream field if streaming
|
||||
if req.Stream != nil && *req.Stream {
|
||||
if err := writer.WriteField("stream", "true"); err != nil {
|
||||
return providerUtils.NewBifrostOperationError("failed to write stream field", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Add optional fields
|
||||
if req.Language != nil {
|
||||
if err := writer.WriteField("language", *req.Language); err != nil {
|
||||
return providerUtils.NewBifrostOperationError("failed to write language field", err)
|
||||
}
|
||||
}
|
||||
|
||||
if req.Prompt != nil {
|
||||
if err := writer.WriteField("prompt", *req.Prompt); err != nil {
|
||||
return providerUtils.NewBifrostOperationError("failed to write prompt field", err)
|
||||
}
|
||||
}
|
||||
|
||||
if req.ResponseFormat != nil {
|
||||
if err := writer.WriteField("response_format", *req.ResponseFormat); err != nil {
|
||||
return providerUtils.NewBifrostOperationError("failed to write response_format field", err)
|
||||
}
|
||||
}
|
||||
|
||||
if req.Temperature != nil {
|
||||
if err := writer.WriteField("temperature", formatFloat64(*req.Temperature)); err != nil {
|
||||
return providerUtils.NewBifrostOperationError("failed to write temperature field", err)
|
||||
}
|
||||
}
|
||||
|
||||
for _, granularity := range req.TimestampGranularities {
|
||||
if err := writer.WriteField("timestamp_granularities[]", granularity); err != nil {
|
||||
return providerUtils.NewBifrostOperationError("failed to write timestamp_granularities field", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Add file field last - Mistral uses "file" as the form field name.
|
||||
filename := req.Filename
|
||||
if filename == "" {
|
||||
filename = providerUtils.AudioFilenameFromBytes(req.File)
|
||||
}
|
||||
fileWriter, err := writer.CreateFormFile("file", filename)
|
||||
if err != nil {
|
||||
return providerUtils.NewBifrostOperationError("failed to create form file", err)
|
||||
}
|
||||
if _, err := fileWriter.Write(req.File); err != nil {
|
||||
return providerUtils.NewBifrostOperationError("failed to write file data", err)
|
||||
}
|
||||
|
||||
// Close the multipart writer to finalize the form
|
||||
if err := writer.Close(); err != nil {
|
||||
return providerUtils.NewBifrostOperationError("failed to close multipart writer", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// formatFloat64 converts a float64 to string for form fields.
|
||||
func formatFloat64(f float64) string {
|
||||
return strconv.FormatFloat(f, 'f', -1, 64)
|
||||
}
|
||||
|
||||
// ToBifrostTranscriptionStreamResponse converts a Mistral streaming event to Bifrost format.
|
||||
func (e *MistralTranscriptionStreamEvent) ToBifrostTranscriptionStreamResponse() *schemas.BifrostTranscriptionStreamResponse {
|
||||
if e == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
response := &schemas.BifrostTranscriptionStreamResponse{}
|
||||
|
||||
switch MistralTranscriptionStreamEventType(e.Event) {
|
||||
case MistralTranscriptionStreamEventTextDelta:
|
||||
response.Type = schemas.TranscriptionStreamResponseTypeDelta
|
||||
if e.Data != nil {
|
||||
response.Delta = &e.Data.Text
|
||||
response.Text = e.Data.Text
|
||||
}
|
||||
|
||||
case MistralTranscriptionStreamEventLanguage:
|
||||
response.Type = schemas.TranscriptionStreamResponseTypeDelta
|
||||
if e.Data != nil {
|
||||
response.Text = "" // Language event doesn't have text content
|
||||
}
|
||||
|
||||
case MistralTranscriptionStreamEventSegment:
|
||||
response.Type = schemas.TranscriptionStreamResponseTypeDelta
|
||||
if e.Data != nil && e.Data.Segment != nil {
|
||||
response.Text = e.Data.Segment.Text
|
||||
response.Delta = &e.Data.Segment.Text
|
||||
}
|
||||
|
||||
case MistralTranscriptionStreamEventDone:
|
||||
response.Type = schemas.TranscriptionStreamResponseTypeDone
|
||||
if e.Data != nil && e.Data.Usage != nil {
|
||||
totalTokens := e.Data.Usage.TotalTokens
|
||||
inputTokens := e.Data.Usage.PromptTokens
|
||||
outputTokens := e.Data.Usage.CompletionTokens
|
||||
response.Usage = &schemas.TranscriptionUsage{
|
||||
Type: "tokens",
|
||||
TotalTokens: &totalTokens,
|
||||
InputTokens: &inputTokens,
|
||||
OutputTokens: &outputTokens,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return response
|
||||
}
|
||||
1671
core/providers/mistral/transcription_test.go
Normal file
1671
core/providers/mistral/transcription_test.go
Normal file
File diff suppressed because it is too large
Load Diff
206
core/providers/mistral/types.go
Normal file
206
core/providers/mistral/types.go
Normal file
@@ -0,0 +1,206 @@
|
||||
package mistral
|
||||
|
||||
// MistralModel represents a single model in the Mistral Models API response
|
||||
type MistralModel struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
OwnedBy string `json:"owned_by"`
|
||||
Capabilities Capabilities `json:"capabilities"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
MaxContextLength int `json:"max_context_length"`
|
||||
Aliases []string `json:"aliases"`
|
||||
Deprecation *string `json:"deprecation,omitempty"`
|
||||
DeprecationReplacementModel *string `json:"deprecation_replacement_model,omitempty"`
|
||||
DefaultModelTemperature float64 `json:"default_model_temperature"`
|
||||
Type string `json:"type"`
|
||||
}
|
||||
|
||||
// Capabilities describes the model's supported features
|
||||
type Capabilities struct {
|
||||
CompletionChat bool `json:"completion_chat"`
|
||||
CompletionFim bool `json:"completion_fim"`
|
||||
FunctionCalling bool `json:"function_calling"`
|
||||
FineTuning bool `json:"fine_tuning"`
|
||||
Vision bool `json:"vision"`
|
||||
Classification bool `json:"classification"`
|
||||
}
|
||||
|
||||
// MistralListModelsResponse is the root response object from the Mistral Models API
|
||||
type MistralListModelsResponse struct {
|
||||
Object string `json:"object"`
|
||||
Data []MistralModel `json:"data"`
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Transcription Types
|
||||
// ============================================================================
|
||||
|
||||
// MistralTranscriptionRequest represents a Mistral audio transcription request.
|
||||
// Based on: https://docs.mistral.ai/capabilities/audio_transcription
|
||||
type MistralTranscriptionRequest struct {
|
||||
Model string `json:"model"` // Required: e.g., "mistral-audio-transcribe"
|
||||
File []byte `json:"file"` // Required: Binary audio data
|
||||
Filename string `json:"filename"` // Original filename, used to preserve file format extension
|
||||
Language *string `json:"language,omitempty"` // Optional: ISO 639-1 language code
|
||||
Prompt *string `json:"prompt,omitempty"` // Optional: Context hint for transcription
|
||||
ResponseFormat *string `json:"response_format,omitempty"` // Optional: "json", "text", "srt", "verbose_json", "vtt"
|
||||
Temperature *float64 `json:"temperature,omitempty"` // Optional: Sampling temperature (0 to 1)
|
||||
Stream *bool `json:"stream,omitempty"` // Optional: Enable streaming mode
|
||||
TimestampGranularities []string `json:"timestamp_granularities,omitempty"` // Optional: "word" or "segment"
|
||||
}
|
||||
|
||||
// MistralTranscriptionResponse represents Mistral's transcription response.
|
||||
type MistralTranscriptionResponse struct {
|
||||
Text string `json:"text"` // Transcribed text
|
||||
Duration *float64 `json:"duration,omitempty"` // Audio duration in seconds
|
||||
Language *string `json:"language,omitempty"` // Detected language
|
||||
Segments []MistralTranscriptionSegment `json:"segments,omitempty"` // Segments (verbose_json format)
|
||||
Words []MistralTranscriptionWord `json:"words,omitempty"` // Word-level timestamps
|
||||
}
|
||||
|
||||
// MistralTranscriptionSegment represents a segment in verbose_json format.
|
||||
type MistralTranscriptionSegment struct {
|
||||
ID int `json:"id"`
|
||||
Seek int `json:"seek,omitempty"`
|
||||
Start float64 `json:"start"`
|
||||
End float64 `json:"end"`
|
||||
Text string `json:"text"`
|
||||
Tokens []int `json:"tokens,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
AvgLogProb float64 `json:"avg_logprob,omitempty"`
|
||||
CompressionRatio float64 `json:"compression_ratio,omitempty"`
|
||||
NoSpeechProb float64 `json:"no_speech_prob,omitempty"`
|
||||
}
|
||||
|
||||
// MistralTranscriptionWord represents word-level timing information.
|
||||
type MistralTranscriptionWord struct {
|
||||
Word string `json:"word"`
|
||||
Start float64 `json:"start"`
|
||||
End float64 `json:"end"`
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Transcription Streaming Types
|
||||
// ============================================================================
|
||||
|
||||
// MistralTranscriptionStreamEventType represents the type of streaming event.
|
||||
type MistralTranscriptionStreamEventType string
|
||||
|
||||
const (
|
||||
// MistralTranscriptionStreamEventLanguage is the language detection event.
|
||||
MistralTranscriptionStreamEventLanguage MistralTranscriptionStreamEventType = "transcription.language"
|
||||
// MistralTranscriptionStreamEventSegment is the segment event.
|
||||
MistralTranscriptionStreamEventSegment MistralTranscriptionStreamEventType = "transcription.segment"
|
||||
// MistralTranscriptionStreamEventTextDelta is the text delta event.
|
||||
MistralTranscriptionStreamEventTextDelta MistralTranscriptionStreamEventType = "transcription.text.delta"
|
||||
// MistralTranscriptionStreamEventDone is the done event with usage info.
|
||||
MistralTranscriptionStreamEventDone MistralTranscriptionStreamEventType = "transcription.done"
|
||||
)
|
||||
|
||||
// MistralTranscriptionStreamEvent represents a streaming transcription event from Mistral.
|
||||
type MistralTranscriptionStreamEvent struct {
|
||||
Event string `json:"event"`
|
||||
Data *MistralTranscriptionStreamData `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
// MistralTranscriptionStreamData represents the data payload for streaming events.
|
||||
type MistralTranscriptionStreamData struct {
|
||||
// For transcription.text.delta events
|
||||
Text string `json:"text,omitempty"`
|
||||
|
||||
// For transcription.language events
|
||||
Language string `json:"language,omitempty"`
|
||||
|
||||
// For transcription.segment events
|
||||
Segment *MistralTranscriptionStreamSegment `json:"segment,omitempty"`
|
||||
|
||||
// For transcription.done events
|
||||
Model string `json:"model,omitempty"`
|
||||
Usage *MistralTranscriptionUsage `json:"usage,omitempty"`
|
||||
}
|
||||
|
||||
// MistralTranscriptionStreamSegment represents a segment in streaming response.
|
||||
type MistralTranscriptionStreamSegment struct {
|
||||
ID int `json:"id"`
|
||||
Start float64 `json:"start"`
|
||||
End float64 `json:"end"`
|
||||
Text string `json:"text"`
|
||||
}
|
||||
|
||||
// MistralTranscriptionUsage represents usage information in streaming done event.
|
||||
type MistralTranscriptionUsage struct {
|
||||
PromptAudioSeconds int `json:"prompt_audio_seconds,omitempty"`
|
||||
PromptTokens int `json:"prompt_tokens,omitempty"`
|
||||
TotalTokens int `json:"total_tokens,omitempty"`
|
||||
CompletionTokens int `json:"completion_tokens,omitempty"`
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// OCR Types
|
||||
// ============================================================================
|
||||
|
||||
// MistralOCRDocument represents the document input for a Mistral OCR request.
|
||||
type MistralOCRDocument struct {
|
||||
Type string `json:"type"`
|
||||
DocumentURL string `json:"document_url,omitempty"`
|
||||
ImageURL string `json:"image_url,omitempty"`
|
||||
}
|
||||
|
||||
// MistralOCRRequest represents a Mistral OCR API request.
|
||||
type MistralOCRRequest struct {
|
||||
Model string `json:"model"`
|
||||
ID string `json:"id,omitempty"`
|
||||
Document MistralOCRDocument `json:"document"`
|
||||
IncludeImageBase64 *bool `json:"include_image_base64,omitempty"`
|
||||
Pages []int `json:"pages,omitempty"`
|
||||
ImageLimit *int `json:"image_limit,omitempty"`
|
||||
ImageMinSize *int `json:"image_min_size,omitempty"`
|
||||
TableFormat *string `json:"table_format,omitempty"`
|
||||
ExtractHeader *bool `json:"extract_header,omitempty"`
|
||||
ExtractFooter *bool `json:"extract_footer,omitempty"`
|
||||
BBoxAnnotationFormat *string `json:"bbox_annotation_format,omitempty"`
|
||||
DocumentAnnotationFormat *string `json:"document_annotation_format,omitempty"`
|
||||
DocumentAnnotationPrompt *string `json:"document_annotation_prompt,omitempty"`
|
||||
ExtraParams map[string]interface{} `json:"-"`
|
||||
}
|
||||
|
||||
// MistralOCRPageImage represents an extracted image in Mistral's OCR response.
|
||||
type MistralOCRPageImage struct {
|
||||
ID string `json:"id"`
|
||||
TopLeftX float64 `json:"top_left_x"`
|
||||
TopLeftY float64 `json:"top_left_y"`
|
||||
BottomRightX float64 `json:"bottom_right_x"`
|
||||
BottomRightY float64 `json:"bottom_right_y"`
|
||||
ImageBase64 *string `json:"image_base64,omitempty"`
|
||||
}
|
||||
|
||||
// MistralOCRPageDimensions represents page dimensions in Mistral's OCR response.
|
||||
type MistralOCRPageDimensions struct {
|
||||
DPI int `json:"dpi"`
|
||||
Height int `json:"height"`
|
||||
Width int `json:"width"`
|
||||
}
|
||||
|
||||
// MistralOCRPage represents a single page in Mistral's OCR response.
|
||||
type MistralOCRPage struct {
|
||||
Index int `json:"index"`
|
||||
Markdown string `json:"markdown"`
|
||||
Images []MistralOCRPageImage `json:"images,omitempty"`
|
||||
Dimensions *MistralOCRPageDimensions `json:"dimensions,omitempty"`
|
||||
}
|
||||
|
||||
// MistralOCRUsageInfo represents usage information in Mistral's OCR response.
|
||||
type MistralOCRUsageInfo struct {
|
||||
PagesProcessed int `json:"pages_processed"`
|
||||
DocSizeBytes int `json:"doc_size_bytes"`
|
||||
}
|
||||
|
||||
// MistralOCRResponse represents Mistral's OCR API response.
|
||||
type MistralOCRResponse struct {
|
||||
Model string `json:"model"`
|
||||
Pages []MistralOCRPage `json:"pages"`
|
||||
UsageInfo *MistralOCRUsageInfo `json:"usage_info,omitempty"`
|
||||
DocumentAnnotation *string `json:"document_annotation,omitempty"`
|
||||
}
|
||||
Reference in New Issue
Block a user