package mistral import ( "context" "io" "net/http" "net/http/httptest" "testing" "github.com/bytedance/sonic" "github.com/maximhq/bifrost/core/schemas" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/valyala/fasthttp" ) const customMistralProviderName = schemas.ModelProvider("custom-mistral") func TestParseMistralError_UsesExportedConverterMetadata(t *testing.T) { t.Parallel() resp := fasthttp.AcquireResponse() defer fasthttp.ReleaseResponse(resp) resp.SetStatusCode(http.StatusBadRequest) resp.SetBodyString(`{"message":"invalid request","type":"invalid_request_error","code":"bad_request"}`) bifrostErr := ParseMistralError(resp) require.NotNil(t, bifrostErr) require.NotNil(t, bifrostErr.Error) assert.Equal(t, "invalid request", bifrostErr.Error.Message) assert.Equal(t, schemas.Ptr("invalid_request_error"), bifrostErr.Error.Type) assert.Equal(t, schemas.Ptr("bad_request"), bifrostErr.Error.Code) // Note: ExtraFields.Provider is populated by bifrost.go's dispatcher via // PopulateExtraFields, not by ParseMistralError called in isolation. } func TestMistralProvider_CustomAliasChatStreamUsesBaseCompatibilityAndAliasMetadata(t *testing.T) { t.Parallel() var capturedRequest map[string]any server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { body, err := io.ReadAll(r.Body) require.NoError(t, err) require.NoError(t, sonic.Unmarshal(body, &capturedRequest)) w.Header().Set("Content-Type", "text/event-stream") flusher, ok := w.(http.Flusher) require.True(t, ok) _, err = w.Write([]byte("data: {\"id\":\"chatcmpl-1\",\"object\":\"chat.completion.chunk\",\"created\":1,\"model\":\"mistral-small-latest\",\"choices\":[{\"index\":0,\"delta\":{\"role\":\"assistant\",\"content\":\"hello\"}}]}\n\n")) require.NoError(t, err) flusher.Flush() _, err = w.Write([]byte("data: [DONE]\n\n")) require.NoError(t, err) flusher.Flush() })) defer server.Close() provider := NewMistralProvider(&schemas.ProviderConfig{ NetworkConfig: schemas.NetworkConfig{BaseURL: server.URL}, CustomProviderConfig: &schemas.CustomProviderConfig{ CustomProviderKey: string(customMistralProviderName), BaseProviderType: schemas.Mistral, }, }, &testLogger{}) ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) ctx.SetValue(schemas.BifrostContextKeyIsCustomProvider, true) request := &schemas.BifrostChatRequest{ Provider: customMistralProviderName, Model: "mistral-small-latest", Input: []schemas.ChatMessage{{ Role: schemas.ChatMessageRoleUser, Content: &schemas.ChatMessageContent{ ContentStr: schemas.Ptr("hello"), }, }}, Params: &schemas.ChatParameters{ MaxCompletionTokens: schemas.Ptr(32), ToolChoice: &schemas.ChatToolChoice{ ChatToolChoiceStruct: &schemas.ChatToolChoiceStruct{ Type: schemas.ChatToolChoiceTypeFunction, Function: &schemas.ChatToolChoiceFunction{ Name: "lookup", }, }, }, }, } postHookRunner := func(_ *schemas.BifrostContext, response *schemas.BifrostResponse, err *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError) { return response, err } stream, bifrostErr := provider.ChatCompletionStream(ctx, postHookRunner, nil, schemas.Key{}, request) require.Nil(t, bifrostErr) var firstResponse *schemas.BifrostChatResponse for chunk := range stream { if chunk.BifrostError != nil { t.Fatalf("unexpected stream error: %s", chunk.BifrostError.Error.Message) } if chunk.BifrostChatResponse != nil { firstResponse = chunk.BifrostChatResponse break } } require.NotNil(t, firstResponse) // Note: ExtraFields.Provider on stream chunks is populated by bifrost.go's // dispatcher via PopulateExtraFields, not by provider streaming methods // called in isolation. require.NotNil(t, capturedRequest) assert.Equal(t, float64(32), capturedRequest["max_tokens"]) assert.NotContains(t, capturedRequest, "max_completion_tokens") assert.Equal(t, "any", capturedRequest["tool_choice"]) assert.Equal(t, "mistral-small-latest", capturedRequest["model"]) assert.Equal(t, true, capturedRequest["stream"]) assert.Equal(t, customMistralProviderName, provider.GetProviderKey()) assert.Equal(t, customMistralProviderName, request.Provider) } func TestMistralProvider_CustomAliasEmbeddingReportsAliasMetadata(t *testing.T) { t.Parallel() server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") _, err := w.Write([]byte(`{"object":"list","data":[{"object":"embedding","embedding":[0.1,0.2],"index":0}],"model":"codestral-embed","usage":{"prompt_tokens":1,"total_tokens":1}}`)) require.NoError(t, err) })) defer server.Close() provider := NewMistralProvider(&schemas.ProviderConfig{ NetworkConfig: schemas.NetworkConfig{BaseURL: server.URL}, CustomProviderConfig: &schemas.CustomProviderConfig{ CustomProviderKey: string(customMistralProviderName), BaseProviderType: schemas.Mistral, }, }, &testLogger{}) ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) request := &schemas.BifrostEmbeddingRequest{ Provider: customMistralProviderName, Model: "codestral-embed", Input: &schemas.EmbeddingInput{ Texts: []string{"hello"}, }, } response, bifrostErr := provider.Embedding(ctx, schemas.Key{}, request) require.Nil(t, bifrostErr) require.NotNil(t, response) // Note: ExtraFields.Provider and ResolvedModelUsed are populated by // bifrost.go's dispatcher via PopulateExtraFields, not by provider // methods called in isolation. }