first commit

This commit is contained in:
Beyhan Oğur
2026-04-26 21:52:23 +03:00
commit 880f412e2c
2662 changed files with 866266 additions and 0 deletions

View File

@@ -0,0 +1,417 @@
package bedrock
import (
"encoding/json"
"fmt"
"time"
"github.com/bytedance/sonic"
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
"github.com/maximhq/bifrost/core/schemas"
)
// BedrockBatchJobRequest represents a request to create a batch inference job.
type BedrockBatchJobRequest struct {
JobName string `json:"jobName"`
ModelID *string `json:"modelId"`
RoleArn string `json:"roleArn"`
InputDataConfig BedrockInputDataConfig `json:"inputDataConfig"`
OutputDataConfig BedrockOutputDataConfig `json:"outputDataConfig"`
TimeoutDurationInHours int `json:"timeoutDurationInHours,omitempty"`
Tags []BedrockTag `json:"tags,omitempty"`
}
// BedrockInputDataConfig represents the input configuration for a batch job.
type BedrockInputDataConfig struct {
S3InputDataConfig BedrockS3InputDataConfig `json:"s3InputDataConfig"`
}
// BedrockS3InputDataConfig represents S3 input configuration.
type BedrockS3InputDataConfig struct {
S3Uri string `json:"s3Uri"`
S3InputFormat string `json:"s3InputFormat,omitempty"` // "JSONL"
Endpoint *string `json:"endpoint,omitempty"`
FileID *string `json:"file_id,omitempty"`
}
// BedrockOutputDataConfig represents the output configuration for a batch job.
type BedrockOutputDataConfig struct {
S3OutputDataConfig BedrockS3OutputDataConfig `json:"s3OutputDataConfig"`
}
// BedrockS3OutputDataConfig represents S3 output configuration.
type BedrockS3OutputDataConfig struct {
S3Uri string `json:"s3Uri"`
}
// BedrockTag represents a tag for a batch job.
type BedrockTag struct {
Key string `json:"key"`
Value string `json:"value"`
}
// BedrockBatchJobResponse represents a batch job response.
type BedrockBatchJobResponse struct {
JobArn string `json:"jobArn"`
Status string `json:"status"`
JobName string `json:"jobName,omitempty"`
ModelID string `json:"modelId,omitempty"`
RoleArn string `json:"roleArn,omitempty"`
InputDataConfig *BedrockInputDataConfig `json:"inputDataConfig,omitempty"`
OutputDataConfig *BedrockOutputDataConfig `json:"outputDataConfig,omitempty"`
VpcConfig *BedrockVpcConfig `json:"vpcConfig,omitempty"`
SubmitTime *time.Time `json:"submitTime,omitempty"`
LastModifiedTime *time.Time `json:"lastModifiedTime,omitempty"`
EndTime *time.Time `json:"endTime,omitempty"`
Message string `json:"message,omitempty"`
ClientRequestToken string `json:"clientRequestToken,omitempty"`
JobExpirationTime *time.Time `json:"jobExpirationTime,omitempty"`
TimeoutDurationInHours int `json:"timeoutDurationInHours,omitempty"`
}
// BedrockBatchJobListResponse represents a list of batch jobs.
type BedrockBatchJobListResponse struct {
InvocationJobSummaries []BedrockBatchJobSummary `json:"invocationJobSummaries"`
NextToken *string `json:"nextToken,omitempty"`
}
// BedrockBatchJobSummary represents a summary of a batch job.
type BedrockBatchJobSummary struct {
JobArn string `json:"jobArn"`
JobName string `json:"jobName"`
ModelID string `json:"modelId"`
Status string `json:"status"`
SubmitTime *time.Time `json:"submitTime,omitempty"`
LastModifiedTime *time.Time `json:"lastModifiedTime,omitempty"`
EndTime *time.Time `json:"endTime,omitempty"`
Message string `json:"message,omitempty"`
}
// BedrockBatchResultRecord represents a single result record in Bedrock batch output JSONL.
type BedrockBatchResultRecord struct {
RecordID string `json:"recordId"`
ModelOutput json.RawMessage `json:"modelOutput,omitempty"`
Error *BedrockBatchError `json:"error,omitempty"`
}
// BedrockBatchError represents an error in batch processing.
type BedrockBatchError struct {
ErrorCode int `json:"errorCode,omitempty"`
ErrorMessage string `json:"errorMessage,omitempty"`
}
// BedrockBatchListRequest represents a request to list batch jobs.
type BedrockBatchListRequest struct {
MaxResults int `json:"maxResults,omitempty"`
NextToken *string `json:"nextToken,omitempty"`
StatusEquals string `json:"statusEquals,omitempty"`
NameContains string `json:"nameContains,omitempty"`
}
// BedrockBatchRetrieveRequest represents a request to retrieve a batch job.
type BedrockBatchRetrieveRequest struct {
JobIdentifier string `json:"jobIdentifier"`
}
// BedrockBatchCancelRequest represents a request to cancel/stop a batch job.
type BedrockBatchCancelRequest struct {
JobIdentifier string `json:"jobIdentifier"`
}
// BedrockBatchCancelResponse represents the response from stopping a batch job.
type BedrockBatchCancelResponse struct {
JobArn string `json:"jobArn"`
Status string `json:"status"`
}
// ToBifrostBatchStatus converts Bedrock status to Bifrost status.
func ToBifrostBatchStatus(status string) schemas.BatchStatus {
switch status {
case "Submitted", "Validating":
return schemas.BatchStatusValidating
case "InProgress":
return schemas.BatchStatusInProgress
case "Completed":
return schemas.BatchStatusCompleted
case "Failed", "PartiallyCompleted":
return schemas.BatchStatusFailed
case "Stopping":
return schemas.BatchStatusCancelling
case "Stopped":
return schemas.BatchStatusCancelled
case "Expired":
return schemas.BatchStatusExpired
case "Scheduled":
return schemas.BatchStatusValidating
default:
return schemas.BatchStatus(status)
}
}
// parseBatchResultsJSONL parses JSONL content from Bedrock batch output into Bifrost format.
// Returns the parsed results and any parse errors encountered.
func parseBatchResultsJSONL(content []byte, provider *BedrockProvider) ([]schemas.BatchResultItem, []schemas.BatchError) {
var results []schemas.BatchResultItem
parseResult := providerUtils.ParseJSONL(content, func(line []byte) error {
var bedrockResult BedrockBatchResultRecord
if err := sonic.Unmarshal(line, &bedrockResult); err != nil {
provider.logger.Warn(fmt.Sprintf("failed to parse batch result line: %v", err))
return err
}
// Convert Bedrock format to Bifrost format
resultItem := schemas.BatchResultItem{
CustomID: bedrockResult.RecordID,
}
if bedrockResult.ModelOutput != nil {
var bodyMap map[string]interface{}
if err := sonic.Unmarshal(bedrockResult.ModelOutput, &bodyMap); err == nil {
resultItem.Response = &schemas.BatchResultResponse{
StatusCode: 200,
Body: bodyMap,
}
} else {
resultItem.Error = &schemas.BatchResultError{
Code: "parse_error",
Message: fmt.Sprintf("failed to parse model output: %v", err),
}
}
}
if bedrockResult.Error != nil {
resultItem.Error = &schemas.BatchResultError{
Code: fmt.Sprintf("%d", bedrockResult.Error.ErrorCode),
Message: bedrockResult.Error.ErrorMessage,
}
// Set status code to indicate error if there's an error
if resultItem.Response == nil {
resultItem.Response = &schemas.BatchResultResponse{
StatusCode: bedrockResult.Error.ErrorCode,
}
}
}
results = append(results, resultItem)
return nil
})
return results, parseResult.Errors
}
// ToBedrockBatchJobResponse converts a Bifrost batch create response to Bedrock format.
func ToBedrockBatchJobResponse(resp *schemas.BifrostBatchCreateResponse) *BedrockBatchJobResponse {
// Here if the provider is not Bedrock - then we create a dummy arn and string using the batch ID
if resp.ExtraFields.Provider != schemas.Bedrock {
return &BedrockBatchJobResponse{
JobArn: fmt.Sprintf("arn:aws:bedrock:us-east-1:444444444444:batch:%s", resp.ID),
Status: toBedrockBatchStatus(resp.Status),
}
}
// For bedrock, we go as is
result := &BedrockBatchJobResponse{
JobArn: resp.ID,
Status: toBedrockBatchStatus(resp.Status),
}
if resp.Metadata != nil {
if jobName, ok := resp.Metadata["job_name"]; ok {
result.JobName = jobName
}
}
if resp.CreatedAt > 0 {
t := time.Unix(resp.CreatedAt, 0)
result.SubmitTime = &t
}
return result
}
// ToBedrockBatchJobListResponse converts a Bifrost batch list response to Bedrock format.
func ToBedrockBatchJobListResponse(resp *schemas.BifrostBatchListResponse) *BedrockBatchJobListResponse {
result := &BedrockBatchJobListResponse{
InvocationJobSummaries: make([]BedrockBatchJobSummary, len(resp.Data)),
}
for i, batch := range resp.Data {
summary := BedrockBatchJobSummary{
JobArn: batch.ID,
Status: toBedrockBatchStatus(batch.Status),
}
if batch.Metadata != nil {
if jobName, ok := batch.Metadata["job_name"]; ok {
summary.JobName = jobName
}
if modelId, ok := batch.Metadata["model_id"]; ok {
summary.ModelID = modelId
}
}
if batch.CreatedAt > 0 {
t := time.Unix(batch.CreatedAt, 0)
summary.SubmitTime = &t
}
if batch.CompletedAt != nil && *batch.CompletedAt > 0 {
t := time.Unix(*batch.CompletedAt, 0)
summary.EndTime = &t
}
result.InvocationJobSummaries[i] = summary
}
if resp.LastID != nil {
result.NextToken = resp.LastID
}
return result
}
// ToBedrockBatchJobRetrieveResponse converts a Bifrost batch retrieve response to Bedrock format.
func ToBedrockBatchJobRetrieveResponse(resp *schemas.BifrostBatchRetrieveResponse) *BedrockBatchJobResponse {
result := &BedrockBatchJobResponse{
JobArn: resp.ID,
Status: toBedrockBatchStatus(resp.Status),
}
if resp.Metadata != nil {
if jobName, ok := resp.Metadata["job_name"]; ok {
result.JobName = jobName
}
}
if resp.CreatedAt > 0 {
t := time.Unix(resp.CreatedAt, 0)
result.SubmitTime = &t
}
if resp.CompletedAt != nil && *resp.CompletedAt > 0 {
t := time.Unix(*resp.CompletedAt, 0)
result.EndTime = &t
}
if resp.InputFileID != "" {
result.InputDataConfig = &BedrockInputDataConfig{
S3InputDataConfig: BedrockS3InputDataConfig{
S3Uri: resp.InputFileID,
S3InputFormat: "JSONL",
},
}
}
if resp.OutputFileID != nil && *resp.OutputFileID != "" {
result.OutputDataConfig = &BedrockOutputDataConfig{
S3OutputDataConfig: BedrockS3OutputDataConfig{
S3Uri: *resp.OutputFileID,
},
}
}
return result
}
// toBedrockBatchStatus converts Bifrost batch status to Bedrock status.
func toBedrockBatchStatus(status schemas.BatchStatus) string {
switch status {
case schemas.BatchStatusValidating:
return "Validating"
case schemas.BatchStatusInProgress:
return "InProgress"
case schemas.BatchStatusCompleted:
fallthrough
case schemas.BatchStatusEnded:
return "Completed"
case schemas.BatchStatusFailed:
return "Failed"
case schemas.BatchStatusCancelling:
return "Stopping"
case schemas.BatchStatusCancelled:
return "Stopped"
case schemas.BatchStatusExpired:
return "Expired"
default:
return string(status)
}
}
// ToBifrostBatchListRequest converts a Bedrock batch list request to Bifrost format.
func ToBifrostBatchListRequest(req *BedrockBatchListRequest, provider schemas.ModelProvider) *schemas.BifrostBatchListRequest {
result := &schemas.BifrostBatchListRequest{
Provider: provider,
Limit: req.MaxResults,
}
if req.NextToken != nil {
result.PageToken = req.NextToken
}
if req.StatusEquals != "" || req.NameContains != "" {
result.ExtraParams = make(map[string]interface{})
if req.StatusEquals != "" {
result.ExtraParams["statusEquals"] = req.StatusEquals
}
if req.NameContains != "" {
result.ExtraParams["nameContains"] = req.NameContains
}
}
return result
}
// ToBifrostBatchRetrieveRequest converts a Bedrock batch retrieve request to Bifrost format.
func ToBifrostBatchRetrieveRequest(req *BedrockBatchRetrieveRequest, provider schemas.ModelProvider) *schemas.BifrostBatchRetrieveRequest {
return &schemas.BifrostBatchRetrieveRequest{
Provider: provider,
BatchID: req.JobIdentifier,
}
}
// ToBifrostBatchCancelRequest converts a Bedrock batch cancel request to Bifrost format.
func ToBifrostBatchCancelRequest(req *BedrockBatchCancelRequest, provider schemas.ModelProvider) *schemas.BifrostBatchCancelRequest {
return &schemas.BifrostBatchCancelRequest{
Provider: provider,
BatchID: req.JobIdentifier,
}
}
// ToBedrockBatchCancelResponse converts a Bifrost batch cancel response to Bedrock format.
func ToBedrockBatchCancelResponse(resp *schemas.BifrostBatchCancelResponse) *BedrockBatchCancelResponse {
return &BedrockBatchCancelResponse{
JobArn: resp.ID,
Status: toBedrockBatchStatus(resp.Status),
}
}
// splitJSONL splits JSONL content into individual lines.
func splitJSONL(data []byte) [][]byte {
var lines [][]byte
start := 0
for i, b := range data {
if b == '\n' {
if i > start {
lines = append(lines, data[start:i])
}
start = i + 1
}
}
if start < len(data) {
lines = append(lines, data[start:])
}
return lines
}
// BedrockVpcConfig represents VPC configuration for a batch job.
type BedrockVpcConfig struct {
SecurityGroupIds []string `json:"securityGroupIds,omitempty"`
SubnetIds []string `json:"subnetIds,omitempty"`
}
// BedrockBatchManifest represents the manifest.json.out file structure from S3.
type BedrockBatchManifest struct {
TotalRecordCount int `json:"totalRecordCount"`
ProcessedRecordCount int `json:"processedRecordCount"`
ErrorRecordCount int `json:"errorRecordCount"`
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,449 @@
package bedrock
import (
"context"
"fmt"
"time"
"github.com/google/uuid"
"github.com/maximhq/bifrost/core/schemas"
)
// ToBedrockChatCompletionRequest converts a Bifrost request to Bedrock Converse API format
func ToBedrockChatCompletionRequest(ctx *schemas.BifrostContext, bifrostReq *schemas.BifrostChatRequest) (*BedrockConverseRequest, error) {
if bifrostReq == nil {
return nil, fmt.Errorf("bifrost request is nil")
}
if bifrostReq.Input == nil {
return nil, fmt.Errorf("only chat completion requests are supported for Bedrock Converse API")
}
bedrockReq := &BedrockConverseRequest{
ModelID: bifrostReq.Model,
}
// Convert messages and system messages
messages, systemMessages, err := convertMessages(bifrostReq.Input)
if err != nil {
return nil, fmt.Errorf("failed to convert messages: %w", err)
}
bedrockReq.Messages = messages
if len(systemMessages) > 0 {
bedrockReq.System = systemMessages
}
// Convert parameters and configurations
if err := convertChatParameters(ctx, bifrostReq, bedrockReq); err != nil {
return nil, fmt.Errorf("failed to convert chat parameters: %w", err)
}
// Ensure tool config is present when needed
ensureChatToolConfigForConversation(bifrostReq, bedrockReq)
return bedrockReq, nil
}
// ToBifrostChatResponse converts a Bedrock Converse API response to Bifrost format
func (response *BedrockConverseResponse) ToBifrostChatResponse(ctx context.Context, model string) (*schemas.BifrostChatResponse, error) {
if response == nil {
return nil, fmt.Errorf("bedrock response is nil")
}
// Convert content blocks and tool calls
var contentStr *string
var contentBlocks []schemas.ChatContentBlock
var toolCalls []schemas.ChatAssistantMessageToolCall
var reasoningDetails []schemas.ChatReasoningDetails
var reasoningText string
if response.Output.Message != nil {
for _, contentBlock := range response.Output.Message.Content {
// Handle text content
if contentBlock.Text != nil && *contentBlock.Text != "" {
chatContentBlock := schemas.ChatContentBlock{
Type: schemas.ChatContentBlockTypeText,
Text: contentBlock.Text,
}
contentBlocks = append(contentBlocks, chatContentBlock)
}
if contentBlock.ToolUse != nil {
// Check if this is the structured output tool
if structuredOutputToolName, ok := ctx.Value(schemas.BifrostContextKeyStructuredOutputToolName).(string); ok && contentBlock.ToolUse.Name == structuredOutputToolName {
// This is structured output - set contentStr and skip adding to toolCalls
if contentBlock.ToolUse.Input != nil {
jsonStr := string(contentBlock.ToolUse.Input)
contentStr = &jsonStr
}
continue // Skip adding to toolCalls
}
// Regular tool call processing
var arguments string
if contentBlock.ToolUse.Input != nil {
arguments = string(contentBlock.ToolUse.Input)
} else {
arguments = "{}"
}
toolUseID := contentBlock.ToolUse.ToolUseID
toolUseName := contentBlock.ToolUse.Name
toolCalls = append(toolCalls, schemas.ChatAssistantMessageToolCall{
Index: uint16(len(toolCalls)),
Type: schemas.Ptr("function"),
ID: &toolUseID,
Function: schemas.ChatAssistantMessageToolCallFunction{
Name: &toolUseName,
Arguments: arguments,
},
})
}
// Handle reasoning content
if contentBlock.ReasoningContent != nil {
if contentBlock.ReasoningContent.ReasoningText == nil {
continue
}
reasoningDetails = append(reasoningDetails, schemas.ChatReasoningDetails{
Index: len(reasoningDetails),
Type: schemas.BifrostReasoningDetailsTypeText,
Text: contentBlock.ReasoningContent.ReasoningText.Text,
Signature: contentBlock.ReasoningContent.ReasoningText.Signature,
})
if contentBlock.ReasoningContent.ReasoningText.Text != nil {
reasoningText += *contentBlock.ReasoningContent.ReasoningText.Text + "\n"
}
}
// Handle document content
if contentBlock.Document != nil {
fileBlock := schemas.ChatContentBlock{
Type: schemas.ChatContentBlockTypeFile,
File: &schemas.ChatInputFile{},
}
// Set filename from document name
if contentBlock.Document.Name != "" {
fileBlock.File.Filename = &contentBlock.Document.Name
}
// Set file type based on format
if contentBlock.Document.Format != "" {
var fileType string
switch contentBlock.Document.Format {
case "pdf":
fileType = "application/pdf"
case "txt":
fileType = "text/plain"
case "md":
fileType = "text/markdown"
case "html":
fileType = "text/html"
case "csv":
fileType = "text/csv"
case "doc":
fileType = "application/msword"
case "docx":
fileType = "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
case "xls":
fileType = "application/vnd.ms-excel"
case "xlsx":
fileType = "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
default:
fileType = "application/pdf"
}
fileBlock.File.FileType = &fileType
}
// Convert document source data
if contentBlock.Document.Source != nil {
if contentBlock.Document.Source.Bytes != nil {
fileBlock.File.FileData = contentBlock.Document.Source.Bytes
} else if contentBlock.Document.Source.Text != nil {
fileBlock.File.FileData = contentBlock.Document.Source.Text
}
}
contentBlocks = append(contentBlocks, fileBlock)
}
}
}
if len(contentBlocks) == 1 && contentBlocks[0].Type == schemas.ChatContentBlockTypeText {
contentStr = contentBlocks[0].Text
contentBlocks = nil
}
// Create the message content
messageContent := schemas.ChatMessageContent{
ContentStr: contentStr,
ContentBlocks: contentBlocks,
}
// Create assistant message if we have tool calls
var assistantMessage *schemas.ChatAssistantMessage
if len(toolCalls) > 0 {
assistantMessage = &schemas.ChatAssistantMessage{
ToolCalls: toolCalls,
}
}
if len(reasoningDetails) > 0 {
if assistantMessage == nil {
assistantMessage = &schemas.ChatAssistantMessage{}
}
assistantMessage.ReasoningDetails = reasoningDetails
if reasoningText != "" {
assistantMessage.Reasoning = new(reasoningText)
}
}
// Create the response choice
choices := []schemas.BifrostResponseChoice{
{
Index: 0,
ChatNonStreamResponseChoice: &schemas.ChatNonStreamResponseChoice{
Message: &schemas.ChatMessage{
Role: schemas.ChatMessageRoleAssistant,
Content: &messageContent,
ChatAssistantMessage: assistantMessage,
},
},
FinishReason: schemas.Ptr(convertBedrockStopReason(response.StopReason)),
},
}
var usage *schemas.BifrostLLMUsage
if response.Usage != nil {
// Convert usage information
usage = &schemas.BifrostLLMUsage{
PromptTokens: response.Usage.InputTokens,
CompletionTokens: response.Usage.OutputTokens,
TotalTokens: response.Usage.TotalTokens,
}
// Handle cached tokens if present
if response.Usage.CacheReadInputTokens > 0 {
if usage.PromptTokensDetails == nil {
usage.PromptTokensDetails = &schemas.ChatPromptTokensDetails{}
}
usage.PromptTokensDetails.CachedReadTokens = response.Usage.CacheReadInputTokens
usage.PromptTokens = usage.PromptTokens + response.Usage.CacheReadInputTokens
}
if response.Usage.CacheWriteInputTokens > 0 {
if usage.PromptTokensDetails == nil {
usage.PromptTokensDetails = &schemas.ChatPromptTokensDetails{}
}
usage.PromptTokensDetails.CachedWriteTokens = response.Usage.CacheWriteInputTokens
usage.PromptTokens = usage.PromptTokens + response.Usage.CacheWriteInputTokens
}
}
// Create the final Bifrost response
bifrostResponse := &schemas.BifrostChatResponse{
ID: uuid.New().String(),
Model: model,
Object: "chat.completion",
Choices: choices,
Usage: usage,
Created: int(time.Now().Unix()),
ExtraFields: schemas.BifrostResponseExtraFields{
},
}
if response.ServiceTier != nil && response.ServiceTier.Type != "" {
bifrostResponse.ServiceTier = &response.ServiceTier.Type
}
return bifrostResponse, nil
}
// BedrockStreamState tracks per-stream tool call index state.
type BedrockStreamState struct {
nextToolCallIndex int
contentBlockToToolCallIdx map[int]int
}
// NewBedrockStreamState returns initialised stream state for one streaming response.
func NewBedrockStreamState() *BedrockStreamState {
return &BedrockStreamState{
contentBlockToToolCallIdx: make(map[int]int),
}
}
func (chunk *BedrockStreamEvent) ToBifrostChatCompletionStream(state *BedrockStreamState) (*schemas.BifrostChatResponse, *schemas.BifrostError, bool) {
if state == nil {
state = NewBedrockStreamState()
} else if state.contentBlockToToolCallIdx == nil {
state.contentBlockToToolCallIdx = make(map[int]int)
}
// event with metrics/usage is the last and with stop reason is the second last
switch {
case chunk.Role != nil:
// Send empty response to signal start
streamResponse := &schemas.BifrostChatResponse{
Object: "chat.completion.chunk",
Choices: []schemas.BifrostResponseChoice{
{
Index: 0,
ChatStreamResponseChoice: &schemas.ChatStreamResponseChoice{
Delta: &schemas.ChatStreamResponseChoiceDelta{
Role: chunk.Role,
},
},
},
},
}
return streamResponse, nil, false
case chunk.Start != nil && chunk.Start.ToolUse != nil:
toolUseStart := chunk.Start.ToolUse
toolCallIdx := 0
if chunk.ContentBlockIndex != nil {
toolCallIdx = state.nextToolCallIndex
state.contentBlockToToolCallIdx[*chunk.ContentBlockIndex] = toolCallIdx
state.nextToolCallIndex++
}
// Create tool call structure for start event
var toolCall schemas.ChatAssistantMessageToolCall
toolCall.Index = uint16(toolCallIdx)
toolCall.ID = schemas.Ptr(toolUseStart.ToolUseID)
toolCall.Type = schemas.Ptr("function")
toolCall.Function.Name = schemas.Ptr(toolUseStart.Name)
toolCall.Function.Arguments = "" // Start with empty arguments
streamResponse := &schemas.BifrostChatResponse{
Object: "chat.completion.chunk",
Choices: []schemas.BifrostResponseChoice{
{
Index: 0,
ChatStreamResponseChoice: &schemas.ChatStreamResponseChoice{
Delta: &schemas.ChatStreamResponseChoiceDelta{
ToolCalls: []schemas.ChatAssistantMessageToolCall{toolCall},
},
},
},
},
}
return streamResponse, nil, false
case chunk.Delta != nil:
switch {
case chunk.Delta.Text != nil:
// Handle text delta
text := *chunk.Delta.Text
if text != "" {
streamResponse := &schemas.BifrostChatResponse{
Object: "chat.completion.chunk",
Choices: []schemas.BifrostResponseChoice{
{
Index: 0,
ChatStreamResponseChoice: &schemas.ChatStreamResponseChoice{
Delta: &schemas.ChatStreamResponseChoiceDelta{
Content: &text,
},
},
},
},
}
return streamResponse, nil, false
}
case chunk.Delta.ToolUse != nil:
// Handle tool use delta
toolUseDelta := chunk.Delta.ToolUse
toolCallIdx := 0
if chunk.ContentBlockIndex != nil {
toolCallIdx = state.contentBlockToToolCallIdx[*chunk.ContentBlockIndex]
}
// Create tool call structure
var toolCall schemas.ChatAssistantMessageToolCall
toolCall.Index = uint16(toolCallIdx)
toolCall.Type = schemas.Ptr("function")
// For streaming, we need to accumulate tool use data
// This is a simplified approach - in practice, you'd need to track tool calls across chunks
toolCall.Function.Arguments = toolUseDelta.Input
streamResponse := &schemas.BifrostChatResponse{
Object: "chat.completion.chunk",
Choices: []schemas.BifrostResponseChoice{
{
Index: 0,
ChatStreamResponseChoice: &schemas.ChatStreamResponseChoice{
Delta: &schemas.ChatStreamResponseChoiceDelta{
ToolCalls: []schemas.ChatAssistantMessageToolCall{toolCall},
},
},
},
},
}
return streamResponse, nil, false
case chunk.Delta.ReasoningContent != nil:
// Handle reasoning content delta
reasoningContentDelta := chunk.Delta.ReasoningContent
// Only construct and return a response when either Text or Signature is set
if (reasoningContentDelta.Text == nil || *reasoningContentDelta.Text == "") && reasoningContentDelta.Signature == nil {
return nil, nil, false
}
var streamResponse *schemas.BifrostChatResponse
if reasoningContentDelta.Text != nil && *reasoningContentDelta.Text != "" {
streamResponse = &schemas.BifrostChatResponse{
Object: "chat.completion.chunk",
Choices: []schemas.BifrostResponseChoice{
{
Index: 0,
ChatStreamResponseChoice: &schemas.ChatStreamResponseChoice{
Delta: &schemas.ChatStreamResponseChoiceDelta{
Reasoning: reasoningContentDelta.Text,
ReasoningDetails: []schemas.ChatReasoningDetails{
{
Index: 0,
Type: schemas.BifrostReasoningDetailsTypeText,
Text: reasoningContentDelta.Text,
},
},
},
},
},
},
}
} else if reasoningContentDelta.Signature != nil {
streamResponse = &schemas.BifrostChatResponse{
Object: "chat.completion.chunk",
Choices: []schemas.BifrostResponseChoice{
{
Index: 0,
ChatStreamResponseChoice: &schemas.ChatStreamResponseChoice{
Delta: &schemas.ChatStreamResponseChoiceDelta{
ReasoningDetails: []schemas.ChatReasoningDetails{
{
Index: 0,
Type: schemas.BifrostReasoningDetailsTypeText,
Signature: reasoningContentDelta.Signature,
},
},
},
},
},
},
}
}
return streamResponse, nil, false
}
}
return nil, nil, false
}

View File

@@ -0,0 +1,477 @@
package bedrock
import (
"context"
"encoding/json"
"strings"
"testing"
"github.com/maximhq/bifrost/core/schemas"
)
// TestConvertToolConfig_DropsServerToolsOnBedrock locks in the bug fix from
// the user-reported repro: sending `web_search_20260209` via the OpenAI-
// compatible /v1/chat/completions endpoint to Bedrock was producing a
// malformed ToolConfig that Bedrock rejected with 400 "The provided request
// is not valid". The fix strips unsupported server tools before the
// conversion loop so the outbound request is valid.
func TestConvertToolConfig_DropsServerToolsOnBedrock(t *testing.T) {
params := &schemas.ChatParameters{
Tools: []schemas.ChatTool{
{
Type: schemas.ChatToolTypeFunction,
Function: &schemas.ChatToolFunction{
Name: "get_weather",
Description: schemas.Ptr("Get weather by city"),
Parameters: &schemas.ToolFunctionParameters{
Type: "object",
},
},
},
{
// Server tool — Bedrock doesn't support web_search per Table 20.
// Should be stripped silently.
Type: schemas.ChatToolType("web_search_20260209"),
Name: "web_search",
},
},
}
cfg := convertToolConfig("global.anthropic.claude-sonnet-4-6", params)
if cfg == nil {
t.Fatalf("expected ToolConfig, got nil (function tool should have survived)")
}
if len(cfg.Tools) != 1 {
t.Fatalf("expected exactly 1 tool (function), got %d: %+v", len(cfg.Tools), cfg.Tools)
}
if cfg.Tools[0].ToolSpec == nil || cfg.Tools[0].ToolSpec.Name != "get_weather" {
t.Errorf("expected function tool 'get_weather' to survive, got %+v", cfg.Tools[0])
}
}
// TestConvertToolConfig_ReturnsNilWhenAllDropped locks in the empty-slice
// guard. Bedrock's Converse API rejects `"toolConfig": {"tools": []}` with a
// 400; when every tool is unsupported and gets stripped, convertToolConfig
// must return nil so no ToolConfig ships at all.
func TestConvertToolConfig_ReturnsNilWhenAllDropped(t *testing.T) {
params := &schemas.ChatParameters{
Tools: []schemas.ChatTool{
{
Type: schemas.ChatToolType("web_search_20260209"),
Name: "web_search",
},
{
Type: schemas.ChatToolType("web_fetch_20260309"),
Name: "web_fetch",
},
{
Type: schemas.ChatToolType("code_execution_20250825"),
Name: "code_execution",
},
},
}
cfg := convertToolConfig("global.anthropic.claude-sonnet-4-6", params)
if cfg != nil {
t.Fatalf("expected nil ToolConfig (all tools unsupported on Bedrock), got %+v", cfg)
}
}
// TestConvertToolConfig_KeepsBedrockSupportedServerTools — locks in that
// Bedrock-supported server tools (bash, memory, text_editor, computer,
// tool_search) do NOT appear in Converse's typed toolConfig.tools slot —
// they must be tunneled via additionalModelRequestFields (exercised in
// TestCollectBedrockServerTools_*). If the only tool is a server tool,
// toolConfig is nil so we don't ship {"toolConfig": {"tools": []}}.
func TestConvertToolConfig_KeepsBedrockSupportedServerTools(t *testing.T) {
params := &schemas.ChatParameters{
Tools: []schemas.ChatTool{
{
Type: schemas.ChatToolType("bash_20250124"),
Name: "bash",
},
},
}
cfg := convertToolConfig("global.anthropic.claude-sonnet-4-6", params)
if cfg != nil {
t.Fatalf("expected nil toolConfig (server tools flow via additionalModelRequestFields, not toolSpec), got %+v", cfg)
}
}
// TestCollectBedrockServerTools_BashOnly — bash is Bedrock-supported per the
// B-header list; the helper must emit it as a native-JSON tool entry with no
// derived beta header (bash has no high-confidence 1:1 beta-header mapping;
// callers rely on extra-headers for that).
func TestCollectBedrockServerTools_BashOnly(t *testing.T) {
params := &schemas.ChatParameters{
Tools: []schemas.ChatTool{
{
Type: schemas.ChatToolType("bash_20250124"),
Name: "bash",
},
},
}
tools, betas := collectBedrockServerTools(params)
if len(tools) != 1 {
t.Fatalf("expected 1 server tool, got %d", len(tools))
}
got := string(tools[0])
if !strings.Contains(got, `"type":"bash_20250124"`) || !strings.Contains(got, `"name":"bash"`) {
t.Errorf("expected native Anthropic bash shape, got %s", got)
}
if len(betas) != 0 {
t.Errorf("expected no derived beta headers for bash (no 1:1 mapping), got %v", betas)
}
}
// TestCollectBedrockServerTools_ComputerDerivesBeta — computer_YYYYMMDD must
// derive computer-use-YYYY-MM-DD as the beta header, gated through
// FilterBetaHeadersForProvider(Bedrock) which keeps computer-use-* headers.
func TestCollectBedrockServerTools_ComputerDerivesBeta(t *testing.T) {
params := &schemas.ChatParameters{
Tools: []schemas.ChatTool{
{
Type: schemas.ChatToolType("computer_20251124"),
Name: "computer",
DisplayWidthPx: schemas.Ptr(1280),
DisplayHeightPx: schemas.Ptr(800),
},
},
}
tools, betas := collectBedrockServerTools(params)
if len(tools) != 1 {
t.Fatalf("expected 1 server tool, got %d", len(tools))
}
if !strings.Contains(string(tools[0]), `"display_width_px":1280`) {
t.Errorf("expected computer variant fields to flow through, got %s", string(tools[0]))
}
if len(betas) != 1 || betas[0] != "computer-use-2025-11-24" {
t.Errorf("expected [computer-use-2025-11-24], got %v", betas)
}
}
// TestCollectBedrockServerTools_MemoryDerivesContextManagement — memory
// activates via the context-management-2025-06-27 bundle on Bedrock (cite:
// anthropic/types.go:179).
func TestCollectBedrockServerTools_MemoryDerivesContextManagement(t *testing.T) {
params := &schemas.ChatParameters{
Tools: []schemas.ChatTool{
{
Type: schemas.ChatToolType("memory_20250818"),
Name: "memory",
},
},
}
_, betas := collectBedrockServerTools(params)
if len(betas) != 1 || betas[0] != "context-management-2025-06-27" {
t.Errorf("expected [context-management-2025-06-27], got %v", betas)
}
}
// TestCollectBedrockServerTools_StripsUnsupported — web_search isn't in
// Bedrock's ProviderFeatures (WebSearch=false), so ValidateChatToolsForProvider
// drops it and the helper must emit nothing.
func TestCollectBedrockServerTools_StripsUnsupported(t *testing.T) {
params := &schemas.ChatParameters{
Tools: []schemas.ChatTool{
{
Type: schemas.ChatToolType("web_search_20260209"),
Name: "web_search",
},
},
}
tools, betas := collectBedrockServerTools(params)
if len(tools) != 0 {
t.Errorf("expected no server tools (web_search unsupported on Bedrock), got %d", len(tools))
}
if len(betas) != 0 {
t.Errorf("expected no betas when all tools filtered, got %v", betas)
}
}
// TestCollectBedrockServerTools_FunctionToolsIgnored — function/custom tools
// go through convertToolConfig, not this helper.
func TestCollectBedrockServerTools_FunctionToolsIgnored(t *testing.T) {
params := &schemas.ChatParameters{
Tools: []schemas.ChatTool{
{
Type: schemas.ChatToolTypeFunction,
Function: &schemas.ChatToolFunction{
Name: "get_weather",
Parameters: &schemas.ToolFunctionParameters{
Type: "object",
},
},
},
},
}
tools, betas := collectBedrockServerTools(params)
if len(tools) != 0 || len(betas) != 0 {
t.Errorf("function tools should not flow through server-tool helper, got tools=%d betas=%v", len(tools), betas)
}
}
// TestBuildBedrockServerToolChoice_PinnedServerTool — caller pins a kept
// server tool (computer) by name. Converse's typed toolConfig.toolChoice path
// can't carry this because toolConfig.tools doesn't include server tools; the
// existing reconciliation silently drops the pin. The tunneled path must
// emit {"type":"tool","name":"computer"} into additionalModelRequestFields.
func TestBuildBedrockServerToolChoice_PinnedServerTool(t *testing.T) {
computer := schemas.ChatTool{
Type: schemas.ChatToolType("computer_20251124"),
Name: "computer",
DisplayWidthPx: schemas.Ptr(1280),
}
params := &schemas.ChatParameters{
Tools: []schemas.ChatTool{computer},
ToolChoice: &schemas.ChatToolChoice{
ChatToolChoiceStruct: &schemas.ChatToolChoiceStruct{
Type: schemas.ChatToolChoiceTypeFunction,
Function: &schemas.ChatToolChoiceFunction{Name: "computer"},
},
},
}
choice, ok := buildBedrockServerToolChoice(params, []schemas.ChatTool{computer})
if !ok {
t.Fatalf("expected tunneled tool_choice for pinned server tool, got (nil, false)")
}
got := string(choice)
if !strings.Contains(got, `"type":"tool"`) || !strings.Contains(got, `"name":"computer"`) {
t.Errorf("expected Anthropic-native {type:tool,name:computer}, got %s", got)
}
}
// TestBuildBedrockServerToolChoice_PinnedFunctionTool_NotTunneled — function
// tool pins stay on Converse's typed path (toolConfig.toolChoice.tool). The
// helper must not double-emit.
func TestBuildBedrockServerToolChoice_PinnedFunctionTool_NotTunneled(t *testing.T) {
fn := schemas.ChatTool{
Type: schemas.ChatToolTypeFunction,
Function: &schemas.ChatToolFunction{
Name: "get_weather",
Parameters: &schemas.ToolFunctionParameters{Type: "object"},
},
}
params := &schemas.ChatParameters{
Tools: []schemas.ChatTool{fn},
ToolChoice: &schemas.ChatToolChoice{
ChatToolChoiceStruct: &schemas.ChatToolChoiceStruct{
Type: schemas.ChatToolChoiceTypeFunction,
Function: &schemas.ChatToolChoiceFunction{Name: "get_weather"},
},
},
}
if _, ok := buildBedrockServerToolChoice(params, []schemas.ChatTool{fn}); ok {
t.Errorf("expected no tunneling for function-tool pin (typed Converse path handles it)")
}
}
// TestBuildBedrockServerToolChoice_AnyWithOnlyServerTools — tool_choice:any
// with only server tools: convertToolConfig returns nil (bedrockTools empty),
// so the typed any-contract is lost. The tunneled path must emit
// {"type":"any"} to preserve the forcing semantics.
func TestBuildBedrockServerToolChoice_AnyWithOnlyServerTools(t *testing.T) {
bash := schemas.ChatTool{
Type: schemas.ChatToolType("bash_20250124"),
Name: "bash",
}
anyStr := string(schemas.ChatToolChoiceTypeAny)
params := &schemas.ChatParameters{
Tools: []schemas.ChatTool{bash},
ToolChoice: &schemas.ChatToolChoice{
ChatToolChoiceStr: &anyStr,
},
}
choice, ok := buildBedrockServerToolChoice(params, []schemas.ChatTool{bash})
if !ok {
t.Fatalf("expected tunneled any-contract when only server tools are present, got (nil, false)")
}
got := string(choice)
if !strings.Contains(got, `"type":"any"`) {
t.Errorf("expected {type:any}, got %s", got)
}
}
// TestBuildBedrockServerToolChoice_AnyWithFunctionTool_NotTunneled — when at
// least one function/custom tool is present, Converse's typed
// toolConfig.toolChoice.any carries the any-contract. Don't double-emit.
func TestBuildBedrockServerToolChoice_AnyWithFunctionTool_NotTunneled(t *testing.T) {
fn := schemas.ChatTool{
Type: schemas.ChatToolTypeFunction,
Function: &schemas.ChatToolFunction{
Name: "get_weather",
Parameters: &schemas.ToolFunctionParameters{Type: "object"},
},
}
bash := schemas.ChatTool{
Type: schemas.ChatToolType("bash_20250124"),
Name: "bash",
}
anyStr := string(schemas.ChatToolChoiceTypeAny)
params := &schemas.ChatParameters{
Tools: []schemas.ChatTool{fn, bash},
ToolChoice: &schemas.ChatToolChoice{
ChatToolChoiceStr: &anyStr,
},
}
if _, ok := buildBedrockServerToolChoice(params, []schemas.ChatTool{fn, bash}); ok {
t.Errorf("expected no tunneling when function/custom tool is present (typed Converse path handles any)")
}
}
// TestBuildBedrockServerToolChoice_UnsupportedServerToolPin_NotTunneled — the
// caller pins web_search, which ValidateChatToolsForProvider strips on
// Bedrock. The pin name is absent from the filtered set; the helper must not
// fabricate a tunneled tool_choice for a tool that isn't in the request.
func TestBuildBedrockServerToolChoice_UnsupportedServerToolPin_NotTunneled(t *testing.T) {
// The caller's original request had web_search, but it's been stripped.
// We pass the filtered slice (empty for the server-tool axis) to mimic
// the convertChatParameters call path.
params := &schemas.ChatParameters{
Tools: []schemas.ChatTool{{Type: schemas.ChatToolType("web_search_20260209"), Name: "web_search"}},
ToolChoice: &schemas.ChatToolChoice{
ChatToolChoiceStruct: &schemas.ChatToolChoiceStruct{
Type: schemas.ChatToolChoiceTypeFunction,
Function: &schemas.ChatToolChoiceFunction{Name: "web_search"},
},
},
}
// Filtered (post-ValidateChatToolsForProvider(Bedrock)) — web_search is dropped.
filtered := []schemas.ChatTool{}
if _, ok := buildBedrockServerToolChoice(params, filtered); ok {
t.Errorf("expected no tunneling when pinned name was stripped by provider validation")
}
}
// TestConvertChatParameters_PinnedServerToolE2E — end-to-end verification
// that convertChatParameters composes convertToolConfig +
// collectBedrockServerTools + buildBedrockServerToolChoice such that a
// request pinning a kept server tool produces:
// - AdditionalModelRequestFields.tools containing the server tool
// - AdditionalModelRequestFields.tool_choice with Anthropic-native shape
// - ToolConfig nil (no function tools → Converse's typed path is inactive)
func TestConvertChatParameters_PinnedServerToolE2E(t *testing.T) {
bifrostReq := &schemas.BifrostChatRequest{
Model: "global.anthropic.claude-sonnet-4-6",
Params: &schemas.ChatParameters{
Tools: []schemas.ChatTool{
{
Type: schemas.ChatToolType("computer_20251124"),
Name: "computer",
DisplayWidthPx: schemas.Ptr(1280),
},
},
ToolChoice: &schemas.ChatToolChoice{
ChatToolChoiceStruct: &schemas.ChatToolChoiceStruct{
Type: schemas.ChatToolChoiceTypeFunction,
Function: &schemas.ChatToolChoiceFunction{Name: "computer"},
},
},
},
}
bedrockReq := &BedrockConverseRequest{}
if err := convertChatParameters(nil, bifrostReq, bedrockReq); err != nil {
t.Fatalf("convertChatParameters failed: %v", err)
}
if bedrockReq.ToolConfig != nil {
t.Errorf("expected nil ToolConfig (no function/custom tools), got %+v", bedrockReq.ToolConfig)
}
if bedrockReq.AdditionalModelRequestFields == nil {
t.Fatalf("expected AdditionalModelRequestFields to carry server-tool payload, got nil")
}
tools, ok := bedrockReq.AdditionalModelRequestFields.Get("tools")
if !ok {
t.Errorf("expected additionalModelRequestFields.tools to be set for server tool")
} else if toolsSlice, castOK := tools.([]json.RawMessage); !castOK || len(toolsSlice) != 1 {
t.Errorf("expected 1 server tool in additionalModelRequestFields.tools, got %+v", tools)
}
choice, ok := bedrockReq.AdditionalModelRequestFields.Get("tool_choice")
if !ok {
t.Fatalf("expected additionalModelRequestFields.tool_choice to carry pinned server-tool contract")
}
choiceRaw, castOK := choice.(json.RawMessage)
if !castOK {
t.Fatalf("expected tool_choice value to be json.RawMessage, got %T", choice)
}
got := string(choiceRaw)
if !strings.Contains(got, `"type":"tool"`) || !strings.Contains(got, `"name":"computer"`) {
t.Errorf("expected {type:tool,name:computer}, got %s", got)
}
}
// TestConvertChatParameters_ResponseFormatWithPinnedServerTool_NoConflictingChoice
// locks in the fix for the "two conflicting tool-choice directives" hazard:
// when response_format forces the synthetic bf_so_* tool via
// ToolConfig.ToolChoice, the tunneled additionalModelRequestFields.tool_choice
// (which would pin a server tool) must be suppressed so Bedrock doesn't
// receive both pins in the same Converse call. Uses a Nova model since
// Anthropic models route response_format through native output_config.format
// (no synthetic tool), so the conflict only surfaces on non-Anthropic
// Bedrock targets.
func TestConvertChatParameters_ResponseFormatWithPinnedServerTool_NoConflictingChoice(t *testing.T) {
responseFormat := any(map[string]any{
"type": "json_schema",
"json_schema": map[string]any{
"name": "classification",
"schema": map[string]any{
"type": "object",
"properties": map[string]any{
"topic": map[string]any{"type": "string"},
},
"required": []any{"topic"},
},
},
})
bifrostReq := &schemas.BifrostChatRequest{
Model: "amazon.nova-pro-v1:0",
Params: &schemas.ChatParameters{
ResponseFormat: &responseFormat,
Tools: []schemas.ChatTool{
{
Type: schemas.ChatToolType("bash_20250124"),
Name: "bash",
},
},
ToolChoice: &schemas.ChatToolChoice{
ChatToolChoiceStruct: &schemas.ChatToolChoiceStruct{
Type: schemas.ChatToolChoiceTypeFunction,
Function: &schemas.ChatToolChoiceFunction{Name: "bash"},
},
},
},
}
ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
bedrockReq := &BedrockConverseRequest{}
if err := convertChatParameters(ctx, bifrostReq, bedrockReq); err != nil {
t.Fatalf("convertChatParameters failed: %v", err)
}
// Synthetic bf_so_* tool must be injected and pinned via Converse's typed path.
if bedrockReq.ToolConfig == nil {
t.Fatalf("expected ToolConfig with synthetic bf_so_* tool, got nil")
}
if bedrockReq.ToolConfig.ToolChoice == nil || bedrockReq.ToolConfig.ToolChoice.Tool == nil {
t.Fatalf("expected ToolConfig.ToolChoice.Tool to pin synthetic structured-output tool, got %+v", bedrockReq.ToolConfig.ToolChoice)
}
if !strings.HasPrefix(bedrockReq.ToolConfig.ToolChoice.Tool.Name, "bf_so_") {
t.Errorf("expected ToolConfig.ToolChoice.Tool.Name to start with bf_so_, got %q", bedrockReq.ToolConfig.ToolChoice.Tool.Name)
}
// Server tool must still be tunneled so the model has it available.
if bedrockReq.AdditionalModelRequestFields == nil {
t.Fatalf("expected AdditionalModelRequestFields to carry tunneled server-tool payload, got nil")
}
if _, ok := bedrockReq.AdditionalModelRequestFields.Get("tools"); !ok {
t.Errorf("expected additionalModelRequestFields.tools to still carry bash server tool")
}
// Guarded field: tunneled tool_choice MUST be absent because response_format
// forces the synthetic tool. Two tool-choice directives in the same request
// would let Bedrock pick one and silently violate the structured-output contract.
if _, ok := bedrockReq.AdditionalModelRequestFields.Get("tool_choice"); ok {
t.Errorf("expected NO additionalModelRequestFields.tool_choice when response_format pins bf_so_* (conflict hazard)")
}
}

View File

@@ -0,0 +1,57 @@
package bedrock
import (
"strings"
"github.com/maximhq/bifrost/core/schemas"
)
const estimatedBytesPerToken = 4
// ToBifrostCountTokensResponse converts a Bedrock count tokens response to Bifrost format
func (resp *BedrockCountTokensResponse) ToBifrostCountTokensResponse(model string) *schemas.BifrostCountTokensResponse {
if resp == nil {
return nil
}
totalTokens := resp.InputTokens
return &schemas.BifrostCountTokensResponse{
Model: model,
InputTokens: resp.InputTokens,
TotalTokens: &totalTokens,
Object: "response.input_tokens",
}
}
// ToBedrockCountTokensResponse converts a Bifrost count tokens response to Bedrock native format
func ToBedrockCountTokensResponse(resp *schemas.BifrostCountTokensResponse) *BedrockCountTokensResponse {
if resp == nil {
return nil
}
return &BedrockCountTokensResponse{
InputTokens: resp.InputTokens,
}
}
// isCountTokensUnsupported checks whether a BifrostError indicates that the
// Bedrock model does not support the count-tokens operation.
func isCountTokensUnsupported(err *schemas.BifrostError) bool {
if err == nil || err.Error == nil {
return false
}
return strings.Contains(strings.ToLower(err.Error.Message), "doesn't support counting tokens")
}
// estimateTokenCount returns a rough token count derived from the byte length
// of the serialized request body. Claude's tokenizer averages ~4 bytes per
// token on mixed content; this intentionally rounds up so that context-window
// management decisions stay on the conservative side.
func estimateTokenCount(requestBody []byte) int {
n := len(requestBody)
if n == 0 {
return 0
}
return (n + estimatedBytesPerToken - 1) / estimatedBytesPerToken
}

View File

@@ -0,0 +1,105 @@
package bedrock
import (
"testing"
"github.com/maximhq/bifrost/core/schemas"
"github.com/stretchr/testify/assert"
)
func TestIsCountTokensUnsupported(t *testing.T) {
tests := []struct {
name string
err *schemas.BifrostError
expected bool
}{
{
name: "nil error",
err: nil,
expected: false,
},
{
name: "nil error field",
err: &schemas.BifrostError{},
expected: false,
},
{
name: "matching bedrock error message",
err: &schemas.BifrostError{
Error: &schemas.ErrorField{
Message: "The provided model doesn't support counting tokens.",
},
},
expected: true,
},
{
name: "matching message with different casing",
err: &schemas.BifrostError{
Error: &schemas.ErrorField{
Message: "the provided model DOESN'T SUPPORT COUNTING TOKENS.",
},
},
expected: true,
},
{
name: "unrelated error message",
err: &schemas.BifrostError{
Error: &schemas.ErrorField{
Message: "access denied",
},
},
expected: false,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
assert.Equal(t, tc.expected, isCountTokensUnsupported(tc.err))
})
}
}
func TestEstimateTokenCount(t *testing.T) {
tests := []struct {
name string
input []byte
expected int
}{
{
name: "empty input",
input: []byte{},
expected: 0,
},
{
name: "nil input",
input: nil,
expected: 0,
},
{
name: "exact multiple of 4",
input: make([]byte, 100),
expected: 25,
},
{
name: "rounds up",
input: make([]byte, 101),
expected: 26,
},
{
name: "single byte",
input: []byte("x"),
expected: 1,
},
{
name: "realistic json body",
input: []byte(`{"messages":[{"role":"user","content":"Hello, how are you today?"}],"model":"us.anthropic.claude-sonnet-4-6"}`),
expected: 28,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
assert.Equal(t, tc.expected, estimateTokenCount(tc.input))
})
}
}

View File

@@ -0,0 +1,271 @@
package bedrock
import (
"encoding/json"
"fmt"
"strings"
"github.com/maximhq/bifrost/core/schemas"
)
// ToBedrockTitanEmbeddingRequest converts a Bifrost embedding request to Bedrock Titan format
func ToBedrockTitanEmbeddingRequest(bifrostReq *schemas.BifrostEmbeddingRequest) (*BedrockTitanEmbeddingRequest, error) {
if bifrostReq == nil {
return nil, fmt.Errorf("bifrost embedding request is nil")
}
// Validate that only single text input is provided for Titan models
if bifrostReq.Input.Text == nil && len(bifrostReq.Input.Texts) == 0 {
return nil, fmt.Errorf("no input text provided for embedding")
}
titanReq := &BedrockTitanEmbeddingRequest{}
// Set input text
if bifrostReq.Input.Text != nil {
titanReq.InputText = *bifrostReq.Input.Text
} else if len(bifrostReq.Input.Texts) > 0 {
var embeddingText string
for _, text := range bifrostReq.Input.Texts {
embeddingText += text + " \n"
}
titanReq.InputText = embeddingText
}
if bifrostReq.Params != nil {
titanReq.Dimensions = bifrostReq.Params.Dimensions
if normalize, ok := bifrostReq.Params.ExtraParams["normalize"]; ok {
if b, ok := normalize.(bool); ok {
titanReq.Normalize = &b
}
}
// Forward remaining extra params (excluding normalize which is now a first-class field)
if len(bifrostReq.Params.ExtraParams) > 0 {
extra := make(map[string]interface{})
for k, v := range bifrostReq.Params.ExtraParams {
if k != "normalize" {
extra[k] = v
}
}
if len(extra) > 0 {
titanReq.ExtraParams = extra
}
}
}
return titanReq, nil
}
// ToBifrostEmbeddingResponse converts a Bedrock Titan embedding response to Bifrost format
func (response *BedrockTitanEmbeddingResponse) ToBifrostEmbeddingResponse() *schemas.BifrostEmbeddingResponse {
if response == nil {
return nil
}
bifrostResponse := &schemas.BifrostEmbeddingResponse{
Object: "list",
Data: []schemas.EmbeddingData{
{
Index: 0,
Object: "embedding",
Embedding: schemas.EmbeddingStruct{
EmbeddingArray: response.Embedding,
},
},
},
Usage: &schemas.BifrostLLMUsage{
PromptTokens: response.InputTextTokenCount,
TotalTokens: response.InputTextTokenCount,
},
}
return bifrostResponse
}
// ToBedrockCohereEmbeddingRequest converts a Bifrost embedding request to Bedrock Cohere format.
// Unlike the direct Cohere API, Bedrock does not accept a "model" field in the request body.
func ToBedrockCohereEmbeddingRequest(bifrostReq *schemas.BifrostEmbeddingRequest) (*BedrockCohereEmbeddingRequest, error) {
if bifrostReq == nil {
return nil, fmt.Errorf("bifrost embedding request is nil")
}
if bifrostReq.Input == nil || (bifrostReq.Input.Text == nil && len(bifrostReq.Input.Texts) == 0) {
return nil, fmt.Errorf("no input provided for embedding")
}
req := &BedrockCohereEmbeddingRequest{}
// Map texts
if bifrostReq.Input.Text != nil {
req.Texts = []string{*bifrostReq.Input.Text}
} else if len(bifrostReq.Input.Texts) > 0 {
req.Texts = bifrostReq.Input.Texts
}
if bifrostReq.Params != nil {
extra := make(map[string]interface{}, len(bifrostReq.Params.ExtraParams))
for k, v := range bifrostReq.Params.ExtraParams {
extra[k] = v
}
if v, ok := extra["input_type"]; ok {
if s, ok := v.(string); ok {
req.InputType = s
delete(extra, "input_type")
}
}
if v, ok := extra["truncate"]; ok {
if s, ok := v.(string); ok {
req.Truncate = &s
delete(extra, "truncate")
}
}
if v, ok := extra["embedding_types"]; ok {
if ss, ok := v.([]string); ok {
req.EmbeddingTypes = ss
delete(extra, "embedding_types")
}
}
if v, ok := extra["images"]; ok {
if ss, ok := v.([]string); ok {
req.Images = ss
delete(extra, "images")
}
}
if v, ok := extra["inputs"]; ok {
if inputs, ok := v.([]BedrockCohereEmbeddingInput); ok {
req.Inputs = inputs
delete(extra, "inputs")
}
}
if v, ok := extra["max_tokens"]; ok {
switch n := v.(type) {
case int:
req.MaxTokens = &n
delete(extra, "max_tokens")
case float64:
i := int(n)
req.MaxTokens = &i
delete(extra, "max_tokens")
}
}
if bifrostReq.Params.Dimensions != nil {
req.OutputDimension = bifrostReq.Params.Dimensions
}
if len(extra) > 0 {
req.ExtraParams = extra
}
}
return req, nil
}
// DetermineEmbeddingModelType determines the embedding model type from the model name
func DetermineEmbeddingModelType(model string) (string, error) {
switch {
case strings.Contains(model, "amazon.titan-embed-text"):
return "titan", nil
case strings.Contains(model, "cohere.embed"):
return "cohere", nil
default:
return "", fmt.Errorf("unsupported embedding model: %s", model)
}
}
// ToBifrostEmbeddingResponse converts a BedrockCohereEmbeddingResponse to Bifrost format.
// Bedrock returns embeddings as a raw [][]float32 when response_type is "embeddings_floats"
// (the default, when no embedding_types are requested), and as a typed object when
// response_type is "embeddings_by_type".
func (r *BedrockCohereEmbeddingResponse) ToBifrostEmbeddingResponse() (*schemas.BifrostEmbeddingResponse, error) {
if r == nil {
return nil, fmt.Errorf("nil Bedrock Cohere embedding response")
}
bifrostResponse := &schemas.BifrostEmbeddingResponse{Object: "list"}
switch r.ResponseType {
case "embeddings_by_type":
// Object form: {"float": [[...]], "int8": [[...]], "uint8": [[...]], "binary": [[...]], "ubinary": [[...]], "base64": [...]}
var typed struct {
Float [][]float32 `json:"float"`
Base64 []string `json:"base64"`
Int8 [][]int8 `json:"int8"`
Uint8 [][]int32 `json:"uint8"` // int32 avoids []byte→base64 JSON issue
Binary [][]int8 `json:"binary"`
Ubinary [][]int32 `json:"ubinary"` // int32 avoids []byte→base64 JSON issue
}
if err := json.Unmarshal(r.Embeddings, &typed); err != nil {
return nil, fmt.Errorf("error parsing embeddings_by_type: %w", err)
}
if typed.Float != nil {
for i, emb := range typed.Float {
float64Emb := make([]float64, len(emb))
for j, v := range emb {
float64Emb[j] = float64(v)
}
bifrostResponse.Data = append(bifrostResponse.Data, schemas.EmbeddingData{
Object: "embedding",
Index: i,
Embedding: schemas.EmbeddingStruct{EmbeddingArray: float64Emb},
})
}
}
if typed.Base64 != nil {
for i, emb := range typed.Base64 {
e := emb
bifrostResponse.Data = append(bifrostResponse.Data, schemas.EmbeddingData{
Object: "embedding",
Index: i,
Embedding: schemas.EmbeddingStruct{EmbeddingStr: &e},
})
}
}
for i, emb := range typed.Int8 {
bifrostResponse.Data = append(bifrostResponse.Data, schemas.EmbeddingData{
Object: "embedding",
Index: i,
Embedding: schemas.EmbeddingStruct{EmbeddingInt8Array: emb},
})
}
for i, emb := range typed.Binary {
bifrostResponse.Data = append(bifrostResponse.Data, schemas.EmbeddingData{
Object: "embedding",
Index: i,
Embedding: schemas.EmbeddingStruct{EmbeddingInt8Array: emb},
})
}
for i, emb := range typed.Uint8 {
bifrostResponse.Data = append(bifrostResponse.Data, schemas.EmbeddingData{
Object: "embedding",
Index: i,
Embedding: schemas.EmbeddingStruct{EmbeddingInt32Array: emb},
})
}
for i, emb := range typed.Ubinary {
bifrostResponse.Data = append(bifrostResponse.Data, schemas.EmbeddingData{
Object: "embedding",
Index: i,
Embedding: schemas.EmbeddingStruct{EmbeddingInt32Array: emb},
})
}
default:
// Default / "embeddings_floats": raw array form [[...], [...]]
var floats [][]float32
if err := json.Unmarshal(r.Embeddings, &floats); err != nil {
return nil, fmt.Errorf("error parsing embeddings_floats: %w", err)
}
for i, emb := range floats {
float64Emb := make([]float64, len(emb))
for j, v := range emb {
float64Emb[j] = float64(v)
}
bifrostResponse.Data = append(bifrostResponse.Data, schemas.EmbeddingData{
Object: "embedding",
Index: i,
Embedding: schemas.EmbeddingStruct{EmbeddingArray: float64Emb},
})
}
}
return bifrostResponse, nil
}

View File

@@ -0,0 +1,114 @@
package bedrock
import (
"context"
"testing"
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
"github.com/maximhq/bifrost/core/schemas"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestToBedrockCohereEmbeddingRequest(t *testing.T) {
t.Run("returns error for nil request", func(t *testing.T) {
req, err := ToBedrockCohereEmbeddingRequest(nil)
require.Error(t, err)
assert.Nil(t, req)
assert.Contains(t, err.Error(), "nil")
})
t.Run("returns error for missing input", func(t *testing.T) {
req, err := ToBedrockCohereEmbeddingRequest(&schemas.BifrostEmbeddingRequest{})
require.Error(t, err)
assert.Nil(t, req)
assert.Contains(t, err.Error(), "no input")
})
t.Run("returns error for non-nil but empty input", func(t *testing.T) {
req, err := ToBedrockCohereEmbeddingRequest(&schemas.BifrostEmbeddingRequest{
Input: &schemas.EmbeddingInput{},
})
require.Error(t, err)
assert.Nil(t, req)
assert.Contains(t, err.Error(), "no input")
})
t.Run("single text strips model and extracts typed params", func(t *testing.T) {
text := "hello"
truncate := "RIGHT"
dimensions := 512
bifrostReq := &schemas.BifrostEmbeddingRequest{
Model: "cohere.embed-english-v3",
Input: &schemas.EmbeddingInput{Text: &text},
Params: &schemas.EmbeddingParameters{
Dimensions: &dimensions,
ExtraParams: map[string]interface{}{
"input_type": "search_query",
"embedding_types": []string{"float"},
"truncate": truncate,
"max_tokens": float64(128),
"trace_id": "req-123",
},
},
}
req, err := ToBedrockCohereEmbeddingRequest(bifrostReq)
require.NoError(t, err)
require.NotNil(t, req)
assert.Equal(t, "search_query", req.InputType)
assert.Equal(t, []string{"hello"}, req.Texts)
assert.Equal(t, []string{"float"}, req.EmbeddingTypes)
assert.Equal(t, &dimensions, req.OutputDimension)
assert.Equal(t, 128, *req.MaxTokens)
require.NotNil(t, req.Truncate)
assert.Equal(t, truncate, *req.Truncate)
assert.Equal(t, map[string]interface{}{"trace_id": "req-123"}, req.ExtraParams)
})
t.Run("multiple texts preserve bedrock body shape", func(t *testing.T) {
bifrostReq := &schemas.BifrostEmbeddingRequest{
Model: "cohere.embed-multilingual-v3",
Input: &schemas.EmbeddingInput{Texts: []string{"hello", "world"}},
Params: &schemas.EmbeddingParameters{
ExtraParams: map[string]interface{}{
"input_type": "search_document",
},
},
}
req, err := ToBedrockCohereEmbeddingRequest(bifrostReq)
require.NoError(t, err)
assert.Equal(t, []string{"hello", "world"}, req.Texts)
assert.Equal(t, "search_document", req.InputType)
})
}
func TestToBedrockCohereEmbeddingRequestBodyOmitsModel(t *testing.T) {
text := "hello"
bifrostReq := &schemas.BifrostEmbeddingRequest{
Model: "cohere.embed-english-v3",
Input: &schemas.EmbeddingInput{Text: &text},
Params: &schemas.EmbeddingParameters{
ExtraParams: map[string]interface{}{
"input_type": "search_document",
"embedding_types": []string{"float"},
},
},
}
wireBody, bifrostErr := providerUtils.CheckContextAndGetRequestBody(
context.Background(),
bifrostReq,
func() (providerUtils.RequestBodyWithExtraParams, error) {
return ToBedrockCohereEmbeddingRequest(bifrostReq)
},
)
require.Nil(t, bifrostErr)
assert.NotContains(t, string(wireBody), `"model"`)
assert.JSONEq(t, `{
"input_type": "search_document",
"texts": ["hello"],
"embedding_types": ["float"]
}`, string(wireBody))
}

View File

@@ -0,0 +1,34 @@
package bedrock
import (
"net/http"
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
"github.com/maximhq/bifrost/core/schemas"
"github.com/valyala/fasthttp"
)
func parseBedrockHTTPError(statusCode int, headers http.Header, body []byte) *schemas.BifrostError {
fastResp := fasthttp.AcquireResponse()
defer fasthttp.ReleaseResponse(fastResp)
fastResp.SetStatusCode(statusCode)
for k, values := range headers {
for _, value := range values {
fastResp.Header.Add(k, value)
}
}
fastResp.SetBody(body)
var errorResp BedrockError
bifrostErr := providerUtils.HandleProviderAPIError(fastResp, &errorResp)
if errorResp.Message != "" {
if bifrostErr.Error == nil {
bifrostErr.Error = &schemas.ErrorField{}
}
bifrostErr.Error.Message = errorResp.Message
bifrostErr.Error.Code = errorResp.Code
}
return bifrostErr
}

View File

@@ -0,0 +1,276 @@
package bedrock
import (
"fmt"
"html"
"net/url"
"strings"
"time"
"github.com/bytedance/sonic"
"github.com/maximhq/bifrost/core/schemas"
)
// escapeS3KeyForURL escapes each segment of an S3 key path individually.
// This prevents signature and URL parsing failures with special characters.
// We can't use url.PathEscape on the full key as it escapes "/" to "%2F",
// but we need each segment properly escaped per RFC 3986 for AWS SigV4 signing.
func escapeS3KeyForURL(key string) string {
if key == "" {
return ""
}
parts := strings.Split(key, "/")
for i, p := range parts {
parts[i] = url.PathEscape(p)
}
return strings.Join(parts, "/")
}
// parseS3URI parses an S3 URI (s3://bucket/key or bucket-name) and returns bucket name and key.
func parseS3URI(uri string) (bucket, key string) {
if strings.HasPrefix(uri, "s3://") {
uri = strings.TrimPrefix(uri, "s3://")
parts := strings.SplitN(uri, "/", 2)
bucket = parts[0]
if len(parts) > 1 {
key = parts[1]
}
} else {
// Assume it's just a bucket name
bucket = uri
}
return
}
// S3ListObjectsResponse represents S3 ListObjectsV2 response.
type S3ListObjectsResponse struct {
Contents []S3Object `json:"contents"`
IsTruncated bool `json:"isTruncated"`
NextContinuationToken string `json:"nextContinuationToken,omitempty"`
}
// S3Object represents an S3 object in list response.
type S3Object struct {
Key string `json:"key"`
Size int64 `json:"size"`
LastModified time.Time `json:"lastModified"`
}
// parseS3ListResponse parses S3 ListObjectsV2 XML response.
func parseS3ListResponse(body []byte, resp *S3ListObjectsResponse) error {
// S3 returns XML, so we need to parse it
// Try JSON first (some S3-compatible services return JSON)
if err := sonic.Unmarshal(body, resp); err == nil && len(resp.Contents) > 0 {
return nil
}
// Parse XML using simple string matching for key fields
// This is a lightweight approach that doesn't require encoding/xml
bodyStr := string(body)
// Parse IsTruncated
if strings.Contains(bodyStr, "<IsTruncated>true</IsTruncated>") {
resp.IsTruncated = true
}
// Parse NextContinuationToken
if start := strings.Index(bodyStr, "<NextContinuationToken>"); start >= 0 {
start += len("<NextContinuationToken>")
if end := strings.Index(bodyStr[start:], "</NextContinuationToken>"); end >= 0 {
resp.NextContinuationToken = bodyStr[start : start+end]
}
}
// Parse Contents
contents := bodyStr
for {
start := strings.Index(contents, "<Contents>")
if start < 0 {
break
}
end := strings.Index(contents[start:], "</Contents>")
if end < 0 {
break
}
contentBlock := contents[start : start+end+len("</Contents>")]
contents = contents[start+end+len("</Contents>"):]
obj := S3Object{}
// Parse Key
if keyStart := strings.Index(contentBlock, "<Key>"); keyStart >= 0 {
keyStart += len("<Key>")
if keyEnd := strings.Index(contentBlock[keyStart:], "</Key>"); keyEnd >= 0 {
obj.Key = html.UnescapeString(contentBlock[keyStart : keyStart+keyEnd])
}
}
// Parse Size
if sizeStart := strings.Index(contentBlock, "<Size>"); sizeStart >= 0 {
sizeStart += len("<Size>")
if sizeEnd := strings.Index(contentBlock[sizeStart:], "</Size>"); sizeEnd >= 0 {
sizeStr := contentBlock[sizeStart : sizeStart+sizeEnd]
fmt.Sscanf(sizeStr, "%d", &obj.Size)
}
}
// Parse LastModified
if lmStart := strings.Index(contentBlock, "<LastModified>"); lmStart >= 0 {
lmStart += len("<LastModified>")
if lmEnd := strings.Index(contentBlock[lmStart:], "</LastModified>"); lmEnd >= 0 {
lmStr := contentBlock[lmStart : lmStart+lmEnd]
if t, err := time.Parse(time.RFC3339Nano, lmStr); err == nil {
obj.LastModified = t
}
}
}
if obj.Key != "" {
resp.Contents = append(resp.Contents, obj)
}
}
return nil
}
// ==================== BEDROCK FILE TYPE CONVERTERS ====================
// ToBedrockFileUploadResponse converts a Bifrost file upload response to Bedrock format.
func ToBedrockFileUploadResponse(resp *schemas.BifrostFileUploadResponse) *BedrockFileUploadResponse {
if resp == nil {
return nil
}
// Parse S3 URI to get bucket and key
bucket, key := parseS3URI(resp.ID)
return &BedrockFileUploadResponse{
S3Uri: resp.ID,
Bucket: bucket,
Key: key,
SizeBytes: resp.Bytes,
ContentType: "application/jsonl",
CreatedAt: resp.CreatedAt,
}
}
// ToBedrockFileListResponse converts a Bifrost file list response to Bedrock format.
func ToBedrockFileListResponse(resp *schemas.BifrostFileListResponse) *BedrockFileListResponse {
if resp == nil {
return nil
}
files := make([]BedrockFileInfo, len(resp.Data))
for i, f := range resp.Data {
_, key := parseS3URI(f.ID)
files[i] = BedrockFileInfo{
S3Uri: f.ID,
Key: key,
SizeBytes: f.Bytes,
LastModified: f.CreatedAt,
}
}
return &BedrockFileListResponse{
Files: files,
IsTruncated: resp.HasMore,
}
}
// ToBedrockFileRetrieveResponse converts a Bifrost file retrieve response to Bedrock format.
func ToBedrockFileRetrieveResponse(resp *schemas.BifrostFileRetrieveResponse) *BedrockFileRetrieveResponse {
if resp == nil {
return nil
}
_, key := parseS3URI(resp.ID)
return &BedrockFileRetrieveResponse{
S3Uri: resp.ID,
Key: key,
SizeBytes: resp.Bytes,
LastModified: resp.CreatedAt,
ContentType: "application/jsonl",
}
}
// ToBedrockFileDeleteResponse converts a Bifrost file delete response to Bedrock format.
func ToBedrockFileDeleteResponse(resp *schemas.BifrostFileDeleteResponse) *BedrockFileDeleteResponse {
if resp == nil {
return nil
}
return &BedrockFileDeleteResponse{
S3Uri: resp.ID,
Deleted: resp.Deleted,
}
}
// ToBedrockFileContentResponse converts a Bifrost file content response to Bedrock format.
func ToBedrockFileContentResponse(resp *schemas.BifrostFileContentResponse) *BedrockFileContentResponse {
if resp == nil {
return nil
}
return &BedrockFileContentResponse{
S3Uri: resp.FileID,
Content: resp.Content,
ContentType: resp.ContentType,
SizeBytes: int64(len(resp.Content)),
}
}
// ==================== S3 API XML FORMATTERS ====================
// ToS3ListObjectsV2XML converts a Bifrost file list response to S3 ListObjectsV2 XML format.
func ToS3ListObjectsV2XML(resp *schemas.BifrostFileListResponse, bucket, prefix string, maxKeys int) []byte {
if resp == nil {
return []byte(`<?xml version="1.0" encoding="UTF-8"?><ListBucketResult xmlns="http://s3.amazonaws.com/doc/2006-03-01/"></ListBucketResult>`)
}
var sb strings.Builder
sb.WriteString(`<?xml version="1.0" encoding="UTF-8"?>`)
sb.WriteString(`<ListBucketResult xmlns="http://s3.amazonaws.com/doc/2006-03-01/">`)
sb.WriteString(fmt.Sprintf("<Name>%s</Name>", bucket))
sb.WriteString(fmt.Sprintf("<Prefix>%s</Prefix>", prefix))
sb.WriteString(fmt.Sprintf("<KeyCount>%d</KeyCount>", len(resp.Data)))
sb.WriteString(fmt.Sprintf("<MaxKeys>%d</MaxKeys>", maxKeys))
if resp.HasMore {
sb.WriteString("<IsTruncated>true</IsTruncated>")
if resp.After != nil && *resp.After != "" {
sb.WriteString(fmt.Sprintf("<NextContinuationToken>%s</NextContinuationToken>", *resp.After))
}
} else {
sb.WriteString("<IsTruncated>false</IsTruncated>")
}
for _, f := range resp.Data {
// Extract key from S3 URI
_, key := parseS3URI(f.ID)
sb.WriteString("<Contents>")
sb.WriteString(fmt.Sprintf("<Key>%s</Key>", key))
sb.WriteString(fmt.Sprintf("<Size>%d</Size>", f.Bytes))
if f.CreatedAt > 0 {
sb.WriteString(fmt.Sprintf("<LastModified>%s</LastModified>", time.Unix(f.CreatedAt, 0).UTC().Format(time.RFC3339)))
}
sb.WriteString("<StorageClass>STANDARD</StorageClass>")
sb.WriteString("</Contents>")
}
sb.WriteString("</ListBucketResult>")
return []byte(sb.String())
}
// ToS3ErrorXML converts an error to S3 error XML format.
func ToS3ErrorXML(code, message, resource, requestID string) []byte {
var sb strings.Builder
sb.WriteString(`<?xml version="1.0" encoding="UTF-8"?>`)
sb.WriteString("<Error>")
sb.WriteString(fmt.Sprintf("<Code>%s</Code>", code))
sb.WriteString(fmt.Sprintf("<Message>%s</Message>", message))
sb.WriteString(fmt.Sprintf("<Resource>%s</Resource>", resource))
sb.WriteString(fmt.Sprintf("<RequestId>%s</RequestId>", requestID))
sb.WriteString("</Error>")
return []byte(sb.String())
}

View File

@@ -0,0 +1,742 @@
package bedrock
import (
"encoding/base64"
"fmt"
"strconv"
"strings"
"github.com/maximhq/bifrost/core/schemas"
)
// mapQualityToBedrock maps quality values to Bedrock format:
// - "low" and "medium" -> "standard"
// - "high" -> "premium"
// - "standard" and "premium" (case-insensitive) -> pass through as lowercase ("standard"/"premium")
func mapQualityToBedrock(quality *string) *string {
if quality == nil {
return nil
}
qualityLower := strings.ToLower(strings.TrimSpace(*quality))
switch qualityLower {
case "low", "medium":
return schemas.Ptr("standard")
case "high":
return schemas.Ptr("premium")
case "standard":
return schemas.Ptr("standard")
case "premium":
return schemas.Ptr("premium")
default:
return quality
}
}
// isStabilityAIModel returns true if the model is a Stability AI model (contains "stability.")
func isStabilityAIModel(model string) bool {
return strings.Contains(strings.ToLower(model), "stability.")
}
// isPromptOnlyImageGenerationModel returns true for image generation models that use a flat
// {"prompt": "..."} payload (no taskType field). Covers Vertex Imagen and similar models.
// Stability AI is excluded here — it's handled separately because it also supports image edit.
func isPromptOnlyImageGenerationModel(model string) bool {
m := strings.ToLower(model)
return strings.Contains(m, "image")
}
// ToStabilityAIImageGenerationRequest converts a Bifrost image generation request to the Stability AI
// flat request format used by Bedrock (stability.stable-image-* models).
func ToStabilityAIImageGenerationRequest(request *schemas.BifrostImageGenerationRequest) (*StabilityAIImageGenerationRequest, error) {
if request == nil {
return nil, fmt.Errorf("request is nil")
}
if request.Input == nil {
return nil, fmt.Errorf("request input is required")
}
req := &StabilityAIImageGenerationRequest{
Prompt: request.Input.Prompt,
}
if request.Params != nil {
if request.Params.AspectRatio != nil {
req.AspectRatio = request.Params.AspectRatio
}
if request.Params.OutputFormat != nil {
req.OutputFormat = request.Params.OutputFormat
}
if request.Params.Seed != nil {
req.Seed = request.Params.Seed
}
if request.Params.NegativePrompt != nil {
req.NegativePrompt = request.Params.NegativePrompt
}
if request.Params.ExtraParams != nil {
// aspect_ratio may also arrive via ExtraParams if not in knownFields; skip if already set
if req.AspectRatio == nil {
if ar, ok := schemas.SafeExtractStringPointer(request.Params.ExtraParams["aspect_ratio"]); ok {
delete(request.Params.ExtraParams, "aspect_ratio")
req.AspectRatio = ar
}
}
req.ExtraParams = request.Params.ExtraParams
}
}
return req, nil
}
// ToBedrockImageGenerationRequest converts a Bifrost image generation request to a Bedrock image generation request
func ToBedrockImageGenerationRequest(request *schemas.BifrostImageGenerationRequest) (*BedrockImageGenerationRequest, error) {
if request == nil {
return nil, fmt.Errorf("request is nil")
}
if request.Input == nil {
return nil, fmt.Errorf("request input is required")
}
bedrockReq := &BedrockImageGenerationRequest{
TaskType: schemas.Ptr(TaskTypeTextImage),
TextToImageParams: &BedrockTextToImageParams{
Text: request.Input.Prompt,
},
ImageGenerationConfig: &ImageGenerationConfig{},
}
if request.Params != nil {
if request.Params.N != nil {
bedrockReq.ImageGenerationConfig.NumberOfImages = request.Params.N
}
if request.Params.NegativePrompt != nil {
bedrockReq.TextToImageParams.NegativeText = request.Params.NegativePrompt
}
if request.Params.Seed != nil {
bedrockReq.ImageGenerationConfig.Seed = request.Params.Seed
}
if request.Params.Quality != nil {
bedrockReq.ImageGenerationConfig.Quality = mapQualityToBedrock(request.Params.Quality)
}
if request.Params.Style != nil {
bedrockReq.TextToImageParams.Style = request.Params.Style
}
if request.Params.Size != nil && strings.TrimSpace(strings.ToLower(*request.Params.Size)) != "auto" {
size := strings.Split(strings.TrimSpace(strings.ToLower(*request.Params.Size)), "x")
if len(size) != 2 {
return nil, fmt.Errorf("invalid size format: expected 'WIDTHxHEIGHT', got %q", *request.Params.Size)
}
width, err := strconv.Atoi(size[0])
if err != nil {
return nil, fmt.Errorf("invalid width in size %q: %w", *request.Params.Size, err)
}
height, err := strconv.Atoi(size[1])
if err != nil {
return nil, fmt.Errorf("invalid height in size %q: %w", *request.Params.Size, err)
}
bedrockReq.ImageGenerationConfig.Width = schemas.Ptr(width)
bedrockReq.ImageGenerationConfig.Height = schemas.Ptr(height)
}
if request.Params.ExtraParams != nil {
if cfgScale, ok := schemas.SafeExtractFloat64Pointer(request.Params.ExtraParams["cfgScale"]); ok {
delete(request.Params.ExtraParams, "cfgScale")
bedrockReq.ImageGenerationConfig.CfgScale = cfgScale
}
bedrockReq.ExtraParams = request.Params.ExtraParams
}
}
return bedrockReq, nil
}
// ToStabilityAIImageGenerationResponse converts a BifrostImageGenerationResponse back to
// the native Bedrock invoke API response format used by Stability AI models.
// Stability AI models use the same BedrockImageGenerationResponse format as Titan/Nova Canvas.
func ToStabilityAIImageGenerationResponse(response *schemas.BifrostImageGenerationResponse) (*BedrockImageGenerationResponse, error) {
if response == nil {
return nil, fmt.Errorf("response is nil")
}
result := &BedrockImageGenerationResponse{}
for _, d := range response.Data {
result.Images = append(result.Images, d.B64JSON)
}
if response.ImageGenerationResponseParameters != nil {
result.FinishReasons = response.ImageGenerationResponseParameters.FinishReasons
result.Seeds = response.ImageGenerationResponseParameters.Seeds
}
return result, nil
}
// ToBedrockImageVariationRequest converts a Bifrost image variation request to a Bedrock image variation request
func ToBedrockImageVariationRequest(request *schemas.BifrostImageVariationRequest) (*BedrockImageVariationRequest, error) {
if request == nil {
return nil, fmt.Errorf("request is nil")
}
if request.Input == nil || request.Input.Image.Image == nil || len(request.Input.Image.Image) == 0 {
return nil, fmt.Errorf("request.Input.Image is required")
}
bedrockReq := &BedrockImageVariationRequest{
TaskType: schemas.Ptr(TaskTypeImageVariation),
ImageVariationParams: &BedrockImageVariationParams{
Images: []string{},
},
ImageGenerationConfig: &ImageGenerationConfig{},
}
// Convert all images to base64 strings
// Primary image from Input.Image
imageBase64 := base64.StdEncoding.EncodeToString(request.Input.Image.Image)
bedrockReq.ImageVariationParams.Images = append(bedrockReq.ImageVariationParams.Images, imageBase64)
// Additional images from ExtraParams (stored as [][]byte)
if request.Params != nil && request.Params.ExtraParams != nil {
if additionalImages, ok := request.Params.ExtraParams["images"]; ok {
delete(request.Params.ExtraParams, "images")
// Handle array of byte arrays (stored by HTTP handler)
if imagesArray, ok := additionalImages.([][]byte); ok {
for _, imgBytes := range imagesArray {
if len(imgBytes) > 0 {
additionalBase64 := base64.StdEncoding.EncodeToString(imgBytes)
bedrockReq.ImageVariationParams.Images = append(bedrockReq.ImageVariationParams.Images, additionalBase64)
}
}
}
}
// Extract optional fields from ExtraParams
if prompt, ok := schemas.SafeExtractStringPointer(request.Params.ExtraParams["prompt"]); ok {
delete(request.Params.ExtraParams, "prompt")
bedrockReq.ImageVariationParams.Text = prompt
}
if negativeText, ok := schemas.SafeExtractStringPointer(request.Params.ExtraParams["negativeText"]); ok {
delete(request.Params.ExtraParams, "negativeText")
bedrockReq.ImageVariationParams.NegativeText = negativeText
}
if similarityStrength, ok := schemas.SafeExtractFloat64Pointer(request.Params.ExtraParams["similarityStrength"]); ok {
delete(request.Params.ExtraParams, "similarityStrength")
// Validate similarityStrength range (0.2 to 1.0)
if *similarityStrength < 0.2 || *similarityStrength > 1.0 {
return nil, fmt.Errorf("similarityStrength must be between 0.2 and 1.0, got %f", *similarityStrength)
}
bedrockReq.ImageVariationParams.SimilarityStrength = similarityStrength
}
bedrockReq.ExtraParams = request.Params.ExtraParams
}
// Map standard params to ImageGenerationConfig
if request.Params != nil {
if request.Params.N != nil {
bedrockReq.ImageGenerationConfig.NumberOfImages = request.Params.N
}
if request.Params.Size != nil && strings.TrimSpace(strings.ToLower(*request.Params.Size)) != "auto" {
size := strings.Split(strings.TrimSpace(strings.ToLower(*request.Params.Size)), "x")
if len(size) != 2 {
return nil, fmt.Errorf("invalid size format: expected 'WIDTHxHEIGHT', got %q", *request.Params.Size)
}
width, err := strconv.Atoi(size[0])
if err != nil {
return nil, fmt.Errorf("invalid width in size %q: %w", *request.Params.Size, err)
}
height, err := strconv.Atoi(size[1])
if err != nil {
return nil, fmt.Errorf("invalid height in size %q: %w", *request.Params.Size, err)
}
bedrockReq.ImageGenerationConfig.Width = schemas.Ptr(width)
bedrockReq.ImageGenerationConfig.Height = schemas.Ptr(height)
}
// Extract quality and cfgScale from ExtraParams
if request.Params.ExtraParams != nil {
if quality, ok := schemas.SafeExtractStringPointer(request.Params.ExtraParams["quality"]); ok {
bedrockReq.ImageGenerationConfig.Quality = mapQualityToBedrock(quality)
}
if cfgScale, ok := schemas.SafeExtractFloat64Pointer(request.Params.ExtraParams["cfgScale"]); ok {
bedrockReq.ImageGenerationConfig.CfgScale = cfgScale
}
}
}
return bedrockReq, nil
}
// ToBedrockImageEditRequest converts a Bifrost image edit request to a Bedrock image edit request
func ToBedrockImageEditRequest(request *schemas.BifrostImageEditRequest) (*BedrockImageEditRequest, error) {
// Validate request
if request == nil || request.Input == nil {
return nil, fmt.Errorf("request or input is nil")
}
if len(request.Input.Images) == 0 || len(request.Input.Images[0].Image) == 0 {
return nil, fmt.Errorf("at least one image is required")
}
// Validate and extract type (required)
if request.Params == nil || request.Params.Type == nil {
return nil, fmt.Errorf("type field is required (must be inpainting, outpainting, or background_removal)")
}
editType := strings.ToLower(*request.Params.Type)
// Convert first image to base64
imageBase64 := base64.StdEncoding.EncodeToString(request.Input.Images[0].Image)
bedrockReq := &BedrockImageEditRequest{}
switch editType {
case "inpainting":
bedrockReq.TaskType = schemas.Ptr(TaskTypeInpainting)
bedrockReq.InPaintingParams = buildInPaintingParams(imageBase64, request)
bedrockReq.ImageGenerationConfig = buildImageGenerationConfig(request.Params)
case "outpainting":
bedrockReq.TaskType = schemas.Ptr(TaskTypeOutpainting)
bedrockReq.OutPaintingParams = buildOutPaintingParams(imageBase64, request)
bedrockReq.ImageGenerationConfig = buildImageGenerationConfig(request.Params)
case "background_removal":
bedrockReq.TaskType = schemas.Ptr(TaskTypeBackgroundRemoval)
bedrockReq.BackgroundRemovalParams = &BedrockBackgroundRemovalParams{
Image: imageBase64,
}
default:
return nil, fmt.Errorf("unsupported type for Bedrock: %s (must be inpainting, outpainting, or background_removal)", editType)
}
bedrockReq.ExtraParams = request.Params.ExtraParams
return bedrockReq, nil
}
// Helper functions
func buildInPaintingParams(imageBase64 string, request *schemas.BifrostImageEditRequest) *BedrockInPaintingParams {
params := &BedrockInPaintingParams{
Image: imageBase64,
Text: request.Input.Prompt,
}
if request.Params.NegativePrompt != nil {
params.NegativeText = request.Params.NegativePrompt
}
if request.Params.ExtraParams != nil {
if maskPrompt, ok := schemas.SafeExtractStringPointer(request.Params.ExtraParams["mask_prompt"]); ok {
delete(request.Params.ExtraParams, "mask_prompt")
params.MaskPrompt = maskPrompt
}
if returnMask, ok := schemas.SafeExtractBoolPointer(request.Params.ExtraParams["return_mask"]); ok {
delete(request.Params.ExtraParams, "return_mask")
params.ReturnMask = returnMask
}
}
// Convert mask to base64 if present
if len(request.Params.Mask) > 0 {
maskBase64 := base64.StdEncoding.EncodeToString(request.Params.Mask)
params.MaskImage = &maskBase64
}
return params
}
func buildOutPaintingParams(imageBase64 string, request *schemas.BifrostImageEditRequest) *BedrockOutPaintingParams {
params := &BedrockOutPaintingParams{
Text: request.Input.Prompt,
Image: imageBase64,
}
if request.Params.NegativePrompt != nil {
params.NegativeText = request.Params.NegativePrompt
}
if request.Params.ExtraParams != nil {
if maskPrompt, ok := schemas.SafeExtractStringPointer(request.Params.ExtraParams["mask_prompt"]); ok {
delete(request.Params.ExtraParams, "mask_prompt")
params.MaskPrompt = maskPrompt
}
if returnMask, ok := schemas.SafeExtractBoolPointer(request.Params.ExtraParams["return_mask"]); ok {
delete(request.Params.ExtraParams, "return_mask")
params.ReturnMask = returnMask
}
if outPaintingMode, ok := schemas.SafeExtractStringPointer(request.Params.ExtraParams["outpainting_mode"]); ok {
// Validate mode
mode := strings.ToUpper(*outPaintingMode)
if mode == "DEFAULT" || mode == "PRECISE" {
delete(request.Params.ExtraParams, "outpainting_mode")
params.OutPaintingMode = &mode
}
}
}
// Convert mask to base64 if present
if len(request.Params.Mask) > 0 {
maskBase64 := base64.StdEncoding.EncodeToString(request.Params.Mask)
params.MaskImage = &maskBase64
}
return params
}
func buildImageGenerationConfig(params *schemas.ImageEditParameters) *ImageGenerationConfig {
config := &ImageGenerationConfig{}
if params.N != nil {
config.NumberOfImages = params.N
}
// Parse size (reuse logic from image generation)
if params.Size != nil && strings.TrimSpace(strings.ToLower(*params.Size)) != "auto" {
size := strings.Split(strings.TrimSpace(strings.ToLower(*params.Size)), "x")
if len(size) == 2 {
width, err := strconv.Atoi(size[0])
if err == nil {
height, err := strconv.Atoi(size[1])
if err == nil {
config.Width = schemas.Ptr(width)
config.Height = schemas.Ptr(height)
}
}
}
}
if params.Quality != nil {
config.Quality = mapQualityToBedrock(params.Quality)
}
if params.Seed != nil {
config.Seed = params.Seed
}
if params.ExtraParams != nil {
if cfgScale, ok := schemas.SafeExtractFloat64Pointer(params.ExtraParams["cfgScale"]); ok {
delete(params.ExtraParams, "cfgScale")
config.CfgScale = cfgScale
}
}
return config
}
// getStabilityAITaskTypeFromParams maps the generic BifrostImageEditParameters.Type value
// to a Stability AI task type string. Returns "" if the value is not a recognized Stability AI task type.
func getStabilityAITaskTypeFromParams(t string) string {
switch strings.ToLower(t) {
case "inpainting", "inpaint":
return "inpaint"
case "outpainting", "outpaint":
return "outpaint"
case "background_removal", "remove_background":
return "remove-bg"
case "erase_object":
return "erase-object"
case "upscale_fast":
return "upscale-fast"
case "upscale_creative":
return "upscale-creative"
case "upscale_conservative":
return "upscale-conservative"
case "recolor":
return "recolor"
case "search_replace":
return "search-replace"
case "control_sketch":
return "control-sketch"
case "control_structure":
return "control-structure"
case "style_guide":
return "style-guide"
case "style_transfer":
return "style-transfer"
default:
return ""
}
}
// getStabilityAIEditTaskType infers the Stability AI edit task from the model name.
// Returns an error if the model name does not match any known pattern.
func getStabilityAIEditTaskType(model string) (string, error) {
m := strings.ToLower(model)
switch {
case strings.Contains(m, "stable-creative-upscale"):
return "upscale-creative", nil
case strings.Contains(m, "stable-conservative-upscale"):
return "upscale-conservative", nil
case strings.Contains(m, "stable-fast-upscale"):
return "upscale-fast", nil
case strings.Contains(m, "stable-image-inpaint"):
return "inpaint", nil
case strings.Contains(m, "stable-outpaint"):
return "outpaint", nil
case strings.Contains(m, "stable-image-search-recolor"):
return "recolor", nil
case strings.Contains(m, "stable-image-search-replace"):
return "search-replace", nil
case strings.Contains(m, "stable-image-erase-object"):
return "erase-object", nil
case strings.Contains(m, "stable-image-remove-background"):
return "remove-bg", nil
case strings.Contains(m, "stable-image-control-sketch"):
return "control-sketch", nil
case strings.Contains(m, "stable-image-control-structure"):
return "control-structure", nil
case strings.Contains(m, "stable-image-style-guide"):
return "style-guide", nil
case strings.Contains(m, "stable-style-transfer"):
return "style-transfer", nil
default:
return "", fmt.Errorf("cannot determine task type from stability ai model name %q", model)
}
}
// ToStabilityAIImageEditRequest converts a Bifrost image edit request to the Stability AI flat request
// format used by Bedrock edit models. Only fields valid for the detected task type are populated.
// deployment is the resolved model identifier (after applying any deployment alias mapping); it is
// used for task-type inference so that alias-mapped models route correctly.
func ToStabilityAIImageEditRequest(request *schemas.BifrostImageEditRequest, deployment string) (*StabilityAIImageEditRequest, error) {
if request == nil || request.Input == nil {
return nil, fmt.Errorf("request or input is nil")
}
var taskType string
if request.Params != nil && request.Params.Type != nil {
taskType = getStabilityAITaskTypeFromParams(*request.Params.Type)
}
if taskType == "" {
var err error
taskType, err = getStabilityAIEditTaskType(deployment)
if err != nil {
return nil, err
}
}
req := &StabilityAIImageEditRequest{}
// Image sourcing
if taskType == "style-transfer" {
if len(request.Input.Images) != 2 {
return nil, fmt.Errorf("style-transfer requires exactly two images: init_image and style_image")
}
if len(request.Input.Images[0].Image) == 0 || len(request.Input.Images[1].Image) == 0 {
return nil, fmt.Errorf("style-transfer requires non-empty init_image and style_image")
}
initB64 := base64.StdEncoding.EncodeToString(request.Input.Images[0].Image)
styleB64 := base64.StdEncoding.EncodeToString(request.Input.Images[1].Image)
req.InitImage = &initB64
req.StyleImage = &styleB64
} else {
if len(request.Input.Images) == 0 || len(request.Input.Images[0].Image) == 0 {
return nil, fmt.Errorf("at least one image is required")
}
imageB64 := base64.StdEncoding.EncodeToString(request.Input.Images[0].Image)
req.Image = &imageB64
}
// Common fields populated based on task allowlist
prompt := request.Input.Prompt
switch taskType {
case "inpaint", "recolor", "search-replace", "control-sketch", "control-structure",
"style-guide", "upscale-creative", "upscale-conservative", "outpaint", "style-transfer":
req.Prompt = &prompt
}
// Negative prompt
if request.Params != nil && request.Params.NegativePrompt != nil {
switch taskType {
case "inpaint", "outpaint", "recolor", "search-replace", "control-sketch",
"control-structure", "style-guide", "upscale-creative", "upscale-conservative", "style-transfer":
req.NegativePrompt = request.Params.NegativePrompt
}
}
// Seed
if request.Params != nil && request.Params.Seed != nil {
switch taskType {
case "inpaint", "outpaint", "recolor", "search-replace", "erase-object", "control-sketch",
"control-structure", "style-guide", "upscale-creative", "upscale-conservative", "style-transfer":
req.Seed = request.Params.Seed
}
}
// Mask (from Params.Mask bytes)
if request.Params != nil && len(request.Params.Mask) > 0 {
switch taskType {
case "inpaint", "erase-object":
maskB64 := base64.StdEncoding.EncodeToString(request.Params.Mask)
req.Mask = &maskB64
}
}
// ExtraParams
if request.Params != nil {
// Typed OutputFormat takes priority over ExtraParams
if request.Params.OutputFormat != nil {
req.OutputFormat = request.Params.OutputFormat
}
if request.Params.ExtraParams != nil {
ep := make(map[string]interface{}, len(request.Params.ExtraParams))
for k, v := range request.Params.ExtraParams {
ep[k] = v
}
// output_format — all tasks (fallback if not already set by typed field)
if req.OutputFormat == nil {
if v, ok := schemas.SafeExtractStringPointer(ep["output_format"]); ok {
delete(ep, "output_format")
req.OutputFormat = v
}
}
// style_preset
switch taskType {
case "inpaint", "outpaint", "recolor", "search-replace", "control-sketch",
"control-structure", "style-guide", "upscale-creative":
if v, ok := schemas.SafeExtractStringPointer(ep["style_preset"]); ok {
delete(ep, "style_preset")
req.StylePreset = v
}
}
// grow_mask
switch taskType {
case "inpaint", "recolor", "search-replace", "erase-object":
if v, ok := schemas.SafeExtractIntPointer(ep["grow_mask"]); ok {
delete(ep, "grow_mask")
req.GrowMask = v
}
}
// outpaint directional fields
if taskType == "outpaint" {
if v, ok := schemas.SafeExtractIntPointer(ep["left"]); ok {
delete(ep, "left")
req.Left = v
}
if v, ok := schemas.SafeExtractIntPointer(ep["right"]); ok {
delete(ep, "right")
req.Right = v
}
if v, ok := schemas.SafeExtractIntPointer(ep["up"]); ok {
delete(ep, "up")
req.Up = v
}
if v, ok := schemas.SafeExtractIntPointer(ep["down"]); ok {
delete(ep, "down")
req.Down = v
}
}
// creativity
switch taskType {
case "upscale-creative", "upscale-conservative", "outpaint":
if v, ok := schemas.SafeExtractFloat64Pointer(ep["creativity"]); ok {
delete(ep, "creativity")
req.Creativity = v
}
}
// select_prompt (recolor)
if taskType == "recolor" {
if v, ok := schemas.SafeExtractStringPointer(ep["select_prompt"]); ok {
delete(ep, "select_prompt")
req.SelectPrompt = v
}
}
// search_prompt (search-replace)
if taskType == "search-replace" {
if v, ok := schemas.SafeExtractStringPointer(ep["search_prompt"]); ok {
delete(ep, "search_prompt")
req.SearchPrompt = v
}
}
// control_strength
switch taskType {
case "control-sketch", "control-structure":
if v, ok := schemas.SafeExtractFloat64Pointer(ep["control_strength"]); ok {
delete(ep, "control_strength")
req.ControlStrength = v
}
}
// style-guide fields
if taskType == "style-guide" {
if v, ok := schemas.SafeExtractStringPointer(ep["aspect_ratio"]); ok {
delete(ep, "aspect_ratio")
req.AspectRatio = v
}
if v, ok := schemas.SafeExtractFloat64Pointer(ep["fidelity"]); ok {
delete(ep, "fidelity")
req.Fidelity = v
}
}
// style-transfer fields
if taskType == "style-transfer" {
if v, ok := schemas.SafeExtractFloat64Pointer(ep["style_strength"]); ok {
delete(ep, "style_strength")
req.StyleStrength = v
}
if v, ok := schemas.SafeExtractFloat64Pointer(ep["composition_fidelity"]); ok {
delete(ep, "composition_fidelity")
req.CompositionFidelity = v
}
if v, ok := schemas.SafeExtractFloat64Pointer(ep["change_strength"]); ok {
delete(ep, "change_strength")
req.ChangeStrength = v
}
}
req.ExtraParams = ep
}
}
// Validate required per-task fields
if taskType == "recolor" && (req.SelectPrompt == nil || *req.SelectPrompt == "") {
return nil, fmt.Errorf("select_prompt is required for stability ai recolor task")
}
if taskType == "search-replace" && (req.SearchPrompt == nil || *req.SearchPrompt == "") {
return nil, fmt.Errorf("search_prompt is required for stability ai search-replace task")
}
return req, nil
}
// ToBifrostImageGenerationResponse converts a Bedrock image generation response to a Bifrost image generation response
func ToBifrostImageGenerationResponse(response *BedrockImageGenerationResponse) *schemas.BifrostImageGenerationResponse {
if response == nil {
return nil
}
bifrostResponse := &schemas.BifrostImageGenerationResponse{}
if len(response.FinishReasons) > 0 || len(response.Seeds) > 0 {
bifrostResponse.ImageGenerationResponseParameters = &schemas.ImageGenerationResponseParameters{
FinishReasons: append([]*string(nil), response.FinishReasons...),
Seeds: append([]int(nil), response.Seeds...),
}
}
for index, image := range response.Images {
bifrostResponse.Data = append(bifrostResponse.Data, schemas.ImageData{
B64JSON: image,
Index: index,
})
}
return bifrostResponse
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,130 @@
package bedrock
import (
"strings"
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
"github.com/maximhq/bifrost/core/schemas"
)
// BedrockRerankRequest is the Bedrock Agent Runtime rerank request body.
type BedrockRerankRequest struct {
Queries []BedrockRerankQuery `json:"queries"`
Sources []BedrockRerankSource `json:"sources"`
RerankingConfiguration BedrockRerankingConfiguration `json:"rerankingConfiguration"`
}
// GetExtraParams implements RequestBodyWithExtraParams.
func (*BedrockRerankRequest) GetExtraParams() map[string]interface{} {
return nil
}
const (
bedrockRerankQueryTypeText = "TEXT"
bedrockRerankSourceTypeInline = "INLINE"
bedrockRerankInlineDocumentTypeText = "TEXT"
bedrockRerankConfigurationTypeBedrock = "BEDROCK_RERANKING_MODEL"
)
type BedrockRerankQuery struct {
Type string `json:"type"`
TextQuery BedrockRerankTextRef `json:"textQuery"`
}
type BedrockRerankSource struct {
Type string `json:"type"`
InlineDocumentSource BedrockRerankInlineSource `json:"inlineDocumentSource"`
}
type BedrockRerankInlineSource struct {
Type string `json:"type"`
TextDocument BedrockRerankTextValue `json:"textDocument"`
}
type BedrockRerankTextRef struct {
Text string `json:"text"`
}
type BedrockRerankTextValue struct {
Text string `json:"text"`
}
type BedrockRerankingConfiguration struct {
Type string `json:"type"`
BedrockRerankingConfiguration BedrockRerankingModelConfiguration `json:"bedrockRerankingConfiguration"`
}
type BedrockRerankingModelConfiguration struct {
ModelConfiguration BedrockRerankModelConfiguration `json:"modelConfiguration"`
NumberOfResults *int `json:"numberOfResults,omitempty"`
}
type BedrockRerankModelConfiguration struct {
ModelARN string `json:"modelArn"`
AdditionalModelRequestFields map[string]interface{} `json:"additionalModelRequestFields,omitempty"`
}
// BedrockRerankResponse is the Bedrock Agent Runtime rerank response body.
type BedrockRerankResponse struct {
Results []BedrockRerankResult `json:"results"`
NextToken *string `json:"nextToken,omitempty"`
}
type BedrockRerankResult struct {
Index int `json:"index"`
RelevanceScore float64 `json:"relevanceScore"`
Document *BedrockRerankResponseDocument `json:"document,omitempty"`
}
type BedrockRerankResponseDocument struct {
Type string `json:"type,omitempty"`
TextDocument *BedrockRerankTextValue `json:"textDocument,omitempty"`
}
func (response *BedrockListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels schemas.WhiteList, blacklistedModels schemas.BlackList, aliases map[string]string, unfiltered bool) *schemas.BifrostListModelsResponse {
if response == nil {
return nil
}
bifrostResponse := &schemas.BifrostListModelsResponse{
Data: make([]schemas.Model, 0, len(response.ModelSummaries)),
}
pipeline := &providerUtils.ListModelsPipeline{
AllowedModels: allowedModels,
BlacklistedModels: blacklistedModels,
Aliases: aliases,
Unfiltered: unfiltered,
ProviderKey: providerKey,
MatchFns: providerUtils.DefaultMatchFns(),
}
if pipeline.ShouldEarlyExit() {
return bifrostResponse
}
included := make(map[string]bool)
for _, model := range response.ModelSummaries {
for _, result := range pipeline.FilterModel(model.ModelID) {
modelEntry := schemas.Model{
ID: string(providerKey) + "/" + result.ResolvedID,
Name: schemas.Ptr(model.ModelName),
OwnedBy: schemas.Ptr(model.ProviderName),
Architecture: &schemas.Architecture{
InputModalities: model.InputModalities,
OutputModalities: model.OutputModalities,
},
}
if result.AliasValue != "" {
modelEntry.Alias = schemas.Ptr(result.AliasValue)
}
bifrostResponse.Data = append(bifrostResponse.Data, modelEntry)
included[strings.ToLower(result.ResolvedID)] = true
}
}
bifrostResponse.Data = append(bifrostResponse.Data,
pipeline.BackfillModels(included)...)
return bifrostResponse
}

View File

@@ -0,0 +1,55 @@
package bedrock
import (
"encoding/json"
"testing"
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
schemas "github.com/maximhq/bifrost/core/schemas"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestPayloadOrdering_BedrockConverseRequest(t *testing.T) {
req := &BedrockConverseRequest{
Messages: []BedrockMessage{
{
Role: "user",
Content: []BedrockContentBlock{
{Text: schemas.Ptr("hello")},
},
},
},
InferenceConfig: &BedrockInferenceConfig{
Temperature: schemas.Ptr(0.7),
MaxTokens: schemas.Ptr(1024),
},
ToolConfig: &BedrockToolConfig{
Tools: []BedrockTool{
{
ToolSpec: &BedrockToolSpec{
Name: "get_weather",
Description: schemas.Ptr("Get weather"),
InputSchema: BedrockToolInputSchema{
JSON: json.RawMessage(`{"type":"object","properties":{"location":{"type":"string"}},"required":["location"]}`),
},
},
},
},
},
}
result, err := providerUtils.MarshalSorted(req)
require.NoError(t, err)
golden := `{"messages":[{"role":"user","content":[{"text":"hello"}]}],"inferenceConfig":{"maxTokens":1024,"temperature":0.7},"toolConfig":{"tools":[{"toolSpec":{"name":"get_weather","description":"Get weather","inputSchema":{"json":{"type":"object","properties":{"location":{"type":"string"}},"required":["location"]}}}}]}}`
assert.Equal(t, golden, string(result), "payload field ordering changed — if intentional, update the golden string")
// Determinism: 100 iterations must produce identical bytes
for i := 0; i < 100; i++ {
iter, err := providerUtils.MarshalSorted(req)
require.NoError(t, err)
assert.Equal(t, string(result), string(iter), "non-deterministic marshal output on iteration %d", i)
}
}

View File

@@ -0,0 +1,168 @@
package bedrock
import (
"fmt"
"sort"
"strings"
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
"github.com/maximhq/bifrost/core/schemas"
)
// ToBedrockRerankRequest converts a Bifrost rerank request into Bedrock Agent Runtime format.
func ToBedrockRerankRequest(bifrostReq *schemas.BifrostRerankRequest, modelARN string) (*BedrockRerankRequest, error) {
if bifrostReq == nil {
return nil, fmt.Errorf("bifrost rerank request is nil")
}
if strings.TrimSpace(modelARN) == "" {
return nil, fmt.Errorf("bedrock rerank model ARN is empty")
}
if len(bifrostReq.Documents) == 0 {
return nil, fmt.Errorf("documents are required for rerank request")
}
bedrockReq := &BedrockRerankRequest{
Queries: []BedrockRerankQuery{
{
Type: bedrockRerankQueryTypeText,
TextQuery: BedrockRerankTextRef{
Text: bifrostReq.Query,
},
},
},
Sources: make([]BedrockRerankSource, len(bifrostReq.Documents)),
RerankingConfiguration: BedrockRerankingConfiguration{
Type: bedrockRerankConfigurationTypeBedrock,
BedrockRerankingConfiguration: BedrockRerankingModelConfiguration{
ModelConfiguration: BedrockRerankModelConfiguration{
ModelARN: modelARN,
},
},
},
}
for i, doc := range bifrostReq.Documents {
bedrockReq.Sources[i] = BedrockRerankSource{
Type: bedrockRerankSourceTypeInline,
InlineDocumentSource: BedrockRerankInlineSource{
Type: bedrockRerankInlineDocumentTypeText,
TextDocument: BedrockRerankTextValue{
Text: doc.Text,
},
},
}
}
if bifrostReq.Params == nil {
return bedrockReq, nil
}
if bifrostReq.Params.TopN != nil {
topN := *bifrostReq.Params.TopN
if topN < 1 {
return nil, fmt.Errorf("top_n must be at least 1")
}
if topN > len(bifrostReq.Documents) {
topN = len(bifrostReq.Documents)
}
bedrockReq.RerankingConfiguration.BedrockRerankingConfiguration.NumberOfResults = schemas.Ptr(topN)
}
additionalFields := make(map[string]interface{})
if bifrostReq.Params.MaxTokensPerDoc != nil {
additionalFields["max_tokens_per_doc"] = *bifrostReq.Params.MaxTokensPerDoc
}
if bifrostReq.Params.Priority != nil {
additionalFields["priority"] = *bifrostReq.Params.Priority
}
for k, v := range bifrostReq.Params.ExtraParams {
additionalFields[k] = v
}
if len(additionalFields) > 0 {
bedrockReq.RerankingConfiguration.BedrockRerankingConfiguration.ModelConfiguration.AdditionalModelRequestFields = additionalFields
}
return bedrockReq, nil
}
// ToBifrostRerankResponse converts a Bedrock rerank response into Bifrost format.
func (response *BedrockRerankResponse) ToBifrostRerankResponse(documents []schemas.RerankDocument, returnDocuments bool) *schemas.BifrostRerankResponse {
if response == nil {
return nil
}
bifrostResponse := &schemas.BifrostRerankResponse{
Results: make([]schemas.RerankResult, 0, len(response.Results)),
}
for _, result := range response.Results {
rerankResult := schemas.RerankResult{
Index: result.Index,
RelevanceScore: result.RelevanceScore,
}
if result.Document != nil && result.Document.TextDocument != nil {
rerankResult.Document = &schemas.RerankDocument{
Text: result.Document.TextDocument.Text,
}
}
bifrostResponse.Results = append(bifrostResponse.Results, rerankResult)
}
sort.SliceStable(bifrostResponse.Results, func(i, j int) bool {
if bifrostResponse.Results[i].RelevanceScore == bifrostResponse.Results[j].RelevanceScore {
return bifrostResponse.Results[i].Index < bifrostResponse.Results[j].Index
}
return bifrostResponse.Results[i].RelevanceScore > bifrostResponse.Results[j].RelevanceScore
})
if returnDocuments {
for i := range bifrostResponse.Results {
resultIndex := bifrostResponse.Results[i].Index
if resultIndex >= 0 && resultIndex < len(documents) {
bifrostResponse.Results[i].Document = schemas.Ptr(documents[resultIndex])
}
}
}
return bifrostResponse
}
// ToBifrostRerankRequest converts a Bedrock Agent Runtime rerank request to Bifrost format.
func (req *BedrockRerankRequest) ToBifrostRerankRequest(ctx *schemas.BifrostContext) *schemas.BifrostRerankRequest {
if req == nil {
return nil
}
modelARN := req.RerankingConfiguration.BedrockRerankingConfiguration.ModelConfiguration.ModelARN
provider, model := schemas.ParseModelString(modelARN, providerUtils.CheckAndSetDefaultProvider(ctx, schemas.Bedrock))
bifrostReq := &schemas.BifrostRerankRequest{
Provider: provider,
Model: model,
Params: &schemas.RerankParameters{},
}
// Extract query from the first query entry
if len(req.Queries) > 0 {
bifrostReq.Query = req.Queries[0].TextQuery.Text
}
// Convert sources to documents
for _, source := range req.Sources {
bifrostReq.Documents = append(bifrostReq.Documents, schemas.RerankDocument{
Text: source.InlineDocumentSource.TextDocument.Text,
})
}
// Extract TopN from NumberOfResults
if req.RerankingConfiguration.BedrockRerankingConfiguration.NumberOfResults != nil {
bifrostReq.Params.TopN = req.RerankingConfiguration.BedrockRerankingConfiguration.NumberOfResults
}
// Pass AdditionalModelRequestFields as ExtraParams
if fields := req.RerankingConfiguration.BedrockRerankingConfiguration.ModelConfiguration.AdditionalModelRequestFields; len(fields) > 0 {
bifrostReq.Params.ExtraParams = fields
}
return bifrostReq
}

View File

@@ -0,0 +1,230 @@
package bedrock
import (
"context"
"testing"
"github.com/maximhq/bifrost/core/schemas"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestToBedrockRerankRequest(t *testing.T) {
topN := 10
maxTokensPerDoc := 512
priority := 3
req, err := ToBedrockRerankRequest(&schemas.BifrostRerankRequest{
Model: "arn:aws:bedrock:us-east-1::foundation-model/cohere.rerank-v3-5:0",
Query: "capital of france",
Documents: []schemas.RerankDocument{
{Text: "Paris is the capital of France."},
{Text: "Berlin is the capital of Germany."},
},
Params: &schemas.RerankParameters{
TopN: schemas.Ptr(topN),
MaxTokensPerDoc: schemas.Ptr(maxTokensPerDoc),
Priority: schemas.Ptr(priority),
ExtraParams: map[string]interface{}{
"truncate": "END",
},
},
}, "arn:aws:bedrock:us-east-1::foundation-model/cohere.rerank-v3-5:0")
require.NoError(t, err)
require.NotNil(t, req)
require.Len(t, req.Queries, 1)
assert.Equal(t, "TEXT", req.Queries[0].Type)
assert.Equal(t, "capital of france", req.Queries[0].TextQuery.Text)
require.Len(t, req.Sources, 2)
require.NotNil(t, req.RerankingConfiguration.BedrockRerankingConfiguration.NumberOfResults)
assert.Equal(t, 2, *req.RerankingConfiguration.BedrockRerankingConfiguration.NumberOfResults, "top_n must be clamped to source count")
fields := req.RerankingConfiguration.BedrockRerankingConfiguration.ModelConfiguration.AdditionalModelRequestFields
require.NotNil(t, fields)
assert.Equal(t, maxTokensPerDoc, fields["max_tokens_per_doc"])
assert.Equal(t, priority, fields["priority"])
assert.Equal(t, "END", fields["truncate"])
}
func TestBedrockRerankResponseToBifrostRerankResponse(t *testing.T) {
response := (&BedrockRerankResponse{
Results: []BedrockRerankResult{
{
Index: 2,
RelevanceScore: 0.21,
Document: &BedrockRerankResponseDocument{
TextDocument: &BedrockRerankTextValue{Text: "doc-2"},
},
},
{
Index: 1,
RelevanceScore: 0.95,
Document: &BedrockRerankResponseDocument{
TextDocument: &BedrockRerankTextValue{Text: "doc-1"},
},
},
{
Index: 0,
RelevanceScore: 0.95,
Document: &BedrockRerankResponseDocument{
TextDocument: &BedrockRerankTextValue{Text: "doc-0"},
},
},
},
}).ToBifrostRerankResponse(nil, false)
require.NotNil(t, response)
require.Len(t, response.Results, 3)
assert.Equal(t, 0, response.Results[0].Index)
assert.Equal(t, 1, response.Results[1].Index)
assert.Equal(t, 2, response.Results[2].Index)
assert.Equal(t, "doc-0", response.Results[0].Document.Text)
assert.Equal(t, "doc-1", response.Results[1].Document.Text)
}
func TestBedrockRerankResponseToBifrostRerankResponseReturnDocuments(t *testing.T) {
requestDocs := []schemas.RerankDocument{
{Text: "request-doc-0"},
{Text: "request-doc-1"},
{Text: "request-doc-2"},
}
response := (&BedrockRerankResponse{
Results: []BedrockRerankResult{
{
Index: 2,
RelevanceScore: 0.21,
Document: &BedrockRerankResponseDocument{
TextDocument: &BedrockRerankTextValue{Text: "provider-doc-2"},
},
},
{
Index: 1,
RelevanceScore: 0.95,
Document: &BedrockRerankResponseDocument{
TextDocument: &BedrockRerankTextValue{Text: "provider-doc-1"},
},
},
{
Index: 0,
RelevanceScore: 0.95,
Document: &BedrockRerankResponseDocument{
TextDocument: &BedrockRerankTextValue{Text: "provider-doc-0"},
},
},
},
}).ToBifrostRerankResponse(requestDocs, true)
require.NotNil(t, response)
require.Len(t, response.Results, 3)
require.NotNil(t, response.Results[0].Document)
require.NotNil(t, response.Results[1].Document)
require.NotNil(t, response.Results[2].Document)
assert.Equal(t, 0, response.Results[0].Index)
assert.Equal(t, 1, response.Results[1].Index)
assert.Equal(t, 2, response.Results[2].Index)
assert.Equal(t, "request-doc-0", response.Results[0].Document.Text)
assert.Equal(t, "request-doc-1", response.Results[1].Document.Text)
assert.Equal(t, "request-doc-2", response.Results[2].Document.Text)
}
func TestBedrockRerankRequestToBifrostRerankRequest(t *testing.T) {
topN := 3
bedrockReq := &BedrockRerankRequest{
Queries: []BedrockRerankQuery{
{
Type: bedrockRerankQueryTypeText,
TextQuery: BedrockRerankTextRef{Text: "capital of france"},
},
},
Sources: []BedrockRerankSource{
{
Type: bedrockRerankSourceTypeInline,
InlineDocumentSource: BedrockRerankInlineSource{
Type: bedrockRerankInlineDocumentTypeText,
TextDocument: BedrockRerankTextValue{Text: "Paris is the capital of France."},
},
},
{
Type: bedrockRerankSourceTypeInline,
InlineDocumentSource: BedrockRerankInlineSource{
Type: bedrockRerankInlineDocumentTypeText,
TextDocument: BedrockRerankTextValue{Text: "Berlin is the capital of Germany."},
},
},
},
RerankingConfiguration: BedrockRerankingConfiguration{
Type: bedrockRerankConfigurationTypeBedrock,
BedrockRerankingConfiguration: BedrockRerankingModelConfiguration{
NumberOfResults: &topN,
ModelConfiguration: BedrockRerankModelConfiguration{
ModelARN: "arn:aws:bedrock:us-east-1::foundation-model/cohere.rerank-v3-5:0",
AdditionalModelRequestFields: map[string]interface{}{
"truncate": "END",
},
},
},
},
}
bifrostCtx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
result := bedrockReq.ToBifrostRerankRequest(bifrostCtx)
require.NotNil(t, result)
assert.Equal(t, schemas.Bedrock, result.Provider)
assert.Equal(t, "arn:aws:bedrock:us-east-1::foundation-model/cohere.rerank-v3-5:0", result.Model)
assert.Equal(t, "capital of france", result.Query)
require.Len(t, result.Documents, 2)
assert.Equal(t, "Paris is the capital of France.", result.Documents[0].Text)
assert.Equal(t, "Berlin is the capital of Germany.", result.Documents[1].Text)
require.NotNil(t, result.Params)
require.NotNil(t, result.Params.TopN)
assert.Equal(t, 3, *result.Params.TopN)
require.NotNil(t, result.Params.ExtraParams)
assert.Equal(t, "END", result.Params.ExtraParams["truncate"])
}
func TestBedrockRerankRequestToBifrostRerankRequestNil(t *testing.T) {
var req *BedrockRerankRequest
assert.Nil(t, req.ToBifrostRerankRequest(nil))
}
func TestResolveBedrockDeployment(t *testing.T) {
key := schemas.Key{
Aliases: schemas.KeyAliases{
"cohere-rerank": "arn:aws:bedrock:us-east-1::foundation-model/cohere.rerank-v3-5:0",
},
}
deployment := key.Aliases.Resolve("cohere-rerank")
assert.Equal(t, "arn:aws:bedrock:us-east-1::foundation-model/cohere.rerank-v3-5:0", deployment)
assert.Equal(t, "cohere.rerank-v3-5:0", key.Aliases.Resolve("cohere.rerank-v3-5:0"))
assert.Equal(t, "", key.Aliases.Resolve(""))
}
func TestBedrockRerankRequiresARNModelIdentifier(t *testing.T) {
provider := &BedrockProvider{}
ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
key := schemas.Key{
Aliases: schemas.KeyAliases{
"cohere-rerank": "cohere.rerank-v3-5:0",
},
}
response, bifrostErr := provider.Rerank(ctx, key, &schemas.BifrostRerankRequest{
Model: "cohere-rerank",
Query: "capital of france",
Documents: []schemas.RerankDocument{
{Text: "Paris is the capital of France."},
},
})
require.Nil(t, response)
require.NotNil(t, bifrostErr)
require.NotNil(t, bifrostErr.Error)
assert.Contains(t, bifrostErr.Error.Message, "requires an ARN")
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,130 @@
package bedrock
import (
"bytes"
"context"
"fmt"
"time"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/service/s3"
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
"github.com/maximhq/bifrost/core/schemas"
)
// uploadToS3 uploads content to an S3 bucket using the provided credentials.
func uploadToS3(
ctx context.Context,
accessKey, secretKey string,
sessionToken *string,
region string,
bucket, key string,
content []byte,
) *schemas.BifrostError {
// Create AWS config with credentials
var cfg aws.Config
var err error
if accessKey != "" && secretKey != "" {
// Use provided credentials
var creds aws.CredentialsProvider
if sessionToken != nil && *sessionToken != "" {
creds = credentials.NewStaticCredentialsProvider(accessKey, secretKey, *sessionToken)
} else {
creds = credentials.NewStaticCredentialsProvider(accessKey, secretKey, "")
}
cfg, err = config.LoadDefaultConfig(ctx,
config.WithRegion(region),
config.WithCredentialsProvider(creds),
)
} else {
// Use default credentials chain (IAM role, env vars, etc.)
cfg, err = config.LoadDefaultConfig(ctx, config.WithRegion(region))
}
if err != nil {
return providerUtils.NewBifrostOperationError("failed to load aws config for s3", err)
}
// Create S3 client
client := s3.NewFromConfig(cfg)
// Upload the content
_, err = client.PutObject(ctx, &s3.PutObjectInput{
Bucket: aws.String(bucket),
Key: aws.String(key),
Body: bytes.NewReader(content),
ContentType: aws.String("application/jsonl"),
})
if err != nil {
return providerUtils.NewBifrostOperationError(fmt.Sprintf("failed to upload to s3: %s/%s", bucket, key), err)
}
return nil
}
// generateBatchInputS3Key generates a unique S3 key for batch input files.
func generateBatchInputS3Key(jobName string) string {
timestamp := time.Now().UnixNano()
return fmt.Sprintf("bifrost-batch-input/%s-%d.jsonl", jobName, timestamp)
}
// deriveInputS3URIFromOutput derives an input S3 URI from the output S3 URI.
// It uses the same bucket but with a different path for input files.
func deriveInputS3URIFromOutput(outputS3URI, inputKey string) string {
bucket, _ := parseS3URI(outputS3URI)
return fmt.Sprintf("s3://%s/%s", bucket, inputKey)
}
// ConvertBedrockRequestsToJSONL converts batch request items to JSONL format for Bedrock.
// Bedrock uses a specific format for batch inference requests.
func ConvertBedrockRequestsToJSONL(requests []schemas.BatchRequestItem, modelID *string) ([]byte, error) {
// Model ID is required for Bedrock batch JSONL conversion
if modelID == nil || *modelID == "" {
return nil, fmt.Errorf("modelID is required for Bedrock batch JSONL conversion")
}
// Initialize the buffer
var buf bytes.Buffer
// Iterate over the requests
for _, req := range requests {
// Build the Bedrock batch request format
bedrockReq := map[string]interface{}{
"recordId": req.CustomID,
"modelInput": map[string]interface{}{
"modelId": *modelID,
},
}
// If the request has a body, use it as the model input parameters
if req.Body != nil {
modelInput := bedrockReq["modelInput"].(map[string]interface{})
for k, v := range req.Body {
if k != "model" { // Don't override modelId
modelInput[k] = v
}
}
} else if req.Params != nil {
modelInput := bedrockReq["modelInput"].(map[string]interface{})
for k, v := range req.Params {
if k != "model" {
modelInput[k] = v
}
}
}
// Marshal the request as a JSON line
line, err := providerUtils.MarshalSorted(bedrockReq)
if err != nil {
return nil, fmt.Errorf("failed to marshal batch request item %s: %w", req.CustomID, err)
}
buf.Write(line)
buf.WriteByte('\n')
}
return buf.Bytes(), nil
}

View File

@@ -0,0 +1,433 @@
package bedrock
import (
"context"
"crypto/hmac"
"crypto/sha256"
"encoding/hex"
"fmt"
"sort"
"strconv"
"strings"
"sync"
"time"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/smithy-go/encoding/httpbinding"
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
schemas "github.com/maximhq/bifrost/core/schemas"
"github.com/valyala/fasthttp"
)
const (
signingAlgorithm = "AWS4-HMAC-SHA256"
amzDateKey = "X-Amz-Date"
amzSecurityToken = "X-Amz-Security-Token"
timeFormat = "20060102T150405Z"
shortTimeFormat = "20060102"
)
// Headers to ignore during signing
var ignoredHeaders = map[string]struct{}{
"authorization": {},
"user-agent": {},
"x-amzn-trace-id": {},
"expect": {},
"transfer-encoding": {},
}
// signingKeyCache caches derived signing keys to avoid recomputation
type signingKeyCache struct {
cache map[string]cachedKey
mu sync.RWMutex
}
type cachedKey struct {
key []byte
date string // YYYYMMDD format
accessKey string
}
var keyCache = &signingKeyCache{
cache: make(map[string]cachedKey),
}
// hmacSHA256 computes HMAC-SHA256
func hmacSHA256(key, data []byte) []byte {
h := hmac.New(sha256.New, key)
h.Write(data)
return h.Sum(nil)
}
// deriveSigningKey derives the AWS signing key
func deriveSigningKey(secret, dateStamp, region, service string) []byte {
kDate := hmacSHA256([]byte("AWS4"+secret), []byte(dateStamp))
kRegion := hmacSHA256(kDate, []byte(region))
kService := hmacSHA256(kRegion, []byte(service))
kSigning := hmacSHA256(kService, []byte("aws4_request"))
return kSigning
}
// getSigningKey retrieves or computes the signing key with caching
func getSigningKey(accessKey, secretKey, dateStamp, region, service string) []byte {
cacheKey := fmt.Sprintf("%s/%s/%s/%s", accessKey, dateStamp, region, service)
keyCache.mu.RLock()
if cached, ok := keyCache.cache[cacheKey]; ok && cached.accessKey == accessKey && cached.date == dateStamp {
keyCache.mu.RUnlock()
return cached.key
}
keyCache.mu.RUnlock()
keyCache.mu.Lock()
defer keyCache.mu.Unlock()
// Double-check after acquiring write lock
if cached, ok := keyCache.cache[cacheKey]; ok && cached.accessKey == accessKey && cached.date == dateStamp {
return cached.key
}
key := deriveSigningKey(secretKey, dateStamp, region, service)
keyCache.cache[cacheKey] = cachedKey{
key: key,
date: dateStamp,
accessKey: accessKey,
}
return key
}
// stripExcessSpaces removes excess spaces from a string
func stripExcessSpaces(str string) string {
str = strings.TrimSpace(str)
if !strings.Contains(str, " ") {
return str
}
var result strings.Builder
result.Grow(len(str))
prevWasSpace := false
for _, ch := range str {
if ch == ' ' {
if !prevWasSpace {
result.WriteRune(ch)
}
prevWasSpace = true
} else {
result.WriteRune(ch)
prevWasSpace = false
}
}
return result.String()
}
// percentEncodeRFC3986 encodes a string per RFC 3986
// Keep unreserved characters (A-Z, a-z, 0-9, -, _, ., ~) as-is
// Percent-encode everything else as %HH using uppercase hex
func percentEncodeRFC3986(s string) string {
var result strings.Builder
result.Grow(len(s))
for i := 0; i < len(s); i++ {
b := s[i]
// RFC 3986 unreserved characters
if (b >= 'A' && b <= 'Z') ||
(b >= 'a' && b <= 'z') ||
(b >= '0' && b <= '9') ||
b == '-' || b == '_' || b == '.' || b == '~' {
result.WriteByte(b)
} else {
// Percent-encode with uppercase hex
result.WriteByte('%')
result.WriteByte(uppercaseHex(b >> 4))
result.WriteByte(uppercaseHex(b & 0x0F))
}
}
return result.String()
}
// uppercaseHex returns the uppercase hex character for a nibble (0-15)
func uppercaseHex(b byte) byte {
if b < 10 {
return '0' + b
}
return 'A' + (b - 10)
}
// percentDecode decodes percent-encoded sequences in a string without treating + as space
// This differs from url.QueryUnescape which uses form encoding (+ becomes space)
func percentDecode(s string) string {
// Quick check if there are any percent signs
if !strings.Contains(s, "%") {
return s
}
var result strings.Builder
result.Grow(len(s))
for i := 0; i < len(s); {
if s[i] == '%' && i+2 < len(s) {
// Try to decode the hex sequence
if h1 := unhex(s[i+1]); h1 >= 0 {
if h2 := unhex(s[i+2]); h2 >= 0 {
result.WriteByte(byte(h1<<4 | h2))
i += 3
continue
}
}
}
result.WriteByte(s[i])
i++
}
return result.String()
}
// unhex converts a hex character to its value, or -1 if not a hex char
func unhex(c byte) int {
switch {
case '0' <= c && c <= '9':
return int(c - '0')
case 'a' <= c && c <= 'f':
return int(c - 'a' + 10)
case 'A' <= c && c <= 'F':
return int(c - 'A' + 10)
}
return -1
}
// queryPair represents a query parameter name-value pair
type queryPair struct {
encodedName string
encodedValue string
}
// buildCanonicalQueryString builds a canonical query string per AWS SigV4 spec
// using proper RFC 3986 percent-encoding
func buildCanonicalQueryString(queryString string) string {
if queryString == "" {
return ""
}
// Split the raw query string on '&' into pairs
rawPairs := strings.Split(queryString, "&")
pairs := make([]queryPair, 0, len(rawPairs))
for _, rawPair := range rawPairs {
if rawPair == "" {
continue
}
// Split on the first '=' to get name and value
var name, value string
if idx := strings.IndexByte(rawPair, '='); idx >= 0 {
name = rawPair[:idx]
value = rawPair[idx+1:]
} else {
// No '=' means name only, empty value
name = rawPair
value = ""
}
// Decode percent-encoded sequences first to normalize (handles already-encoded values)
// then encode per RFC 3986 to ensure consistent encoding
// Note: We use percentDecode instead of url.QueryUnescape because the latter
// treats + as space (form encoding), but we need + to encode as %2B
decodedName := percentDecode(name)
decodedValue := percentDecode(value)
// Percent-encode name and value per RFC 3986
encodedName := percentEncodeRFC3986(decodedName)
encodedValue := percentEncodeRFC3986(decodedValue)
pairs = append(pairs, queryPair{
encodedName: encodedName,
encodedValue: encodedValue,
})
}
// Sort pairs lexicographically by encoded name, then by encoded value
sort.Slice(pairs, func(i, j int) bool {
if pairs[i].encodedName != pairs[j].encodedName {
return pairs[i].encodedName < pairs[j].encodedName
}
return pairs[i].encodedValue < pairs[j].encodedValue
})
// Join encoded pairs with '&'
var result strings.Builder
for i, pair := range pairs {
if i > 0 {
result.WriteByte('&')
}
result.WriteString(pair.encodedName)
result.WriteByte('=')
result.WriteString(pair.encodedValue)
}
return result.String()
}
// signAWSRequestFastHTTP signs a fasthttp request using AWS Signature Version 4
// This is a native implementation that avoids allocating http.Request
func signAWSRequestFastHTTP(
ctx context.Context,
req *fasthttp.Request,
body []byte,
accessKey, secretKey string,
sessionToken *string,
region, service string,
) *schemas.BifrostError {
// Get AWS credentials if not provided
if accessKey == "" && secretKey == "" {
cfg, err := config.LoadDefaultConfig(ctx, config.WithRegion(region))
if err != nil {
return providerUtils.NewBifrostOperationError("failed to load aws config", err)
}
creds, err := cfg.Credentials.Retrieve(ctx)
if err != nil {
return providerUtils.NewBifrostOperationError("failed to retrieve aws credentials", err)
}
accessKey = creds.AccessKeyID
secretKey = creds.SecretAccessKey
if creds.SessionToken != "" {
st := creds.SessionToken
sessionToken = &st
}
}
// Get current time
now := time.Now().UTC()
amzDate := now.Format(timeFormat)
dateStamp := now.Format(shortTimeFormat)
// Parse URI
uri := req.URI()
host := string(uri.Host())
path := string(uri.Path())
if path == "" {
path = "/"
}
queryString := string(uri.QueryString())
// Escape path for canonical URI (Bedrock doesn't disable escaping)
canonicalURI := httpbinding.EscapePath(path, false)
// Calculate payload hash
hash := sha256.Sum256(body)
payloadHash := hex.EncodeToString(hash[:])
// Set required headers
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
req.Header.Set(amzDateKey, amzDate)
if sessionToken != nil && *sessionToken != "" {
req.Header.Set(amzSecurityToken, *sessionToken)
}
// Build canonical headers
var headerNames []string
headerMap := make(map[string][]string)
// Always include host
headerNames = append(headerNames, "host")
headerMap["host"] = []string{host}
// Include content-length if body is present
if cl := req.Header.ContentLength(); cl >= 0 {
headerNames = append(headerNames, "content-length")
headerMap["content-length"] = []string{strconv.Itoa(cl)}
}
// Collect other headers
for key, value := range req.Header.All() {
keyStr := strings.ToLower(string(key))
// Skip ignored headers
if _, ignore := ignoredHeaders[keyStr]; ignore {
continue
}
// Skip if already handled
if keyStr == "host" || keyStr == "content-length" {
continue
}
if _, exists := headerMap[keyStr]; !exists {
headerNames = append(headerNames, keyStr)
}
headerMap[keyStr] = append(headerMap[keyStr], string(value))
}
// Sort header names
sort.Strings(headerNames)
// Build canonical headers string
var canonicalHeaders strings.Builder
for _, name := range headerNames {
canonicalHeaders.WriteString(name)
canonicalHeaders.WriteRune(':')
values := headerMap[name]
for i, v := range values {
cleanedValue := stripExcessSpaces(v)
canonicalHeaders.WriteString(cleanedValue)
if i < len(values)-1 {
canonicalHeaders.WriteRune(',')
}
}
canonicalHeaders.WriteRune('\n')
}
signedHeaders := strings.Join(headerNames, ";")
// Build canonical query string using RFC 3986 encoding
canonicalQueryString := buildCanonicalQueryString(queryString)
// Build canonical request
canonicalRequest := strings.Join([]string{
string(req.Header.Method()),
canonicalURI,
canonicalQueryString,
canonicalHeaders.String(),
signedHeaders,
payloadHash,
}, "\n")
// Build credential scope
credentialScope := strings.Join([]string{
dateStamp,
region,
service,
"aws4_request",
}, "/")
// Build string to sign
canonicalRequestHash := sha256.Sum256([]byte(canonicalRequest))
stringToSign := strings.Join([]string{
signingAlgorithm,
amzDate,
credentialScope,
hex.EncodeToString(canonicalRequestHash[:]),
}, "\n")
// Calculate signature
signingKey := getSigningKey(accessKey, secretKey, dateStamp, region, service)
signature := hex.EncodeToString(hmacSHA256(signingKey, []byte(stringToSign)))
// Build authorization header
authHeader := fmt.Sprintf("%s Credential=%s/%s, SignedHeaders=%s, Signature=%s",
signingAlgorithm,
accessKey,
credentialScope,
signedHeaders,
signature,
)
req.Header.Set("Authorization", authHeader)
return nil
}

View File

@@ -0,0 +1,229 @@
package bedrock
import (
"strings"
"github.com/maximhq/bifrost/core/providers/anthropic"
"github.com/maximhq/bifrost/core/providers/utils"
"github.com/maximhq/bifrost/core/schemas"
)
// ToBedrockTextCompletionRequest converts a Bifrost text completion request to Bedrock format
func ToBedrockTextCompletionRequest(bifrostReq *schemas.BifrostTextCompletionRequest) *BedrockTextCompletionRequest {
if bifrostReq == nil || (bifrostReq.Input.PromptStr == nil && len(bifrostReq.Input.PromptArray) == 0) {
return nil
}
// Extract the raw prompt from bifrostReq
prompt := ""
if bifrostReq.Input != nil {
if bifrostReq.Input.PromptStr != nil {
prompt = *bifrostReq.Input.PromptStr
} else if len(bifrostReq.Input.PromptArray) > 0 && bifrostReq.Input.PromptArray != nil {
prompt = strings.Join(bifrostReq.Input.PromptArray, "\n\n")
}
}
bedrockReq := &BedrockTextCompletionRequest{
Prompt: prompt,
}
// Apply parameters
if bifrostReq.Params != nil {
bedrockReq.Temperature = bifrostReq.Params.Temperature
bedrockReq.TopP = bifrostReq.Params.TopP
if bifrostReq.Params.ExtraParams != nil {
bedrockReq.ExtraParams = bifrostReq.Params.ExtraParams
if topK, ok := schemas.SafeExtractIntPointer(bifrostReq.Params.ExtraParams["top_k"]); ok {
delete(bedrockReq.ExtraParams, "top_k")
bedrockReq.TopK = topK
}
}
}
// Apply model-specific formatting and field naming
if strings.Contains(bifrostReq.Model, "anthropic.") || strings.Contains(bifrostReq.Model, "claude") {
// For Claude models, wrap the prompt in Anthropic format and use Anthropic field names
anthropicReq := anthropic.ToAnthropicTextCompletionRequest(bifrostReq)
bedrockReq.Prompt = anthropicReq.Prompt
bedrockReq.MaxTokensToSample = &anthropicReq.MaxTokensToSample
bedrockReq.StopSequences = anthropicReq.StopSequences
} else {
// For other models, use standard field names with raw prompt
if bifrostReq.Params != nil {
bedrockReq.MaxTokens = bifrostReq.Params.MaxTokens
bedrockReq.Stop = bifrostReq.Params.Stop
}
}
return bedrockReq
}
// ToBifrostTextCompletionRequest converts a Bedrock text completion request to Bifrost format
func (request *BedrockTextCompletionRequest) ToBifrostTextCompletionRequest(ctx *schemas.BifrostContext) *schemas.BifrostTextCompletionRequest {
if request == nil {
return nil
}
prompt := request.Prompt
// Fallback for Claude 3 Messages API
if prompt == "" && len(request.Messages) > 0 {
var parts []string
for _, msg := range request.Messages {
for _, content := range msg.Content {
if content.Text != nil {
parts = append(parts, *content.Text)
}
}
}
prompt = strings.Join(parts, "\n\n")
}
provider, model := schemas.ParseModelString(request.ModelID, utils.CheckAndSetDefaultProvider(ctx, schemas.Bedrock))
bifrostReq := &schemas.BifrostTextCompletionRequest{
Provider: provider,
Model: model,
Input: &schemas.TextCompletionInput{
PromptStr: &prompt,
},
Params: &schemas.TextCompletionParameters{
Temperature: request.Temperature,
TopP: request.TopP,
},
}
if request.MaxTokens != nil {
bifrostReq.Params.MaxTokens = request.MaxTokens
} else if request.MaxTokensToSample != nil {
bifrostReq.Params.MaxTokens = request.MaxTokensToSample
}
if len(request.Stop) > 0 {
bifrostReq.Params.Stop = request.Stop
} else if len(request.StopSequences) > 0 {
bifrostReq.Params.Stop = request.StopSequences
}
return bifrostReq
}
// ToBifrostTextCompletionResponse converts a Bedrock Anthropic text response to Bifrost format
func (response *BedrockAnthropicTextResponse) ToBifrostTextCompletionResponse() *schemas.BifrostTextCompletionResponse {
if response == nil {
return nil
}
return &schemas.BifrostTextCompletionResponse{
Object: "text_completion",
Choices: []schemas.BifrostResponseChoice{
{
Index: 0,
TextCompletionResponseChoice: &schemas.TextCompletionResponseChoice{
Text: &response.Completion,
},
FinishReason: &response.StopReason,
},
},
ExtraFields: schemas.BifrostResponseExtraFields{
},
}
}
// ToBifrostTextCompletionResponse converts a Bedrock Mistral text response to Bifrost format
func (response *BedrockMistralTextResponse) ToBifrostTextCompletionResponse() *schemas.BifrostTextCompletionResponse {
if response == nil {
return nil
}
var choices []schemas.BifrostResponseChoice
for i, output := range response.Outputs {
choices = append(choices, schemas.BifrostResponseChoice{
Index: i,
TextCompletionResponseChoice: &schemas.TextCompletionResponseChoice{
Text: &output.Text,
},
FinishReason: &output.StopReason,
})
}
return &schemas.BifrostTextCompletionResponse{
Object: "text_completion",
Choices: choices,
ExtraFields: schemas.BifrostResponseExtraFields{
},
}
}
// ToBedrockTextCompletionResponse converts a BifrostTextCompletionResponse back to Bedrock text completion format
// Returns either *BedrockAnthropicTextResponse or *BedrockMistralTextResponse based on the model
func ToBedrockTextCompletionResponse(bifrostResp *schemas.BifrostTextCompletionResponse) interface{} {
if bifrostResp == nil {
return nil
}
// Determine response format based on resolved model identity.
// Use ResolvedModelUsed (actual provider ID) for accurate family detection,
// falling back to bifrostResp.Model, then OriginalModelRequested as a last resort.
model := bifrostResp.Model
if bifrostResp.ExtraFields.ResolvedModelUsed != "" {
model = bifrostResp.ExtraFields.ResolvedModelUsed
} else if model == "" && bifrostResp.ExtraFields.OriginalModelRequested != "" {
model = bifrostResp.ExtraFields.OriginalModelRequested
}
if strings.Contains(model, "anthropic.") || strings.Contains(model, "claude") {
// Convert to Anthropic format
bedrockResp := &BedrockAnthropicTextResponse{}
// Convert choices to completion text
if len(bifrostResp.Choices) > 0 {
choice := bifrostResp.Choices[0] // Anthropic text API typically returns one choice
if choice.TextCompletionResponseChoice != nil && choice.TextCompletionResponseChoice.Text != nil {
bedrockResp.Completion = *choice.TextCompletionResponseChoice.Text
}
if choice.FinishReason != nil {
bedrockResp.StopReason = *choice.FinishReason
}
}
return bedrockResp
} else if strings.Contains(model, "mistral.") {
// Convert to Mistral format
bedrockResp := &BedrockMistralTextResponse{}
// Convert choices to outputs
for _, choice := range bifrostResp.Choices {
var output struct {
Text string `json:"text"`
StopReason string `json:"stop_reason"`
}
if choice.TextCompletionResponseChoice != nil && choice.TextCompletionResponseChoice.Text != nil {
output.Text = *choice.TextCompletionResponseChoice.Text
}
if choice.FinishReason != nil {
output.StopReason = *choice.FinishReason
}
bedrockResp.Outputs = append(bedrockResp.Outputs, output)
}
return bedrockResp
}
// Default to Anthropic format if model type cannot be determined
bedrockResp := &BedrockAnthropicTextResponse{}
if len(bifrostResp.Choices) > 0 {
choice := bifrostResp.Choices[0]
if choice.TextCompletionResponseChoice != nil && choice.TextCompletionResponseChoice.Text != nil {
bedrockResp.Completion = *choice.TextCompletionResponseChoice.Text
}
if choice.FinishReason != nil {
bedrockResp.StopReason = *choice.FinishReason
}
}
return bedrockResp
}

View File

@@ -0,0 +1,714 @@
package bedrock
import (
"context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/json"
"encoding/pem"
"io"
"math/big"
"net"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"time"
"github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream"
"github.com/maximhq/bifrost/core/schemas"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// redirectTransport is an http.RoundTripper that rewrites every request's
// host/scheme to a fixed target URL, used to redirect provider requests to a
// local httptest.Server without modifying provider code.
type redirectTransport struct {
target *url.URL
transport http.RoundTripper
}
func (r *redirectTransport) RoundTrip(req *http.Request) (*http.Response, error) {
cloned := req.Clone(req.Context())
cloned.URL.Scheme = r.target.Scheme
cloned.URL.Host = r.target.Host
cloned.Host = r.target.Host
return r.transport.RoundTrip(cloned)
}
// noopLogger is a no-op schemas.Logger for use in tests.
type noopLogger struct{}
func (noopLogger) Debug(string, ...any) {}
func (noopLogger) Info(string, ...any) {}
func (noopLogger) Warn(string, ...any) {}
func (noopLogger) Error(string, ...any) {}
func (noopLogger) Fatal(string, ...any) {}
func (noopLogger) SetLevel(schemas.LogLevel) {}
func (noopLogger) SetOutputType(schemas.LoggerOutputType) {}
func (noopLogger) LogHTTPRequest(schemas.LogLevel, string) schemas.LogEventBuilder {
return schemas.NoopLogEvent
}
// newTestProviderWithServer returns a BedrockProvider whose HTTP client is
// redirected to the given httptest.Server.
func newTestProviderWithServer(t *testing.T, ts *httptest.Server) *BedrockProvider {
t.Helper()
config := &schemas.ProviderConfig{
NetworkConfig: schemas.NetworkConfig{
DefaultRequestTimeoutInSeconds: 5,
},
}
config.CheckAndSetDefaults()
provider, err := NewBedrockProvider(config, noopLogger{})
require.NoError(t, err)
targetURL, err := url.Parse(ts.URL)
require.NoError(t, err)
redirect := &redirectTransport{
target: targetURL,
transport: ts.Client().Transport,
}
provider.client = &http.Client{
Transport: redirect,
Timeout: 5 * time.Second,
}
// Streaming paths use streamingClient (no Timeout); redirect it to the
// test server too, otherwise Bedrock streaming tests would hit the real
// AWS endpoint.
provider.streamingClient = &http.Client{Transport: redirect}
return provider
}
// testBedrockKey returns a minimal Key with a bearer value so makeStreamingRequest
// skips IAM signing and proceeds to the HTTP call.
func testBedrockKey() schemas.Key {
region := schemas.NewEnvVar("us-east-1")
return schemas.Key{
Value: *schemas.NewEnvVar("test-api-key"),
BedrockKeyConfig: &schemas.BedrockKeyConfig{
Region: region,
},
}
}
// testBedrockCtx returns a BifrostContext suitable for unit tests.
func testBedrockCtx() *schemas.BifrostContext {
return schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
}
// noopPostHookRunner is a PostHookRunner that passes through results unchanged.
func noopPostHookRunner(_ *schemas.BifrostContext, result *schemas.BifrostResponse, err *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError) {
return result, err
}
// testChatRequest returns a minimal BifrostChatRequest for streaming tests.
func testChatRequest() *schemas.BifrostChatRequest {
content := "hello"
return &schemas.BifrostChatRequest{
Model: "anthropic.claude-sonnet-4-5",
Input: []schemas.ChatMessage{
{
Role: schemas.ChatMessageRoleUser,
Content: &schemas.ChatMessageContent{ContentStr: &content},
},
},
}
}
// TestMakeStreamingRequest_StaleConnection_IsRetryable verifies that when the
// HTTP server closes the connection before sending a response (simulating a
// stale HTTP/2 connection), makeStreamingRequest returns a BifrostError with
// IsBifrostError:false so the retry gate in executeRequestWithRetries retries.
func TestMakeStreamingRequest_StaleConnection_IsRetryable(t *testing.T) {
// Server that immediately closes the connection without sending anything.
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
hj, ok := w.(http.Hijacker)
if !ok {
http.Error(w, "hijack not supported", http.StatusInternalServerError)
return
}
conn, _, _ := hj.Hijack()
conn.Close() // close without writing any response
}))
defer ts.Close()
provider := newTestProviderWithServer(t, ts)
ctx := testBedrockCtx()
key := testBedrockKey()
_, bifrostErr := provider.makeStreamingRequest(ctx, []byte(`{}`), key, "anthropic.claude-sonnet-4-5", "converse-stream")
require.NotNil(t, bifrostErr, "expected error when server closes connection")
assert.False(t, bifrostErr.IsBifrostError,
"stale-connection error must be IsBifrostError:false so the retry gate can retry it")
require.NotNil(t, bifrostErr.Error)
// Either ErrProviderNetworkError (net.OpError) or ErrProviderDoRequest (EOF/connection-reset)
// are both retryable — the key invariant is IsBifrostError:false.
assert.Contains(t, []string{schemas.ErrProviderNetworkError, schemas.ErrProviderDoRequest}, bifrostErr.Error.Message,
"stale-connection error must use a retryable error message")
}
// TestChatCompletionStream_StaleConnection_ChunkIsRetryable verifies that when
// the server returns HTTP 200 but closes the body immediately (simulating a
// stale connection mid-stream before any EventStream data arrives), the first
// chunk received from the stream channel carries a BifrostError with
// IsBifrostError:false so that CheckFirstStreamChunkForError + the retry gate
// can transparently retry the request.
func TestChatCompletionStream_StaleConnection_ChunkIsRetryable(t *testing.T) {
// Server: returns 200 with the correct content-type but closes body immediately.
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/vnd.amazon.eventstream")
w.WriteHeader(http.StatusOK)
if f, ok := w.(http.Flusher); ok {
f.Flush()
}
hj, ok := w.(http.Hijacker)
if !ok {
return
}
conn, _, _ := hj.Hijack()
conn.Close() // close without any EventStream bytes
}))
defer ts.Close()
provider := newTestProviderWithServer(t, ts)
ctx := testBedrockCtx()
key := testBedrockKey()
streamChan, bifrostErr := provider.ChatCompletionStream(ctx, noopPostHookRunner, nil, key, testChatRequest())
if bifrostErr != nil {
// Error surfaced synchronously (e.g. connection refused before HTTP 200).
assert.False(t, bifrostErr.IsBifrostError,
"pre-stream network error must be IsBifrostError:false")
return
}
// Error surfaced as the first stream chunk.
require.NotNil(t, streamChan)
chunk, ok := <-streamChan
require.True(t, ok, "channel must not be empty")
require.NotNil(t, chunk)
require.NotNil(t, chunk.BifrostError, "expected an error chunk from the stream")
assert.False(t, chunk.BifrostError.IsBifrostError,
"stream transport error must be IsBifrostError:false so the retry gate can retry it")
require.NotNil(t, chunk.BifrostError.Error)
assert.Equal(t, schemas.ErrProviderNetworkError, chunk.BifrostError.Error.Message,
"stream transport error must use ErrProviderNetworkError message")
// Drain any remaining chunks.
for range streamChan {
}
}
// TestChatCompletionStream_NetOpError_ChunkIsRetryable verifies the specific
// "use of closed network connection" *net.OpError scenario from issue #2424:
// a successful HTTP connection that is then closed server-side produces a
// *net.OpError during EventStream decoding, which must arrive as a retryable
// IsBifrostError:false chunk.
func TestChatCompletionStream_NetOpError_ChunkIsRetryable(t *testing.T) {
// Server: returns 200 + correct headers, writes a truncated EventStream
// prelude (not a valid frame), then forcibly resets the TCP connection —
// producing a *net.OpError on the client's read side.
ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/vnd.amazon.eventstream")
w.WriteHeader(http.StatusOK)
if f, ok := w.(http.Flusher); ok {
f.Flush()
}
// Write a partial EventStream frame header (3 bytes, not a valid frame).
_, _ = w.Write([]byte{0x00, 0x00, 0x00})
if f, ok := w.(http.Flusher); ok {
f.Flush()
}
hj, ok := w.(http.Hijacker)
if !ok {
return
}
conn, _, _ := hj.Hijack()
// RST instead of FIN — guarantees a *net.OpError on the client read.
if tc, ok := conn.(*net.TCPConn); ok {
_ = tc.SetLinger(0)
}
conn.Close()
}))
ts.Start()
defer ts.Close()
provider := newTestProviderWithServer(t, ts)
ctx := testBedrockCtx()
key := testBedrockKey()
streamChan, bifrostErr := provider.ChatCompletionStream(ctx, noopPostHookRunner, nil, key, testChatRequest())
if bifrostErr != nil {
assert.False(t, bifrostErr.IsBifrostError,
"pre-stream network error must be IsBifrostError:false")
return
}
require.NotNil(t, streamChan)
// Collect chunks until we find an error chunk (may not be the very first
// if the OS buffers the partial write, but it must appear before close).
var errChunk *schemas.BifrostStreamChunk
for chunk := range streamChan {
if chunk != nil && chunk.BifrostError != nil {
errChunk = chunk
break
}
}
// Drain remaining.
for range streamChan {
}
require.NotNil(t, errChunk, "expected an error chunk from the stream")
assert.False(t, errChunk.BifrostError.IsBifrostError,
"net.OpError during EventStream decoding must be IsBifrostError:false so the retry gate can retry it")
require.NotNil(t, errChunk.BifrostError.Error)
assert.Equal(t, schemas.ErrProviderNetworkError, errChunk.BifrostError.Error.Message,
"net.OpError during EventStream decoding must use ErrProviderNetworkError message")
}
// writeEventStreamException encodes a well-formed AWS EventStream exception
// frame with the given exception type and message into w.
// The frame format is: prelude (total_len + headers_len + CRC) + headers + payload + message_CRC.
// We use the AWS SDK's eventstream.Encoder so the binary framing is correct.
func writeEventStreamException(t *testing.T, w io.Writer, excType, msg string) {
t.Helper()
enc := eventstream.NewEncoder()
payload, err := json.Marshal(map[string]string{"message": msg})
require.NoError(t, err, "failed to marshal exception payload")
headers := eventstream.Headers{
{Name: ":message-type", Value: eventstream.StringValue("exception")},
{Name: ":exception-type", Value: eventstream.StringValue(excType)},
{Name: ":content-type", Value: eventstream.StringValue("application/json")},
}
err = enc.Encode(w, eventstream.Message{Headers: headers, Payload: payload})
require.NoError(t, err, "failed to encode EventStream exception frame")
}
// TestChatCompletionStream_RetryableException_ChunkIsRetryable verifies that
// when AWS Bedrock sends a retryable exception (serviceUnavailableException,
// throttlingException, etc.) through the EventStream, the resulting error chunk
// has IsBifrostError:false and the correct HTTP StatusCode so that the retry
// gate in executeRequestWithRetries can retry the request.
func TestChatCompletionStream_RetryableException_ChunkIsRetryable(t *testing.T) {
tests := []struct {
excType string
expectedStatus int
}{
{"serviceUnavailableException", 503},
{"throttlingException", 429},
{"modelNotReadyException", 503},
{"internalServerException", 500},
}
for _, tc := range tests {
tc := tc
t.Run(tc.excType, func(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/vnd.amazon.eventstream")
w.WriteHeader(http.StatusOK)
if f, ok := w.(http.Flusher); ok {
f.Flush()
}
writeEventStreamException(t, w, tc.excType, "service is unavailable, please retry")
if f, ok := w.(http.Flusher); ok {
f.Flush()
}
}))
defer ts.Close()
provider := newTestProviderWithServer(t, ts)
ctx := testBedrockCtx()
key := testBedrockKey()
streamChan, bifrostErr := provider.ChatCompletionStream(ctx, noopPostHookRunner, nil, key, testChatRequest())
require.Nil(t, bifrostErr, "expected EventStream exception to surface as a stream chunk")
require.NotNil(t, streamChan)
var errChunk *schemas.BifrostStreamChunk
for chunk := range streamChan {
if chunk != nil && chunk.BifrostError != nil {
errChunk = chunk
break
}
}
for range streamChan {
}
require.NotNil(t, errChunk, "expected error chunk for %s", tc.excType)
assert.False(t, errChunk.BifrostError.IsBifrostError,
"%s must be IsBifrostError:false so retry gate can retry it", tc.excType)
require.NotNil(t, errChunk.BifrostError.StatusCode,
"%s must carry a StatusCode for the retryableStatusCodes gate", tc.excType)
assert.Equal(t, tc.expectedStatus, *errChunk.BifrostError.StatusCode,
"%s must map to HTTP %d", tc.excType, tc.expectedStatus)
})
}
}
// TestChatCompletionStream_NonRetryableException_IsTerminal verifies that
// non-retryable exception types (e.g. validationException, accessDeniedException)
// continue to use ProcessAndSendError (IsBifrostError:true) and are NOT retried.
func TestChatCompletionStream_NonRetryableException_IsTerminal(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/vnd.amazon.eventstream")
w.WriteHeader(http.StatusOK)
if f, ok := w.(http.Flusher); ok {
f.Flush()
}
writeEventStreamException(t, w, "validationException", "input validation failed")
if f, ok := w.(http.Flusher); ok {
f.Flush()
}
}))
defer ts.Close()
provider := newTestProviderWithServer(t, ts)
ctx := testBedrockCtx()
key := testBedrockKey()
streamChan, bifrostErr := provider.ChatCompletionStream(ctx, noopPostHookRunner, nil, key, testChatRequest())
require.Nil(t, bifrostErr, "expected EventStream exception to surface as a stream chunk")
require.NotNil(t, streamChan)
var errChunk *schemas.BifrostStreamChunk
for chunk := range streamChan {
if chunk != nil && chunk.BifrostError != nil {
errChunk = chunk
break
}
}
for range streamChan {
}
require.NotNil(t, errChunk, "expected error chunk for validationException")
assert.True(t, errChunk.BifrostError.IsBifrostError,
"non-retryable validationException must remain IsBifrostError:true")
}
// testTextCompletionRequest returns a minimal BifrostTextCompletionRequest for streaming tests.
func testTextCompletionRequest() *schemas.BifrostTextCompletionRequest {
prompt := "hello"
return &schemas.BifrostTextCompletionRequest{
Model: "anthropic.claude-sonnet-4-5",
Input: &schemas.TextCompletionInput{PromptStr: &prompt},
}
}
// testResponsesRequest returns a minimal BifrostResponsesRequest for streaming tests.
func testResponsesRequest() *schemas.BifrostResponsesRequest {
msgType := schemas.ResponsesMessageType("message")
roleUser := schemas.ResponsesMessageRoleType("user")
content := "hello"
return &schemas.BifrostResponsesRequest{
Model: "anthropic.claude-sonnet-4-5",
Input: []schemas.ResponsesMessage{
{
Type: &msgType,
Role: &roleUser,
Content: &schemas.ResponsesMessageContent{ContentStr: &content},
},
},
}
}
// assertRetryableExceptionChunk is the shared assertion helper for all three
// streaming-method retryable-exception tests.
func assertRetryableExceptionChunk(t *testing.T, streamChan chan *schemas.BifrostStreamChunk, bifrostErr *schemas.BifrostError, excType string, expectedStatus int) {
t.Helper()
require.Nil(t, bifrostErr, "expected EventStream exception to surface as a stream chunk, not a pre-stream error")
require.NotNil(t, streamChan)
var errChunk *schemas.BifrostStreamChunk
for chunk := range streamChan {
if chunk != nil && chunk.BifrostError != nil {
errChunk = chunk
break
}
}
for range streamChan {
}
require.NotNil(t, errChunk, "expected error chunk for %s", excType)
assert.False(t, errChunk.BifrostError.IsBifrostError,
"%s must be IsBifrostError:false so retry gate can retry it", excType)
require.NotNil(t, errChunk.BifrostError.StatusCode,
"%s must carry a StatusCode for the retryableStatusCodes gate", excType)
assert.Equal(t, expectedStatus, *errChunk.BifrostError.StatusCode,
"%s must map to HTTP %d", excType, expectedStatus)
}
// TestTextCompletionStream_RetryableException_ChunkIsRetryable mirrors the
// ChatCompletionStream test for the TextCompletionStream path, which has
// slightly different payload-parsing logic (extra BedrockError JSON unmarshal).
func TestTextCompletionStream_RetryableException_ChunkIsRetryable(t *testing.T) {
tests := []struct {
excType string
expectedStatus int
}{
{"serviceUnavailableException", 503},
{"throttlingException", 429},
{"modelNotReadyException", 503},
{"internalServerException", 500},
}
for _, tc := range tests {
tc := tc
t.Run(tc.excType, func(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/vnd.amazon.eventstream")
w.WriteHeader(http.StatusOK)
if f, ok := w.(http.Flusher); ok {
f.Flush()
}
writeEventStreamException(t, w, tc.excType, "service is unavailable, please retry")
if f, ok := w.(http.Flusher); ok {
f.Flush()
}
}))
defer ts.Close()
provider := newTestProviderWithServer(t, ts)
streamChan, bifrostErr := provider.TextCompletionStream(testBedrockCtx(), noopPostHookRunner, nil, testBedrockKey(), testTextCompletionRequest())
assertRetryableExceptionChunk(t, streamChan, bifrostErr, tc.excType, tc.expectedStatus)
})
}
}
// TestResponsesStream_RetryableException_ChunkIsRetryable mirrors the
// ChatCompletionStream test for the ResponsesStream path.
func TestResponsesStream_RetryableException_ChunkIsRetryable(t *testing.T) {
tests := []struct {
excType string
expectedStatus int
}{
{"serviceUnavailableException", 503},
{"throttlingException", 429},
{"modelNotReadyException", 503},
{"internalServerException", 500},
}
for _, tc := range tests {
tc := tc
t.Run(tc.excType, func(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/vnd.amazon.eventstream")
w.WriteHeader(http.StatusOK)
if f, ok := w.(http.Flusher); ok {
f.Flush()
}
writeEventStreamException(t, w, tc.excType, "service is unavailable, please retry")
if f, ok := w.(http.Flusher); ok {
f.Flush()
}
}))
defer ts.Close()
provider := newTestProviderWithServer(t, ts)
streamChan, bifrostErr := provider.ResponsesStream(testBedrockCtx(), noopPostHookRunner, nil, testBedrockKey(), testResponsesRequest())
assertRetryableExceptionChunk(t, streamChan, bifrostErr, tc.excType, tc.expectedStatus)
})
}
}
func generateTestCACert(t *testing.T) string {
t.Helper()
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
require.NoError(t, err)
template := &x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{CommonName: "testca"},
NotBefore: time.Now(),
NotAfter: time.Now().Add(10 * 365 * 24 * time.Hour),
IsCA: true,
BasicConstraintsValid: true,
}
certDER, err := x509.CreateCertificate(rand.Reader, template, template, &key.PublicKey, key)
require.NoError(t, err)
certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER})
return string(certPEM)
}
func TestBedrockTransportHTTP2Config(t *testing.T) {
config := &schemas.ProviderConfig{
NetworkConfig: schemas.NetworkConfig{
DefaultRequestTimeoutInSeconds: 30,
MaxConnsPerHost: 5000,
EnforceHTTP2: true,
},
}
config.CheckAndSetDefaults()
provider, err := NewBedrockProvider(config, nil)
require.NoError(t, err)
require.NotNil(t, provider)
transport, ok := provider.client.Transport.(*http.Transport)
require.True(t, ok, "transport should be *http.Transport")
assert.Equal(t, 5000, transport.MaxConnsPerHost)
assert.Equal(t, schemas.DefaultMaxIdleConnsPerHost, transport.MaxIdleConnsPerHost)
assert.Equal(t, schemas.DefaultMaxIdleConnsPerHost, transport.MaxIdleConns)
assert.True(t, transport.ForceAttemptHTTP2)
}
func TestBedrockTransportCustomMaxConns(t *testing.T) {
config := &schemas.ProviderConfig{
NetworkConfig: schemas.NetworkConfig{
DefaultRequestTimeoutInSeconds: 30,
MaxConnsPerHost: 50,
},
}
config.CheckAndSetDefaults()
provider, err := NewBedrockProvider(config, nil)
require.NoError(t, err)
transport, ok := provider.client.Transport.(*http.Transport)
require.True(t, ok)
assert.Equal(t, 50, transport.MaxConnsPerHost)
assert.Equal(t, schemas.DefaultMaxIdleConnsPerHost, transport.MaxIdleConnsPerHost)
assert.Equal(t, schemas.DefaultMaxIdleConnsPerHost, transport.MaxIdleConns)
}
func TestBedrockTransportDefaultMaxConns(t *testing.T) {
config := &schemas.ProviderConfig{
NetworkConfig: schemas.NetworkConfig{
DefaultRequestTimeoutInSeconds: 30,
// MaxConnsPerHost left as 0 — should default to 5000
},
}
config.CheckAndSetDefaults()
assert.Equal(t, schemas.DefaultMaxConnsPerHost, config.NetworkConfig.MaxConnsPerHost)
provider, err := NewBedrockProvider(config, nil)
require.NoError(t, err)
transport, ok := provider.client.Transport.(*http.Transport)
require.True(t, ok)
assert.Equal(t, schemas.DefaultMaxConnsPerHost, transport.MaxConnsPerHost)
assert.Equal(t, schemas.DefaultMaxIdleConnsPerHost, transport.MaxIdleConnsPerHost)
assert.Equal(t, schemas.DefaultMaxIdleConnsPerHost, transport.MaxIdleConns)
}
func TestBedrockTransportTLSInsecureSkipVerify(t *testing.T) {
config := &schemas.ProviderConfig{
NetworkConfig: schemas.NetworkConfig{
DefaultRequestTimeoutInSeconds: 30,
InsecureSkipVerify: true,
EnforceHTTP2: true,
},
}
config.CheckAndSetDefaults()
provider, err := NewBedrockProvider(config, nil)
require.NoError(t, err)
transport, ok := provider.client.Transport.(*http.Transport)
require.True(t, ok)
require.NotNil(t, transport.TLSClientConfig)
assert.True(t, transport.TLSClientConfig.InsecureSkipVerify)
assert.Equal(t, uint16(tls.VersionTLS12), transport.TLSClientConfig.MinVersion)
// ForceAttemptHTTP2 should still be true even with custom TLS config
assert.True(t, transport.ForceAttemptHTTP2)
}
func TestBedrockTransportTLSCACert(t *testing.T) {
testCACert := generateTestCACert(t)
config := &schemas.ProviderConfig{
NetworkConfig: schemas.NetworkConfig{
DefaultRequestTimeoutInSeconds: 30,
CACertPEM: testCACert,
EnforceHTTP2: true,
},
}
config.CheckAndSetDefaults()
provider, err := NewBedrockProvider(config, nil)
require.NoError(t, err)
transport, ok := provider.client.Transport.(*http.Transport)
require.True(t, ok)
require.NotNil(t, transport.TLSClientConfig)
assert.NotNil(t, transport.TLSClientConfig.RootCAs)
assert.Equal(t, uint16(tls.VersionTLS12), transport.TLSClientConfig.MinVersion)
assert.True(t, transport.ForceAttemptHTTP2)
}
func TestBedrockTransportDefaultTLS(t *testing.T) {
config := &schemas.ProviderConfig{
NetworkConfig: schemas.NetworkConfig{
DefaultRequestTimeoutInSeconds: 30,
// No TLS settings — should use system defaults
},
}
config.CheckAndSetDefaults()
provider, err := NewBedrockProvider(config, nil)
require.NoError(t, err)
transport, ok := provider.client.Transport.(*http.Transport)
require.True(t, ok)
// No custom TLS config should be set
assert.Nil(t, transport.TLSClientConfig)
// EnforceHTTP2 not set — ForceAttemptHTTP2 should be false
assert.False(t, transport.ForceAttemptHTTP2)
}
func TestBedrockTransportEnforceHTTP2(t *testing.T) {
config := &schemas.ProviderConfig{
NetworkConfig: schemas.NetworkConfig{
DefaultRequestTimeoutInSeconds: 30,
EnforceHTTP2: true,
},
}
config.CheckAndSetDefaults()
provider, err := NewBedrockProvider(config, nil)
require.NoError(t, err)
transport, ok := provider.client.Transport.(*http.Transport)
require.True(t, ok)
assert.True(t, transport.ForceAttemptHTTP2)
// TLSNextProto should NOT be set when HTTP/2 is enforced, allowing ALPN negotiation
assert.Nil(t, transport.TLSNextProto)
}
func TestBedrockTransportEnforceHTTP2Disabled(t *testing.T) {
config := &schemas.ProviderConfig{
NetworkConfig: schemas.NetworkConfig{
DefaultRequestTimeoutInSeconds: 30,
EnforceHTTP2: false,
},
}
config.CheckAndSetDefaults()
provider, err := NewBedrockProvider(config, nil)
require.NoError(t, err)
transport, ok := provider.client.Transport.(*http.Transport)
require.True(t, ok)
assert.False(t, transport.ForceAttemptHTTP2)
// TLSNextProto must be set to empty map to truly disable HTTP/2 ALPN negotiation
assert.NotNil(t, transport.TLSNextProto)
assert.Empty(t, transport.TLSNextProto)
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff