first commit
This commit is contained in:
755
core/providers/mistral/ocr_test.go
Normal file
755
core/providers/mistral/ocr_test.go
Normal file
@@ -0,0 +1,755 @@
|
||||
package mistral
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestToMistralOCRRequest tests conversion from Bifrost OCR request to Mistral OCR request.
|
||||
func TestToMistralOCRRequest(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input *schemas.BifrostOCRRequest
|
||||
validate func(t *testing.T, result *MistralOCRRequest)
|
||||
}{
|
||||
{
|
||||
name: "nil request returns nil",
|
||||
input: nil,
|
||||
validate: func(t *testing.T, result *MistralOCRRequest) {
|
||||
assert.Nil(t, result)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "basic document_url request",
|
||||
input: &schemas.BifrostOCRRequest{
|
||||
Model: "mistral-ocr-latest",
|
||||
Document: schemas.OCRDocument{
|
||||
Type: schemas.OCRDocumentTypeDocumentURL,
|
||||
DocumentURL: schemas.Ptr("https://example.com/doc.pdf"),
|
||||
},
|
||||
},
|
||||
validate: func(t *testing.T, result *MistralOCRRequest) {
|
||||
require.NotNil(t, result)
|
||||
assert.Equal(t, "mistral-ocr-latest", result.Model)
|
||||
assert.Equal(t, "document_url", result.Document.Type)
|
||||
assert.Equal(t, "https://example.com/doc.pdf", result.Document.DocumentURL)
|
||||
assert.Empty(t, result.Document.ImageURL)
|
||||
assert.Empty(t, result.ID)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "basic image_url request",
|
||||
input: &schemas.BifrostOCRRequest{
|
||||
Model: "mistral-ocr-latest",
|
||||
Document: schemas.OCRDocument{
|
||||
Type: schemas.OCRDocumentTypeImageURL,
|
||||
ImageURL: schemas.Ptr("https://example.com/image.png"),
|
||||
},
|
||||
},
|
||||
validate: func(t *testing.T, result *MistralOCRRequest) {
|
||||
require.NotNil(t, result)
|
||||
assert.Equal(t, "mistral-ocr-latest", result.Model)
|
||||
assert.Equal(t, "image_url", result.Document.Type)
|
||||
assert.Equal(t, "https://example.com/image.png", result.Document.ImageURL)
|
||||
assert.Empty(t, result.Document.DocumentURL)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "request with ID",
|
||||
input: &schemas.BifrostOCRRequest{
|
||||
Model: "mistral-ocr-latest",
|
||||
ID: schemas.Ptr("req-123"),
|
||||
Document: schemas.OCRDocument{
|
||||
Type: schemas.OCRDocumentTypeDocumentURL,
|
||||
DocumentURL: schemas.Ptr("https://example.com/doc.pdf"),
|
||||
},
|
||||
},
|
||||
validate: func(t *testing.T, result *MistralOCRRequest) {
|
||||
require.NotNil(t, result)
|
||||
assert.Equal(t, "req-123", result.ID)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "request with all parameters",
|
||||
input: &schemas.BifrostOCRRequest{
|
||||
Model: "mistral-ocr-latest",
|
||||
Document: schemas.OCRDocument{
|
||||
Type: schemas.OCRDocumentTypeDocumentURL,
|
||||
DocumentURL: schemas.Ptr("https://example.com/doc.pdf"),
|
||||
},
|
||||
Params: &schemas.OCRParameters{
|
||||
IncludeImageBase64: schemas.Ptr(true),
|
||||
Pages: []int{0, 1, 2},
|
||||
ImageLimit: schemas.Ptr(10),
|
||||
ImageMinSize: schemas.Ptr(100),
|
||||
TableFormat: schemas.Ptr("html"),
|
||||
ExtractHeader: schemas.Ptr(true),
|
||||
ExtractFooter: schemas.Ptr(false),
|
||||
BBoxAnnotationFormat: schemas.Ptr("json"),
|
||||
DocumentAnnotationFormat: schemas.Ptr("markdown"),
|
||||
DocumentAnnotationPrompt: schemas.Ptr("Summarize this document"),
|
||||
},
|
||||
},
|
||||
validate: func(t *testing.T, result *MistralOCRRequest) {
|
||||
require.NotNil(t, result)
|
||||
assert.Equal(t, "mistral-ocr-latest", result.Model)
|
||||
assert.Equal(t, "document_url", result.Document.Type)
|
||||
assert.Equal(t, "https://example.com/doc.pdf", result.Document.DocumentURL)
|
||||
|
||||
require.NotNil(t, result.IncludeImageBase64)
|
||||
assert.True(t, *result.IncludeImageBase64)
|
||||
assert.Equal(t, []int{0, 1, 2}, result.Pages)
|
||||
require.NotNil(t, result.ImageLimit)
|
||||
assert.Equal(t, 10, *result.ImageLimit)
|
||||
require.NotNil(t, result.ImageMinSize)
|
||||
assert.Equal(t, 100, *result.ImageMinSize)
|
||||
require.NotNil(t, result.TableFormat)
|
||||
assert.Equal(t, "html", *result.TableFormat)
|
||||
require.NotNil(t, result.ExtractHeader)
|
||||
assert.True(t, *result.ExtractHeader)
|
||||
require.NotNil(t, result.ExtractFooter)
|
||||
assert.False(t, *result.ExtractFooter)
|
||||
require.NotNil(t, result.BBoxAnnotationFormat)
|
||||
assert.Equal(t, "json", *result.BBoxAnnotationFormat)
|
||||
require.NotNil(t, result.DocumentAnnotationFormat)
|
||||
assert.Equal(t, "markdown", *result.DocumentAnnotationFormat)
|
||||
require.NotNil(t, result.DocumentAnnotationPrompt)
|
||||
assert.Equal(t, "Summarize this document", *result.DocumentAnnotationPrompt)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "request with nil params",
|
||||
input: &schemas.BifrostOCRRequest{
|
||||
Model: "mistral-ocr-latest",
|
||||
Document: schemas.OCRDocument{
|
||||
Type: schemas.OCRDocumentTypeDocumentURL,
|
||||
DocumentURL: schemas.Ptr("https://example.com/doc.pdf"),
|
||||
},
|
||||
Params: nil,
|
||||
},
|
||||
validate: func(t *testing.T, result *MistralOCRRequest) {
|
||||
require.NotNil(t, result)
|
||||
assert.Nil(t, result.IncludeImageBase64)
|
||||
assert.Nil(t, result.Pages)
|
||||
assert.Nil(t, result.ImageLimit)
|
||||
assert.Nil(t, result.ImageMinSize)
|
||||
assert.Nil(t, result.TableFormat)
|
||||
assert.Nil(t, result.ExtractHeader)
|
||||
assert.Nil(t, result.ExtractFooter)
|
||||
assert.Nil(t, result.BBoxAnnotationFormat)
|
||||
assert.Nil(t, result.DocumentAnnotationFormat)
|
||||
assert.Nil(t, result.DocumentAnnotationPrompt)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
result := ToMistralOCRRequest(tt.input)
|
||||
tt.validate(t, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestToBifrostOCRResponse tests conversion from Mistral OCR response to Bifrost OCR response.
|
||||
func TestToBifrostOCRResponse(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input *MistralOCRResponse
|
||||
validate func(t *testing.T, result *schemas.BifrostOCRResponse)
|
||||
}{
|
||||
{
|
||||
name: "nil response returns nil",
|
||||
input: nil,
|
||||
validate: func(t *testing.T, result *schemas.BifrostOCRResponse) {
|
||||
assert.Nil(t, result)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "basic response with single page",
|
||||
input: &MistralOCRResponse{
|
||||
Model: "mistral-ocr-latest",
|
||||
Pages: []MistralOCRPage{
|
||||
{
|
||||
Index: 0,
|
||||
Markdown: "# Hello World\n\nThis is a test document.",
|
||||
},
|
||||
},
|
||||
},
|
||||
validate: func(t *testing.T, result *schemas.BifrostOCRResponse) {
|
||||
require.NotNil(t, result)
|
||||
assert.Equal(t, "mistral-ocr-latest", result.Model)
|
||||
require.Len(t, result.Pages, 1)
|
||||
assert.Equal(t, 0, result.Pages[0].Index)
|
||||
assert.Equal(t, "# Hello World\n\nThis is a test document.", result.Pages[0].Markdown)
|
||||
assert.Nil(t, result.Pages[0].Images)
|
||||
assert.Nil(t, result.Pages[0].Dimensions)
|
||||
assert.Nil(t, result.UsageInfo)
|
||||
assert.Nil(t, result.DocumentAnnotation)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "response with images",
|
||||
input: &MistralOCRResponse{
|
||||
Model: "mistral-ocr-latest",
|
||||
Pages: []MistralOCRPage{
|
||||
{
|
||||
Index: 0,
|
||||
Markdown: "Page with image",
|
||||
Images: []MistralOCRPageImage{
|
||||
{
|
||||
ID: "img-1",
|
||||
TopLeftX: 10.5,
|
||||
TopLeftY: 20.3,
|
||||
BottomRightX: 100.0,
|
||||
BottomRightY: 200.0,
|
||||
ImageBase64: schemas.Ptr("base64encodeddata"),
|
||||
},
|
||||
{
|
||||
ID: "img-2",
|
||||
TopLeftX: 50.0,
|
||||
TopLeftY: 60.0,
|
||||
BottomRightX: 150.0,
|
||||
BottomRightY: 250.0,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
validate: func(t *testing.T, result *schemas.BifrostOCRResponse) {
|
||||
require.NotNil(t, result)
|
||||
require.Len(t, result.Pages, 1)
|
||||
require.Len(t, result.Pages[0].Images, 2)
|
||||
|
||||
img1 := result.Pages[0].Images[0]
|
||||
assert.Equal(t, "img-1", img1.ID)
|
||||
assert.Equal(t, 10.5, img1.TopLeftX)
|
||||
assert.Equal(t, 20.3, img1.TopLeftY)
|
||||
assert.Equal(t, 100.0, img1.BottomRightX)
|
||||
assert.Equal(t, 200.0, img1.BottomRightY)
|
||||
require.NotNil(t, img1.ImageBase64)
|
||||
assert.Equal(t, "base64encodeddata", *img1.ImageBase64)
|
||||
|
||||
img2 := result.Pages[0].Images[1]
|
||||
assert.Equal(t, "img-2", img2.ID)
|
||||
assert.Nil(t, img2.ImageBase64)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "response with dimensions",
|
||||
input: &MistralOCRResponse{
|
||||
Model: "mistral-ocr-latest",
|
||||
Pages: []MistralOCRPage{
|
||||
{
|
||||
Index: 0,
|
||||
Markdown: "Page with dimensions",
|
||||
Dimensions: &MistralOCRPageDimensions{
|
||||
DPI: 300,
|
||||
Height: 2200,
|
||||
Width: 1700,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
validate: func(t *testing.T, result *schemas.BifrostOCRResponse) {
|
||||
require.NotNil(t, result)
|
||||
require.Len(t, result.Pages, 1)
|
||||
require.NotNil(t, result.Pages[0].Dimensions)
|
||||
assert.Equal(t, 300, result.Pages[0].Dimensions.DPI)
|
||||
assert.Equal(t, 2200, result.Pages[0].Dimensions.Height)
|
||||
assert.Equal(t, 1700, result.Pages[0].Dimensions.Width)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "response with usage info",
|
||||
input: &MistralOCRResponse{
|
||||
Model: "mistral-ocr-latest",
|
||||
Pages: []MistralOCRPage{
|
||||
{Index: 0, Markdown: "Page 1"},
|
||||
{Index: 1, Markdown: "Page 2"},
|
||||
},
|
||||
UsageInfo: &MistralOCRUsageInfo{
|
||||
PagesProcessed: 2,
|
||||
DocSizeBytes: 1024000,
|
||||
},
|
||||
},
|
||||
validate: func(t *testing.T, result *schemas.BifrostOCRResponse) {
|
||||
require.NotNil(t, result)
|
||||
require.Len(t, result.Pages, 2)
|
||||
require.NotNil(t, result.UsageInfo)
|
||||
assert.Equal(t, 2, result.UsageInfo.PagesProcessed)
|
||||
assert.Equal(t, 1024000, result.UsageInfo.DocSizeBytes)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "response with document annotation",
|
||||
input: &MistralOCRResponse{
|
||||
Model: "mistral-ocr-latest",
|
||||
Pages: []MistralOCRPage{
|
||||
{Index: 0, Markdown: "Page content"},
|
||||
},
|
||||
DocumentAnnotation: schemas.Ptr("This is a legal contract."),
|
||||
},
|
||||
validate: func(t *testing.T, result *schemas.BifrostOCRResponse) {
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, result.DocumentAnnotation)
|
||||
assert.Equal(t, "This is a legal contract.", *result.DocumentAnnotation)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "response with empty pages",
|
||||
input: &MistralOCRResponse{
|
||||
Model: "mistral-ocr-latest",
|
||||
Pages: []MistralOCRPage{},
|
||||
},
|
||||
validate: func(t *testing.T, result *schemas.BifrostOCRResponse) {
|
||||
require.NotNil(t, result)
|
||||
assert.Empty(t, result.Pages)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "full response with all fields",
|
||||
input: &MistralOCRResponse{
|
||||
Model: "mistral-ocr-latest",
|
||||
Pages: []MistralOCRPage{
|
||||
{
|
||||
Index: 0,
|
||||
Markdown: "# Title\n\nParagraph with **bold** text.",
|
||||
Images: []MistralOCRPageImage{
|
||||
{
|
||||
ID: "img-0-1",
|
||||
TopLeftX: 0,
|
||||
TopLeftY: 0,
|
||||
BottomRightX: 500,
|
||||
BottomRightY: 300,
|
||||
ImageBase64: schemas.Ptr("aW1hZ2VkYXRh"),
|
||||
},
|
||||
},
|
||||
Dimensions: &MistralOCRPageDimensions{
|
||||
DPI: 150,
|
||||
Height: 1100,
|
||||
Width: 850,
|
||||
},
|
||||
},
|
||||
},
|
||||
UsageInfo: &MistralOCRUsageInfo{
|
||||
PagesProcessed: 1,
|
||||
DocSizeBytes: 512000,
|
||||
},
|
||||
DocumentAnnotation: schemas.Ptr("A technical report."),
|
||||
},
|
||||
validate: func(t *testing.T, result *schemas.BifrostOCRResponse) {
|
||||
require.NotNil(t, result)
|
||||
assert.Equal(t, "mistral-ocr-latest", result.Model)
|
||||
require.Len(t, result.Pages, 1)
|
||||
|
||||
page := result.Pages[0]
|
||||
assert.Equal(t, 0, page.Index)
|
||||
assert.Contains(t, page.Markdown, "# Title")
|
||||
require.Len(t, page.Images, 1)
|
||||
assert.Equal(t, "img-0-1", page.Images[0].ID)
|
||||
require.NotNil(t, page.Images[0].ImageBase64)
|
||||
assert.Equal(t, "aW1hZ2VkYXRh", *page.Images[0].ImageBase64)
|
||||
require.NotNil(t, page.Dimensions)
|
||||
assert.Equal(t, 150, page.Dimensions.DPI)
|
||||
|
||||
require.NotNil(t, result.UsageInfo)
|
||||
assert.Equal(t, 1, result.UsageInfo.PagesProcessed)
|
||||
assert.Equal(t, 512000, result.UsageInfo.DocSizeBytes)
|
||||
|
||||
require.NotNil(t, result.DocumentAnnotation)
|
||||
assert.Equal(t, "A technical report.", *result.DocumentAnnotation)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
result := tt.input.ToBifrostOCRResponse()
|
||||
tt.validate(t, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestOCRWithMockServer tests the OCR method with a mock HTTP server.
|
||||
func TestOCRWithMockServer(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
request *schemas.BifrostOCRRequest
|
||||
statusCode int
|
||||
responseBody interface{}
|
||||
expectError bool
|
||||
errorContains string
|
||||
validateError func(t *testing.T, err *schemas.BifrostError)
|
||||
validateResult func(t *testing.T, resp *schemas.BifrostOCRResponse)
|
||||
}{
|
||||
{
|
||||
name: "successful OCR with document_url",
|
||||
request: &schemas.BifrostOCRRequest{
|
||||
Model: "mistral-ocr-latest",
|
||||
Document: schemas.OCRDocument{
|
||||
Type: schemas.OCRDocumentTypeDocumentURL,
|
||||
DocumentURL: schemas.Ptr("https://example.com/doc.pdf"),
|
||||
},
|
||||
},
|
||||
statusCode: http.StatusOK,
|
||||
responseBody: MistralOCRResponse{
|
||||
Model: "mistral-ocr-latest",
|
||||
Pages: []MistralOCRPage{
|
||||
{
|
||||
Index: 0,
|
||||
Markdown: "# Test Document\n\nThis is page 1.",
|
||||
},
|
||||
{
|
||||
Index: 1,
|
||||
Markdown: "## Section 2\n\nThis is page 2.",
|
||||
},
|
||||
},
|
||||
UsageInfo: &MistralOCRUsageInfo{
|
||||
PagesProcessed: 2,
|
||||
DocSizeBytes: 2048,
|
||||
},
|
||||
},
|
||||
expectError: false,
|
||||
validateResult: func(t *testing.T, resp *schemas.BifrostOCRResponse) {
|
||||
assert.Equal(t, "mistral-ocr-latest", resp.Model)
|
||||
require.Len(t, resp.Pages, 2)
|
||||
assert.Equal(t, 0, resp.Pages[0].Index)
|
||||
assert.Contains(t, resp.Pages[0].Markdown, "Test Document")
|
||||
assert.Equal(t, 1, resp.Pages[1].Index)
|
||||
require.NotNil(t, resp.UsageInfo)
|
||||
assert.Equal(t, 2, resp.UsageInfo.PagesProcessed)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "successful OCR with image_url",
|
||||
request: &schemas.BifrostOCRRequest{
|
||||
Model: "mistral-ocr-latest",
|
||||
Document: schemas.OCRDocument{
|
||||
Type: schemas.OCRDocumentTypeImageURL,
|
||||
ImageURL: schemas.Ptr("https://example.com/image.png"),
|
||||
},
|
||||
},
|
||||
statusCode: http.StatusOK,
|
||||
responseBody: MistralOCRResponse{
|
||||
Model: "mistral-ocr-latest",
|
||||
Pages: []MistralOCRPage{
|
||||
{
|
||||
Index: 0,
|
||||
Markdown: "Text extracted from image",
|
||||
Images: []MistralOCRPageImage{
|
||||
{
|
||||
ID: "img-1",
|
||||
TopLeftX: 0,
|
||||
TopLeftY: 0,
|
||||
BottomRightX: 100,
|
||||
BottomRightY: 100,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectError: false,
|
||||
validateResult: func(t *testing.T, resp *schemas.BifrostOCRResponse) {
|
||||
assert.Equal(t, "mistral-ocr-latest", resp.Model)
|
||||
require.Len(t, resp.Pages, 1)
|
||||
require.Len(t, resp.Pages[0].Images, 1)
|
||||
assert.Equal(t, "img-1", resp.Pages[0].Images[0].ID)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "server error 500",
|
||||
request: &schemas.BifrostOCRRequest{
|
||||
Model: "mistral-ocr-latest",
|
||||
Document: schemas.OCRDocument{
|
||||
Type: schemas.OCRDocumentTypeDocumentURL,
|
||||
DocumentURL: schemas.Ptr("https://example.com/doc.pdf"),
|
||||
},
|
||||
},
|
||||
statusCode: http.StatusInternalServerError,
|
||||
responseBody: map[string]interface{}{
|
||||
"message": "Internal server error",
|
||||
"type": "server_error",
|
||||
"code": "internal_error",
|
||||
},
|
||||
expectError: true,
|
||||
errorContains: "Internal server error",
|
||||
validateError: func(t *testing.T, err *schemas.BifrostError) {
|
||||
require.NotNil(t, err)
|
||||
require.NotNil(t, err.Error)
|
||||
require.NotNil(t, err.StatusCode)
|
||||
assert.Equal(t, http.StatusInternalServerError, *err.StatusCode)
|
||||
require.NotNil(t, err.Error.Type)
|
||||
assert.Equal(t, "server_error", *err.Error.Type)
|
||||
require.NotNil(t, err.Error.Code)
|
||||
assert.Equal(t, "internal_error", *err.Error.Code)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "unauthorized 401",
|
||||
request: &schemas.BifrostOCRRequest{
|
||||
Model: "mistral-ocr-latest",
|
||||
Document: schemas.OCRDocument{
|
||||
Type: schemas.OCRDocumentTypeDocumentURL,
|
||||
DocumentURL: schemas.Ptr("https://example.com/doc.pdf"),
|
||||
},
|
||||
},
|
||||
statusCode: http.StatusUnauthorized,
|
||||
responseBody: map[string]interface{}{
|
||||
"message": "Unauthorized",
|
||||
"type": "authentication_error",
|
||||
"code": "invalid_api_key",
|
||||
},
|
||||
expectError: true,
|
||||
errorContains: "Unauthorized",
|
||||
validateError: func(t *testing.T, err *schemas.BifrostError) {
|
||||
require.NotNil(t, err)
|
||||
require.NotNil(t, err.Error)
|
||||
require.NotNil(t, err.StatusCode)
|
||||
assert.Equal(t, http.StatusUnauthorized, *err.StatusCode)
|
||||
require.NotNil(t, err.Error.Type)
|
||||
assert.Equal(t, "authentication_error", *err.Error.Type)
|
||||
require.NotNil(t, err.Error.Code)
|
||||
assert.Equal(t, "invalid_api_key", *err.Error.Code)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty response body",
|
||||
request: &schemas.BifrostOCRRequest{
|
||||
Model: "mistral-ocr-latest",
|
||||
Document: schemas.OCRDocument{
|
||||
Type: schemas.OCRDocumentTypeDocumentURL,
|
||||
DocumentURL: schemas.Ptr("https://example.com/doc.pdf"),
|
||||
},
|
||||
},
|
||||
statusCode: http.StatusOK,
|
||||
responseBody: nil, // will send empty body
|
||||
expectError: true,
|
||||
errorContains: "",
|
||||
},
|
||||
{
|
||||
name: "HTML error response",
|
||||
request: &schemas.BifrostOCRRequest{
|
||||
Model: "mistral-ocr-latest",
|
||||
Document: schemas.OCRDocument{
|
||||
Type: schemas.OCRDocumentTypeDocumentURL,
|
||||
DocumentURL: schemas.Ptr("https://example.com/doc.pdf"),
|
||||
},
|
||||
},
|
||||
statusCode: http.StatusOK,
|
||||
responseBody: "html_error", // sentinel to trigger HTML response
|
||||
expectError: true,
|
||||
errorContains: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, http.MethodPost, r.Method)
|
||||
assert.Equal(t, "/v1/ocr", r.URL.Path)
|
||||
assert.Equal(t, "application/json", r.Header.Get("Content-Type"))
|
||||
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
assert.Contains(t, authHeader, "Bearer")
|
||||
|
||||
switch body := tt.responseBody.(type) {
|
||||
case nil:
|
||||
// Send empty body
|
||||
case string:
|
||||
if body == "html_error" {
|
||||
w.Header().Set("Content-Type", "text/html")
|
||||
}
|
||||
default:
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
}
|
||||
|
||||
w.WriteHeader(tt.statusCode)
|
||||
|
||||
switch body := tt.responseBody.(type) {
|
||||
case nil:
|
||||
// Send empty body
|
||||
case string:
|
||||
if body == "html_error" {
|
||||
w.Write([]byte("<html><body>502 Bad Gateway</body></html>"))
|
||||
}
|
||||
default:
|
||||
responseJSON, err := sonic.Marshal(body)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to marshal response: %v", err)
|
||||
}
|
||||
w.Write(responseJSON)
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
provider := NewMistralProvider(&schemas.ProviderConfig{
|
||||
NetworkConfig: schemas.NetworkConfig{
|
||||
BaseURL: server.URL,
|
||||
DefaultRequestTimeoutInSeconds: 30,
|
||||
},
|
||||
}, &testLogger{})
|
||||
|
||||
ctx, cancel := schemas.NewBifrostContextWithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
resp, err := provider.OCR(ctx, schemas.Key{Value: *schemas.NewEnvVar("test-api-key")}, tt.request)
|
||||
|
||||
if tt.expectError {
|
||||
require.NotNil(t, err)
|
||||
if tt.errorContains != "" {
|
||||
assert.Contains(t, err.Error.Message, tt.errorContains)
|
||||
}
|
||||
if tt.validateError != nil {
|
||||
tt.validateError(t, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
require.Nil(t, err)
|
||||
require.NotNil(t, resp)
|
||||
tt.validateResult(t, resp)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestOCRNilInput tests handling of nil OCR request.
|
||||
func TestOCRNilInput(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
provider := NewMistralProvider(&schemas.ProviderConfig{
|
||||
NetworkConfig: schemas.NetworkConfig{
|
||||
BaseURL: "https://api.mistral.ai",
|
||||
DefaultRequestTimeoutInSeconds: 30,
|
||||
},
|
||||
}, &testLogger{})
|
||||
|
||||
ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
|
||||
|
||||
resp, err := provider.OCR(ctx, schemas.Key{Value: *schemas.NewEnvVar("test-key")}, nil)
|
||||
|
||||
require.NotNil(t, err)
|
||||
assert.Nil(t, resp)
|
||||
assert.Contains(t, err.Error.Message, "ocr request input is not provided")
|
||||
}
|
||||
|
||||
// TestOCRRequestValidation tests that the mock server receives correctly serialized request bodies.
|
||||
func TestOCRRequestValidation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Parse the request body to validate it was serialized correctly
|
||||
var mistralReq MistralOCRRequest
|
||||
err := sonic.ConfigDefault.NewDecoder(r.Body).Decode(&mistralReq)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "mistral-ocr-latest", mistralReq.Model)
|
||||
assert.Equal(t, "document_url", mistralReq.Document.Type)
|
||||
assert.Equal(t, "https://example.com/doc.pdf", mistralReq.Document.DocumentURL)
|
||||
assert.NotNil(t, mistralReq.IncludeImageBase64)
|
||||
assert.True(t, *mistralReq.IncludeImageBase64)
|
||||
assert.Equal(t, []int{0, 1}, mistralReq.Pages)
|
||||
|
||||
// Return a valid response
|
||||
resp := MistralOCRResponse{
|
||||
Model: "mistral-ocr-latest",
|
||||
Pages: []MistralOCRPage{
|
||||
{Index: 0, Markdown: "Page 1"},
|
||||
{Index: 1, Markdown: "Page 2"},
|
||||
},
|
||||
}
|
||||
responseJSON, _ := sonic.Marshal(resp)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write(responseJSON)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
provider := NewMistralProvider(&schemas.ProviderConfig{
|
||||
NetworkConfig: schemas.NetworkConfig{
|
||||
BaseURL: server.URL,
|
||||
DefaultRequestTimeoutInSeconds: 30,
|
||||
},
|
||||
}, &testLogger{})
|
||||
|
||||
ctx, cancel := schemas.NewBifrostContextWithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
request := &schemas.BifrostOCRRequest{
|
||||
Model: "mistral-ocr-latest",
|
||||
Document: schemas.OCRDocument{
|
||||
Type: schemas.OCRDocumentTypeDocumentURL,
|
||||
DocumentURL: schemas.Ptr("https://example.com/doc.pdf"),
|
||||
},
|
||||
Params: &schemas.OCRParameters{
|
||||
IncludeImageBase64: schemas.Ptr(true),
|
||||
Pages: []int{0, 1},
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := provider.OCR(ctx, schemas.Key{Value: *schemas.NewEnvVar("test-api-key")}, request)
|
||||
|
||||
require.Nil(t, err)
|
||||
require.NotNil(t, resp)
|
||||
assert.Equal(t, "mistral-ocr-latest", resp.Model)
|
||||
require.Len(t, resp.Pages, 2)
|
||||
}
|
||||
|
||||
// TestMistralOCRIntegration tests the OCR endpoint with the real Mistral API.
|
||||
// This test requires MISTRAL_API_KEY environment variable to be set.
|
||||
// Run with: MISTRAL_API_KEY=xxx go test -v -run TestMistralOCRIntegration
|
||||
func TestMistralOCRIntegration(t *testing.T) {
|
||||
apiKey := os.Getenv("MISTRAL_API_KEY")
|
||||
if apiKey == "" {
|
||||
t.Skip("Skipping integration test: MISTRAL_API_KEY not set")
|
||||
}
|
||||
|
||||
provider := NewMistralProvider(&schemas.ProviderConfig{
|
||||
NetworkConfig: schemas.NetworkConfig{
|
||||
BaseURL: "https://api.mistral.ai",
|
||||
DefaultRequestTimeoutInSeconds: 60,
|
||||
},
|
||||
}, &testLogger{})
|
||||
|
||||
ctx, cancel := schemas.NewBifrostContextWithTimeout(context.Background(), 60*time.Second)
|
||||
defer cancel()
|
||||
|
||||
request := &schemas.BifrostOCRRequest{
|
||||
Model: "mistral-ocr-latest",
|
||||
Document: schemas.OCRDocument{
|
||||
Type: schemas.OCRDocumentTypeDocumentURL,
|
||||
DocumentURL: schemas.Ptr("https://arxiv.org/pdf/2201.04234"),
|
||||
},
|
||||
Params: &schemas.OCRParameters{
|
||||
Pages: []int{0},
|
||||
},
|
||||
}
|
||||
|
||||
resp, bifrostErr := provider.OCR(ctx, schemas.Key{Value: *schemas.NewEnvVar(apiKey)}, request)
|
||||
|
||||
require.Nil(t, bifrostErr, "OCR request failed: %v", bifrostErr)
|
||||
require.NotNil(t, resp)
|
||||
assert.Equal(t, "mistral-ocr-latest", resp.Model)
|
||||
require.NotEmpty(t, resp.Pages, "Expected at least one page")
|
||||
assert.Equal(t, 0, resp.Pages[0].Index)
|
||||
assert.NotEmpty(t, resp.Pages[0].Markdown, "Expected non-empty markdown for page 0")
|
||||
assert.Greater(t, resp.ExtraFields.Latency, int64(0))
|
||||
}
|
||||
Reference in New Issue
Block a user