first commit
This commit is contained in:
417
core/providers/bedrock/batch.go
Normal file
417
core/providers/bedrock/batch.go
Normal file
@@ -0,0 +1,417 @@
|
||||
package bedrock
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// BedrockBatchJobRequest represents a request to create a batch inference job.
|
||||
type BedrockBatchJobRequest struct {
|
||||
JobName string `json:"jobName"`
|
||||
ModelID *string `json:"modelId"`
|
||||
RoleArn string `json:"roleArn"`
|
||||
InputDataConfig BedrockInputDataConfig `json:"inputDataConfig"`
|
||||
OutputDataConfig BedrockOutputDataConfig `json:"outputDataConfig"`
|
||||
TimeoutDurationInHours int `json:"timeoutDurationInHours,omitempty"`
|
||||
Tags []BedrockTag `json:"tags,omitempty"`
|
||||
}
|
||||
|
||||
// BedrockInputDataConfig represents the input configuration for a batch job.
|
||||
type BedrockInputDataConfig struct {
|
||||
S3InputDataConfig BedrockS3InputDataConfig `json:"s3InputDataConfig"`
|
||||
}
|
||||
|
||||
// BedrockS3InputDataConfig represents S3 input configuration.
|
||||
type BedrockS3InputDataConfig struct {
|
||||
S3Uri string `json:"s3Uri"`
|
||||
S3InputFormat string `json:"s3InputFormat,omitempty"` // "JSONL"
|
||||
Endpoint *string `json:"endpoint,omitempty"`
|
||||
FileID *string `json:"file_id,omitempty"`
|
||||
}
|
||||
|
||||
// BedrockOutputDataConfig represents the output configuration for a batch job.
|
||||
type BedrockOutputDataConfig struct {
|
||||
S3OutputDataConfig BedrockS3OutputDataConfig `json:"s3OutputDataConfig"`
|
||||
}
|
||||
|
||||
// BedrockS3OutputDataConfig represents S3 output configuration.
|
||||
type BedrockS3OutputDataConfig struct {
|
||||
S3Uri string `json:"s3Uri"`
|
||||
}
|
||||
|
||||
// BedrockTag represents a tag for a batch job.
|
||||
type BedrockTag struct {
|
||||
Key string `json:"key"`
|
||||
Value string `json:"value"`
|
||||
}
|
||||
|
||||
// BedrockBatchJobResponse represents a batch job response.
|
||||
type BedrockBatchJobResponse struct {
|
||||
JobArn string `json:"jobArn"`
|
||||
Status string `json:"status"`
|
||||
JobName string `json:"jobName,omitempty"`
|
||||
ModelID string `json:"modelId,omitempty"`
|
||||
RoleArn string `json:"roleArn,omitempty"`
|
||||
InputDataConfig *BedrockInputDataConfig `json:"inputDataConfig,omitempty"`
|
||||
OutputDataConfig *BedrockOutputDataConfig `json:"outputDataConfig,omitempty"`
|
||||
VpcConfig *BedrockVpcConfig `json:"vpcConfig,omitempty"`
|
||||
SubmitTime *time.Time `json:"submitTime,omitempty"`
|
||||
LastModifiedTime *time.Time `json:"lastModifiedTime,omitempty"`
|
||||
EndTime *time.Time `json:"endTime,omitempty"`
|
||||
Message string `json:"message,omitempty"`
|
||||
ClientRequestToken string `json:"clientRequestToken,omitempty"`
|
||||
JobExpirationTime *time.Time `json:"jobExpirationTime,omitempty"`
|
||||
TimeoutDurationInHours int `json:"timeoutDurationInHours,omitempty"`
|
||||
}
|
||||
|
||||
// BedrockBatchJobListResponse represents a list of batch jobs.
|
||||
type BedrockBatchJobListResponse struct {
|
||||
InvocationJobSummaries []BedrockBatchJobSummary `json:"invocationJobSummaries"`
|
||||
NextToken *string `json:"nextToken,omitempty"`
|
||||
}
|
||||
|
||||
// BedrockBatchJobSummary represents a summary of a batch job.
|
||||
type BedrockBatchJobSummary struct {
|
||||
JobArn string `json:"jobArn"`
|
||||
JobName string `json:"jobName"`
|
||||
ModelID string `json:"modelId"`
|
||||
Status string `json:"status"`
|
||||
SubmitTime *time.Time `json:"submitTime,omitempty"`
|
||||
LastModifiedTime *time.Time `json:"lastModifiedTime,omitempty"`
|
||||
EndTime *time.Time `json:"endTime,omitempty"`
|
||||
Message string `json:"message,omitempty"`
|
||||
}
|
||||
|
||||
// BedrockBatchResultRecord represents a single result record in Bedrock batch output JSONL.
|
||||
type BedrockBatchResultRecord struct {
|
||||
RecordID string `json:"recordId"`
|
||||
ModelOutput json.RawMessage `json:"modelOutput,omitempty"`
|
||||
Error *BedrockBatchError `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// BedrockBatchError represents an error in batch processing.
|
||||
type BedrockBatchError struct {
|
||||
ErrorCode int `json:"errorCode,omitempty"`
|
||||
ErrorMessage string `json:"errorMessage,omitempty"`
|
||||
}
|
||||
|
||||
// BedrockBatchListRequest represents a request to list batch jobs.
|
||||
type BedrockBatchListRequest struct {
|
||||
MaxResults int `json:"maxResults,omitempty"`
|
||||
NextToken *string `json:"nextToken,omitempty"`
|
||||
StatusEquals string `json:"statusEquals,omitempty"`
|
||||
NameContains string `json:"nameContains,omitempty"`
|
||||
}
|
||||
|
||||
// BedrockBatchRetrieveRequest represents a request to retrieve a batch job.
|
||||
type BedrockBatchRetrieveRequest struct {
|
||||
JobIdentifier string `json:"jobIdentifier"`
|
||||
}
|
||||
|
||||
// BedrockBatchCancelRequest represents a request to cancel/stop a batch job.
|
||||
type BedrockBatchCancelRequest struct {
|
||||
JobIdentifier string `json:"jobIdentifier"`
|
||||
}
|
||||
|
||||
// BedrockBatchCancelResponse represents the response from stopping a batch job.
|
||||
type BedrockBatchCancelResponse struct {
|
||||
JobArn string `json:"jobArn"`
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
// ToBifrostBatchStatus converts Bedrock status to Bifrost status.
|
||||
func ToBifrostBatchStatus(status string) schemas.BatchStatus {
|
||||
switch status {
|
||||
case "Submitted", "Validating":
|
||||
return schemas.BatchStatusValidating
|
||||
case "InProgress":
|
||||
return schemas.BatchStatusInProgress
|
||||
case "Completed":
|
||||
return schemas.BatchStatusCompleted
|
||||
case "Failed", "PartiallyCompleted":
|
||||
return schemas.BatchStatusFailed
|
||||
case "Stopping":
|
||||
return schemas.BatchStatusCancelling
|
||||
case "Stopped":
|
||||
return schemas.BatchStatusCancelled
|
||||
case "Expired":
|
||||
return schemas.BatchStatusExpired
|
||||
case "Scheduled":
|
||||
return schemas.BatchStatusValidating
|
||||
default:
|
||||
return schemas.BatchStatus(status)
|
||||
}
|
||||
}
|
||||
|
||||
// parseBatchResultsJSONL parses JSONL content from Bedrock batch output into Bifrost format.
|
||||
// Returns the parsed results and any parse errors encountered.
|
||||
func parseBatchResultsJSONL(content []byte, provider *BedrockProvider) ([]schemas.BatchResultItem, []schemas.BatchError) {
|
||||
var results []schemas.BatchResultItem
|
||||
|
||||
parseResult := providerUtils.ParseJSONL(content, func(line []byte) error {
|
||||
var bedrockResult BedrockBatchResultRecord
|
||||
if err := sonic.Unmarshal(line, &bedrockResult); err != nil {
|
||||
provider.logger.Warn(fmt.Sprintf("failed to parse batch result line: %v", err))
|
||||
return err
|
||||
}
|
||||
|
||||
// Convert Bedrock format to Bifrost format
|
||||
resultItem := schemas.BatchResultItem{
|
||||
CustomID: bedrockResult.RecordID,
|
||||
}
|
||||
|
||||
if bedrockResult.ModelOutput != nil {
|
||||
var bodyMap map[string]interface{}
|
||||
if err := sonic.Unmarshal(bedrockResult.ModelOutput, &bodyMap); err == nil {
|
||||
resultItem.Response = &schemas.BatchResultResponse{
|
||||
StatusCode: 200,
|
||||
Body: bodyMap,
|
||||
}
|
||||
} else {
|
||||
resultItem.Error = &schemas.BatchResultError{
|
||||
Code: "parse_error",
|
||||
Message: fmt.Sprintf("failed to parse model output: %v", err),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if bedrockResult.Error != nil {
|
||||
resultItem.Error = &schemas.BatchResultError{
|
||||
Code: fmt.Sprintf("%d", bedrockResult.Error.ErrorCode),
|
||||
Message: bedrockResult.Error.ErrorMessage,
|
||||
}
|
||||
// Set status code to indicate error if there's an error
|
||||
if resultItem.Response == nil {
|
||||
resultItem.Response = &schemas.BatchResultResponse{
|
||||
StatusCode: bedrockResult.Error.ErrorCode,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
results = append(results, resultItem)
|
||||
return nil
|
||||
})
|
||||
|
||||
return results, parseResult.Errors
|
||||
}
|
||||
|
||||
// ToBedrockBatchJobResponse converts a Bifrost batch create response to Bedrock format.
|
||||
func ToBedrockBatchJobResponse(resp *schemas.BifrostBatchCreateResponse) *BedrockBatchJobResponse {
|
||||
// Here if the provider is not Bedrock - then we create a dummy arn and string using the batch ID
|
||||
if resp.ExtraFields.Provider != schemas.Bedrock {
|
||||
return &BedrockBatchJobResponse{
|
||||
JobArn: fmt.Sprintf("arn:aws:bedrock:us-east-1:444444444444:batch:%s", resp.ID),
|
||||
Status: toBedrockBatchStatus(resp.Status),
|
||||
}
|
||||
}
|
||||
// For bedrock, we go as is
|
||||
result := &BedrockBatchJobResponse{
|
||||
JobArn: resp.ID,
|
||||
Status: toBedrockBatchStatus(resp.Status),
|
||||
}
|
||||
|
||||
if resp.Metadata != nil {
|
||||
if jobName, ok := resp.Metadata["job_name"]; ok {
|
||||
result.JobName = jobName
|
||||
}
|
||||
}
|
||||
|
||||
if resp.CreatedAt > 0 {
|
||||
t := time.Unix(resp.CreatedAt, 0)
|
||||
result.SubmitTime = &t
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// ToBedrockBatchJobListResponse converts a Bifrost batch list response to Bedrock format.
|
||||
func ToBedrockBatchJobListResponse(resp *schemas.BifrostBatchListResponse) *BedrockBatchJobListResponse {
|
||||
result := &BedrockBatchJobListResponse{
|
||||
InvocationJobSummaries: make([]BedrockBatchJobSummary, len(resp.Data)),
|
||||
}
|
||||
|
||||
for i, batch := range resp.Data {
|
||||
summary := BedrockBatchJobSummary{
|
||||
JobArn: batch.ID,
|
||||
Status: toBedrockBatchStatus(batch.Status),
|
||||
}
|
||||
|
||||
if batch.Metadata != nil {
|
||||
if jobName, ok := batch.Metadata["job_name"]; ok {
|
||||
summary.JobName = jobName
|
||||
}
|
||||
if modelId, ok := batch.Metadata["model_id"]; ok {
|
||||
summary.ModelID = modelId
|
||||
}
|
||||
}
|
||||
|
||||
if batch.CreatedAt > 0 {
|
||||
t := time.Unix(batch.CreatedAt, 0)
|
||||
summary.SubmitTime = &t
|
||||
}
|
||||
|
||||
if batch.CompletedAt != nil && *batch.CompletedAt > 0 {
|
||||
t := time.Unix(*batch.CompletedAt, 0)
|
||||
summary.EndTime = &t
|
||||
}
|
||||
|
||||
result.InvocationJobSummaries[i] = summary
|
||||
}
|
||||
|
||||
if resp.LastID != nil {
|
||||
result.NextToken = resp.LastID
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// ToBedrockBatchJobRetrieveResponse converts a Bifrost batch retrieve response to Bedrock format.
|
||||
func ToBedrockBatchJobRetrieveResponse(resp *schemas.BifrostBatchRetrieveResponse) *BedrockBatchJobResponse {
|
||||
result := &BedrockBatchJobResponse{
|
||||
JobArn: resp.ID,
|
||||
Status: toBedrockBatchStatus(resp.Status),
|
||||
}
|
||||
|
||||
if resp.Metadata != nil {
|
||||
if jobName, ok := resp.Metadata["job_name"]; ok {
|
||||
result.JobName = jobName
|
||||
}
|
||||
}
|
||||
|
||||
if resp.CreatedAt > 0 {
|
||||
t := time.Unix(resp.CreatedAt, 0)
|
||||
result.SubmitTime = &t
|
||||
}
|
||||
|
||||
if resp.CompletedAt != nil && *resp.CompletedAt > 0 {
|
||||
t := time.Unix(*resp.CompletedAt, 0)
|
||||
result.EndTime = &t
|
||||
}
|
||||
|
||||
if resp.InputFileID != "" {
|
||||
result.InputDataConfig = &BedrockInputDataConfig{
|
||||
S3InputDataConfig: BedrockS3InputDataConfig{
|
||||
S3Uri: resp.InputFileID,
|
||||
S3InputFormat: "JSONL",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
if resp.OutputFileID != nil && *resp.OutputFileID != "" {
|
||||
result.OutputDataConfig = &BedrockOutputDataConfig{
|
||||
S3OutputDataConfig: BedrockS3OutputDataConfig{
|
||||
S3Uri: *resp.OutputFileID,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// toBedrockBatchStatus converts Bifrost batch status to Bedrock status.
|
||||
func toBedrockBatchStatus(status schemas.BatchStatus) string {
|
||||
switch status {
|
||||
case schemas.BatchStatusValidating:
|
||||
return "Validating"
|
||||
case schemas.BatchStatusInProgress:
|
||||
return "InProgress"
|
||||
case schemas.BatchStatusCompleted:
|
||||
fallthrough
|
||||
case schemas.BatchStatusEnded:
|
||||
return "Completed"
|
||||
case schemas.BatchStatusFailed:
|
||||
return "Failed"
|
||||
case schemas.BatchStatusCancelling:
|
||||
return "Stopping"
|
||||
case schemas.BatchStatusCancelled:
|
||||
return "Stopped"
|
||||
case schemas.BatchStatusExpired:
|
||||
return "Expired"
|
||||
default:
|
||||
return string(status)
|
||||
}
|
||||
}
|
||||
|
||||
// ToBifrostBatchListRequest converts a Bedrock batch list request to Bifrost format.
|
||||
func ToBifrostBatchListRequest(req *BedrockBatchListRequest, provider schemas.ModelProvider) *schemas.BifrostBatchListRequest {
|
||||
result := &schemas.BifrostBatchListRequest{
|
||||
Provider: provider,
|
||||
Limit: req.MaxResults,
|
||||
}
|
||||
|
||||
if req.NextToken != nil {
|
||||
result.PageToken = req.NextToken
|
||||
}
|
||||
|
||||
if req.StatusEquals != "" || req.NameContains != "" {
|
||||
result.ExtraParams = make(map[string]interface{})
|
||||
if req.StatusEquals != "" {
|
||||
result.ExtraParams["statusEquals"] = req.StatusEquals
|
||||
}
|
||||
if req.NameContains != "" {
|
||||
result.ExtraParams["nameContains"] = req.NameContains
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// ToBifrostBatchRetrieveRequest converts a Bedrock batch retrieve request to Bifrost format.
|
||||
func ToBifrostBatchRetrieveRequest(req *BedrockBatchRetrieveRequest, provider schemas.ModelProvider) *schemas.BifrostBatchRetrieveRequest {
|
||||
return &schemas.BifrostBatchRetrieveRequest{
|
||||
Provider: provider,
|
||||
BatchID: req.JobIdentifier,
|
||||
}
|
||||
}
|
||||
|
||||
// ToBifrostBatchCancelRequest converts a Bedrock batch cancel request to Bifrost format.
|
||||
func ToBifrostBatchCancelRequest(req *BedrockBatchCancelRequest, provider schemas.ModelProvider) *schemas.BifrostBatchCancelRequest {
|
||||
return &schemas.BifrostBatchCancelRequest{
|
||||
Provider: provider,
|
||||
BatchID: req.JobIdentifier,
|
||||
}
|
||||
}
|
||||
|
||||
// ToBedrockBatchCancelResponse converts a Bifrost batch cancel response to Bedrock format.
|
||||
func ToBedrockBatchCancelResponse(resp *schemas.BifrostBatchCancelResponse) *BedrockBatchCancelResponse {
|
||||
return &BedrockBatchCancelResponse{
|
||||
JobArn: resp.ID,
|
||||
Status: toBedrockBatchStatus(resp.Status),
|
||||
}
|
||||
}
|
||||
|
||||
// splitJSONL splits JSONL content into individual lines.
|
||||
func splitJSONL(data []byte) [][]byte {
|
||||
var lines [][]byte
|
||||
start := 0
|
||||
for i, b := range data {
|
||||
if b == '\n' {
|
||||
if i > start {
|
||||
lines = append(lines, data[start:i])
|
||||
}
|
||||
start = i + 1
|
||||
}
|
||||
}
|
||||
if start < len(data) {
|
||||
lines = append(lines, data[start:])
|
||||
}
|
||||
return lines
|
||||
}
|
||||
|
||||
// BedrockVpcConfig represents VPC configuration for a batch job.
|
||||
type BedrockVpcConfig struct {
|
||||
SecurityGroupIds []string `json:"securityGroupIds,omitempty"`
|
||||
SubnetIds []string `json:"subnetIds,omitempty"`
|
||||
}
|
||||
|
||||
// BedrockBatchManifest represents the manifest.json.out file structure from S3.
|
||||
type BedrockBatchManifest struct {
|
||||
TotalRecordCount int `json:"totalRecordCount"`
|
||||
ProcessedRecordCount int `json:"processedRecordCount"`
|
||||
ErrorRecordCount int `json:"errorRecordCount"`
|
||||
}
|
||||
3597
core/providers/bedrock/bedrock.go
Normal file
3597
core/providers/bedrock/bedrock.go
Normal file
File diff suppressed because it is too large
Load Diff
4326
core/providers/bedrock/bedrock_test.go
Normal file
4326
core/providers/bedrock/bedrock_test.go
Normal file
File diff suppressed because it is too large
Load Diff
449
core/providers/bedrock/chat.go
Normal file
449
core/providers/bedrock/chat.go
Normal file
@@ -0,0 +1,449 @@
|
||||
package bedrock
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// ToBedrockChatCompletionRequest converts a Bifrost request to Bedrock Converse API format
|
||||
func ToBedrockChatCompletionRequest(ctx *schemas.BifrostContext, bifrostReq *schemas.BifrostChatRequest) (*BedrockConverseRequest, error) {
|
||||
if bifrostReq == nil {
|
||||
return nil, fmt.Errorf("bifrost request is nil")
|
||||
}
|
||||
|
||||
if bifrostReq.Input == nil {
|
||||
return nil, fmt.Errorf("only chat completion requests are supported for Bedrock Converse API")
|
||||
}
|
||||
|
||||
bedrockReq := &BedrockConverseRequest{
|
||||
ModelID: bifrostReq.Model,
|
||||
}
|
||||
|
||||
// Convert messages and system messages
|
||||
messages, systemMessages, err := convertMessages(bifrostReq.Input)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to convert messages: %w", err)
|
||||
}
|
||||
bedrockReq.Messages = messages
|
||||
if len(systemMessages) > 0 {
|
||||
bedrockReq.System = systemMessages
|
||||
}
|
||||
|
||||
// Convert parameters and configurations
|
||||
if err := convertChatParameters(ctx, bifrostReq, bedrockReq); err != nil {
|
||||
return nil, fmt.Errorf("failed to convert chat parameters: %w", err)
|
||||
}
|
||||
|
||||
// Ensure tool config is present when needed
|
||||
ensureChatToolConfigForConversation(bifrostReq, bedrockReq)
|
||||
|
||||
return bedrockReq, nil
|
||||
}
|
||||
|
||||
// ToBifrostChatResponse converts a Bedrock Converse API response to Bifrost format
|
||||
func (response *BedrockConverseResponse) ToBifrostChatResponse(ctx context.Context, model string) (*schemas.BifrostChatResponse, error) {
|
||||
if response == nil {
|
||||
return nil, fmt.Errorf("bedrock response is nil")
|
||||
}
|
||||
|
||||
// Convert content blocks and tool calls
|
||||
var contentStr *string
|
||||
var contentBlocks []schemas.ChatContentBlock
|
||||
var toolCalls []schemas.ChatAssistantMessageToolCall
|
||||
var reasoningDetails []schemas.ChatReasoningDetails
|
||||
var reasoningText string
|
||||
|
||||
if response.Output.Message != nil {
|
||||
for _, contentBlock := range response.Output.Message.Content {
|
||||
// Handle text content
|
||||
if contentBlock.Text != nil && *contentBlock.Text != "" {
|
||||
chatContentBlock := schemas.ChatContentBlock{
|
||||
Type: schemas.ChatContentBlockTypeText,
|
||||
Text: contentBlock.Text,
|
||||
}
|
||||
contentBlocks = append(contentBlocks, chatContentBlock)
|
||||
}
|
||||
|
||||
if contentBlock.ToolUse != nil {
|
||||
// Check if this is the structured output tool
|
||||
if structuredOutputToolName, ok := ctx.Value(schemas.BifrostContextKeyStructuredOutputToolName).(string); ok && contentBlock.ToolUse.Name == structuredOutputToolName {
|
||||
// This is structured output - set contentStr and skip adding to toolCalls
|
||||
if contentBlock.ToolUse.Input != nil {
|
||||
jsonStr := string(contentBlock.ToolUse.Input)
|
||||
contentStr = &jsonStr
|
||||
}
|
||||
continue // Skip adding to toolCalls
|
||||
}
|
||||
|
||||
// Regular tool call processing
|
||||
var arguments string
|
||||
if contentBlock.ToolUse.Input != nil {
|
||||
arguments = string(contentBlock.ToolUse.Input)
|
||||
} else {
|
||||
arguments = "{}"
|
||||
}
|
||||
|
||||
toolUseID := contentBlock.ToolUse.ToolUseID
|
||||
toolUseName := contentBlock.ToolUse.Name
|
||||
|
||||
toolCalls = append(toolCalls, schemas.ChatAssistantMessageToolCall{
|
||||
Index: uint16(len(toolCalls)),
|
||||
Type: schemas.Ptr("function"),
|
||||
ID: &toolUseID,
|
||||
Function: schemas.ChatAssistantMessageToolCallFunction{
|
||||
Name: &toolUseName,
|
||||
Arguments: arguments,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// Handle reasoning content
|
||||
if contentBlock.ReasoningContent != nil {
|
||||
if contentBlock.ReasoningContent.ReasoningText == nil {
|
||||
continue
|
||||
}
|
||||
reasoningDetails = append(reasoningDetails, schemas.ChatReasoningDetails{
|
||||
Index: len(reasoningDetails),
|
||||
Type: schemas.BifrostReasoningDetailsTypeText,
|
||||
Text: contentBlock.ReasoningContent.ReasoningText.Text,
|
||||
Signature: contentBlock.ReasoningContent.ReasoningText.Signature,
|
||||
})
|
||||
if contentBlock.ReasoningContent.ReasoningText.Text != nil {
|
||||
reasoningText += *contentBlock.ReasoningContent.ReasoningText.Text + "\n"
|
||||
}
|
||||
}
|
||||
|
||||
// Handle document content
|
||||
if contentBlock.Document != nil {
|
||||
fileBlock := schemas.ChatContentBlock{
|
||||
Type: schemas.ChatContentBlockTypeFile,
|
||||
File: &schemas.ChatInputFile{},
|
||||
}
|
||||
|
||||
// Set filename from document name
|
||||
if contentBlock.Document.Name != "" {
|
||||
fileBlock.File.Filename = &contentBlock.Document.Name
|
||||
}
|
||||
|
||||
// Set file type based on format
|
||||
if contentBlock.Document.Format != "" {
|
||||
var fileType string
|
||||
switch contentBlock.Document.Format {
|
||||
case "pdf":
|
||||
fileType = "application/pdf"
|
||||
case "txt":
|
||||
fileType = "text/plain"
|
||||
case "md":
|
||||
fileType = "text/markdown"
|
||||
case "html":
|
||||
fileType = "text/html"
|
||||
case "csv":
|
||||
fileType = "text/csv"
|
||||
case "doc":
|
||||
fileType = "application/msword"
|
||||
case "docx":
|
||||
fileType = "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
|
||||
case "xls":
|
||||
fileType = "application/vnd.ms-excel"
|
||||
case "xlsx":
|
||||
fileType = "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
|
||||
default:
|
||||
fileType = "application/pdf"
|
||||
}
|
||||
fileBlock.File.FileType = &fileType
|
||||
}
|
||||
|
||||
// Convert document source data
|
||||
if contentBlock.Document.Source != nil {
|
||||
if contentBlock.Document.Source.Bytes != nil {
|
||||
fileBlock.File.FileData = contentBlock.Document.Source.Bytes
|
||||
} else if contentBlock.Document.Source.Text != nil {
|
||||
fileBlock.File.FileData = contentBlock.Document.Source.Text
|
||||
}
|
||||
}
|
||||
|
||||
contentBlocks = append(contentBlocks, fileBlock)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(contentBlocks) == 1 && contentBlocks[0].Type == schemas.ChatContentBlockTypeText {
|
||||
contentStr = contentBlocks[0].Text
|
||||
contentBlocks = nil
|
||||
}
|
||||
|
||||
// Create the message content
|
||||
messageContent := schemas.ChatMessageContent{
|
||||
ContentStr: contentStr,
|
||||
ContentBlocks: contentBlocks,
|
||||
}
|
||||
|
||||
// Create assistant message if we have tool calls
|
||||
var assistantMessage *schemas.ChatAssistantMessage
|
||||
if len(toolCalls) > 0 {
|
||||
assistantMessage = &schemas.ChatAssistantMessage{
|
||||
ToolCalls: toolCalls,
|
||||
}
|
||||
}
|
||||
if len(reasoningDetails) > 0 {
|
||||
if assistantMessage == nil {
|
||||
assistantMessage = &schemas.ChatAssistantMessage{}
|
||||
}
|
||||
assistantMessage.ReasoningDetails = reasoningDetails
|
||||
if reasoningText != "" {
|
||||
assistantMessage.Reasoning = new(reasoningText)
|
||||
}
|
||||
}
|
||||
|
||||
// Create the response choice
|
||||
choices := []schemas.BifrostResponseChoice{
|
||||
{
|
||||
Index: 0,
|
||||
ChatNonStreamResponseChoice: &schemas.ChatNonStreamResponseChoice{
|
||||
Message: &schemas.ChatMessage{
|
||||
Role: schemas.ChatMessageRoleAssistant,
|
||||
Content: &messageContent,
|
||||
ChatAssistantMessage: assistantMessage,
|
||||
},
|
||||
},
|
||||
FinishReason: schemas.Ptr(convertBedrockStopReason(response.StopReason)),
|
||||
},
|
||||
}
|
||||
var usage *schemas.BifrostLLMUsage
|
||||
if response.Usage != nil {
|
||||
// Convert usage information
|
||||
usage = &schemas.BifrostLLMUsage{
|
||||
PromptTokens: response.Usage.InputTokens,
|
||||
CompletionTokens: response.Usage.OutputTokens,
|
||||
TotalTokens: response.Usage.TotalTokens,
|
||||
}
|
||||
// Handle cached tokens if present
|
||||
if response.Usage.CacheReadInputTokens > 0 {
|
||||
if usage.PromptTokensDetails == nil {
|
||||
usage.PromptTokensDetails = &schemas.ChatPromptTokensDetails{}
|
||||
}
|
||||
usage.PromptTokensDetails.CachedReadTokens = response.Usage.CacheReadInputTokens
|
||||
usage.PromptTokens = usage.PromptTokens + response.Usage.CacheReadInputTokens
|
||||
}
|
||||
if response.Usage.CacheWriteInputTokens > 0 {
|
||||
if usage.PromptTokensDetails == nil {
|
||||
usage.PromptTokensDetails = &schemas.ChatPromptTokensDetails{}
|
||||
}
|
||||
usage.PromptTokensDetails.CachedWriteTokens = response.Usage.CacheWriteInputTokens
|
||||
usage.PromptTokens = usage.PromptTokens + response.Usage.CacheWriteInputTokens
|
||||
}
|
||||
}
|
||||
|
||||
// Create the final Bifrost response
|
||||
bifrostResponse := &schemas.BifrostChatResponse{
|
||||
ID: uuid.New().String(),
|
||||
Model: model,
|
||||
Object: "chat.completion",
|
||||
Choices: choices,
|
||||
Usage: usage,
|
||||
Created: int(time.Now().Unix()),
|
||||
ExtraFields: schemas.BifrostResponseExtraFields{
|
||||
},
|
||||
}
|
||||
|
||||
if response.ServiceTier != nil && response.ServiceTier.Type != "" {
|
||||
bifrostResponse.ServiceTier = &response.ServiceTier.Type
|
||||
}
|
||||
|
||||
return bifrostResponse, nil
|
||||
}
|
||||
|
||||
// BedrockStreamState tracks per-stream tool call index state.
|
||||
type BedrockStreamState struct {
|
||||
nextToolCallIndex int
|
||||
contentBlockToToolCallIdx map[int]int
|
||||
}
|
||||
|
||||
// NewBedrockStreamState returns initialised stream state for one streaming response.
|
||||
func NewBedrockStreamState() *BedrockStreamState {
|
||||
return &BedrockStreamState{
|
||||
contentBlockToToolCallIdx: make(map[int]int),
|
||||
}
|
||||
}
|
||||
|
||||
func (chunk *BedrockStreamEvent) ToBifrostChatCompletionStream(state *BedrockStreamState) (*schemas.BifrostChatResponse, *schemas.BifrostError, bool) {
|
||||
if state == nil {
|
||||
state = NewBedrockStreamState()
|
||||
} else if state.contentBlockToToolCallIdx == nil {
|
||||
state.contentBlockToToolCallIdx = make(map[int]int)
|
||||
}
|
||||
|
||||
// event with metrics/usage is the last and with stop reason is the second last
|
||||
switch {
|
||||
case chunk.Role != nil:
|
||||
// Send empty response to signal start
|
||||
streamResponse := &schemas.BifrostChatResponse{
|
||||
Object: "chat.completion.chunk",
|
||||
Choices: []schemas.BifrostResponseChoice{
|
||||
{
|
||||
Index: 0,
|
||||
ChatStreamResponseChoice: &schemas.ChatStreamResponseChoice{
|
||||
Delta: &schemas.ChatStreamResponseChoiceDelta{
|
||||
Role: chunk.Role,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
return streamResponse, nil, false
|
||||
|
||||
case chunk.Start != nil && chunk.Start.ToolUse != nil:
|
||||
toolUseStart := chunk.Start.ToolUse
|
||||
|
||||
toolCallIdx := 0
|
||||
if chunk.ContentBlockIndex != nil {
|
||||
toolCallIdx = state.nextToolCallIndex
|
||||
state.contentBlockToToolCallIdx[*chunk.ContentBlockIndex] = toolCallIdx
|
||||
state.nextToolCallIndex++
|
||||
}
|
||||
|
||||
// Create tool call structure for start event
|
||||
var toolCall schemas.ChatAssistantMessageToolCall
|
||||
toolCall.Index = uint16(toolCallIdx)
|
||||
toolCall.ID = schemas.Ptr(toolUseStart.ToolUseID)
|
||||
toolCall.Type = schemas.Ptr("function")
|
||||
toolCall.Function.Name = schemas.Ptr(toolUseStart.Name)
|
||||
toolCall.Function.Arguments = "" // Start with empty arguments
|
||||
|
||||
streamResponse := &schemas.BifrostChatResponse{
|
||||
Object: "chat.completion.chunk",
|
||||
Choices: []schemas.BifrostResponseChoice{
|
||||
{
|
||||
Index: 0,
|
||||
ChatStreamResponseChoice: &schemas.ChatStreamResponseChoice{
|
||||
Delta: &schemas.ChatStreamResponseChoiceDelta{
|
||||
ToolCalls: []schemas.ChatAssistantMessageToolCall{toolCall},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
return streamResponse, nil, false
|
||||
|
||||
case chunk.Delta != nil:
|
||||
switch {
|
||||
case chunk.Delta.Text != nil:
|
||||
// Handle text delta
|
||||
text := *chunk.Delta.Text
|
||||
if text != "" {
|
||||
streamResponse := &schemas.BifrostChatResponse{
|
||||
Object: "chat.completion.chunk",
|
||||
Choices: []schemas.BifrostResponseChoice{
|
||||
{
|
||||
Index: 0,
|
||||
ChatStreamResponseChoice: &schemas.ChatStreamResponseChoice{
|
||||
Delta: &schemas.ChatStreamResponseChoiceDelta{
|
||||
Content: &text,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
return streamResponse, nil, false
|
||||
}
|
||||
|
||||
case chunk.Delta.ToolUse != nil:
|
||||
// Handle tool use delta
|
||||
toolUseDelta := chunk.Delta.ToolUse
|
||||
|
||||
toolCallIdx := 0
|
||||
if chunk.ContentBlockIndex != nil {
|
||||
toolCallIdx = state.contentBlockToToolCallIdx[*chunk.ContentBlockIndex]
|
||||
}
|
||||
|
||||
// Create tool call structure
|
||||
var toolCall schemas.ChatAssistantMessageToolCall
|
||||
toolCall.Index = uint16(toolCallIdx)
|
||||
toolCall.Type = schemas.Ptr("function")
|
||||
|
||||
// For streaming, we need to accumulate tool use data
|
||||
// This is a simplified approach - in practice, you'd need to track tool calls across chunks
|
||||
toolCall.Function.Arguments = toolUseDelta.Input
|
||||
|
||||
streamResponse := &schemas.BifrostChatResponse{
|
||||
Object: "chat.completion.chunk",
|
||||
Choices: []schemas.BifrostResponseChoice{
|
||||
{
|
||||
Index: 0,
|
||||
ChatStreamResponseChoice: &schemas.ChatStreamResponseChoice{
|
||||
Delta: &schemas.ChatStreamResponseChoiceDelta{
|
||||
ToolCalls: []schemas.ChatAssistantMessageToolCall{toolCall},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
return streamResponse, nil, false
|
||||
|
||||
case chunk.Delta.ReasoningContent != nil:
|
||||
// Handle reasoning content delta
|
||||
reasoningContentDelta := chunk.Delta.ReasoningContent
|
||||
|
||||
// Only construct and return a response when either Text or Signature is set
|
||||
if (reasoningContentDelta.Text == nil || *reasoningContentDelta.Text == "") && reasoningContentDelta.Signature == nil {
|
||||
return nil, nil, false
|
||||
}
|
||||
|
||||
var streamResponse *schemas.BifrostChatResponse
|
||||
if reasoningContentDelta.Text != nil && *reasoningContentDelta.Text != "" {
|
||||
streamResponse = &schemas.BifrostChatResponse{
|
||||
Object: "chat.completion.chunk",
|
||||
Choices: []schemas.BifrostResponseChoice{
|
||||
{
|
||||
Index: 0,
|
||||
ChatStreamResponseChoice: &schemas.ChatStreamResponseChoice{
|
||||
Delta: &schemas.ChatStreamResponseChoiceDelta{
|
||||
Reasoning: reasoningContentDelta.Text,
|
||||
ReasoningDetails: []schemas.ChatReasoningDetails{
|
||||
{
|
||||
Index: 0,
|
||||
Type: schemas.BifrostReasoningDetailsTypeText,
|
||||
Text: reasoningContentDelta.Text,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
} else if reasoningContentDelta.Signature != nil {
|
||||
streamResponse = &schemas.BifrostChatResponse{
|
||||
Object: "chat.completion.chunk",
|
||||
Choices: []schemas.BifrostResponseChoice{
|
||||
{
|
||||
Index: 0,
|
||||
ChatStreamResponseChoice: &schemas.ChatStreamResponseChoice{
|
||||
Delta: &schemas.ChatStreamResponseChoiceDelta{
|
||||
ReasoningDetails: []schemas.ChatReasoningDetails{
|
||||
{
|
||||
Index: 0,
|
||||
Type: schemas.BifrostReasoningDetailsTypeText,
|
||||
Signature: reasoningContentDelta.Signature,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
return streamResponse, nil, false
|
||||
}
|
||||
}
|
||||
|
||||
return nil, nil, false
|
||||
}
|
||||
477
core/providers/bedrock/convert_tool_config_test.go
Normal file
477
core/providers/bedrock/convert_tool_config_test.go
Normal file
@@ -0,0 +1,477 @@
|
||||
package bedrock
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// TestConvertToolConfig_DropsServerToolsOnBedrock locks in the bug fix from
|
||||
// the user-reported repro: sending `web_search_20260209` via the OpenAI-
|
||||
// compatible /v1/chat/completions endpoint to Bedrock was producing a
|
||||
// malformed ToolConfig that Bedrock rejected with 400 "The provided request
|
||||
// is not valid". The fix strips unsupported server tools before the
|
||||
// conversion loop so the outbound request is valid.
|
||||
func TestConvertToolConfig_DropsServerToolsOnBedrock(t *testing.T) {
|
||||
params := &schemas.ChatParameters{
|
||||
Tools: []schemas.ChatTool{
|
||||
{
|
||||
Type: schemas.ChatToolTypeFunction,
|
||||
Function: &schemas.ChatToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: schemas.Ptr("Get weather by city"),
|
||||
Parameters: &schemas.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
// Server tool — Bedrock doesn't support web_search per Table 20.
|
||||
// Should be stripped silently.
|
||||
Type: schemas.ChatToolType("web_search_20260209"),
|
||||
Name: "web_search",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
cfg := convertToolConfig("global.anthropic.claude-sonnet-4-6", params)
|
||||
if cfg == nil {
|
||||
t.Fatalf("expected ToolConfig, got nil (function tool should have survived)")
|
||||
}
|
||||
if len(cfg.Tools) != 1 {
|
||||
t.Fatalf("expected exactly 1 tool (function), got %d: %+v", len(cfg.Tools), cfg.Tools)
|
||||
}
|
||||
if cfg.Tools[0].ToolSpec == nil || cfg.Tools[0].ToolSpec.Name != "get_weather" {
|
||||
t.Errorf("expected function tool 'get_weather' to survive, got %+v", cfg.Tools[0])
|
||||
}
|
||||
}
|
||||
|
||||
// TestConvertToolConfig_ReturnsNilWhenAllDropped locks in the empty-slice
|
||||
// guard. Bedrock's Converse API rejects `"toolConfig": {"tools": []}` with a
|
||||
// 400; when every tool is unsupported and gets stripped, convertToolConfig
|
||||
// must return nil so no ToolConfig ships at all.
|
||||
func TestConvertToolConfig_ReturnsNilWhenAllDropped(t *testing.T) {
|
||||
params := &schemas.ChatParameters{
|
||||
Tools: []schemas.ChatTool{
|
||||
{
|
||||
Type: schemas.ChatToolType("web_search_20260209"),
|
||||
Name: "web_search",
|
||||
},
|
||||
{
|
||||
Type: schemas.ChatToolType("web_fetch_20260309"),
|
||||
Name: "web_fetch",
|
||||
},
|
||||
{
|
||||
Type: schemas.ChatToolType("code_execution_20250825"),
|
||||
Name: "code_execution",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
cfg := convertToolConfig("global.anthropic.claude-sonnet-4-6", params)
|
||||
if cfg != nil {
|
||||
t.Fatalf("expected nil ToolConfig (all tools unsupported on Bedrock), got %+v", cfg)
|
||||
}
|
||||
}
|
||||
|
||||
// TestConvertToolConfig_KeepsBedrockSupportedServerTools — locks in that
|
||||
// Bedrock-supported server tools (bash, memory, text_editor, computer,
|
||||
// tool_search) do NOT appear in Converse's typed toolConfig.tools slot —
|
||||
// they must be tunneled via additionalModelRequestFields (exercised in
|
||||
// TestCollectBedrockServerTools_*). If the only tool is a server tool,
|
||||
// toolConfig is nil so we don't ship {"toolConfig": {"tools": []}}.
|
||||
func TestConvertToolConfig_KeepsBedrockSupportedServerTools(t *testing.T) {
|
||||
params := &schemas.ChatParameters{
|
||||
Tools: []schemas.ChatTool{
|
||||
{
|
||||
Type: schemas.ChatToolType("bash_20250124"),
|
||||
Name: "bash",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
cfg := convertToolConfig("global.anthropic.claude-sonnet-4-6", params)
|
||||
if cfg != nil {
|
||||
t.Fatalf("expected nil toolConfig (server tools flow via additionalModelRequestFields, not toolSpec), got %+v", cfg)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCollectBedrockServerTools_BashOnly — bash is Bedrock-supported per the
|
||||
// B-header list; the helper must emit it as a native-JSON tool entry with no
|
||||
// derived beta header (bash has no high-confidence 1:1 beta-header mapping;
|
||||
// callers rely on extra-headers for that).
|
||||
func TestCollectBedrockServerTools_BashOnly(t *testing.T) {
|
||||
params := &schemas.ChatParameters{
|
||||
Tools: []schemas.ChatTool{
|
||||
{
|
||||
Type: schemas.ChatToolType("bash_20250124"),
|
||||
Name: "bash",
|
||||
},
|
||||
},
|
||||
}
|
||||
tools, betas := collectBedrockServerTools(params)
|
||||
if len(tools) != 1 {
|
||||
t.Fatalf("expected 1 server tool, got %d", len(tools))
|
||||
}
|
||||
got := string(tools[0])
|
||||
if !strings.Contains(got, `"type":"bash_20250124"`) || !strings.Contains(got, `"name":"bash"`) {
|
||||
t.Errorf("expected native Anthropic bash shape, got %s", got)
|
||||
}
|
||||
if len(betas) != 0 {
|
||||
t.Errorf("expected no derived beta headers for bash (no 1:1 mapping), got %v", betas)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCollectBedrockServerTools_ComputerDerivesBeta — computer_YYYYMMDD must
|
||||
// derive computer-use-YYYY-MM-DD as the beta header, gated through
|
||||
// FilterBetaHeadersForProvider(Bedrock) which keeps computer-use-* headers.
|
||||
func TestCollectBedrockServerTools_ComputerDerivesBeta(t *testing.T) {
|
||||
params := &schemas.ChatParameters{
|
||||
Tools: []schemas.ChatTool{
|
||||
{
|
||||
Type: schemas.ChatToolType("computer_20251124"),
|
||||
Name: "computer",
|
||||
DisplayWidthPx: schemas.Ptr(1280),
|
||||
DisplayHeightPx: schemas.Ptr(800),
|
||||
},
|
||||
},
|
||||
}
|
||||
tools, betas := collectBedrockServerTools(params)
|
||||
if len(tools) != 1 {
|
||||
t.Fatalf("expected 1 server tool, got %d", len(tools))
|
||||
}
|
||||
if !strings.Contains(string(tools[0]), `"display_width_px":1280`) {
|
||||
t.Errorf("expected computer variant fields to flow through, got %s", string(tools[0]))
|
||||
}
|
||||
if len(betas) != 1 || betas[0] != "computer-use-2025-11-24" {
|
||||
t.Errorf("expected [computer-use-2025-11-24], got %v", betas)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCollectBedrockServerTools_MemoryDerivesContextManagement — memory
|
||||
// activates via the context-management-2025-06-27 bundle on Bedrock (cite:
|
||||
// anthropic/types.go:179).
|
||||
func TestCollectBedrockServerTools_MemoryDerivesContextManagement(t *testing.T) {
|
||||
params := &schemas.ChatParameters{
|
||||
Tools: []schemas.ChatTool{
|
||||
{
|
||||
Type: schemas.ChatToolType("memory_20250818"),
|
||||
Name: "memory",
|
||||
},
|
||||
},
|
||||
}
|
||||
_, betas := collectBedrockServerTools(params)
|
||||
if len(betas) != 1 || betas[0] != "context-management-2025-06-27" {
|
||||
t.Errorf("expected [context-management-2025-06-27], got %v", betas)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCollectBedrockServerTools_StripsUnsupported — web_search isn't in
|
||||
// Bedrock's ProviderFeatures (WebSearch=false), so ValidateChatToolsForProvider
|
||||
// drops it and the helper must emit nothing.
|
||||
func TestCollectBedrockServerTools_StripsUnsupported(t *testing.T) {
|
||||
params := &schemas.ChatParameters{
|
||||
Tools: []schemas.ChatTool{
|
||||
{
|
||||
Type: schemas.ChatToolType("web_search_20260209"),
|
||||
Name: "web_search",
|
||||
},
|
||||
},
|
||||
}
|
||||
tools, betas := collectBedrockServerTools(params)
|
||||
if len(tools) != 0 {
|
||||
t.Errorf("expected no server tools (web_search unsupported on Bedrock), got %d", len(tools))
|
||||
}
|
||||
if len(betas) != 0 {
|
||||
t.Errorf("expected no betas when all tools filtered, got %v", betas)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCollectBedrockServerTools_FunctionToolsIgnored — function/custom tools
|
||||
// go through convertToolConfig, not this helper.
|
||||
func TestCollectBedrockServerTools_FunctionToolsIgnored(t *testing.T) {
|
||||
params := &schemas.ChatParameters{
|
||||
Tools: []schemas.ChatTool{
|
||||
{
|
||||
Type: schemas.ChatToolTypeFunction,
|
||||
Function: &schemas.ChatToolFunction{
|
||||
Name: "get_weather",
|
||||
Parameters: &schemas.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
tools, betas := collectBedrockServerTools(params)
|
||||
if len(tools) != 0 || len(betas) != 0 {
|
||||
t.Errorf("function tools should not flow through server-tool helper, got tools=%d betas=%v", len(tools), betas)
|
||||
}
|
||||
}
|
||||
|
||||
// TestBuildBedrockServerToolChoice_PinnedServerTool — caller pins a kept
|
||||
// server tool (computer) by name. Converse's typed toolConfig.toolChoice path
|
||||
// can't carry this because toolConfig.tools doesn't include server tools; the
|
||||
// existing reconciliation silently drops the pin. The tunneled path must
|
||||
// emit {"type":"tool","name":"computer"} into additionalModelRequestFields.
|
||||
func TestBuildBedrockServerToolChoice_PinnedServerTool(t *testing.T) {
|
||||
computer := schemas.ChatTool{
|
||||
Type: schemas.ChatToolType("computer_20251124"),
|
||||
Name: "computer",
|
||||
DisplayWidthPx: schemas.Ptr(1280),
|
||||
}
|
||||
params := &schemas.ChatParameters{
|
||||
Tools: []schemas.ChatTool{computer},
|
||||
ToolChoice: &schemas.ChatToolChoice{
|
||||
ChatToolChoiceStruct: &schemas.ChatToolChoiceStruct{
|
||||
Type: schemas.ChatToolChoiceTypeFunction,
|
||||
Function: &schemas.ChatToolChoiceFunction{Name: "computer"},
|
||||
},
|
||||
},
|
||||
}
|
||||
choice, ok := buildBedrockServerToolChoice(params, []schemas.ChatTool{computer})
|
||||
if !ok {
|
||||
t.Fatalf("expected tunneled tool_choice for pinned server tool, got (nil, false)")
|
||||
}
|
||||
got := string(choice)
|
||||
if !strings.Contains(got, `"type":"tool"`) || !strings.Contains(got, `"name":"computer"`) {
|
||||
t.Errorf("expected Anthropic-native {type:tool,name:computer}, got %s", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestBuildBedrockServerToolChoice_PinnedFunctionTool_NotTunneled — function
|
||||
// tool pins stay on Converse's typed path (toolConfig.toolChoice.tool). The
|
||||
// helper must not double-emit.
|
||||
func TestBuildBedrockServerToolChoice_PinnedFunctionTool_NotTunneled(t *testing.T) {
|
||||
fn := schemas.ChatTool{
|
||||
Type: schemas.ChatToolTypeFunction,
|
||||
Function: &schemas.ChatToolFunction{
|
||||
Name: "get_weather",
|
||||
Parameters: &schemas.ToolFunctionParameters{Type: "object"},
|
||||
},
|
||||
}
|
||||
params := &schemas.ChatParameters{
|
||||
Tools: []schemas.ChatTool{fn},
|
||||
ToolChoice: &schemas.ChatToolChoice{
|
||||
ChatToolChoiceStruct: &schemas.ChatToolChoiceStruct{
|
||||
Type: schemas.ChatToolChoiceTypeFunction,
|
||||
Function: &schemas.ChatToolChoiceFunction{Name: "get_weather"},
|
||||
},
|
||||
},
|
||||
}
|
||||
if _, ok := buildBedrockServerToolChoice(params, []schemas.ChatTool{fn}); ok {
|
||||
t.Errorf("expected no tunneling for function-tool pin (typed Converse path handles it)")
|
||||
}
|
||||
}
|
||||
|
||||
// TestBuildBedrockServerToolChoice_AnyWithOnlyServerTools — tool_choice:any
|
||||
// with only server tools: convertToolConfig returns nil (bedrockTools empty),
|
||||
// so the typed any-contract is lost. The tunneled path must emit
|
||||
// {"type":"any"} to preserve the forcing semantics.
|
||||
func TestBuildBedrockServerToolChoice_AnyWithOnlyServerTools(t *testing.T) {
|
||||
bash := schemas.ChatTool{
|
||||
Type: schemas.ChatToolType("bash_20250124"),
|
||||
Name: "bash",
|
||||
}
|
||||
anyStr := string(schemas.ChatToolChoiceTypeAny)
|
||||
params := &schemas.ChatParameters{
|
||||
Tools: []schemas.ChatTool{bash},
|
||||
ToolChoice: &schemas.ChatToolChoice{
|
||||
ChatToolChoiceStr: &anyStr,
|
||||
},
|
||||
}
|
||||
choice, ok := buildBedrockServerToolChoice(params, []schemas.ChatTool{bash})
|
||||
if !ok {
|
||||
t.Fatalf("expected tunneled any-contract when only server tools are present, got (nil, false)")
|
||||
}
|
||||
got := string(choice)
|
||||
if !strings.Contains(got, `"type":"any"`) {
|
||||
t.Errorf("expected {type:any}, got %s", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestBuildBedrockServerToolChoice_AnyWithFunctionTool_NotTunneled — when at
|
||||
// least one function/custom tool is present, Converse's typed
|
||||
// toolConfig.toolChoice.any carries the any-contract. Don't double-emit.
|
||||
func TestBuildBedrockServerToolChoice_AnyWithFunctionTool_NotTunneled(t *testing.T) {
|
||||
fn := schemas.ChatTool{
|
||||
Type: schemas.ChatToolTypeFunction,
|
||||
Function: &schemas.ChatToolFunction{
|
||||
Name: "get_weather",
|
||||
Parameters: &schemas.ToolFunctionParameters{Type: "object"},
|
||||
},
|
||||
}
|
||||
bash := schemas.ChatTool{
|
||||
Type: schemas.ChatToolType("bash_20250124"),
|
||||
Name: "bash",
|
||||
}
|
||||
anyStr := string(schemas.ChatToolChoiceTypeAny)
|
||||
params := &schemas.ChatParameters{
|
||||
Tools: []schemas.ChatTool{fn, bash},
|
||||
ToolChoice: &schemas.ChatToolChoice{
|
||||
ChatToolChoiceStr: &anyStr,
|
||||
},
|
||||
}
|
||||
if _, ok := buildBedrockServerToolChoice(params, []schemas.ChatTool{fn, bash}); ok {
|
||||
t.Errorf("expected no tunneling when function/custom tool is present (typed Converse path handles any)")
|
||||
}
|
||||
}
|
||||
|
||||
// TestBuildBedrockServerToolChoice_UnsupportedServerToolPin_NotTunneled — the
|
||||
// caller pins web_search, which ValidateChatToolsForProvider strips on
|
||||
// Bedrock. The pin name is absent from the filtered set; the helper must not
|
||||
// fabricate a tunneled tool_choice for a tool that isn't in the request.
|
||||
func TestBuildBedrockServerToolChoice_UnsupportedServerToolPin_NotTunneled(t *testing.T) {
|
||||
// The caller's original request had web_search, but it's been stripped.
|
||||
// We pass the filtered slice (empty for the server-tool axis) to mimic
|
||||
// the convertChatParameters call path.
|
||||
params := &schemas.ChatParameters{
|
||||
Tools: []schemas.ChatTool{{Type: schemas.ChatToolType("web_search_20260209"), Name: "web_search"}},
|
||||
ToolChoice: &schemas.ChatToolChoice{
|
||||
ChatToolChoiceStruct: &schemas.ChatToolChoiceStruct{
|
||||
Type: schemas.ChatToolChoiceTypeFunction,
|
||||
Function: &schemas.ChatToolChoiceFunction{Name: "web_search"},
|
||||
},
|
||||
},
|
||||
}
|
||||
// Filtered (post-ValidateChatToolsForProvider(Bedrock)) — web_search is dropped.
|
||||
filtered := []schemas.ChatTool{}
|
||||
if _, ok := buildBedrockServerToolChoice(params, filtered); ok {
|
||||
t.Errorf("expected no tunneling when pinned name was stripped by provider validation")
|
||||
}
|
||||
}
|
||||
|
||||
// TestConvertChatParameters_PinnedServerToolE2E — end-to-end verification
|
||||
// that convertChatParameters composes convertToolConfig +
|
||||
// collectBedrockServerTools + buildBedrockServerToolChoice such that a
|
||||
// request pinning a kept server tool produces:
|
||||
// - AdditionalModelRequestFields.tools containing the server tool
|
||||
// - AdditionalModelRequestFields.tool_choice with Anthropic-native shape
|
||||
// - ToolConfig nil (no function tools → Converse's typed path is inactive)
|
||||
func TestConvertChatParameters_PinnedServerToolE2E(t *testing.T) {
|
||||
bifrostReq := &schemas.BifrostChatRequest{
|
||||
Model: "global.anthropic.claude-sonnet-4-6",
|
||||
Params: &schemas.ChatParameters{
|
||||
Tools: []schemas.ChatTool{
|
||||
{
|
||||
Type: schemas.ChatToolType("computer_20251124"),
|
||||
Name: "computer",
|
||||
DisplayWidthPx: schemas.Ptr(1280),
|
||||
},
|
||||
},
|
||||
ToolChoice: &schemas.ChatToolChoice{
|
||||
ChatToolChoiceStruct: &schemas.ChatToolChoiceStruct{
|
||||
Type: schemas.ChatToolChoiceTypeFunction,
|
||||
Function: &schemas.ChatToolChoiceFunction{Name: "computer"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
bedrockReq := &BedrockConverseRequest{}
|
||||
if err := convertChatParameters(nil, bifrostReq, bedrockReq); err != nil {
|
||||
t.Fatalf("convertChatParameters failed: %v", err)
|
||||
}
|
||||
if bedrockReq.ToolConfig != nil {
|
||||
t.Errorf("expected nil ToolConfig (no function/custom tools), got %+v", bedrockReq.ToolConfig)
|
||||
}
|
||||
if bedrockReq.AdditionalModelRequestFields == nil {
|
||||
t.Fatalf("expected AdditionalModelRequestFields to carry server-tool payload, got nil")
|
||||
}
|
||||
tools, ok := bedrockReq.AdditionalModelRequestFields.Get("tools")
|
||||
if !ok {
|
||||
t.Errorf("expected additionalModelRequestFields.tools to be set for server tool")
|
||||
} else if toolsSlice, castOK := tools.([]json.RawMessage); !castOK || len(toolsSlice) != 1 {
|
||||
t.Errorf("expected 1 server tool in additionalModelRequestFields.tools, got %+v", tools)
|
||||
}
|
||||
choice, ok := bedrockReq.AdditionalModelRequestFields.Get("tool_choice")
|
||||
if !ok {
|
||||
t.Fatalf("expected additionalModelRequestFields.tool_choice to carry pinned server-tool contract")
|
||||
}
|
||||
choiceRaw, castOK := choice.(json.RawMessage)
|
||||
if !castOK {
|
||||
t.Fatalf("expected tool_choice value to be json.RawMessage, got %T", choice)
|
||||
}
|
||||
got := string(choiceRaw)
|
||||
if !strings.Contains(got, `"type":"tool"`) || !strings.Contains(got, `"name":"computer"`) {
|
||||
t.Errorf("expected {type:tool,name:computer}, got %s", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestConvertChatParameters_ResponseFormatWithPinnedServerTool_NoConflictingChoice
|
||||
// locks in the fix for the "two conflicting tool-choice directives" hazard:
|
||||
// when response_format forces the synthetic bf_so_* tool via
|
||||
// ToolConfig.ToolChoice, the tunneled additionalModelRequestFields.tool_choice
|
||||
// (which would pin a server tool) must be suppressed so Bedrock doesn't
|
||||
// receive both pins in the same Converse call. Uses a Nova model since
|
||||
// Anthropic models route response_format through native output_config.format
|
||||
// (no synthetic tool), so the conflict only surfaces on non-Anthropic
|
||||
// Bedrock targets.
|
||||
func TestConvertChatParameters_ResponseFormatWithPinnedServerTool_NoConflictingChoice(t *testing.T) {
|
||||
responseFormat := any(map[string]any{
|
||||
"type": "json_schema",
|
||||
"json_schema": map[string]any{
|
||||
"name": "classification",
|
||||
"schema": map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"topic": map[string]any{"type": "string"},
|
||||
},
|
||||
"required": []any{"topic"},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
bifrostReq := &schemas.BifrostChatRequest{
|
||||
Model: "amazon.nova-pro-v1:0",
|
||||
Params: &schemas.ChatParameters{
|
||||
ResponseFormat: &responseFormat,
|
||||
Tools: []schemas.ChatTool{
|
||||
{
|
||||
Type: schemas.ChatToolType("bash_20250124"),
|
||||
Name: "bash",
|
||||
},
|
||||
},
|
||||
ToolChoice: &schemas.ChatToolChoice{
|
||||
ChatToolChoiceStruct: &schemas.ChatToolChoiceStruct{
|
||||
Type: schemas.ChatToolChoiceTypeFunction,
|
||||
Function: &schemas.ChatToolChoiceFunction{Name: "bash"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
|
||||
bedrockReq := &BedrockConverseRequest{}
|
||||
if err := convertChatParameters(ctx, bifrostReq, bedrockReq); err != nil {
|
||||
t.Fatalf("convertChatParameters failed: %v", err)
|
||||
}
|
||||
|
||||
// Synthetic bf_so_* tool must be injected and pinned via Converse's typed path.
|
||||
if bedrockReq.ToolConfig == nil {
|
||||
t.Fatalf("expected ToolConfig with synthetic bf_so_* tool, got nil")
|
||||
}
|
||||
if bedrockReq.ToolConfig.ToolChoice == nil || bedrockReq.ToolConfig.ToolChoice.Tool == nil {
|
||||
t.Fatalf("expected ToolConfig.ToolChoice.Tool to pin synthetic structured-output tool, got %+v", bedrockReq.ToolConfig.ToolChoice)
|
||||
}
|
||||
if !strings.HasPrefix(bedrockReq.ToolConfig.ToolChoice.Tool.Name, "bf_so_") {
|
||||
t.Errorf("expected ToolConfig.ToolChoice.Tool.Name to start with bf_so_, got %q", bedrockReq.ToolConfig.ToolChoice.Tool.Name)
|
||||
}
|
||||
|
||||
// Server tool must still be tunneled so the model has it available.
|
||||
if bedrockReq.AdditionalModelRequestFields == nil {
|
||||
t.Fatalf("expected AdditionalModelRequestFields to carry tunneled server-tool payload, got nil")
|
||||
}
|
||||
if _, ok := bedrockReq.AdditionalModelRequestFields.Get("tools"); !ok {
|
||||
t.Errorf("expected additionalModelRequestFields.tools to still carry bash server tool")
|
||||
}
|
||||
|
||||
// Guarded field: tunneled tool_choice MUST be absent because response_format
|
||||
// forces the synthetic tool. Two tool-choice directives in the same request
|
||||
// would let Bedrock pick one and silently violate the structured-output contract.
|
||||
if _, ok := bedrockReq.AdditionalModelRequestFields.Get("tool_choice"); ok {
|
||||
t.Errorf("expected NO additionalModelRequestFields.tool_choice when response_format pins bf_so_* (conflict hazard)")
|
||||
}
|
||||
}
|
||||
57
core/providers/bedrock/count_tokens.go
Normal file
57
core/providers/bedrock/count_tokens.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package bedrock
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
const estimatedBytesPerToken = 4
|
||||
|
||||
// ToBifrostCountTokensResponse converts a Bedrock count tokens response to Bifrost format
|
||||
func (resp *BedrockCountTokensResponse) ToBifrostCountTokensResponse(model string) *schemas.BifrostCountTokensResponse {
|
||||
if resp == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
totalTokens := resp.InputTokens
|
||||
|
||||
return &schemas.BifrostCountTokensResponse{
|
||||
Model: model,
|
||||
InputTokens: resp.InputTokens,
|
||||
TotalTokens: &totalTokens,
|
||||
Object: "response.input_tokens",
|
||||
}
|
||||
}
|
||||
|
||||
// ToBedrockCountTokensResponse converts a Bifrost count tokens response to Bedrock native format
|
||||
func ToBedrockCountTokensResponse(resp *schemas.BifrostCountTokensResponse) *BedrockCountTokensResponse {
|
||||
if resp == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &BedrockCountTokensResponse{
|
||||
InputTokens: resp.InputTokens,
|
||||
}
|
||||
}
|
||||
|
||||
// isCountTokensUnsupported checks whether a BifrostError indicates that the
|
||||
// Bedrock model does not support the count-tokens operation.
|
||||
func isCountTokensUnsupported(err *schemas.BifrostError) bool {
|
||||
if err == nil || err.Error == nil {
|
||||
return false
|
||||
}
|
||||
return strings.Contains(strings.ToLower(err.Error.Message), "doesn't support counting tokens")
|
||||
}
|
||||
|
||||
// estimateTokenCount returns a rough token count derived from the byte length
|
||||
// of the serialized request body. Claude's tokenizer averages ~4 bytes per
|
||||
// token on mixed content; this intentionally rounds up so that context-window
|
||||
// management decisions stay on the conservative side.
|
||||
func estimateTokenCount(requestBody []byte) int {
|
||||
n := len(requestBody)
|
||||
if n == 0 {
|
||||
return 0
|
||||
}
|
||||
return (n + estimatedBytesPerToken - 1) / estimatedBytesPerToken
|
||||
}
|
||||
105
core/providers/bedrock/count_tokens_test.go
Normal file
105
core/providers/bedrock/count_tokens_test.go
Normal file
@@ -0,0 +1,105 @@
|
||||
package bedrock
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestIsCountTokensUnsupported(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err *schemas.BifrostError
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "nil error",
|
||||
err: nil,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "nil error field",
|
||||
err: &schemas.BifrostError{},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "matching bedrock error message",
|
||||
err: &schemas.BifrostError{
|
||||
Error: &schemas.ErrorField{
|
||||
Message: "The provided model doesn't support counting tokens.",
|
||||
},
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "matching message with different casing",
|
||||
err: &schemas.BifrostError{
|
||||
Error: &schemas.ErrorField{
|
||||
Message: "the provided model DOESN'T SUPPORT COUNTING TOKENS.",
|
||||
},
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "unrelated error message",
|
||||
err: &schemas.BifrostError{
|
||||
Error: &schemas.ErrorField{
|
||||
Message: "access denied",
|
||||
},
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
assert.Equal(t, tc.expected, isCountTokensUnsupported(tc.err))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEstimateTokenCount(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input []byte
|
||||
expected int
|
||||
}{
|
||||
{
|
||||
name: "empty input",
|
||||
input: []byte{},
|
||||
expected: 0,
|
||||
},
|
||||
{
|
||||
name: "nil input",
|
||||
input: nil,
|
||||
expected: 0,
|
||||
},
|
||||
{
|
||||
name: "exact multiple of 4",
|
||||
input: make([]byte, 100),
|
||||
expected: 25,
|
||||
},
|
||||
{
|
||||
name: "rounds up",
|
||||
input: make([]byte, 101),
|
||||
expected: 26,
|
||||
},
|
||||
{
|
||||
name: "single byte",
|
||||
input: []byte("x"),
|
||||
expected: 1,
|
||||
},
|
||||
{
|
||||
name: "realistic json body",
|
||||
input: []byte(`{"messages":[{"role":"user","content":"Hello, how are you today?"}],"model":"us.anthropic.claude-sonnet-4-6"}`),
|
||||
expected: 28,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
assert.Equal(t, tc.expected, estimateTokenCount(tc.input))
|
||||
})
|
||||
}
|
||||
}
|
||||
271
core/providers/bedrock/embedding.go
Normal file
271
core/providers/bedrock/embedding.go
Normal file
@@ -0,0 +1,271 @@
|
||||
package bedrock
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// ToBedrockTitanEmbeddingRequest converts a Bifrost embedding request to Bedrock Titan format
|
||||
func ToBedrockTitanEmbeddingRequest(bifrostReq *schemas.BifrostEmbeddingRequest) (*BedrockTitanEmbeddingRequest, error) {
|
||||
if bifrostReq == nil {
|
||||
return nil, fmt.Errorf("bifrost embedding request is nil")
|
||||
}
|
||||
|
||||
// Validate that only single text input is provided for Titan models
|
||||
if bifrostReq.Input.Text == nil && len(bifrostReq.Input.Texts) == 0 {
|
||||
return nil, fmt.Errorf("no input text provided for embedding")
|
||||
}
|
||||
|
||||
titanReq := &BedrockTitanEmbeddingRequest{}
|
||||
|
||||
// Set input text
|
||||
if bifrostReq.Input.Text != nil {
|
||||
titanReq.InputText = *bifrostReq.Input.Text
|
||||
} else if len(bifrostReq.Input.Texts) > 0 {
|
||||
var embeddingText string
|
||||
for _, text := range bifrostReq.Input.Texts {
|
||||
embeddingText += text + " \n"
|
||||
}
|
||||
titanReq.InputText = embeddingText
|
||||
}
|
||||
|
||||
if bifrostReq.Params != nil {
|
||||
titanReq.Dimensions = bifrostReq.Params.Dimensions
|
||||
if normalize, ok := bifrostReq.Params.ExtraParams["normalize"]; ok {
|
||||
if b, ok := normalize.(bool); ok {
|
||||
titanReq.Normalize = &b
|
||||
}
|
||||
}
|
||||
// Forward remaining extra params (excluding normalize which is now a first-class field)
|
||||
if len(bifrostReq.Params.ExtraParams) > 0 {
|
||||
extra := make(map[string]interface{})
|
||||
for k, v := range bifrostReq.Params.ExtraParams {
|
||||
if k != "normalize" {
|
||||
extra[k] = v
|
||||
}
|
||||
}
|
||||
if len(extra) > 0 {
|
||||
titanReq.ExtraParams = extra
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return titanReq, nil
|
||||
}
|
||||
|
||||
// ToBifrostEmbeddingResponse converts a Bedrock Titan embedding response to Bifrost format
|
||||
func (response *BedrockTitanEmbeddingResponse) ToBifrostEmbeddingResponse() *schemas.BifrostEmbeddingResponse {
|
||||
if response == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
bifrostResponse := &schemas.BifrostEmbeddingResponse{
|
||||
Object: "list",
|
||||
Data: []schemas.EmbeddingData{
|
||||
{
|
||||
Index: 0,
|
||||
Object: "embedding",
|
||||
Embedding: schemas.EmbeddingStruct{
|
||||
EmbeddingArray: response.Embedding,
|
||||
},
|
||||
},
|
||||
},
|
||||
Usage: &schemas.BifrostLLMUsage{
|
||||
PromptTokens: response.InputTextTokenCount,
|
||||
TotalTokens: response.InputTextTokenCount,
|
||||
},
|
||||
}
|
||||
|
||||
return bifrostResponse
|
||||
}
|
||||
|
||||
// ToBedrockCohereEmbeddingRequest converts a Bifrost embedding request to Bedrock Cohere format.
|
||||
// Unlike the direct Cohere API, Bedrock does not accept a "model" field in the request body.
|
||||
func ToBedrockCohereEmbeddingRequest(bifrostReq *schemas.BifrostEmbeddingRequest) (*BedrockCohereEmbeddingRequest, error) {
|
||||
if bifrostReq == nil {
|
||||
return nil, fmt.Errorf("bifrost embedding request is nil")
|
||||
}
|
||||
if bifrostReq.Input == nil || (bifrostReq.Input.Text == nil && len(bifrostReq.Input.Texts) == 0) {
|
||||
return nil, fmt.Errorf("no input provided for embedding")
|
||||
}
|
||||
|
||||
req := &BedrockCohereEmbeddingRequest{}
|
||||
|
||||
// Map texts
|
||||
if bifrostReq.Input.Text != nil {
|
||||
req.Texts = []string{*bifrostReq.Input.Text}
|
||||
} else if len(bifrostReq.Input.Texts) > 0 {
|
||||
req.Texts = bifrostReq.Input.Texts
|
||||
}
|
||||
|
||||
if bifrostReq.Params != nil {
|
||||
extra := make(map[string]interface{}, len(bifrostReq.Params.ExtraParams))
|
||||
for k, v := range bifrostReq.Params.ExtraParams {
|
||||
extra[k] = v
|
||||
}
|
||||
|
||||
if v, ok := extra["input_type"]; ok {
|
||||
if s, ok := v.(string); ok {
|
||||
req.InputType = s
|
||||
delete(extra, "input_type")
|
||||
}
|
||||
}
|
||||
if v, ok := extra["truncate"]; ok {
|
||||
if s, ok := v.(string); ok {
|
||||
req.Truncate = &s
|
||||
delete(extra, "truncate")
|
||||
}
|
||||
}
|
||||
if v, ok := extra["embedding_types"]; ok {
|
||||
if ss, ok := v.([]string); ok {
|
||||
req.EmbeddingTypes = ss
|
||||
delete(extra, "embedding_types")
|
||||
}
|
||||
}
|
||||
if v, ok := extra["images"]; ok {
|
||||
if ss, ok := v.([]string); ok {
|
||||
req.Images = ss
|
||||
delete(extra, "images")
|
||||
}
|
||||
}
|
||||
if v, ok := extra["inputs"]; ok {
|
||||
if inputs, ok := v.([]BedrockCohereEmbeddingInput); ok {
|
||||
req.Inputs = inputs
|
||||
delete(extra, "inputs")
|
||||
}
|
||||
}
|
||||
if v, ok := extra["max_tokens"]; ok {
|
||||
switch n := v.(type) {
|
||||
case int:
|
||||
req.MaxTokens = &n
|
||||
delete(extra, "max_tokens")
|
||||
case float64:
|
||||
i := int(n)
|
||||
req.MaxTokens = &i
|
||||
delete(extra, "max_tokens")
|
||||
}
|
||||
}
|
||||
if bifrostReq.Params.Dimensions != nil {
|
||||
req.OutputDimension = bifrostReq.Params.Dimensions
|
||||
}
|
||||
if len(extra) > 0 {
|
||||
req.ExtraParams = extra
|
||||
}
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
// DetermineEmbeddingModelType determines the embedding model type from the model name
|
||||
func DetermineEmbeddingModelType(model string) (string, error) {
|
||||
switch {
|
||||
case strings.Contains(model, "amazon.titan-embed-text"):
|
||||
return "titan", nil
|
||||
case strings.Contains(model, "cohere.embed"):
|
||||
return "cohere", nil
|
||||
default:
|
||||
return "", fmt.Errorf("unsupported embedding model: %s", model)
|
||||
}
|
||||
}
|
||||
|
||||
// ToBifrostEmbeddingResponse converts a BedrockCohereEmbeddingResponse to Bifrost format.
|
||||
// Bedrock returns embeddings as a raw [][]float32 when response_type is "embeddings_floats"
|
||||
// (the default, when no embedding_types are requested), and as a typed object when
|
||||
// response_type is "embeddings_by_type".
|
||||
func (r *BedrockCohereEmbeddingResponse) ToBifrostEmbeddingResponse() (*schemas.BifrostEmbeddingResponse, error) {
|
||||
if r == nil {
|
||||
return nil, fmt.Errorf("nil Bedrock Cohere embedding response")
|
||||
}
|
||||
|
||||
bifrostResponse := &schemas.BifrostEmbeddingResponse{Object: "list"}
|
||||
|
||||
switch r.ResponseType {
|
||||
case "embeddings_by_type":
|
||||
// Object form: {"float": [[...]], "int8": [[...]], "uint8": [[...]], "binary": [[...]], "ubinary": [[...]], "base64": [...]}
|
||||
var typed struct {
|
||||
Float [][]float32 `json:"float"`
|
||||
Base64 []string `json:"base64"`
|
||||
Int8 [][]int8 `json:"int8"`
|
||||
Uint8 [][]int32 `json:"uint8"` // int32 avoids []byte→base64 JSON issue
|
||||
Binary [][]int8 `json:"binary"`
|
||||
Ubinary [][]int32 `json:"ubinary"` // int32 avoids []byte→base64 JSON issue
|
||||
}
|
||||
if err := json.Unmarshal(r.Embeddings, &typed); err != nil {
|
||||
return nil, fmt.Errorf("error parsing embeddings_by_type: %w", err)
|
||||
}
|
||||
if typed.Float != nil {
|
||||
for i, emb := range typed.Float {
|
||||
float64Emb := make([]float64, len(emb))
|
||||
for j, v := range emb {
|
||||
float64Emb[j] = float64(v)
|
||||
}
|
||||
bifrostResponse.Data = append(bifrostResponse.Data, schemas.EmbeddingData{
|
||||
Object: "embedding",
|
||||
Index: i,
|
||||
Embedding: schemas.EmbeddingStruct{EmbeddingArray: float64Emb},
|
||||
})
|
||||
}
|
||||
}
|
||||
if typed.Base64 != nil {
|
||||
for i, emb := range typed.Base64 {
|
||||
e := emb
|
||||
bifrostResponse.Data = append(bifrostResponse.Data, schemas.EmbeddingData{
|
||||
Object: "embedding",
|
||||
Index: i,
|
||||
Embedding: schemas.EmbeddingStruct{EmbeddingStr: &e},
|
||||
})
|
||||
}
|
||||
}
|
||||
for i, emb := range typed.Int8 {
|
||||
bifrostResponse.Data = append(bifrostResponse.Data, schemas.EmbeddingData{
|
||||
Object: "embedding",
|
||||
Index: i,
|
||||
Embedding: schemas.EmbeddingStruct{EmbeddingInt8Array: emb},
|
||||
})
|
||||
}
|
||||
for i, emb := range typed.Binary {
|
||||
bifrostResponse.Data = append(bifrostResponse.Data, schemas.EmbeddingData{
|
||||
Object: "embedding",
|
||||
Index: i,
|
||||
Embedding: schemas.EmbeddingStruct{EmbeddingInt8Array: emb},
|
||||
})
|
||||
}
|
||||
for i, emb := range typed.Uint8 {
|
||||
bifrostResponse.Data = append(bifrostResponse.Data, schemas.EmbeddingData{
|
||||
Object: "embedding",
|
||||
Index: i,
|
||||
Embedding: schemas.EmbeddingStruct{EmbeddingInt32Array: emb},
|
||||
})
|
||||
}
|
||||
for i, emb := range typed.Ubinary {
|
||||
bifrostResponse.Data = append(bifrostResponse.Data, schemas.EmbeddingData{
|
||||
Object: "embedding",
|
||||
Index: i,
|
||||
Embedding: schemas.EmbeddingStruct{EmbeddingInt32Array: emb},
|
||||
})
|
||||
}
|
||||
|
||||
default:
|
||||
// Default / "embeddings_floats": raw array form [[...], [...]]
|
||||
var floats [][]float32
|
||||
if err := json.Unmarshal(r.Embeddings, &floats); err != nil {
|
||||
return nil, fmt.Errorf("error parsing embeddings_floats: %w", err)
|
||||
}
|
||||
for i, emb := range floats {
|
||||
float64Emb := make([]float64, len(emb))
|
||||
for j, v := range emb {
|
||||
float64Emb[j] = float64(v)
|
||||
}
|
||||
bifrostResponse.Data = append(bifrostResponse.Data, schemas.EmbeddingData{
|
||||
Object: "embedding",
|
||||
Index: i,
|
||||
Embedding: schemas.EmbeddingStruct{EmbeddingArray: float64Emb},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return bifrostResponse, nil
|
||||
}
|
||||
114
core/providers/bedrock/embedding_test.go
Normal file
114
core/providers/bedrock/embedding_test.go
Normal file
@@ -0,0 +1,114 @@
|
||||
package bedrock
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestToBedrockCohereEmbeddingRequest(t *testing.T) {
|
||||
t.Run("returns error for nil request", func(t *testing.T) {
|
||||
req, err := ToBedrockCohereEmbeddingRequest(nil)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, req)
|
||||
assert.Contains(t, err.Error(), "nil")
|
||||
})
|
||||
|
||||
t.Run("returns error for missing input", func(t *testing.T) {
|
||||
req, err := ToBedrockCohereEmbeddingRequest(&schemas.BifrostEmbeddingRequest{})
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, req)
|
||||
assert.Contains(t, err.Error(), "no input")
|
||||
})
|
||||
|
||||
t.Run("returns error for non-nil but empty input", func(t *testing.T) {
|
||||
req, err := ToBedrockCohereEmbeddingRequest(&schemas.BifrostEmbeddingRequest{
|
||||
Input: &schemas.EmbeddingInput{},
|
||||
})
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, req)
|
||||
assert.Contains(t, err.Error(), "no input")
|
||||
})
|
||||
|
||||
t.Run("single text strips model and extracts typed params", func(t *testing.T) {
|
||||
text := "hello"
|
||||
truncate := "RIGHT"
|
||||
dimensions := 512
|
||||
bifrostReq := &schemas.BifrostEmbeddingRequest{
|
||||
Model: "cohere.embed-english-v3",
|
||||
Input: &schemas.EmbeddingInput{Text: &text},
|
||||
Params: &schemas.EmbeddingParameters{
|
||||
Dimensions: &dimensions,
|
||||
ExtraParams: map[string]interface{}{
|
||||
"input_type": "search_query",
|
||||
"embedding_types": []string{"float"},
|
||||
"truncate": truncate,
|
||||
"max_tokens": float64(128),
|
||||
"trace_id": "req-123",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
req, err := ToBedrockCohereEmbeddingRequest(bifrostReq)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, req)
|
||||
assert.Equal(t, "search_query", req.InputType)
|
||||
assert.Equal(t, []string{"hello"}, req.Texts)
|
||||
assert.Equal(t, []string{"float"}, req.EmbeddingTypes)
|
||||
assert.Equal(t, &dimensions, req.OutputDimension)
|
||||
assert.Equal(t, 128, *req.MaxTokens)
|
||||
require.NotNil(t, req.Truncate)
|
||||
assert.Equal(t, truncate, *req.Truncate)
|
||||
assert.Equal(t, map[string]interface{}{"trace_id": "req-123"}, req.ExtraParams)
|
||||
})
|
||||
|
||||
t.Run("multiple texts preserve bedrock body shape", func(t *testing.T) {
|
||||
bifrostReq := &schemas.BifrostEmbeddingRequest{
|
||||
Model: "cohere.embed-multilingual-v3",
|
||||
Input: &schemas.EmbeddingInput{Texts: []string{"hello", "world"}},
|
||||
Params: &schemas.EmbeddingParameters{
|
||||
ExtraParams: map[string]interface{}{
|
||||
"input_type": "search_document",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
req, err := ToBedrockCohereEmbeddingRequest(bifrostReq)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []string{"hello", "world"}, req.Texts)
|
||||
assert.Equal(t, "search_document", req.InputType)
|
||||
})
|
||||
}
|
||||
|
||||
func TestToBedrockCohereEmbeddingRequestBodyOmitsModel(t *testing.T) {
|
||||
text := "hello"
|
||||
bifrostReq := &schemas.BifrostEmbeddingRequest{
|
||||
Model: "cohere.embed-english-v3",
|
||||
Input: &schemas.EmbeddingInput{Text: &text},
|
||||
Params: &schemas.EmbeddingParameters{
|
||||
ExtraParams: map[string]interface{}{
|
||||
"input_type": "search_document",
|
||||
"embedding_types": []string{"float"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
wireBody, bifrostErr := providerUtils.CheckContextAndGetRequestBody(
|
||||
context.Background(),
|
||||
bifrostReq,
|
||||
func() (providerUtils.RequestBodyWithExtraParams, error) {
|
||||
return ToBedrockCohereEmbeddingRequest(bifrostReq)
|
||||
},
|
||||
)
|
||||
require.Nil(t, bifrostErr)
|
||||
assert.NotContains(t, string(wireBody), `"model"`)
|
||||
assert.JSONEq(t, `{
|
||||
"input_type": "search_document",
|
||||
"texts": ["hello"],
|
||||
"embedding_types": ["float"]
|
||||
}`, string(wireBody))
|
||||
}
|
||||
34
core/providers/bedrock/errors.go
Normal file
34
core/providers/bedrock/errors.go
Normal file
@@ -0,0 +1,34 @@
|
||||
package bedrock
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
func parseBedrockHTTPError(statusCode int, headers http.Header, body []byte) *schemas.BifrostError {
|
||||
fastResp := fasthttp.AcquireResponse()
|
||||
defer fasthttp.ReleaseResponse(fastResp)
|
||||
|
||||
fastResp.SetStatusCode(statusCode)
|
||||
for k, values := range headers {
|
||||
for _, value := range values {
|
||||
fastResp.Header.Add(k, value)
|
||||
}
|
||||
}
|
||||
fastResp.SetBody(body)
|
||||
|
||||
var errorResp BedrockError
|
||||
bifrostErr := providerUtils.HandleProviderAPIError(fastResp, &errorResp)
|
||||
if errorResp.Message != "" {
|
||||
if bifrostErr.Error == nil {
|
||||
bifrostErr.Error = &schemas.ErrorField{}
|
||||
}
|
||||
bifrostErr.Error.Message = errorResp.Message
|
||||
bifrostErr.Error.Code = errorResp.Code
|
||||
}
|
||||
|
||||
return bifrostErr
|
||||
}
|
||||
276
core/providers/bedrock/files.go
Normal file
276
core/providers/bedrock/files.go
Normal file
@@ -0,0 +1,276 @@
|
||||
package bedrock
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"html"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// escapeS3KeyForURL escapes each segment of an S3 key path individually.
|
||||
// This prevents signature and URL parsing failures with special characters.
|
||||
// We can't use url.PathEscape on the full key as it escapes "/" to "%2F",
|
||||
// but we need each segment properly escaped per RFC 3986 for AWS SigV4 signing.
|
||||
func escapeS3KeyForURL(key string) string {
|
||||
if key == "" {
|
||||
return ""
|
||||
}
|
||||
parts := strings.Split(key, "/")
|
||||
for i, p := range parts {
|
||||
parts[i] = url.PathEscape(p)
|
||||
}
|
||||
return strings.Join(parts, "/")
|
||||
}
|
||||
|
||||
// parseS3URI parses an S3 URI (s3://bucket/key or bucket-name) and returns bucket name and key.
|
||||
func parseS3URI(uri string) (bucket, key string) {
|
||||
if strings.HasPrefix(uri, "s3://") {
|
||||
uri = strings.TrimPrefix(uri, "s3://")
|
||||
parts := strings.SplitN(uri, "/", 2)
|
||||
bucket = parts[0]
|
||||
if len(parts) > 1 {
|
||||
key = parts[1]
|
||||
}
|
||||
} else {
|
||||
// Assume it's just a bucket name
|
||||
bucket = uri
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// S3ListObjectsResponse represents S3 ListObjectsV2 response.
|
||||
type S3ListObjectsResponse struct {
|
||||
Contents []S3Object `json:"contents"`
|
||||
IsTruncated bool `json:"isTruncated"`
|
||||
NextContinuationToken string `json:"nextContinuationToken,omitempty"`
|
||||
}
|
||||
|
||||
// S3Object represents an S3 object in list response.
|
||||
type S3Object struct {
|
||||
Key string `json:"key"`
|
||||
Size int64 `json:"size"`
|
||||
LastModified time.Time `json:"lastModified"`
|
||||
}
|
||||
|
||||
// parseS3ListResponse parses S3 ListObjectsV2 XML response.
|
||||
func parseS3ListResponse(body []byte, resp *S3ListObjectsResponse) error {
|
||||
// S3 returns XML, so we need to parse it
|
||||
// Try JSON first (some S3-compatible services return JSON)
|
||||
if err := sonic.Unmarshal(body, resp); err == nil && len(resp.Contents) > 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Parse XML using simple string matching for key fields
|
||||
// This is a lightweight approach that doesn't require encoding/xml
|
||||
bodyStr := string(body)
|
||||
|
||||
// Parse IsTruncated
|
||||
if strings.Contains(bodyStr, "<IsTruncated>true</IsTruncated>") {
|
||||
resp.IsTruncated = true
|
||||
}
|
||||
|
||||
// Parse NextContinuationToken
|
||||
if start := strings.Index(bodyStr, "<NextContinuationToken>"); start >= 0 {
|
||||
start += len("<NextContinuationToken>")
|
||||
if end := strings.Index(bodyStr[start:], "</NextContinuationToken>"); end >= 0 {
|
||||
resp.NextContinuationToken = bodyStr[start : start+end]
|
||||
}
|
||||
}
|
||||
|
||||
// Parse Contents
|
||||
contents := bodyStr
|
||||
for {
|
||||
start := strings.Index(contents, "<Contents>")
|
||||
if start < 0 {
|
||||
break
|
||||
}
|
||||
end := strings.Index(contents[start:], "</Contents>")
|
||||
if end < 0 {
|
||||
break
|
||||
}
|
||||
|
||||
contentBlock := contents[start : start+end+len("</Contents>")]
|
||||
contents = contents[start+end+len("</Contents>"):]
|
||||
|
||||
obj := S3Object{}
|
||||
|
||||
// Parse Key
|
||||
if keyStart := strings.Index(contentBlock, "<Key>"); keyStart >= 0 {
|
||||
keyStart += len("<Key>")
|
||||
if keyEnd := strings.Index(contentBlock[keyStart:], "</Key>"); keyEnd >= 0 {
|
||||
obj.Key = html.UnescapeString(contentBlock[keyStart : keyStart+keyEnd])
|
||||
}
|
||||
}
|
||||
|
||||
// Parse Size
|
||||
if sizeStart := strings.Index(contentBlock, "<Size>"); sizeStart >= 0 {
|
||||
sizeStart += len("<Size>")
|
||||
if sizeEnd := strings.Index(contentBlock[sizeStart:], "</Size>"); sizeEnd >= 0 {
|
||||
sizeStr := contentBlock[sizeStart : sizeStart+sizeEnd]
|
||||
fmt.Sscanf(sizeStr, "%d", &obj.Size)
|
||||
}
|
||||
}
|
||||
|
||||
// Parse LastModified
|
||||
if lmStart := strings.Index(contentBlock, "<LastModified>"); lmStart >= 0 {
|
||||
lmStart += len("<LastModified>")
|
||||
if lmEnd := strings.Index(contentBlock[lmStart:], "</LastModified>"); lmEnd >= 0 {
|
||||
lmStr := contentBlock[lmStart : lmStart+lmEnd]
|
||||
if t, err := time.Parse(time.RFC3339Nano, lmStr); err == nil {
|
||||
obj.LastModified = t
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if obj.Key != "" {
|
||||
resp.Contents = append(resp.Contents, obj)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ==================== BEDROCK FILE TYPE CONVERTERS ====================
|
||||
|
||||
// ToBedrockFileUploadResponse converts a Bifrost file upload response to Bedrock format.
|
||||
func ToBedrockFileUploadResponse(resp *schemas.BifrostFileUploadResponse) *BedrockFileUploadResponse {
|
||||
if resp == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Parse S3 URI to get bucket and key
|
||||
bucket, key := parseS3URI(resp.ID)
|
||||
|
||||
return &BedrockFileUploadResponse{
|
||||
S3Uri: resp.ID,
|
||||
Bucket: bucket,
|
||||
Key: key,
|
||||
SizeBytes: resp.Bytes,
|
||||
ContentType: "application/jsonl",
|
||||
CreatedAt: resp.CreatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
// ToBedrockFileListResponse converts a Bifrost file list response to Bedrock format.
|
||||
func ToBedrockFileListResponse(resp *schemas.BifrostFileListResponse) *BedrockFileListResponse {
|
||||
if resp == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
files := make([]BedrockFileInfo, len(resp.Data))
|
||||
for i, f := range resp.Data {
|
||||
_, key := parseS3URI(f.ID)
|
||||
files[i] = BedrockFileInfo{
|
||||
S3Uri: f.ID,
|
||||
Key: key,
|
||||
SizeBytes: f.Bytes,
|
||||
LastModified: f.CreatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
return &BedrockFileListResponse{
|
||||
Files: files,
|
||||
IsTruncated: resp.HasMore,
|
||||
}
|
||||
}
|
||||
|
||||
// ToBedrockFileRetrieveResponse converts a Bifrost file retrieve response to Bedrock format.
|
||||
func ToBedrockFileRetrieveResponse(resp *schemas.BifrostFileRetrieveResponse) *BedrockFileRetrieveResponse {
|
||||
if resp == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
_, key := parseS3URI(resp.ID)
|
||||
|
||||
return &BedrockFileRetrieveResponse{
|
||||
S3Uri: resp.ID,
|
||||
Key: key,
|
||||
SizeBytes: resp.Bytes,
|
||||
LastModified: resp.CreatedAt,
|
||||
ContentType: "application/jsonl",
|
||||
}
|
||||
}
|
||||
|
||||
// ToBedrockFileDeleteResponse converts a Bifrost file delete response to Bedrock format.
|
||||
func ToBedrockFileDeleteResponse(resp *schemas.BifrostFileDeleteResponse) *BedrockFileDeleteResponse {
|
||||
if resp == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &BedrockFileDeleteResponse{
|
||||
S3Uri: resp.ID,
|
||||
Deleted: resp.Deleted,
|
||||
}
|
||||
}
|
||||
|
||||
// ToBedrockFileContentResponse converts a Bifrost file content response to Bedrock format.
|
||||
func ToBedrockFileContentResponse(resp *schemas.BifrostFileContentResponse) *BedrockFileContentResponse {
|
||||
if resp == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &BedrockFileContentResponse{
|
||||
S3Uri: resp.FileID,
|
||||
Content: resp.Content,
|
||||
ContentType: resp.ContentType,
|
||||
SizeBytes: int64(len(resp.Content)),
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== S3 API XML FORMATTERS ====================
|
||||
|
||||
// ToS3ListObjectsV2XML converts a Bifrost file list response to S3 ListObjectsV2 XML format.
|
||||
func ToS3ListObjectsV2XML(resp *schemas.BifrostFileListResponse, bucket, prefix string, maxKeys int) []byte {
|
||||
if resp == nil {
|
||||
return []byte(`<?xml version="1.0" encoding="UTF-8"?><ListBucketResult xmlns="http://s3.amazonaws.com/doc/2006-03-01/"></ListBucketResult>`)
|
||||
}
|
||||
|
||||
var sb strings.Builder
|
||||
sb.WriteString(`<?xml version="1.0" encoding="UTF-8"?>`)
|
||||
sb.WriteString(`<ListBucketResult xmlns="http://s3.amazonaws.com/doc/2006-03-01/">`)
|
||||
sb.WriteString(fmt.Sprintf("<Name>%s</Name>", bucket))
|
||||
sb.WriteString(fmt.Sprintf("<Prefix>%s</Prefix>", prefix))
|
||||
sb.WriteString(fmt.Sprintf("<KeyCount>%d</KeyCount>", len(resp.Data)))
|
||||
sb.WriteString(fmt.Sprintf("<MaxKeys>%d</MaxKeys>", maxKeys))
|
||||
if resp.HasMore {
|
||||
sb.WriteString("<IsTruncated>true</IsTruncated>")
|
||||
if resp.After != nil && *resp.After != "" {
|
||||
sb.WriteString(fmt.Sprintf("<NextContinuationToken>%s</NextContinuationToken>", *resp.After))
|
||||
}
|
||||
} else {
|
||||
sb.WriteString("<IsTruncated>false</IsTruncated>")
|
||||
}
|
||||
|
||||
for _, f := range resp.Data {
|
||||
// Extract key from S3 URI
|
||||
_, key := parseS3URI(f.ID)
|
||||
sb.WriteString("<Contents>")
|
||||
sb.WriteString(fmt.Sprintf("<Key>%s</Key>", key))
|
||||
sb.WriteString(fmt.Sprintf("<Size>%d</Size>", f.Bytes))
|
||||
if f.CreatedAt > 0 {
|
||||
sb.WriteString(fmt.Sprintf("<LastModified>%s</LastModified>", time.Unix(f.CreatedAt, 0).UTC().Format(time.RFC3339)))
|
||||
}
|
||||
sb.WriteString("<StorageClass>STANDARD</StorageClass>")
|
||||
sb.WriteString("</Contents>")
|
||||
}
|
||||
|
||||
sb.WriteString("</ListBucketResult>")
|
||||
return []byte(sb.String())
|
||||
}
|
||||
|
||||
// ToS3ErrorXML converts an error to S3 error XML format.
|
||||
func ToS3ErrorXML(code, message, resource, requestID string) []byte {
|
||||
var sb strings.Builder
|
||||
sb.WriteString(`<?xml version="1.0" encoding="UTF-8"?>`)
|
||||
sb.WriteString("<Error>")
|
||||
sb.WriteString(fmt.Sprintf("<Code>%s</Code>", code))
|
||||
sb.WriteString(fmt.Sprintf("<Message>%s</Message>", message))
|
||||
sb.WriteString(fmt.Sprintf("<Resource>%s</Resource>", resource))
|
||||
sb.WriteString(fmt.Sprintf("<RequestId>%s</RequestId>", requestID))
|
||||
sb.WriteString("</Error>")
|
||||
return []byte(sb.String())
|
||||
}
|
||||
742
core/providers/bedrock/images.go
Normal file
742
core/providers/bedrock/images.go
Normal file
@@ -0,0 +1,742 @@
|
||||
package bedrock
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// mapQualityToBedrock maps quality values to Bedrock format:
|
||||
// - "low" and "medium" -> "standard"
|
||||
// - "high" -> "premium"
|
||||
// - "standard" and "premium" (case-insensitive) -> pass through as lowercase ("standard"/"premium")
|
||||
func mapQualityToBedrock(quality *string) *string {
|
||||
if quality == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
qualityLower := strings.ToLower(strings.TrimSpace(*quality))
|
||||
|
||||
switch qualityLower {
|
||||
case "low", "medium":
|
||||
return schemas.Ptr("standard")
|
||||
case "high":
|
||||
return schemas.Ptr("premium")
|
||||
case "standard":
|
||||
return schemas.Ptr("standard")
|
||||
case "premium":
|
||||
return schemas.Ptr("premium")
|
||||
default:
|
||||
return quality
|
||||
}
|
||||
}
|
||||
|
||||
// isStabilityAIModel returns true if the model is a Stability AI model (contains "stability.")
|
||||
func isStabilityAIModel(model string) bool {
|
||||
return strings.Contains(strings.ToLower(model), "stability.")
|
||||
}
|
||||
|
||||
// isPromptOnlyImageGenerationModel returns true for image generation models that use a flat
|
||||
// {"prompt": "..."} payload (no taskType field). Covers Vertex Imagen and similar models.
|
||||
// Stability AI is excluded here — it's handled separately because it also supports image edit.
|
||||
func isPromptOnlyImageGenerationModel(model string) bool {
|
||||
m := strings.ToLower(model)
|
||||
return strings.Contains(m, "image")
|
||||
}
|
||||
|
||||
// ToStabilityAIImageGenerationRequest converts a Bifrost image generation request to the Stability AI
|
||||
// flat request format used by Bedrock (stability.stable-image-* models).
|
||||
func ToStabilityAIImageGenerationRequest(request *schemas.BifrostImageGenerationRequest) (*StabilityAIImageGenerationRequest, error) {
|
||||
if request == nil {
|
||||
return nil, fmt.Errorf("request is nil")
|
||||
}
|
||||
if request.Input == nil {
|
||||
return nil, fmt.Errorf("request input is required")
|
||||
}
|
||||
|
||||
req := &StabilityAIImageGenerationRequest{
|
||||
Prompt: request.Input.Prompt,
|
||||
}
|
||||
|
||||
if request.Params != nil {
|
||||
if request.Params.AspectRatio != nil {
|
||||
req.AspectRatio = request.Params.AspectRatio
|
||||
}
|
||||
if request.Params.OutputFormat != nil {
|
||||
req.OutputFormat = request.Params.OutputFormat
|
||||
}
|
||||
if request.Params.Seed != nil {
|
||||
req.Seed = request.Params.Seed
|
||||
}
|
||||
if request.Params.NegativePrompt != nil {
|
||||
req.NegativePrompt = request.Params.NegativePrompt
|
||||
}
|
||||
if request.Params.ExtraParams != nil {
|
||||
// aspect_ratio may also arrive via ExtraParams if not in knownFields; skip if already set
|
||||
if req.AspectRatio == nil {
|
||||
if ar, ok := schemas.SafeExtractStringPointer(request.Params.ExtraParams["aspect_ratio"]); ok {
|
||||
delete(request.Params.ExtraParams, "aspect_ratio")
|
||||
req.AspectRatio = ar
|
||||
}
|
||||
}
|
||||
req.ExtraParams = request.Params.ExtraParams
|
||||
}
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
// ToBedrockImageGenerationRequest converts a Bifrost image generation request to a Bedrock image generation request
|
||||
func ToBedrockImageGenerationRequest(request *schemas.BifrostImageGenerationRequest) (*BedrockImageGenerationRequest, error) {
|
||||
if request == nil {
|
||||
return nil, fmt.Errorf("request is nil")
|
||||
}
|
||||
|
||||
if request.Input == nil {
|
||||
return nil, fmt.Errorf("request input is required")
|
||||
}
|
||||
|
||||
bedrockReq := &BedrockImageGenerationRequest{
|
||||
TaskType: schemas.Ptr(TaskTypeTextImage),
|
||||
TextToImageParams: &BedrockTextToImageParams{
|
||||
Text: request.Input.Prompt,
|
||||
},
|
||||
ImageGenerationConfig: &ImageGenerationConfig{},
|
||||
}
|
||||
|
||||
if request.Params != nil {
|
||||
if request.Params.N != nil {
|
||||
bedrockReq.ImageGenerationConfig.NumberOfImages = request.Params.N
|
||||
}
|
||||
if request.Params.NegativePrompt != nil {
|
||||
bedrockReq.TextToImageParams.NegativeText = request.Params.NegativePrompt
|
||||
}
|
||||
if request.Params.Seed != nil {
|
||||
bedrockReq.ImageGenerationConfig.Seed = request.Params.Seed
|
||||
}
|
||||
if request.Params.Quality != nil {
|
||||
bedrockReq.ImageGenerationConfig.Quality = mapQualityToBedrock(request.Params.Quality)
|
||||
}
|
||||
if request.Params.Style != nil {
|
||||
bedrockReq.TextToImageParams.Style = request.Params.Style
|
||||
}
|
||||
if request.Params.Size != nil && strings.TrimSpace(strings.ToLower(*request.Params.Size)) != "auto" {
|
||||
|
||||
size := strings.Split(strings.TrimSpace(strings.ToLower(*request.Params.Size)), "x")
|
||||
if len(size) != 2 {
|
||||
return nil, fmt.Errorf("invalid size format: expected 'WIDTHxHEIGHT', got %q", *request.Params.Size)
|
||||
}
|
||||
|
||||
width, err := strconv.Atoi(size[0])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid width in size %q: %w", *request.Params.Size, err)
|
||||
}
|
||||
|
||||
height, err := strconv.Atoi(size[1])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid height in size %q: %w", *request.Params.Size, err)
|
||||
}
|
||||
|
||||
bedrockReq.ImageGenerationConfig.Width = schemas.Ptr(width)
|
||||
bedrockReq.ImageGenerationConfig.Height = schemas.Ptr(height)
|
||||
}
|
||||
if request.Params.ExtraParams != nil {
|
||||
if cfgScale, ok := schemas.SafeExtractFloat64Pointer(request.Params.ExtraParams["cfgScale"]); ok {
|
||||
delete(request.Params.ExtraParams, "cfgScale")
|
||||
bedrockReq.ImageGenerationConfig.CfgScale = cfgScale
|
||||
}
|
||||
bedrockReq.ExtraParams = request.Params.ExtraParams
|
||||
}
|
||||
}
|
||||
|
||||
return bedrockReq, nil
|
||||
}
|
||||
|
||||
// ToStabilityAIImageGenerationResponse converts a BifrostImageGenerationResponse back to
|
||||
// the native Bedrock invoke API response format used by Stability AI models.
|
||||
// Stability AI models use the same BedrockImageGenerationResponse format as Titan/Nova Canvas.
|
||||
func ToStabilityAIImageGenerationResponse(response *schemas.BifrostImageGenerationResponse) (*BedrockImageGenerationResponse, error) {
|
||||
if response == nil {
|
||||
return nil, fmt.Errorf("response is nil")
|
||||
}
|
||||
result := &BedrockImageGenerationResponse{}
|
||||
for _, d := range response.Data {
|
||||
result.Images = append(result.Images, d.B64JSON)
|
||||
}
|
||||
if response.ImageGenerationResponseParameters != nil {
|
||||
result.FinishReasons = response.ImageGenerationResponseParameters.FinishReasons
|
||||
result.Seeds = response.ImageGenerationResponseParameters.Seeds
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// ToBedrockImageVariationRequest converts a Bifrost image variation request to a Bedrock image variation request
|
||||
func ToBedrockImageVariationRequest(request *schemas.BifrostImageVariationRequest) (*BedrockImageVariationRequest, error) {
|
||||
if request == nil {
|
||||
return nil, fmt.Errorf("request is nil")
|
||||
}
|
||||
|
||||
if request.Input == nil || request.Input.Image.Image == nil || len(request.Input.Image.Image) == 0 {
|
||||
return nil, fmt.Errorf("request.Input.Image is required")
|
||||
}
|
||||
|
||||
bedrockReq := &BedrockImageVariationRequest{
|
||||
TaskType: schemas.Ptr(TaskTypeImageVariation),
|
||||
ImageVariationParams: &BedrockImageVariationParams{
|
||||
Images: []string{},
|
||||
},
|
||||
ImageGenerationConfig: &ImageGenerationConfig{},
|
||||
}
|
||||
|
||||
// Convert all images to base64 strings
|
||||
// Primary image from Input.Image
|
||||
imageBase64 := base64.StdEncoding.EncodeToString(request.Input.Image.Image)
|
||||
bedrockReq.ImageVariationParams.Images = append(bedrockReq.ImageVariationParams.Images, imageBase64)
|
||||
|
||||
// Additional images from ExtraParams (stored as [][]byte)
|
||||
if request.Params != nil && request.Params.ExtraParams != nil {
|
||||
if additionalImages, ok := request.Params.ExtraParams["images"]; ok {
|
||||
delete(request.Params.ExtraParams, "images")
|
||||
// Handle array of byte arrays (stored by HTTP handler)
|
||||
if imagesArray, ok := additionalImages.([][]byte); ok {
|
||||
for _, imgBytes := range imagesArray {
|
||||
if len(imgBytes) > 0 {
|
||||
additionalBase64 := base64.StdEncoding.EncodeToString(imgBytes)
|
||||
bedrockReq.ImageVariationParams.Images = append(bedrockReq.ImageVariationParams.Images, additionalBase64)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Extract optional fields from ExtraParams
|
||||
if prompt, ok := schemas.SafeExtractStringPointer(request.Params.ExtraParams["prompt"]); ok {
|
||||
delete(request.Params.ExtraParams, "prompt")
|
||||
bedrockReq.ImageVariationParams.Text = prompt
|
||||
}
|
||||
if negativeText, ok := schemas.SafeExtractStringPointer(request.Params.ExtraParams["negativeText"]); ok {
|
||||
delete(request.Params.ExtraParams, "negativeText")
|
||||
bedrockReq.ImageVariationParams.NegativeText = negativeText
|
||||
}
|
||||
|
||||
if similarityStrength, ok := schemas.SafeExtractFloat64Pointer(request.Params.ExtraParams["similarityStrength"]); ok {
|
||||
delete(request.Params.ExtraParams, "similarityStrength")
|
||||
// Validate similarityStrength range (0.2 to 1.0)
|
||||
if *similarityStrength < 0.2 || *similarityStrength > 1.0 {
|
||||
return nil, fmt.Errorf("similarityStrength must be between 0.2 and 1.0, got %f", *similarityStrength)
|
||||
}
|
||||
bedrockReq.ImageVariationParams.SimilarityStrength = similarityStrength
|
||||
}
|
||||
bedrockReq.ExtraParams = request.Params.ExtraParams
|
||||
}
|
||||
|
||||
// Map standard params to ImageGenerationConfig
|
||||
if request.Params != nil {
|
||||
if request.Params.N != nil {
|
||||
bedrockReq.ImageGenerationConfig.NumberOfImages = request.Params.N
|
||||
}
|
||||
|
||||
if request.Params.Size != nil && strings.TrimSpace(strings.ToLower(*request.Params.Size)) != "auto" {
|
||||
size := strings.Split(strings.TrimSpace(strings.ToLower(*request.Params.Size)), "x")
|
||||
if len(size) != 2 {
|
||||
return nil, fmt.Errorf("invalid size format: expected 'WIDTHxHEIGHT', got %q", *request.Params.Size)
|
||||
}
|
||||
|
||||
width, err := strconv.Atoi(size[0])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid width in size %q: %w", *request.Params.Size, err)
|
||||
}
|
||||
|
||||
height, err := strconv.Atoi(size[1])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid height in size %q: %w", *request.Params.Size, err)
|
||||
}
|
||||
|
||||
bedrockReq.ImageGenerationConfig.Width = schemas.Ptr(width)
|
||||
bedrockReq.ImageGenerationConfig.Height = schemas.Ptr(height)
|
||||
}
|
||||
|
||||
// Extract quality and cfgScale from ExtraParams
|
||||
if request.Params.ExtraParams != nil {
|
||||
if quality, ok := schemas.SafeExtractStringPointer(request.Params.ExtraParams["quality"]); ok {
|
||||
bedrockReq.ImageGenerationConfig.Quality = mapQualityToBedrock(quality)
|
||||
}
|
||||
|
||||
if cfgScale, ok := schemas.SafeExtractFloat64Pointer(request.Params.ExtraParams["cfgScale"]); ok {
|
||||
bedrockReq.ImageGenerationConfig.CfgScale = cfgScale
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return bedrockReq, nil
|
||||
}
|
||||
|
||||
// ToBedrockImageEditRequest converts a Bifrost image edit request to a Bedrock image edit request
|
||||
func ToBedrockImageEditRequest(request *schemas.BifrostImageEditRequest) (*BedrockImageEditRequest, error) {
|
||||
// Validate request
|
||||
if request == nil || request.Input == nil {
|
||||
return nil, fmt.Errorf("request or input is nil")
|
||||
}
|
||||
|
||||
if len(request.Input.Images) == 0 || len(request.Input.Images[0].Image) == 0 {
|
||||
return nil, fmt.Errorf("at least one image is required")
|
||||
}
|
||||
|
||||
// Validate and extract type (required)
|
||||
if request.Params == nil || request.Params.Type == nil {
|
||||
return nil, fmt.Errorf("type field is required (must be inpainting, outpainting, or background_removal)")
|
||||
}
|
||||
|
||||
editType := strings.ToLower(*request.Params.Type)
|
||||
|
||||
// Convert first image to base64
|
||||
imageBase64 := base64.StdEncoding.EncodeToString(request.Input.Images[0].Image)
|
||||
|
||||
bedrockReq := &BedrockImageEditRequest{}
|
||||
|
||||
switch editType {
|
||||
case "inpainting":
|
||||
bedrockReq.TaskType = schemas.Ptr(TaskTypeInpainting)
|
||||
bedrockReq.InPaintingParams = buildInPaintingParams(imageBase64, request)
|
||||
bedrockReq.ImageGenerationConfig = buildImageGenerationConfig(request.Params)
|
||||
|
||||
case "outpainting":
|
||||
bedrockReq.TaskType = schemas.Ptr(TaskTypeOutpainting)
|
||||
bedrockReq.OutPaintingParams = buildOutPaintingParams(imageBase64, request)
|
||||
bedrockReq.ImageGenerationConfig = buildImageGenerationConfig(request.Params)
|
||||
|
||||
case "background_removal":
|
||||
bedrockReq.TaskType = schemas.Ptr(TaskTypeBackgroundRemoval)
|
||||
bedrockReq.BackgroundRemovalParams = &BedrockBackgroundRemovalParams{
|
||||
Image: imageBase64,
|
||||
}
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported type for Bedrock: %s (must be inpainting, outpainting, or background_removal)", editType)
|
||||
}
|
||||
|
||||
bedrockReq.ExtraParams = request.Params.ExtraParams
|
||||
return bedrockReq, nil
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
func buildInPaintingParams(imageBase64 string, request *schemas.BifrostImageEditRequest) *BedrockInPaintingParams {
|
||||
params := &BedrockInPaintingParams{
|
||||
Image: imageBase64,
|
||||
Text: request.Input.Prompt,
|
||||
}
|
||||
|
||||
if request.Params.NegativePrompt != nil {
|
||||
params.NegativeText = request.Params.NegativePrompt
|
||||
}
|
||||
|
||||
if request.Params.ExtraParams != nil {
|
||||
if maskPrompt, ok := schemas.SafeExtractStringPointer(request.Params.ExtraParams["mask_prompt"]); ok {
|
||||
delete(request.Params.ExtraParams, "mask_prompt")
|
||||
params.MaskPrompt = maskPrompt
|
||||
}
|
||||
if returnMask, ok := schemas.SafeExtractBoolPointer(request.Params.ExtraParams["return_mask"]); ok {
|
||||
delete(request.Params.ExtraParams, "return_mask")
|
||||
params.ReturnMask = returnMask
|
||||
}
|
||||
}
|
||||
|
||||
// Convert mask to base64 if present
|
||||
if len(request.Params.Mask) > 0 {
|
||||
maskBase64 := base64.StdEncoding.EncodeToString(request.Params.Mask)
|
||||
params.MaskImage = &maskBase64
|
||||
}
|
||||
|
||||
return params
|
||||
}
|
||||
|
||||
func buildOutPaintingParams(imageBase64 string, request *schemas.BifrostImageEditRequest) *BedrockOutPaintingParams {
|
||||
params := &BedrockOutPaintingParams{
|
||||
Text: request.Input.Prompt,
|
||||
Image: imageBase64,
|
||||
}
|
||||
|
||||
if request.Params.NegativePrompt != nil {
|
||||
params.NegativeText = request.Params.NegativePrompt
|
||||
}
|
||||
|
||||
if request.Params.ExtraParams != nil {
|
||||
if maskPrompt, ok := schemas.SafeExtractStringPointer(request.Params.ExtraParams["mask_prompt"]); ok {
|
||||
delete(request.Params.ExtraParams, "mask_prompt")
|
||||
params.MaskPrompt = maskPrompt
|
||||
}
|
||||
if returnMask, ok := schemas.SafeExtractBoolPointer(request.Params.ExtraParams["return_mask"]); ok {
|
||||
delete(request.Params.ExtraParams, "return_mask")
|
||||
params.ReturnMask = returnMask
|
||||
}
|
||||
if outPaintingMode, ok := schemas.SafeExtractStringPointer(request.Params.ExtraParams["outpainting_mode"]); ok {
|
||||
// Validate mode
|
||||
mode := strings.ToUpper(*outPaintingMode)
|
||||
if mode == "DEFAULT" || mode == "PRECISE" {
|
||||
delete(request.Params.ExtraParams, "outpainting_mode")
|
||||
params.OutPaintingMode = &mode
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Convert mask to base64 if present
|
||||
if len(request.Params.Mask) > 0 {
|
||||
maskBase64 := base64.StdEncoding.EncodeToString(request.Params.Mask)
|
||||
params.MaskImage = &maskBase64
|
||||
}
|
||||
|
||||
return params
|
||||
}
|
||||
|
||||
func buildImageGenerationConfig(params *schemas.ImageEditParameters) *ImageGenerationConfig {
|
||||
config := &ImageGenerationConfig{}
|
||||
|
||||
if params.N != nil {
|
||||
config.NumberOfImages = params.N
|
||||
}
|
||||
|
||||
// Parse size (reuse logic from image generation)
|
||||
if params.Size != nil && strings.TrimSpace(strings.ToLower(*params.Size)) != "auto" {
|
||||
size := strings.Split(strings.TrimSpace(strings.ToLower(*params.Size)), "x")
|
||||
if len(size) == 2 {
|
||||
width, err := strconv.Atoi(size[0])
|
||||
if err == nil {
|
||||
height, err := strconv.Atoi(size[1])
|
||||
if err == nil {
|
||||
config.Width = schemas.Ptr(width)
|
||||
config.Height = schemas.Ptr(height)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if params.Quality != nil {
|
||||
config.Quality = mapQualityToBedrock(params.Quality)
|
||||
}
|
||||
|
||||
if params.Seed != nil {
|
||||
config.Seed = params.Seed
|
||||
}
|
||||
|
||||
if params.ExtraParams != nil {
|
||||
if cfgScale, ok := schemas.SafeExtractFloat64Pointer(params.ExtraParams["cfgScale"]); ok {
|
||||
delete(params.ExtraParams, "cfgScale")
|
||||
config.CfgScale = cfgScale
|
||||
}
|
||||
}
|
||||
|
||||
return config
|
||||
}
|
||||
|
||||
// getStabilityAITaskTypeFromParams maps the generic BifrostImageEditParameters.Type value
|
||||
// to a Stability AI task type string. Returns "" if the value is not a recognized Stability AI task type.
|
||||
func getStabilityAITaskTypeFromParams(t string) string {
|
||||
switch strings.ToLower(t) {
|
||||
case "inpainting", "inpaint":
|
||||
return "inpaint"
|
||||
case "outpainting", "outpaint":
|
||||
return "outpaint"
|
||||
case "background_removal", "remove_background":
|
||||
return "remove-bg"
|
||||
case "erase_object":
|
||||
return "erase-object"
|
||||
case "upscale_fast":
|
||||
return "upscale-fast"
|
||||
case "upscale_creative":
|
||||
return "upscale-creative"
|
||||
case "upscale_conservative":
|
||||
return "upscale-conservative"
|
||||
case "recolor":
|
||||
return "recolor"
|
||||
case "search_replace":
|
||||
return "search-replace"
|
||||
case "control_sketch":
|
||||
return "control-sketch"
|
||||
case "control_structure":
|
||||
return "control-structure"
|
||||
case "style_guide":
|
||||
return "style-guide"
|
||||
case "style_transfer":
|
||||
return "style-transfer"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// getStabilityAIEditTaskType infers the Stability AI edit task from the model name.
|
||||
// Returns an error if the model name does not match any known pattern.
|
||||
func getStabilityAIEditTaskType(model string) (string, error) {
|
||||
m := strings.ToLower(model)
|
||||
switch {
|
||||
case strings.Contains(m, "stable-creative-upscale"):
|
||||
return "upscale-creative", nil
|
||||
case strings.Contains(m, "stable-conservative-upscale"):
|
||||
return "upscale-conservative", nil
|
||||
case strings.Contains(m, "stable-fast-upscale"):
|
||||
return "upscale-fast", nil
|
||||
case strings.Contains(m, "stable-image-inpaint"):
|
||||
return "inpaint", nil
|
||||
case strings.Contains(m, "stable-outpaint"):
|
||||
return "outpaint", nil
|
||||
case strings.Contains(m, "stable-image-search-recolor"):
|
||||
return "recolor", nil
|
||||
case strings.Contains(m, "stable-image-search-replace"):
|
||||
return "search-replace", nil
|
||||
case strings.Contains(m, "stable-image-erase-object"):
|
||||
return "erase-object", nil
|
||||
case strings.Contains(m, "stable-image-remove-background"):
|
||||
return "remove-bg", nil
|
||||
case strings.Contains(m, "stable-image-control-sketch"):
|
||||
return "control-sketch", nil
|
||||
case strings.Contains(m, "stable-image-control-structure"):
|
||||
return "control-structure", nil
|
||||
case strings.Contains(m, "stable-image-style-guide"):
|
||||
return "style-guide", nil
|
||||
case strings.Contains(m, "stable-style-transfer"):
|
||||
return "style-transfer", nil
|
||||
default:
|
||||
return "", fmt.Errorf("cannot determine task type from stability ai model name %q", model)
|
||||
}
|
||||
}
|
||||
|
||||
// ToStabilityAIImageEditRequest converts a Bifrost image edit request to the Stability AI flat request
|
||||
// format used by Bedrock edit models. Only fields valid for the detected task type are populated.
|
||||
// deployment is the resolved model identifier (after applying any deployment alias mapping); it is
|
||||
// used for task-type inference so that alias-mapped models route correctly.
|
||||
func ToStabilityAIImageEditRequest(request *schemas.BifrostImageEditRequest, deployment string) (*StabilityAIImageEditRequest, error) {
|
||||
if request == nil || request.Input == nil {
|
||||
return nil, fmt.Errorf("request or input is nil")
|
||||
}
|
||||
|
||||
var taskType string
|
||||
if request.Params != nil && request.Params.Type != nil {
|
||||
taskType = getStabilityAITaskTypeFromParams(*request.Params.Type)
|
||||
}
|
||||
if taskType == "" {
|
||||
var err error
|
||||
taskType, err = getStabilityAIEditTaskType(deployment)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
req := &StabilityAIImageEditRequest{}
|
||||
|
||||
// Image sourcing
|
||||
if taskType == "style-transfer" {
|
||||
if len(request.Input.Images) != 2 {
|
||||
return nil, fmt.Errorf("style-transfer requires exactly two images: init_image and style_image")
|
||||
}
|
||||
if len(request.Input.Images[0].Image) == 0 || len(request.Input.Images[1].Image) == 0 {
|
||||
return nil, fmt.Errorf("style-transfer requires non-empty init_image and style_image")
|
||||
}
|
||||
initB64 := base64.StdEncoding.EncodeToString(request.Input.Images[0].Image)
|
||||
styleB64 := base64.StdEncoding.EncodeToString(request.Input.Images[1].Image)
|
||||
req.InitImage = &initB64
|
||||
req.StyleImage = &styleB64
|
||||
} else {
|
||||
if len(request.Input.Images) == 0 || len(request.Input.Images[0].Image) == 0 {
|
||||
return nil, fmt.Errorf("at least one image is required")
|
||||
}
|
||||
imageB64 := base64.StdEncoding.EncodeToString(request.Input.Images[0].Image)
|
||||
req.Image = &imageB64
|
||||
}
|
||||
|
||||
// Common fields populated based on task allowlist
|
||||
prompt := request.Input.Prompt
|
||||
switch taskType {
|
||||
case "inpaint", "recolor", "search-replace", "control-sketch", "control-structure",
|
||||
"style-guide", "upscale-creative", "upscale-conservative", "outpaint", "style-transfer":
|
||||
req.Prompt = &prompt
|
||||
}
|
||||
|
||||
// Negative prompt
|
||||
if request.Params != nil && request.Params.NegativePrompt != nil {
|
||||
switch taskType {
|
||||
case "inpaint", "outpaint", "recolor", "search-replace", "control-sketch",
|
||||
"control-structure", "style-guide", "upscale-creative", "upscale-conservative", "style-transfer":
|
||||
req.NegativePrompt = request.Params.NegativePrompt
|
||||
}
|
||||
}
|
||||
|
||||
// Seed
|
||||
if request.Params != nil && request.Params.Seed != nil {
|
||||
switch taskType {
|
||||
case "inpaint", "outpaint", "recolor", "search-replace", "erase-object", "control-sketch",
|
||||
"control-structure", "style-guide", "upscale-creative", "upscale-conservative", "style-transfer":
|
||||
req.Seed = request.Params.Seed
|
||||
}
|
||||
}
|
||||
|
||||
// Mask (from Params.Mask bytes)
|
||||
if request.Params != nil && len(request.Params.Mask) > 0 {
|
||||
switch taskType {
|
||||
case "inpaint", "erase-object":
|
||||
maskB64 := base64.StdEncoding.EncodeToString(request.Params.Mask)
|
||||
req.Mask = &maskB64
|
||||
}
|
||||
}
|
||||
|
||||
// ExtraParams
|
||||
if request.Params != nil {
|
||||
// Typed OutputFormat takes priority over ExtraParams
|
||||
if request.Params.OutputFormat != nil {
|
||||
req.OutputFormat = request.Params.OutputFormat
|
||||
}
|
||||
|
||||
if request.Params.ExtraParams != nil {
|
||||
ep := make(map[string]interface{}, len(request.Params.ExtraParams))
|
||||
for k, v := range request.Params.ExtraParams {
|
||||
ep[k] = v
|
||||
}
|
||||
|
||||
// output_format — all tasks (fallback if not already set by typed field)
|
||||
if req.OutputFormat == nil {
|
||||
if v, ok := schemas.SafeExtractStringPointer(ep["output_format"]); ok {
|
||||
delete(ep, "output_format")
|
||||
req.OutputFormat = v
|
||||
}
|
||||
}
|
||||
|
||||
// style_preset
|
||||
switch taskType {
|
||||
case "inpaint", "outpaint", "recolor", "search-replace", "control-sketch",
|
||||
"control-structure", "style-guide", "upscale-creative":
|
||||
if v, ok := schemas.SafeExtractStringPointer(ep["style_preset"]); ok {
|
||||
delete(ep, "style_preset")
|
||||
req.StylePreset = v
|
||||
}
|
||||
}
|
||||
|
||||
// grow_mask
|
||||
switch taskType {
|
||||
case "inpaint", "recolor", "search-replace", "erase-object":
|
||||
if v, ok := schemas.SafeExtractIntPointer(ep["grow_mask"]); ok {
|
||||
delete(ep, "grow_mask")
|
||||
req.GrowMask = v
|
||||
}
|
||||
}
|
||||
|
||||
// outpaint directional fields
|
||||
if taskType == "outpaint" {
|
||||
if v, ok := schemas.SafeExtractIntPointer(ep["left"]); ok {
|
||||
delete(ep, "left")
|
||||
req.Left = v
|
||||
}
|
||||
if v, ok := schemas.SafeExtractIntPointer(ep["right"]); ok {
|
||||
delete(ep, "right")
|
||||
req.Right = v
|
||||
}
|
||||
if v, ok := schemas.SafeExtractIntPointer(ep["up"]); ok {
|
||||
delete(ep, "up")
|
||||
req.Up = v
|
||||
}
|
||||
if v, ok := schemas.SafeExtractIntPointer(ep["down"]); ok {
|
||||
delete(ep, "down")
|
||||
req.Down = v
|
||||
}
|
||||
}
|
||||
|
||||
// creativity
|
||||
switch taskType {
|
||||
case "upscale-creative", "upscale-conservative", "outpaint":
|
||||
if v, ok := schemas.SafeExtractFloat64Pointer(ep["creativity"]); ok {
|
||||
delete(ep, "creativity")
|
||||
req.Creativity = v
|
||||
}
|
||||
}
|
||||
|
||||
// select_prompt (recolor)
|
||||
if taskType == "recolor" {
|
||||
if v, ok := schemas.SafeExtractStringPointer(ep["select_prompt"]); ok {
|
||||
delete(ep, "select_prompt")
|
||||
req.SelectPrompt = v
|
||||
}
|
||||
}
|
||||
|
||||
// search_prompt (search-replace)
|
||||
if taskType == "search-replace" {
|
||||
if v, ok := schemas.SafeExtractStringPointer(ep["search_prompt"]); ok {
|
||||
delete(ep, "search_prompt")
|
||||
req.SearchPrompt = v
|
||||
}
|
||||
}
|
||||
|
||||
// control_strength
|
||||
switch taskType {
|
||||
case "control-sketch", "control-structure":
|
||||
if v, ok := schemas.SafeExtractFloat64Pointer(ep["control_strength"]); ok {
|
||||
delete(ep, "control_strength")
|
||||
req.ControlStrength = v
|
||||
}
|
||||
}
|
||||
|
||||
// style-guide fields
|
||||
if taskType == "style-guide" {
|
||||
if v, ok := schemas.SafeExtractStringPointer(ep["aspect_ratio"]); ok {
|
||||
delete(ep, "aspect_ratio")
|
||||
req.AspectRatio = v
|
||||
}
|
||||
if v, ok := schemas.SafeExtractFloat64Pointer(ep["fidelity"]); ok {
|
||||
delete(ep, "fidelity")
|
||||
req.Fidelity = v
|
||||
}
|
||||
}
|
||||
|
||||
// style-transfer fields
|
||||
if taskType == "style-transfer" {
|
||||
if v, ok := schemas.SafeExtractFloat64Pointer(ep["style_strength"]); ok {
|
||||
delete(ep, "style_strength")
|
||||
req.StyleStrength = v
|
||||
}
|
||||
if v, ok := schemas.SafeExtractFloat64Pointer(ep["composition_fidelity"]); ok {
|
||||
delete(ep, "composition_fidelity")
|
||||
req.CompositionFidelity = v
|
||||
}
|
||||
if v, ok := schemas.SafeExtractFloat64Pointer(ep["change_strength"]); ok {
|
||||
delete(ep, "change_strength")
|
||||
req.ChangeStrength = v
|
||||
}
|
||||
}
|
||||
|
||||
req.ExtraParams = ep
|
||||
}
|
||||
}
|
||||
|
||||
// Validate required per-task fields
|
||||
if taskType == "recolor" && (req.SelectPrompt == nil || *req.SelectPrompt == "") {
|
||||
return nil, fmt.Errorf("select_prompt is required for stability ai recolor task")
|
||||
}
|
||||
if taskType == "search-replace" && (req.SearchPrompt == nil || *req.SearchPrompt == "") {
|
||||
return nil, fmt.Errorf("search_prompt is required for stability ai search-replace task")
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
// ToBifrostImageGenerationResponse converts a Bedrock image generation response to a Bifrost image generation response
|
||||
func ToBifrostImageGenerationResponse(response *BedrockImageGenerationResponse) *schemas.BifrostImageGenerationResponse {
|
||||
if response == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
bifrostResponse := &schemas.BifrostImageGenerationResponse{}
|
||||
|
||||
if len(response.FinishReasons) > 0 || len(response.Seeds) > 0 {
|
||||
bifrostResponse.ImageGenerationResponseParameters = &schemas.ImageGenerationResponseParameters{
|
||||
FinishReasons: append([]*string(nil), response.FinishReasons...),
|
||||
Seeds: append([]int(nil), response.Seeds...),
|
||||
}
|
||||
}
|
||||
|
||||
for index, image := range response.Images {
|
||||
bifrostResponse.Data = append(bifrostResponse.Data, schemas.ImageData{
|
||||
B64JSON: image,
|
||||
Index: index,
|
||||
})
|
||||
}
|
||||
|
||||
return bifrostResponse
|
||||
}
|
||||
1442
core/providers/bedrock/invoke.go
Normal file
1442
core/providers/bedrock/invoke.go
Normal file
File diff suppressed because it is too large
Load Diff
130
core/providers/bedrock/models.go
Normal file
130
core/providers/bedrock/models.go
Normal file
@@ -0,0 +1,130 @@
|
||||
package bedrock
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// BedrockRerankRequest is the Bedrock Agent Runtime rerank request body.
|
||||
type BedrockRerankRequest struct {
|
||||
Queries []BedrockRerankQuery `json:"queries"`
|
||||
Sources []BedrockRerankSource `json:"sources"`
|
||||
RerankingConfiguration BedrockRerankingConfiguration `json:"rerankingConfiguration"`
|
||||
}
|
||||
|
||||
// GetExtraParams implements RequestBodyWithExtraParams.
|
||||
func (*BedrockRerankRequest) GetExtraParams() map[string]interface{} {
|
||||
return nil
|
||||
}
|
||||
|
||||
const (
|
||||
bedrockRerankQueryTypeText = "TEXT"
|
||||
bedrockRerankSourceTypeInline = "INLINE"
|
||||
bedrockRerankInlineDocumentTypeText = "TEXT"
|
||||
bedrockRerankConfigurationTypeBedrock = "BEDROCK_RERANKING_MODEL"
|
||||
)
|
||||
|
||||
type BedrockRerankQuery struct {
|
||||
Type string `json:"type"`
|
||||
TextQuery BedrockRerankTextRef `json:"textQuery"`
|
||||
}
|
||||
|
||||
type BedrockRerankSource struct {
|
||||
Type string `json:"type"`
|
||||
InlineDocumentSource BedrockRerankInlineSource `json:"inlineDocumentSource"`
|
||||
}
|
||||
|
||||
type BedrockRerankInlineSource struct {
|
||||
Type string `json:"type"`
|
||||
TextDocument BedrockRerankTextValue `json:"textDocument"`
|
||||
}
|
||||
|
||||
type BedrockRerankTextRef struct {
|
||||
Text string `json:"text"`
|
||||
}
|
||||
|
||||
type BedrockRerankTextValue struct {
|
||||
Text string `json:"text"`
|
||||
}
|
||||
|
||||
type BedrockRerankingConfiguration struct {
|
||||
Type string `json:"type"`
|
||||
BedrockRerankingConfiguration BedrockRerankingModelConfiguration `json:"bedrockRerankingConfiguration"`
|
||||
}
|
||||
|
||||
type BedrockRerankingModelConfiguration struct {
|
||||
ModelConfiguration BedrockRerankModelConfiguration `json:"modelConfiguration"`
|
||||
NumberOfResults *int `json:"numberOfResults,omitempty"`
|
||||
}
|
||||
|
||||
type BedrockRerankModelConfiguration struct {
|
||||
ModelARN string `json:"modelArn"`
|
||||
AdditionalModelRequestFields map[string]interface{} `json:"additionalModelRequestFields,omitempty"`
|
||||
}
|
||||
|
||||
// BedrockRerankResponse is the Bedrock Agent Runtime rerank response body.
|
||||
type BedrockRerankResponse struct {
|
||||
Results []BedrockRerankResult `json:"results"`
|
||||
NextToken *string `json:"nextToken,omitempty"`
|
||||
}
|
||||
|
||||
type BedrockRerankResult struct {
|
||||
Index int `json:"index"`
|
||||
RelevanceScore float64 `json:"relevanceScore"`
|
||||
Document *BedrockRerankResponseDocument `json:"document,omitempty"`
|
||||
}
|
||||
|
||||
type BedrockRerankResponseDocument struct {
|
||||
Type string `json:"type,omitempty"`
|
||||
TextDocument *BedrockRerankTextValue `json:"textDocument,omitempty"`
|
||||
}
|
||||
|
||||
func (response *BedrockListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels schemas.WhiteList, blacklistedModels schemas.BlackList, aliases map[string]string, unfiltered bool) *schemas.BifrostListModelsResponse {
|
||||
if response == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
bifrostResponse := &schemas.BifrostListModelsResponse{
|
||||
Data: make([]schemas.Model, 0, len(response.ModelSummaries)),
|
||||
}
|
||||
|
||||
pipeline := &providerUtils.ListModelsPipeline{
|
||||
AllowedModels: allowedModels,
|
||||
BlacklistedModels: blacklistedModels,
|
||||
Aliases: aliases,
|
||||
Unfiltered: unfiltered,
|
||||
ProviderKey: providerKey,
|
||||
MatchFns: providerUtils.DefaultMatchFns(),
|
||||
}
|
||||
if pipeline.ShouldEarlyExit() {
|
||||
return bifrostResponse
|
||||
}
|
||||
|
||||
included := make(map[string]bool)
|
||||
|
||||
for _, model := range response.ModelSummaries {
|
||||
for _, result := range pipeline.FilterModel(model.ModelID) {
|
||||
modelEntry := schemas.Model{
|
||||
ID: string(providerKey) + "/" + result.ResolvedID,
|
||||
Name: schemas.Ptr(model.ModelName),
|
||||
OwnedBy: schemas.Ptr(model.ProviderName),
|
||||
Architecture: &schemas.Architecture{
|
||||
InputModalities: model.InputModalities,
|
||||
OutputModalities: model.OutputModalities,
|
||||
},
|
||||
}
|
||||
if result.AliasValue != "" {
|
||||
modelEntry.Alias = schemas.Ptr(result.AliasValue)
|
||||
}
|
||||
bifrostResponse.Data = append(bifrostResponse.Data, modelEntry)
|
||||
included[strings.ToLower(result.ResolvedID)] = true
|
||||
}
|
||||
}
|
||||
|
||||
bifrostResponse.Data = append(bifrostResponse.Data,
|
||||
pipeline.BackfillModels(included)...)
|
||||
|
||||
return bifrostResponse
|
||||
}
|
||||
55
core/providers/bedrock/payload_ordering_test.go
Normal file
55
core/providers/bedrock/payload_ordering_test.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package bedrock
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
|
||||
schemas "github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestPayloadOrdering_BedrockConverseRequest(t *testing.T) {
|
||||
req := &BedrockConverseRequest{
|
||||
Messages: []BedrockMessage{
|
||||
{
|
||||
Role: "user",
|
||||
Content: []BedrockContentBlock{
|
||||
{Text: schemas.Ptr("hello")},
|
||||
},
|
||||
},
|
||||
},
|
||||
InferenceConfig: &BedrockInferenceConfig{
|
||||
Temperature: schemas.Ptr(0.7),
|
||||
MaxTokens: schemas.Ptr(1024),
|
||||
},
|
||||
ToolConfig: &BedrockToolConfig{
|
||||
Tools: []BedrockTool{
|
||||
{
|
||||
ToolSpec: &BedrockToolSpec{
|
||||
Name: "get_weather",
|
||||
Description: schemas.Ptr("Get weather"),
|
||||
InputSchema: BedrockToolInputSchema{
|
||||
JSON: json.RawMessage(`{"type":"object","properties":{"location":{"type":"string"}},"required":["location"]}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := providerUtils.MarshalSorted(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
golden := `{"messages":[{"role":"user","content":[{"text":"hello"}]}],"inferenceConfig":{"maxTokens":1024,"temperature":0.7},"toolConfig":{"tools":[{"toolSpec":{"name":"get_weather","description":"Get weather","inputSchema":{"json":{"type":"object","properties":{"location":{"type":"string"}},"required":["location"]}}}}]}}`
|
||||
|
||||
assert.Equal(t, golden, string(result), "payload field ordering changed — if intentional, update the golden string")
|
||||
|
||||
// Determinism: 100 iterations must produce identical bytes
|
||||
for i := 0; i < 100; i++ {
|
||||
iter, err := providerUtils.MarshalSorted(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, string(result), string(iter), "non-deterministic marshal output on iteration %d", i)
|
||||
}
|
||||
}
|
||||
168
core/providers/bedrock/rerank.go
Normal file
168
core/providers/bedrock/rerank.go
Normal file
@@ -0,0 +1,168 @@
|
||||
package bedrock
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// ToBedrockRerankRequest converts a Bifrost rerank request into Bedrock Agent Runtime format.
|
||||
func ToBedrockRerankRequest(bifrostReq *schemas.BifrostRerankRequest, modelARN string) (*BedrockRerankRequest, error) {
|
||||
if bifrostReq == nil {
|
||||
return nil, fmt.Errorf("bifrost rerank request is nil")
|
||||
}
|
||||
if strings.TrimSpace(modelARN) == "" {
|
||||
return nil, fmt.Errorf("bedrock rerank model ARN is empty")
|
||||
}
|
||||
if len(bifrostReq.Documents) == 0 {
|
||||
return nil, fmt.Errorf("documents are required for rerank request")
|
||||
}
|
||||
|
||||
bedrockReq := &BedrockRerankRequest{
|
||||
Queries: []BedrockRerankQuery{
|
||||
{
|
||||
Type: bedrockRerankQueryTypeText,
|
||||
TextQuery: BedrockRerankTextRef{
|
||||
Text: bifrostReq.Query,
|
||||
},
|
||||
},
|
||||
},
|
||||
Sources: make([]BedrockRerankSource, len(bifrostReq.Documents)),
|
||||
RerankingConfiguration: BedrockRerankingConfiguration{
|
||||
Type: bedrockRerankConfigurationTypeBedrock,
|
||||
BedrockRerankingConfiguration: BedrockRerankingModelConfiguration{
|
||||
ModelConfiguration: BedrockRerankModelConfiguration{
|
||||
ModelARN: modelARN,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for i, doc := range bifrostReq.Documents {
|
||||
bedrockReq.Sources[i] = BedrockRerankSource{
|
||||
Type: bedrockRerankSourceTypeInline,
|
||||
InlineDocumentSource: BedrockRerankInlineSource{
|
||||
Type: bedrockRerankInlineDocumentTypeText,
|
||||
TextDocument: BedrockRerankTextValue{
|
||||
Text: doc.Text,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
if bifrostReq.Params == nil {
|
||||
return bedrockReq, nil
|
||||
}
|
||||
|
||||
if bifrostReq.Params.TopN != nil {
|
||||
topN := *bifrostReq.Params.TopN
|
||||
if topN < 1 {
|
||||
return nil, fmt.Errorf("top_n must be at least 1")
|
||||
}
|
||||
if topN > len(bifrostReq.Documents) {
|
||||
topN = len(bifrostReq.Documents)
|
||||
}
|
||||
bedrockReq.RerankingConfiguration.BedrockRerankingConfiguration.NumberOfResults = schemas.Ptr(topN)
|
||||
}
|
||||
|
||||
additionalFields := make(map[string]interface{})
|
||||
if bifrostReq.Params.MaxTokensPerDoc != nil {
|
||||
additionalFields["max_tokens_per_doc"] = *bifrostReq.Params.MaxTokensPerDoc
|
||||
}
|
||||
if bifrostReq.Params.Priority != nil {
|
||||
additionalFields["priority"] = *bifrostReq.Params.Priority
|
||||
}
|
||||
for k, v := range bifrostReq.Params.ExtraParams {
|
||||
additionalFields[k] = v
|
||||
}
|
||||
if len(additionalFields) > 0 {
|
||||
bedrockReq.RerankingConfiguration.BedrockRerankingConfiguration.ModelConfiguration.AdditionalModelRequestFields = additionalFields
|
||||
}
|
||||
|
||||
return bedrockReq, nil
|
||||
}
|
||||
|
||||
// ToBifrostRerankResponse converts a Bedrock rerank response into Bifrost format.
|
||||
func (response *BedrockRerankResponse) ToBifrostRerankResponse(documents []schemas.RerankDocument, returnDocuments bool) *schemas.BifrostRerankResponse {
|
||||
if response == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
bifrostResponse := &schemas.BifrostRerankResponse{
|
||||
Results: make([]schemas.RerankResult, 0, len(response.Results)),
|
||||
}
|
||||
|
||||
for _, result := range response.Results {
|
||||
rerankResult := schemas.RerankResult{
|
||||
Index: result.Index,
|
||||
RelevanceScore: result.RelevanceScore,
|
||||
}
|
||||
if result.Document != nil && result.Document.TextDocument != nil {
|
||||
rerankResult.Document = &schemas.RerankDocument{
|
||||
Text: result.Document.TextDocument.Text,
|
||||
}
|
||||
}
|
||||
bifrostResponse.Results = append(bifrostResponse.Results, rerankResult)
|
||||
}
|
||||
|
||||
sort.SliceStable(bifrostResponse.Results, func(i, j int) bool {
|
||||
if bifrostResponse.Results[i].RelevanceScore == bifrostResponse.Results[j].RelevanceScore {
|
||||
return bifrostResponse.Results[i].Index < bifrostResponse.Results[j].Index
|
||||
}
|
||||
return bifrostResponse.Results[i].RelevanceScore > bifrostResponse.Results[j].RelevanceScore
|
||||
})
|
||||
|
||||
if returnDocuments {
|
||||
for i := range bifrostResponse.Results {
|
||||
resultIndex := bifrostResponse.Results[i].Index
|
||||
if resultIndex >= 0 && resultIndex < len(documents) {
|
||||
bifrostResponse.Results[i].Document = schemas.Ptr(documents[resultIndex])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return bifrostResponse
|
||||
}
|
||||
|
||||
// ToBifrostRerankRequest converts a Bedrock Agent Runtime rerank request to Bifrost format.
|
||||
func (req *BedrockRerankRequest) ToBifrostRerankRequest(ctx *schemas.BifrostContext) *schemas.BifrostRerankRequest {
|
||||
if req == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
modelARN := req.RerankingConfiguration.BedrockRerankingConfiguration.ModelConfiguration.ModelARN
|
||||
provider, model := schemas.ParseModelString(modelARN, providerUtils.CheckAndSetDefaultProvider(ctx, schemas.Bedrock))
|
||||
|
||||
bifrostReq := &schemas.BifrostRerankRequest{
|
||||
Provider: provider,
|
||||
Model: model,
|
||||
Params: &schemas.RerankParameters{},
|
||||
}
|
||||
|
||||
// Extract query from the first query entry
|
||||
if len(req.Queries) > 0 {
|
||||
bifrostReq.Query = req.Queries[0].TextQuery.Text
|
||||
}
|
||||
|
||||
// Convert sources to documents
|
||||
for _, source := range req.Sources {
|
||||
bifrostReq.Documents = append(bifrostReq.Documents, schemas.RerankDocument{
|
||||
Text: source.InlineDocumentSource.TextDocument.Text,
|
||||
})
|
||||
}
|
||||
|
||||
// Extract TopN from NumberOfResults
|
||||
if req.RerankingConfiguration.BedrockRerankingConfiguration.NumberOfResults != nil {
|
||||
bifrostReq.Params.TopN = req.RerankingConfiguration.BedrockRerankingConfiguration.NumberOfResults
|
||||
}
|
||||
|
||||
// Pass AdditionalModelRequestFields as ExtraParams
|
||||
if fields := req.RerankingConfiguration.BedrockRerankingConfiguration.ModelConfiguration.AdditionalModelRequestFields; len(fields) > 0 {
|
||||
bifrostReq.Params.ExtraParams = fields
|
||||
}
|
||||
|
||||
return bifrostReq
|
||||
}
|
||||
230
core/providers/bedrock/rerank_test.go
Normal file
230
core/providers/bedrock/rerank_test.go
Normal file
@@ -0,0 +1,230 @@
|
||||
package bedrock
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestToBedrockRerankRequest(t *testing.T) {
|
||||
topN := 10
|
||||
maxTokensPerDoc := 512
|
||||
priority := 3
|
||||
|
||||
req, err := ToBedrockRerankRequest(&schemas.BifrostRerankRequest{
|
||||
Model: "arn:aws:bedrock:us-east-1::foundation-model/cohere.rerank-v3-5:0",
|
||||
Query: "capital of france",
|
||||
Documents: []schemas.RerankDocument{
|
||||
{Text: "Paris is the capital of France."},
|
||||
{Text: "Berlin is the capital of Germany."},
|
||||
},
|
||||
Params: &schemas.RerankParameters{
|
||||
TopN: schemas.Ptr(topN),
|
||||
MaxTokensPerDoc: schemas.Ptr(maxTokensPerDoc),
|
||||
Priority: schemas.Ptr(priority),
|
||||
ExtraParams: map[string]interface{}{
|
||||
"truncate": "END",
|
||||
},
|
||||
},
|
||||
}, "arn:aws:bedrock:us-east-1::foundation-model/cohere.rerank-v3-5:0")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, req)
|
||||
|
||||
require.Len(t, req.Queries, 1)
|
||||
assert.Equal(t, "TEXT", req.Queries[0].Type)
|
||||
assert.Equal(t, "capital of france", req.Queries[0].TextQuery.Text)
|
||||
require.Len(t, req.Sources, 2)
|
||||
|
||||
require.NotNil(t, req.RerankingConfiguration.BedrockRerankingConfiguration.NumberOfResults)
|
||||
assert.Equal(t, 2, *req.RerankingConfiguration.BedrockRerankingConfiguration.NumberOfResults, "top_n must be clamped to source count")
|
||||
|
||||
fields := req.RerankingConfiguration.BedrockRerankingConfiguration.ModelConfiguration.AdditionalModelRequestFields
|
||||
require.NotNil(t, fields)
|
||||
assert.Equal(t, maxTokensPerDoc, fields["max_tokens_per_doc"])
|
||||
assert.Equal(t, priority, fields["priority"])
|
||||
assert.Equal(t, "END", fields["truncate"])
|
||||
}
|
||||
|
||||
func TestBedrockRerankResponseToBifrostRerankResponse(t *testing.T) {
|
||||
response := (&BedrockRerankResponse{
|
||||
Results: []BedrockRerankResult{
|
||||
{
|
||||
Index: 2,
|
||||
RelevanceScore: 0.21,
|
||||
Document: &BedrockRerankResponseDocument{
|
||||
TextDocument: &BedrockRerankTextValue{Text: "doc-2"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Index: 1,
|
||||
RelevanceScore: 0.95,
|
||||
Document: &BedrockRerankResponseDocument{
|
||||
TextDocument: &BedrockRerankTextValue{Text: "doc-1"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Index: 0,
|
||||
RelevanceScore: 0.95,
|
||||
Document: &BedrockRerankResponseDocument{
|
||||
TextDocument: &BedrockRerankTextValue{Text: "doc-0"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}).ToBifrostRerankResponse(nil, false)
|
||||
|
||||
require.NotNil(t, response)
|
||||
require.Len(t, response.Results, 3)
|
||||
|
||||
assert.Equal(t, 0, response.Results[0].Index)
|
||||
assert.Equal(t, 1, response.Results[1].Index)
|
||||
assert.Equal(t, 2, response.Results[2].Index)
|
||||
assert.Equal(t, "doc-0", response.Results[0].Document.Text)
|
||||
assert.Equal(t, "doc-1", response.Results[1].Document.Text)
|
||||
}
|
||||
|
||||
func TestBedrockRerankResponseToBifrostRerankResponseReturnDocuments(t *testing.T) {
|
||||
requestDocs := []schemas.RerankDocument{
|
||||
{Text: "request-doc-0"},
|
||||
{Text: "request-doc-1"},
|
||||
{Text: "request-doc-2"},
|
||||
}
|
||||
|
||||
response := (&BedrockRerankResponse{
|
||||
Results: []BedrockRerankResult{
|
||||
{
|
||||
Index: 2,
|
||||
RelevanceScore: 0.21,
|
||||
Document: &BedrockRerankResponseDocument{
|
||||
TextDocument: &BedrockRerankTextValue{Text: "provider-doc-2"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Index: 1,
|
||||
RelevanceScore: 0.95,
|
||||
Document: &BedrockRerankResponseDocument{
|
||||
TextDocument: &BedrockRerankTextValue{Text: "provider-doc-1"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Index: 0,
|
||||
RelevanceScore: 0.95,
|
||||
Document: &BedrockRerankResponseDocument{
|
||||
TextDocument: &BedrockRerankTextValue{Text: "provider-doc-0"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}).ToBifrostRerankResponse(requestDocs, true)
|
||||
|
||||
require.NotNil(t, response)
|
||||
require.Len(t, response.Results, 3)
|
||||
require.NotNil(t, response.Results[0].Document)
|
||||
require.NotNil(t, response.Results[1].Document)
|
||||
require.NotNil(t, response.Results[2].Document)
|
||||
|
||||
assert.Equal(t, 0, response.Results[0].Index)
|
||||
assert.Equal(t, 1, response.Results[1].Index)
|
||||
assert.Equal(t, 2, response.Results[2].Index)
|
||||
assert.Equal(t, "request-doc-0", response.Results[0].Document.Text)
|
||||
assert.Equal(t, "request-doc-1", response.Results[1].Document.Text)
|
||||
assert.Equal(t, "request-doc-2", response.Results[2].Document.Text)
|
||||
}
|
||||
|
||||
func TestBedrockRerankRequestToBifrostRerankRequest(t *testing.T) {
|
||||
topN := 3
|
||||
bedrockReq := &BedrockRerankRequest{
|
||||
Queries: []BedrockRerankQuery{
|
||||
{
|
||||
Type: bedrockRerankQueryTypeText,
|
||||
TextQuery: BedrockRerankTextRef{Text: "capital of france"},
|
||||
},
|
||||
},
|
||||
Sources: []BedrockRerankSource{
|
||||
{
|
||||
Type: bedrockRerankSourceTypeInline,
|
||||
InlineDocumentSource: BedrockRerankInlineSource{
|
||||
Type: bedrockRerankInlineDocumentTypeText,
|
||||
TextDocument: BedrockRerankTextValue{Text: "Paris is the capital of France."},
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: bedrockRerankSourceTypeInline,
|
||||
InlineDocumentSource: BedrockRerankInlineSource{
|
||||
Type: bedrockRerankInlineDocumentTypeText,
|
||||
TextDocument: BedrockRerankTextValue{Text: "Berlin is the capital of Germany."},
|
||||
},
|
||||
},
|
||||
},
|
||||
RerankingConfiguration: BedrockRerankingConfiguration{
|
||||
Type: bedrockRerankConfigurationTypeBedrock,
|
||||
BedrockRerankingConfiguration: BedrockRerankingModelConfiguration{
|
||||
NumberOfResults: &topN,
|
||||
ModelConfiguration: BedrockRerankModelConfiguration{
|
||||
ModelARN: "arn:aws:bedrock:us-east-1::foundation-model/cohere.rerank-v3-5:0",
|
||||
AdditionalModelRequestFields: map[string]interface{}{
|
||||
"truncate": "END",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
bifrostCtx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
|
||||
result := bedrockReq.ToBifrostRerankRequest(bifrostCtx)
|
||||
|
||||
require.NotNil(t, result)
|
||||
assert.Equal(t, schemas.Bedrock, result.Provider)
|
||||
assert.Equal(t, "arn:aws:bedrock:us-east-1::foundation-model/cohere.rerank-v3-5:0", result.Model)
|
||||
assert.Equal(t, "capital of france", result.Query)
|
||||
require.Len(t, result.Documents, 2)
|
||||
assert.Equal(t, "Paris is the capital of France.", result.Documents[0].Text)
|
||||
assert.Equal(t, "Berlin is the capital of Germany.", result.Documents[1].Text)
|
||||
require.NotNil(t, result.Params)
|
||||
require.NotNil(t, result.Params.TopN)
|
||||
assert.Equal(t, 3, *result.Params.TopN)
|
||||
require.NotNil(t, result.Params.ExtraParams)
|
||||
assert.Equal(t, "END", result.Params.ExtraParams["truncate"])
|
||||
}
|
||||
|
||||
func TestBedrockRerankRequestToBifrostRerankRequestNil(t *testing.T) {
|
||||
var req *BedrockRerankRequest
|
||||
assert.Nil(t, req.ToBifrostRerankRequest(nil))
|
||||
}
|
||||
|
||||
func TestResolveBedrockDeployment(t *testing.T) {
|
||||
key := schemas.Key{
|
||||
Aliases: schemas.KeyAliases{
|
||||
"cohere-rerank": "arn:aws:bedrock:us-east-1::foundation-model/cohere.rerank-v3-5:0",
|
||||
},
|
||||
}
|
||||
|
||||
deployment := key.Aliases.Resolve("cohere-rerank")
|
||||
assert.Equal(t, "arn:aws:bedrock:us-east-1::foundation-model/cohere.rerank-v3-5:0", deployment)
|
||||
assert.Equal(t, "cohere.rerank-v3-5:0", key.Aliases.Resolve("cohere.rerank-v3-5:0"))
|
||||
assert.Equal(t, "", key.Aliases.Resolve(""))
|
||||
}
|
||||
|
||||
func TestBedrockRerankRequiresARNModelIdentifier(t *testing.T) {
|
||||
provider := &BedrockProvider{}
|
||||
ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
|
||||
key := schemas.Key{
|
||||
Aliases: schemas.KeyAliases{
|
||||
"cohere-rerank": "cohere.rerank-v3-5:0",
|
||||
},
|
||||
}
|
||||
|
||||
response, bifrostErr := provider.Rerank(ctx, key, &schemas.BifrostRerankRequest{
|
||||
Model: "cohere-rerank",
|
||||
Query: "capital of france",
|
||||
Documents: []schemas.RerankDocument{
|
||||
{Text: "Paris is the capital of France."},
|
||||
},
|
||||
})
|
||||
|
||||
require.Nil(t, response)
|
||||
require.NotNil(t, bifrostErr)
|
||||
require.NotNil(t, bifrostErr.Error)
|
||||
assert.Contains(t, bifrostErr.Error.Message, "requires an ARN")
|
||||
}
|
||||
3394
core/providers/bedrock/responses.go
Normal file
3394
core/providers/bedrock/responses.go
Normal file
File diff suppressed because it is too large
Load Diff
130
core/providers/bedrock/s3.go
Normal file
130
core/providers/bedrock/s3.go
Normal file
@@ -0,0 +1,130 @@
|
||||
package bedrock
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
"github.com/aws/aws-sdk-go-v2/config"
|
||||
"github.com/aws/aws-sdk-go-v2/credentials"
|
||||
"github.com/aws/aws-sdk-go-v2/service/s3"
|
||||
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// uploadToS3 uploads content to an S3 bucket using the provided credentials.
|
||||
func uploadToS3(
|
||||
ctx context.Context,
|
||||
accessKey, secretKey string,
|
||||
sessionToken *string,
|
||||
region string,
|
||||
bucket, key string,
|
||||
content []byte,
|
||||
) *schemas.BifrostError {
|
||||
// Create AWS config with credentials
|
||||
var cfg aws.Config
|
||||
var err error
|
||||
|
||||
if accessKey != "" && secretKey != "" {
|
||||
// Use provided credentials
|
||||
var creds aws.CredentialsProvider
|
||||
if sessionToken != nil && *sessionToken != "" {
|
||||
creds = credentials.NewStaticCredentialsProvider(accessKey, secretKey, *sessionToken)
|
||||
} else {
|
||||
creds = credentials.NewStaticCredentialsProvider(accessKey, secretKey, "")
|
||||
}
|
||||
|
||||
cfg, err = config.LoadDefaultConfig(ctx,
|
||||
config.WithRegion(region),
|
||||
config.WithCredentialsProvider(creds),
|
||||
)
|
||||
} else {
|
||||
// Use default credentials chain (IAM role, env vars, etc.)
|
||||
cfg, err = config.LoadDefaultConfig(ctx, config.WithRegion(region))
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return providerUtils.NewBifrostOperationError("failed to load aws config for s3", err)
|
||||
}
|
||||
|
||||
// Create S3 client
|
||||
client := s3.NewFromConfig(cfg)
|
||||
|
||||
// Upload the content
|
||||
_, err = client.PutObject(ctx, &s3.PutObjectInput{
|
||||
Bucket: aws.String(bucket),
|
||||
Key: aws.String(key),
|
||||
Body: bytes.NewReader(content),
|
||||
ContentType: aws.String("application/jsonl"),
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return providerUtils.NewBifrostOperationError(fmt.Sprintf("failed to upload to s3: %s/%s", bucket, key), err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// generateBatchInputS3Key generates a unique S3 key for batch input files.
|
||||
func generateBatchInputS3Key(jobName string) string {
|
||||
timestamp := time.Now().UnixNano()
|
||||
return fmt.Sprintf("bifrost-batch-input/%s-%d.jsonl", jobName, timestamp)
|
||||
}
|
||||
|
||||
// deriveInputS3URIFromOutput derives an input S3 URI from the output S3 URI.
|
||||
// It uses the same bucket but with a different path for input files.
|
||||
func deriveInputS3URIFromOutput(outputS3URI, inputKey string) string {
|
||||
bucket, _ := parseS3URI(outputS3URI)
|
||||
return fmt.Sprintf("s3://%s/%s", bucket, inputKey)
|
||||
}
|
||||
|
||||
// ConvertBedrockRequestsToJSONL converts batch request items to JSONL format for Bedrock.
|
||||
// Bedrock uses a specific format for batch inference requests.
|
||||
func ConvertBedrockRequestsToJSONL(requests []schemas.BatchRequestItem, modelID *string) ([]byte, error) {
|
||||
// Model ID is required for Bedrock batch JSONL conversion
|
||||
if modelID == nil || *modelID == "" {
|
||||
return nil, fmt.Errorf("modelID is required for Bedrock batch JSONL conversion")
|
||||
}
|
||||
// Initialize the buffer
|
||||
var buf bytes.Buffer
|
||||
|
||||
// Iterate over the requests
|
||||
for _, req := range requests {
|
||||
// Build the Bedrock batch request format
|
||||
bedrockReq := map[string]interface{}{
|
||||
"recordId": req.CustomID,
|
||||
"modelInput": map[string]interface{}{
|
||||
"modelId": *modelID,
|
||||
},
|
||||
}
|
||||
|
||||
// If the request has a body, use it as the model input parameters
|
||||
if req.Body != nil {
|
||||
modelInput := bedrockReq["modelInput"].(map[string]interface{})
|
||||
for k, v := range req.Body {
|
||||
if k != "model" { // Don't override modelId
|
||||
modelInput[k] = v
|
||||
}
|
||||
}
|
||||
} else if req.Params != nil {
|
||||
modelInput := bedrockReq["modelInput"].(map[string]interface{})
|
||||
for k, v := range req.Params {
|
||||
if k != "model" {
|
||||
modelInput[k] = v
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Marshal the request as a JSON line
|
||||
line, err := providerUtils.MarshalSorted(bedrockReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal batch request item %s: %w", req.CustomID, err)
|
||||
}
|
||||
buf.Write(line)
|
||||
buf.WriteByte('\n')
|
||||
}
|
||||
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
433
core/providers/bedrock/signer.go
Normal file
433
core/providers/bedrock/signer.go
Normal file
@@ -0,0 +1,433 @@
|
||||
package bedrock
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/config"
|
||||
"github.com/aws/smithy-go/encoding/httpbinding"
|
||||
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
|
||||
schemas "github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
const (
|
||||
signingAlgorithm = "AWS4-HMAC-SHA256"
|
||||
amzDateKey = "X-Amz-Date"
|
||||
amzSecurityToken = "X-Amz-Security-Token"
|
||||
timeFormat = "20060102T150405Z"
|
||||
shortTimeFormat = "20060102"
|
||||
)
|
||||
|
||||
// Headers to ignore during signing
|
||||
var ignoredHeaders = map[string]struct{}{
|
||||
"authorization": {},
|
||||
"user-agent": {},
|
||||
"x-amzn-trace-id": {},
|
||||
"expect": {},
|
||||
"transfer-encoding": {},
|
||||
}
|
||||
|
||||
// signingKeyCache caches derived signing keys to avoid recomputation
|
||||
type signingKeyCache struct {
|
||||
cache map[string]cachedKey
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
type cachedKey struct {
|
||||
key []byte
|
||||
date string // YYYYMMDD format
|
||||
accessKey string
|
||||
}
|
||||
|
||||
var keyCache = &signingKeyCache{
|
||||
cache: make(map[string]cachedKey),
|
||||
}
|
||||
|
||||
// hmacSHA256 computes HMAC-SHA256
|
||||
func hmacSHA256(key, data []byte) []byte {
|
||||
h := hmac.New(sha256.New, key)
|
||||
h.Write(data)
|
||||
return h.Sum(nil)
|
||||
}
|
||||
|
||||
// deriveSigningKey derives the AWS signing key
|
||||
func deriveSigningKey(secret, dateStamp, region, service string) []byte {
|
||||
kDate := hmacSHA256([]byte("AWS4"+secret), []byte(dateStamp))
|
||||
kRegion := hmacSHA256(kDate, []byte(region))
|
||||
kService := hmacSHA256(kRegion, []byte(service))
|
||||
kSigning := hmacSHA256(kService, []byte("aws4_request"))
|
||||
return kSigning
|
||||
}
|
||||
|
||||
// getSigningKey retrieves or computes the signing key with caching
|
||||
func getSigningKey(accessKey, secretKey, dateStamp, region, service string) []byte {
|
||||
cacheKey := fmt.Sprintf("%s/%s/%s/%s", accessKey, dateStamp, region, service)
|
||||
|
||||
keyCache.mu.RLock()
|
||||
if cached, ok := keyCache.cache[cacheKey]; ok && cached.accessKey == accessKey && cached.date == dateStamp {
|
||||
keyCache.mu.RUnlock()
|
||||
return cached.key
|
||||
}
|
||||
keyCache.mu.RUnlock()
|
||||
|
||||
keyCache.mu.Lock()
|
||||
defer keyCache.mu.Unlock()
|
||||
|
||||
// Double-check after acquiring write lock
|
||||
if cached, ok := keyCache.cache[cacheKey]; ok && cached.accessKey == accessKey && cached.date == dateStamp {
|
||||
return cached.key
|
||||
}
|
||||
|
||||
key := deriveSigningKey(secretKey, dateStamp, region, service)
|
||||
keyCache.cache[cacheKey] = cachedKey{
|
||||
key: key,
|
||||
date: dateStamp,
|
||||
accessKey: accessKey,
|
||||
}
|
||||
|
||||
return key
|
||||
}
|
||||
|
||||
// stripExcessSpaces removes excess spaces from a string
|
||||
func stripExcessSpaces(str string) string {
|
||||
str = strings.TrimSpace(str)
|
||||
if !strings.Contains(str, " ") {
|
||||
return str
|
||||
}
|
||||
|
||||
var result strings.Builder
|
||||
result.Grow(len(str))
|
||||
prevWasSpace := false
|
||||
|
||||
for _, ch := range str {
|
||||
if ch == ' ' {
|
||||
if !prevWasSpace {
|
||||
result.WriteRune(ch)
|
||||
}
|
||||
prevWasSpace = true
|
||||
} else {
|
||||
result.WriteRune(ch)
|
||||
prevWasSpace = false
|
||||
}
|
||||
}
|
||||
|
||||
return result.String()
|
||||
}
|
||||
|
||||
// percentEncodeRFC3986 encodes a string per RFC 3986
|
||||
// Keep unreserved characters (A-Z, a-z, 0-9, -, _, ., ~) as-is
|
||||
// Percent-encode everything else as %HH using uppercase hex
|
||||
func percentEncodeRFC3986(s string) string {
|
||||
var result strings.Builder
|
||||
result.Grow(len(s))
|
||||
|
||||
for i := 0; i < len(s); i++ {
|
||||
b := s[i]
|
||||
// RFC 3986 unreserved characters
|
||||
if (b >= 'A' && b <= 'Z') ||
|
||||
(b >= 'a' && b <= 'z') ||
|
||||
(b >= '0' && b <= '9') ||
|
||||
b == '-' || b == '_' || b == '.' || b == '~' {
|
||||
result.WriteByte(b)
|
||||
} else {
|
||||
// Percent-encode with uppercase hex
|
||||
result.WriteByte('%')
|
||||
result.WriteByte(uppercaseHex(b >> 4))
|
||||
result.WriteByte(uppercaseHex(b & 0x0F))
|
||||
}
|
||||
}
|
||||
|
||||
return result.String()
|
||||
}
|
||||
|
||||
// uppercaseHex returns the uppercase hex character for a nibble (0-15)
|
||||
func uppercaseHex(b byte) byte {
|
||||
if b < 10 {
|
||||
return '0' + b
|
||||
}
|
||||
return 'A' + (b - 10)
|
||||
}
|
||||
|
||||
// percentDecode decodes percent-encoded sequences in a string without treating + as space
|
||||
// This differs from url.QueryUnescape which uses form encoding (+ becomes space)
|
||||
func percentDecode(s string) string {
|
||||
// Quick check if there are any percent signs
|
||||
if !strings.Contains(s, "%") {
|
||||
return s
|
||||
}
|
||||
|
||||
var result strings.Builder
|
||||
result.Grow(len(s))
|
||||
|
||||
for i := 0; i < len(s); {
|
||||
if s[i] == '%' && i+2 < len(s) {
|
||||
// Try to decode the hex sequence
|
||||
if h1 := unhex(s[i+1]); h1 >= 0 {
|
||||
if h2 := unhex(s[i+2]); h2 >= 0 {
|
||||
result.WriteByte(byte(h1<<4 | h2))
|
||||
i += 3
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
result.WriteByte(s[i])
|
||||
i++
|
||||
}
|
||||
|
||||
return result.String()
|
||||
}
|
||||
|
||||
// unhex converts a hex character to its value, or -1 if not a hex char
|
||||
func unhex(c byte) int {
|
||||
switch {
|
||||
case '0' <= c && c <= '9':
|
||||
return int(c - '0')
|
||||
case 'a' <= c && c <= 'f':
|
||||
return int(c - 'a' + 10)
|
||||
case 'A' <= c && c <= 'F':
|
||||
return int(c - 'A' + 10)
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
// queryPair represents a query parameter name-value pair
|
||||
type queryPair struct {
|
||||
encodedName string
|
||||
encodedValue string
|
||||
}
|
||||
|
||||
// buildCanonicalQueryString builds a canonical query string per AWS SigV4 spec
|
||||
// using proper RFC 3986 percent-encoding
|
||||
func buildCanonicalQueryString(queryString string) string {
|
||||
if queryString == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Split the raw query string on '&' into pairs
|
||||
rawPairs := strings.Split(queryString, "&")
|
||||
pairs := make([]queryPair, 0, len(rawPairs))
|
||||
|
||||
for _, rawPair := range rawPairs {
|
||||
if rawPair == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Split on the first '=' to get name and value
|
||||
var name, value string
|
||||
if idx := strings.IndexByte(rawPair, '='); idx >= 0 {
|
||||
name = rawPair[:idx]
|
||||
value = rawPair[idx+1:]
|
||||
} else {
|
||||
// No '=' means name only, empty value
|
||||
name = rawPair
|
||||
value = ""
|
||||
}
|
||||
|
||||
// Decode percent-encoded sequences first to normalize (handles already-encoded values)
|
||||
// then encode per RFC 3986 to ensure consistent encoding
|
||||
// Note: We use percentDecode instead of url.QueryUnescape because the latter
|
||||
// treats + as space (form encoding), but we need + to encode as %2B
|
||||
decodedName := percentDecode(name)
|
||||
decodedValue := percentDecode(value)
|
||||
|
||||
// Percent-encode name and value per RFC 3986
|
||||
encodedName := percentEncodeRFC3986(decodedName)
|
||||
encodedValue := percentEncodeRFC3986(decodedValue)
|
||||
|
||||
pairs = append(pairs, queryPair{
|
||||
encodedName: encodedName,
|
||||
encodedValue: encodedValue,
|
||||
})
|
||||
}
|
||||
|
||||
// Sort pairs lexicographically by encoded name, then by encoded value
|
||||
sort.Slice(pairs, func(i, j int) bool {
|
||||
if pairs[i].encodedName != pairs[j].encodedName {
|
||||
return pairs[i].encodedName < pairs[j].encodedName
|
||||
}
|
||||
return pairs[i].encodedValue < pairs[j].encodedValue
|
||||
})
|
||||
|
||||
// Join encoded pairs with '&'
|
||||
var result strings.Builder
|
||||
for i, pair := range pairs {
|
||||
if i > 0 {
|
||||
result.WriteByte('&')
|
||||
}
|
||||
result.WriteString(pair.encodedName)
|
||||
result.WriteByte('=')
|
||||
result.WriteString(pair.encodedValue)
|
||||
}
|
||||
|
||||
return result.String()
|
||||
}
|
||||
|
||||
// signAWSRequestFastHTTP signs a fasthttp request using AWS Signature Version 4
|
||||
// This is a native implementation that avoids allocating http.Request
|
||||
func signAWSRequestFastHTTP(
|
||||
ctx context.Context,
|
||||
req *fasthttp.Request,
|
||||
body []byte,
|
||||
accessKey, secretKey string,
|
||||
sessionToken *string,
|
||||
region, service string,
|
||||
) *schemas.BifrostError {
|
||||
// Get AWS credentials if not provided
|
||||
if accessKey == "" && secretKey == "" {
|
||||
cfg, err := config.LoadDefaultConfig(ctx, config.WithRegion(region))
|
||||
if err != nil {
|
||||
return providerUtils.NewBifrostOperationError("failed to load aws config", err)
|
||||
}
|
||||
creds, err := cfg.Credentials.Retrieve(ctx)
|
||||
if err != nil {
|
||||
return providerUtils.NewBifrostOperationError("failed to retrieve aws credentials", err)
|
||||
}
|
||||
accessKey = creds.AccessKeyID
|
||||
secretKey = creds.SecretAccessKey
|
||||
if creds.SessionToken != "" {
|
||||
st := creds.SessionToken
|
||||
sessionToken = &st
|
||||
}
|
||||
}
|
||||
|
||||
// Get current time
|
||||
now := time.Now().UTC()
|
||||
amzDate := now.Format(timeFormat)
|
||||
dateStamp := now.Format(shortTimeFormat)
|
||||
|
||||
// Parse URI
|
||||
uri := req.URI()
|
||||
host := string(uri.Host())
|
||||
path := string(uri.Path())
|
||||
if path == "" {
|
||||
path = "/"
|
||||
}
|
||||
queryString := string(uri.QueryString())
|
||||
|
||||
// Escape path for canonical URI (Bedrock doesn't disable escaping)
|
||||
canonicalURI := httpbinding.EscapePath(path, false)
|
||||
|
||||
// Calculate payload hash
|
||||
hash := sha256.Sum256(body)
|
||||
payloadHash := hex.EncodeToString(hash[:])
|
||||
|
||||
// Set required headers
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set(amzDateKey, amzDate)
|
||||
if sessionToken != nil && *sessionToken != "" {
|
||||
req.Header.Set(amzSecurityToken, *sessionToken)
|
||||
}
|
||||
|
||||
// Build canonical headers
|
||||
var headerNames []string
|
||||
headerMap := make(map[string][]string)
|
||||
|
||||
// Always include host
|
||||
headerNames = append(headerNames, "host")
|
||||
headerMap["host"] = []string{host}
|
||||
|
||||
// Include content-length if body is present
|
||||
if cl := req.Header.ContentLength(); cl >= 0 {
|
||||
headerNames = append(headerNames, "content-length")
|
||||
headerMap["content-length"] = []string{strconv.Itoa(cl)}
|
||||
}
|
||||
|
||||
// Collect other headers
|
||||
for key, value := range req.Header.All() {
|
||||
keyStr := strings.ToLower(string(key))
|
||||
|
||||
// Skip ignored headers
|
||||
if _, ignore := ignoredHeaders[keyStr]; ignore {
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip if already handled
|
||||
if keyStr == "host" || keyStr == "content-length" {
|
||||
continue
|
||||
}
|
||||
|
||||
if _, exists := headerMap[keyStr]; !exists {
|
||||
headerNames = append(headerNames, keyStr)
|
||||
}
|
||||
headerMap[keyStr] = append(headerMap[keyStr], string(value))
|
||||
}
|
||||
|
||||
// Sort header names
|
||||
sort.Strings(headerNames)
|
||||
|
||||
// Build canonical headers string
|
||||
var canonicalHeaders strings.Builder
|
||||
for _, name := range headerNames {
|
||||
canonicalHeaders.WriteString(name)
|
||||
canonicalHeaders.WriteRune(':')
|
||||
|
||||
values := headerMap[name]
|
||||
for i, v := range values {
|
||||
cleanedValue := stripExcessSpaces(v)
|
||||
canonicalHeaders.WriteString(cleanedValue)
|
||||
if i < len(values)-1 {
|
||||
canonicalHeaders.WriteRune(',')
|
||||
}
|
||||
}
|
||||
canonicalHeaders.WriteRune('\n')
|
||||
}
|
||||
|
||||
signedHeaders := strings.Join(headerNames, ";")
|
||||
|
||||
// Build canonical query string using RFC 3986 encoding
|
||||
canonicalQueryString := buildCanonicalQueryString(queryString)
|
||||
|
||||
// Build canonical request
|
||||
canonicalRequest := strings.Join([]string{
|
||||
string(req.Header.Method()),
|
||||
canonicalURI,
|
||||
canonicalQueryString,
|
||||
canonicalHeaders.String(),
|
||||
signedHeaders,
|
||||
payloadHash,
|
||||
}, "\n")
|
||||
|
||||
// Build credential scope
|
||||
credentialScope := strings.Join([]string{
|
||||
dateStamp,
|
||||
region,
|
||||
service,
|
||||
"aws4_request",
|
||||
}, "/")
|
||||
|
||||
// Build string to sign
|
||||
canonicalRequestHash := sha256.Sum256([]byte(canonicalRequest))
|
||||
stringToSign := strings.Join([]string{
|
||||
signingAlgorithm,
|
||||
amzDate,
|
||||
credentialScope,
|
||||
hex.EncodeToString(canonicalRequestHash[:]),
|
||||
}, "\n")
|
||||
|
||||
// Calculate signature
|
||||
signingKey := getSigningKey(accessKey, secretKey, dateStamp, region, service)
|
||||
signature := hex.EncodeToString(hmacSHA256(signingKey, []byte(stringToSign)))
|
||||
|
||||
// Build authorization header
|
||||
authHeader := fmt.Sprintf("%s Credential=%s/%s, SignedHeaders=%s, Signature=%s",
|
||||
signingAlgorithm,
|
||||
accessKey,
|
||||
credentialScope,
|
||||
signedHeaders,
|
||||
signature,
|
||||
)
|
||||
|
||||
req.Header.Set("Authorization", authHeader)
|
||||
|
||||
return nil
|
||||
}
|
||||
229
core/providers/bedrock/text.go
Normal file
229
core/providers/bedrock/text.go
Normal file
@@ -0,0 +1,229 @@
|
||||
package bedrock
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/maximhq/bifrost/core/providers/anthropic"
|
||||
"github.com/maximhq/bifrost/core/providers/utils"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// ToBedrockTextCompletionRequest converts a Bifrost text completion request to Bedrock format
|
||||
func ToBedrockTextCompletionRequest(bifrostReq *schemas.BifrostTextCompletionRequest) *BedrockTextCompletionRequest {
|
||||
if bifrostReq == nil || (bifrostReq.Input.PromptStr == nil && len(bifrostReq.Input.PromptArray) == 0) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Extract the raw prompt from bifrostReq
|
||||
prompt := ""
|
||||
if bifrostReq.Input != nil {
|
||||
if bifrostReq.Input.PromptStr != nil {
|
||||
prompt = *bifrostReq.Input.PromptStr
|
||||
} else if len(bifrostReq.Input.PromptArray) > 0 && bifrostReq.Input.PromptArray != nil {
|
||||
prompt = strings.Join(bifrostReq.Input.PromptArray, "\n\n")
|
||||
}
|
||||
}
|
||||
|
||||
bedrockReq := &BedrockTextCompletionRequest{
|
||||
Prompt: prompt,
|
||||
}
|
||||
|
||||
// Apply parameters
|
||||
if bifrostReq.Params != nil {
|
||||
bedrockReq.Temperature = bifrostReq.Params.Temperature
|
||||
bedrockReq.TopP = bifrostReq.Params.TopP
|
||||
|
||||
if bifrostReq.Params.ExtraParams != nil {
|
||||
bedrockReq.ExtraParams = bifrostReq.Params.ExtraParams
|
||||
if topK, ok := schemas.SafeExtractIntPointer(bifrostReq.Params.ExtraParams["top_k"]); ok {
|
||||
delete(bedrockReq.ExtraParams, "top_k")
|
||||
bedrockReq.TopK = topK
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Apply model-specific formatting and field naming
|
||||
if strings.Contains(bifrostReq.Model, "anthropic.") || strings.Contains(bifrostReq.Model, "claude") {
|
||||
// For Claude models, wrap the prompt in Anthropic format and use Anthropic field names
|
||||
anthropicReq := anthropic.ToAnthropicTextCompletionRequest(bifrostReq)
|
||||
bedrockReq.Prompt = anthropicReq.Prompt
|
||||
bedrockReq.MaxTokensToSample = &anthropicReq.MaxTokensToSample
|
||||
bedrockReq.StopSequences = anthropicReq.StopSequences
|
||||
} else {
|
||||
// For other models, use standard field names with raw prompt
|
||||
if bifrostReq.Params != nil {
|
||||
bedrockReq.MaxTokens = bifrostReq.Params.MaxTokens
|
||||
bedrockReq.Stop = bifrostReq.Params.Stop
|
||||
}
|
||||
}
|
||||
|
||||
return bedrockReq
|
||||
}
|
||||
|
||||
// ToBifrostTextCompletionRequest converts a Bedrock text completion request to Bifrost format
|
||||
func (request *BedrockTextCompletionRequest) ToBifrostTextCompletionRequest(ctx *schemas.BifrostContext) *schemas.BifrostTextCompletionRequest {
|
||||
if request == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
prompt := request.Prompt
|
||||
// Fallback for Claude 3 Messages API
|
||||
if prompt == "" && len(request.Messages) > 0 {
|
||||
var parts []string
|
||||
for _, msg := range request.Messages {
|
||||
for _, content := range msg.Content {
|
||||
if content.Text != nil {
|
||||
parts = append(parts, *content.Text)
|
||||
}
|
||||
}
|
||||
}
|
||||
prompt = strings.Join(parts, "\n\n")
|
||||
}
|
||||
|
||||
provider, model := schemas.ParseModelString(request.ModelID, utils.CheckAndSetDefaultProvider(ctx, schemas.Bedrock))
|
||||
|
||||
bifrostReq := &schemas.BifrostTextCompletionRequest{
|
||||
Provider: provider,
|
||||
Model: model,
|
||||
Input: &schemas.TextCompletionInput{
|
||||
PromptStr: &prompt,
|
||||
},
|
||||
Params: &schemas.TextCompletionParameters{
|
||||
Temperature: request.Temperature,
|
||||
TopP: request.TopP,
|
||||
},
|
||||
}
|
||||
|
||||
if request.MaxTokens != nil {
|
||||
bifrostReq.Params.MaxTokens = request.MaxTokens
|
||||
} else if request.MaxTokensToSample != nil {
|
||||
bifrostReq.Params.MaxTokens = request.MaxTokensToSample
|
||||
}
|
||||
|
||||
if len(request.Stop) > 0 {
|
||||
bifrostReq.Params.Stop = request.Stop
|
||||
} else if len(request.StopSequences) > 0 {
|
||||
bifrostReq.Params.Stop = request.StopSequences
|
||||
}
|
||||
|
||||
return bifrostReq
|
||||
}
|
||||
|
||||
// ToBifrostTextCompletionResponse converts a Bedrock Anthropic text response to Bifrost format
|
||||
func (response *BedrockAnthropicTextResponse) ToBifrostTextCompletionResponse() *schemas.BifrostTextCompletionResponse {
|
||||
if response == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &schemas.BifrostTextCompletionResponse{
|
||||
Object: "text_completion",
|
||||
Choices: []schemas.BifrostResponseChoice{
|
||||
{
|
||||
Index: 0,
|
||||
TextCompletionResponseChoice: &schemas.TextCompletionResponseChoice{
|
||||
Text: &response.Completion,
|
||||
},
|
||||
FinishReason: &response.StopReason,
|
||||
},
|
||||
},
|
||||
ExtraFields: schemas.BifrostResponseExtraFields{
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// ToBifrostTextCompletionResponse converts a Bedrock Mistral text response to Bifrost format
|
||||
func (response *BedrockMistralTextResponse) ToBifrostTextCompletionResponse() *schemas.BifrostTextCompletionResponse {
|
||||
if response == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var choices []schemas.BifrostResponseChoice
|
||||
for i, output := range response.Outputs {
|
||||
choices = append(choices, schemas.BifrostResponseChoice{
|
||||
Index: i,
|
||||
TextCompletionResponseChoice: &schemas.TextCompletionResponseChoice{
|
||||
Text: &output.Text,
|
||||
},
|
||||
FinishReason: &output.StopReason,
|
||||
})
|
||||
}
|
||||
|
||||
return &schemas.BifrostTextCompletionResponse{
|
||||
Object: "text_completion",
|
||||
Choices: choices,
|
||||
ExtraFields: schemas.BifrostResponseExtraFields{
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// ToBedrockTextCompletionResponse converts a BifrostTextCompletionResponse back to Bedrock text completion format
|
||||
// Returns either *BedrockAnthropicTextResponse or *BedrockMistralTextResponse based on the model
|
||||
func ToBedrockTextCompletionResponse(bifrostResp *schemas.BifrostTextCompletionResponse) interface{} {
|
||||
if bifrostResp == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Determine response format based on resolved model identity.
|
||||
// Use ResolvedModelUsed (actual provider ID) for accurate family detection,
|
||||
// falling back to bifrostResp.Model, then OriginalModelRequested as a last resort.
|
||||
model := bifrostResp.Model
|
||||
if bifrostResp.ExtraFields.ResolvedModelUsed != "" {
|
||||
model = bifrostResp.ExtraFields.ResolvedModelUsed
|
||||
} else if model == "" && bifrostResp.ExtraFields.OriginalModelRequested != "" {
|
||||
model = bifrostResp.ExtraFields.OriginalModelRequested
|
||||
}
|
||||
|
||||
if strings.Contains(model, "anthropic.") || strings.Contains(model, "claude") {
|
||||
// Convert to Anthropic format
|
||||
bedrockResp := &BedrockAnthropicTextResponse{}
|
||||
|
||||
// Convert choices to completion text
|
||||
if len(bifrostResp.Choices) > 0 {
|
||||
choice := bifrostResp.Choices[0] // Anthropic text API typically returns one choice
|
||||
if choice.TextCompletionResponseChoice != nil && choice.TextCompletionResponseChoice.Text != nil {
|
||||
bedrockResp.Completion = *choice.TextCompletionResponseChoice.Text
|
||||
}
|
||||
if choice.FinishReason != nil {
|
||||
bedrockResp.StopReason = *choice.FinishReason
|
||||
}
|
||||
}
|
||||
|
||||
return bedrockResp
|
||||
} else if strings.Contains(model, "mistral.") {
|
||||
// Convert to Mistral format
|
||||
bedrockResp := &BedrockMistralTextResponse{}
|
||||
|
||||
// Convert choices to outputs
|
||||
for _, choice := range bifrostResp.Choices {
|
||||
var output struct {
|
||||
Text string `json:"text"`
|
||||
StopReason string `json:"stop_reason"`
|
||||
}
|
||||
|
||||
if choice.TextCompletionResponseChoice != nil && choice.TextCompletionResponseChoice.Text != nil {
|
||||
output.Text = *choice.TextCompletionResponseChoice.Text
|
||||
}
|
||||
if choice.FinishReason != nil {
|
||||
output.StopReason = *choice.FinishReason
|
||||
}
|
||||
|
||||
bedrockResp.Outputs = append(bedrockResp.Outputs, output)
|
||||
}
|
||||
|
||||
return bedrockResp
|
||||
}
|
||||
|
||||
// Default to Anthropic format if model type cannot be determined
|
||||
bedrockResp := &BedrockAnthropicTextResponse{}
|
||||
if len(bifrostResp.Choices) > 0 {
|
||||
choice := bifrostResp.Choices[0]
|
||||
if choice.TextCompletionResponseChoice != nil && choice.TextCompletionResponseChoice.Text != nil {
|
||||
bedrockResp.Completion = *choice.TextCompletionResponseChoice.Text
|
||||
}
|
||||
if choice.FinishReason != nil {
|
||||
bedrockResp.StopReason = *choice.FinishReason
|
||||
}
|
||||
}
|
||||
|
||||
return bedrockResp
|
||||
}
|
||||
714
core/providers/bedrock/transport_test.go
Normal file
714
core/providers/bedrock/transport_test.go
Normal file
@@ -0,0 +1,714 @@
|
||||
package bedrock
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"io"
|
||||
"math/big"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// redirectTransport is an http.RoundTripper that rewrites every request's
|
||||
// host/scheme to a fixed target URL, used to redirect provider requests to a
|
||||
// local httptest.Server without modifying provider code.
|
||||
type redirectTransport struct {
|
||||
target *url.URL
|
||||
transport http.RoundTripper
|
||||
}
|
||||
|
||||
func (r *redirectTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
cloned := req.Clone(req.Context())
|
||||
cloned.URL.Scheme = r.target.Scheme
|
||||
cloned.URL.Host = r.target.Host
|
||||
cloned.Host = r.target.Host
|
||||
return r.transport.RoundTrip(cloned)
|
||||
}
|
||||
|
||||
// noopLogger is a no-op schemas.Logger for use in tests.
|
||||
type noopLogger struct{}
|
||||
|
||||
func (noopLogger) Debug(string, ...any) {}
|
||||
func (noopLogger) Info(string, ...any) {}
|
||||
func (noopLogger) Warn(string, ...any) {}
|
||||
func (noopLogger) Error(string, ...any) {}
|
||||
func (noopLogger) Fatal(string, ...any) {}
|
||||
func (noopLogger) SetLevel(schemas.LogLevel) {}
|
||||
func (noopLogger) SetOutputType(schemas.LoggerOutputType) {}
|
||||
func (noopLogger) LogHTTPRequest(schemas.LogLevel, string) schemas.LogEventBuilder {
|
||||
return schemas.NoopLogEvent
|
||||
}
|
||||
|
||||
// newTestProviderWithServer returns a BedrockProvider whose HTTP client is
|
||||
// redirected to the given httptest.Server.
|
||||
func newTestProviderWithServer(t *testing.T, ts *httptest.Server) *BedrockProvider {
|
||||
t.Helper()
|
||||
config := &schemas.ProviderConfig{
|
||||
NetworkConfig: schemas.NetworkConfig{
|
||||
DefaultRequestTimeoutInSeconds: 5,
|
||||
},
|
||||
}
|
||||
config.CheckAndSetDefaults()
|
||||
provider, err := NewBedrockProvider(config, noopLogger{})
|
||||
require.NoError(t, err)
|
||||
|
||||
targetURL, err := url.Parse(ts.URL)
|
||||
require.NoError(t, err)
|
||||
|
||||
redirect := &redirectTransport{
|
||||
target: targetURL,
|
||||
transport: ts.Client().Transport,
|
||||
}
|
||||
provider.client = &http.Client{
|
||||
Transport: redirect,
|
||||
Timeout: 5 * time.Second,
|
||||
}
|
||||
// Streaming paths use streamingClient (no Timeout); redirect it to the
|
||||
// test server too, otherwise Bedrock streaming tests would hit the real
|
||||
// AWS endpoint.
|
||||
provider.streamingClient = &http.Client{Transport: redirect}
|
||||
return provider
|
||||
}
|
||||
|
||||
// testBedrockKey returns a minimal Key with a bearer value so makeStreamingRequest
|
||||
// skips IAM signing and proceeds to the HTTP call.
|
||||
func testBedrockKey() schemas.Key {
|
||||
region := schemas.NewEnvVar("us-east-1")
|
||||
return schemas.Key{
|
||||
Value: *schemas.NewEnvVar("test-api-key"),
|
||||
BedrockKeyConfig: &schemas.BedrockKeyConfig{
|
||||
Region: region,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// testBedrockCtx returns a BifrostContext suitable for unit tests.
|
||||
func testBedrockCtx() *schemas.BifrostContext {
|
||||
return schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
|
||||
}
|
||||
|
||||
// noopPostHookRunner is a PostHookRunner that passes through results unchanged.
|
||||
func noopPostHookRunner(_ *schemas.BifrostContext, result *schemas.BifrostResponse, err *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError) {
|
||||
return result, err
|
||||
}
|
||||
|
||||
// testChatRequest returns a minimal BifrostChatRequest for streaming tests.
|
||||
func testChatRequest() *schemas.BifrostChatRequest {
|
||||
content := "hello"
|
||||
return &schemas.BifrostChatRequest{
|
||||
Model: "anthropic.claude-sonnet-4-5",
|
||||
Input: []schemas.ChatMessage{
|
||||
{
|
||||
Role: schemas.ChatMessageRoleUser,
|
||||
Content: &schemas.ChatMessageContent{ContentStr: &content},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// TestMakeStreamingRequest_StaleConnection_IsRetryable verifies that when the
|
||||
// HTTP server closes the connection before sending a response (simulating a
|
||||
// stale HTTP/2 connection), makeStreamingRequest returns a BifrostError with
|
||||
// IsBifrostError:false so the retry gate in executeRequestWithRetries retries.
|
||||
func TestMakeStreamingRequest_StaleConnection_IsRetryable(t *testing.T) {
|
||||
// Server that immediately closes the connection without sending anything.
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
hj, ok := w.(http.Hijacker)
|
||||
if !ok {
|
||||
http.Error(w, "hijack not supported", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
conn, _, _ := hj.Hijack()
|
||||
conn.Close() // close without writing any response
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
provider := newTestProviderWithServer(t, ts)
|
||||
ctx := testBedrockCtx()
|
||||
key := testBedrockKey()
|
||||
|
||||
_, bifrostErr := provider.makeStreamingRequest(ctx, []byte(`{}`), key, "anthropic.claude-sonnet-4-5", "converse-stream")
|
||||
|
||||
require.NotNil(t, bifrostErr, "expected error when server closes connection")
|
||||
assert.False(t, bifrostErr.IsBifrostError,
|
||||
"stale-connection error must be IsBifrostError:false so the retry gate can retry it")
|
||||
require.NotNil(t, bifrostErr.Error)
|
||||
// Either ErrProviderNetworkError (net.OpError) or ErrProviderDoRequest (EOF/connection-reset)
|
||||
// are both retryable — the key invariant is IsBifrostError:false.
|
||||
assert.Contains(t, []string{schemas.ErrProviderNetworkError, schemas.ErrProviderDoRequest}, bifrostErr.Error.Message,
|
||||
"stale-connection error must use a retryable error message")
|
||||
}
|
||||
|
||||
// TestChatCompletionStream_StaleConnection_ChunkIsRetryable verifies that when
|
||||
// the server returns HTTP 200 but closes the body immediately (simulating a
|
||||
// stale connection mid-stream before any EventStream data arrives), the first
|
||||
// chunk received from the stream channel carries a BifrostError with
|
||||
// IsBifrostError:false so that CheckFirstStreamChunkForError + the retry gate
|
||||
// can transparently retry the request.
|
||||
func TestChatCompletionStream_StaleConnection_ChunkIsRetryable(t *testing.T) {
|
||||
// Server: returns 200 with the correct content-type but closes body immediately.
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/vnd.amazon.eventstream")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if f, ok := w.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
hj, ok := w.(http.Hijacker)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
conn, _, _ := hj.Hijack()
|
||||
conn.Close() // close without any EventStream bytes
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
provider := newTestProviderWithServer(t, ts)
|
||||
ctx := testBedrockCtx()
|
||||
key := testBedrockKey()
|
||||
|
||||
streamChan, bifrostErr := provider.ChatCompletionStream(ctx, noopPostHookRunner, nil, key, testChatRequest())
|
||||
|
||||
if bifrostErr != nil {
|
||||
// Error surfaced synchronously (e.g. connection refused before HTTP 200).
|
||||
assert.False(t, bifrostErr.IsBifrostError,
|
||||
"pre-stream network error must be IsBifrostError:false")
|
||||
return
|
||||
}
|
||||
|
||||
// Error surfaced as the first stream chunk.
|
||||
require.NotNil(t, streamChan)
|
||||
chunk, ok := <-streamChan
|
||||
require.True(t, ok, "channel must not be empty")
|
||||
require.NotNil(t, chunk)
|
||||
require.NotNil(t, chunk.BifrostError, "expected an error chunk from the stream")
|
||||
|
||||
assert.False(t, chunk.BifrostError.IsBifrostError,
|
||||
"stream transport error must be IsBifrostError:false so the retry gate can retry it")
|
||||
require.NotNil(t, chunk.BifrostError.Error)
|
||||
assert.Equal(t, schemas.ErrProviderNetworkError, chunk.BifrostError.Error.Message,
|
||||
"stream transport error must use ErrProviderNetworkError message")
|
||||
|
||||
// Drain any remaining chunks.
|
||||
for range streamChan {
|
||||
}
|
||||
}
|
||||
|
||||
// TestChatCompletionStream_NetOpError_ChunkIsRetryable verifies the specific
|
||||
// "use of closed network connection" *net.OpError scenario from issue #2424:
|
||||
// a successful HTTP connection that is then closed server-side produces a
|
||||
// *net.OpError during EventStream decoding, which must arrive as a retryable
|
||||
// IsBifrostError:false chunk.
|
||||
func TestChatCompletionStream_NetOpError_ChunkIsRetryable(t *testing.T) {
|
||||
// Server: returns 200 + correct headers, writes a truncated EventStream
|
||||
// prelude (not a valid frame), then forcibly resets the TCP connection —
|
||||
// producing a *net.OpError on the client's read side.
|
||||
ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/vnd.amazon.eventstream")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if f, ok := w.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
// Write a partial EventStream frame header (3 bytes, not a valid frame).
|
||||
_, _ = w.Write([]byte{0x00, 0x00, 0x00})
|
||||
if f, ok := w.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
hj, ok := w.(http.Hijacker)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
conn, _, _ := hj.Hijack()
|
||||
// RST instead of FIN — guarantees a *net.OpError on the client read.
|
||||
if tc, ok := conn.(*net.TCPConn); ok {
|
||||
_ = tc.SetLinger(0)
|
||||
}
|
||||
conn.Close()
|
||||
}))
|
||||
ts.Start()
|
||||
defer ts.Close()
|
||||
|
||||
provider := newTestProviderWithServer(t, ts)
|
||||
ctx := testBedrockCtx()
|
||||
key := testBedrockKey()
|
||||
|
||||
streamChan, bifrostErr := provider.ChatCompletionStream(ctx, noopPostHookRunner, nil, key, testChatRequest())
|
||||
if bifrostErr != nil {
|
||||
assert.False(t, bifrostErr.IsBifrostError,
|
||||
"pre-stream network error must be IsBifrostError:false")
|
||||
return
|
||||
}
|
||||
|
||||
require.NotNil(t, streamChan)
|
||||
|
||||
// Collect chunks until we find an error chunk (may not be the very first
|
||||
// if the OS buffers the partial write, but it must appear before close).
|
||||
var errChunk *schemas.BifrostStreamChunk
|
||||
for chunk := range streamChan {
|
||||
if chunk != nil && chunk.BifrostError != nil {
|
||||
errChunk = chunk
|
||||
break
|
||||
}
|
||||
}
|
||||
// Drain remaining.
|
||||
for range streamChan {
|
||||
}
|
||||
|
||||
require.NotNil(t, errChunk, "expected an error chunk from the stream")
|
||||
assert.False(t, errChunk.BifrostError.IsBifrostError,
|
||||
"net.OpError during EventStream decoding must be IsBifrostError:false so the retry gate can retry it")
|
||||
require.NotNil(t, errChunk.BifrostError.Error)
|
||||
assert.Equal(t, schemas.ErrProviderNetworkError, errChunk.BifrostError.Error.Message,
|
||||
"net.OpError during EventStream decoding must use ErrProviderNetworkError message")
|
||||
}
|
||||
|
||||
// writeEventStreamException encodes a well-formed AWS EventStream exception
|
||||
// frame with the given exception type and message into w.
|
||||
// The frame format is: prelude (total_len + headers_len + CRC) + headers + payload + message_CRC.
|
||||
// We use the AWS SDK's eventstream.Encoder so the binary framing is correct.
|
||||
func writeEventStreamException(t *testing.T, w io.Writer, excType, msg string) {
|
||||
t.Helper()
|
||||
enc := eventstream.NewEncoder()
|
||||
payload, err := json.Marshal(map[string]string{"message": msg})
|
||||
require.NoError(t, err, "failed to marshal exception payload")
|
||||
headers := eventstream.Headers{
|
||||
{Name: ":message-type", Value: eventstream.StringValue("exception")},
|
||||
{Name: ":exception-type", Value: eventstream.StringValue(excType)},
|
||||
{Name: ":content-type", Value: eventstream.StringValue("application/json")},
|
||||
}
|
||||
err = enc.Encode(w, eventstream.Message{Headers: headers, Payload: payload})
|
||||
require.NoError(t, err, "failed to encode EventStream exception frame")
|
||||
}
|
||||
|
||||
// TestChatCompletionStream_RetryableException_ChunkIsRetryable verifies that
|
||||
// when AWS Bedrock sends a retryable exception (serviceUnavailableException,
|
||||
// throttlingException, etc.) through the EventStream, the resulting error chunk
|
||||
// has IsBifrostError:false and the correct HTTP StatusCode so that the retry
|
||||
// gate in executeRequestWithRetries can retry the request.
|
||||
func TestChatCompletionStream_RetryableException_ChunkIsRetryable(t *testing.T) {
|
||||
tests := []struct {
|
||||
excType string
|
||||
expectedStatus int
|
||||
}{
|
||||
{"serviceUnavailableException", 503},
|
||||
{"throttlingException", 429},
|
||||
{"modelNotReadyException", 503},
|
||||
{"internalServerException", 500},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.excType, func(t *testing.T) {
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/vnd.amazon.eventstream")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if f, ok := w.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
writeEventStreamException(t, w, tc.excType, "service is unavailable, please retry")
|
||||
if f, ok := w.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
provider := newTestProviderWithServer(t, ts)
|
||||
ctx := testBedrockCtx()
|
||||
key := testBedrockKey()
|
||||
|
||||
streamChan, bifrostErr := provider.ChatCompletionStream(ctx, noopPostHookRunner, nil, key, testChatRequest())
|
||||
require.Nil(t, bifrostErr, "expected EventStream exception to surface as a stream chunk")
|
||||
|
||||
require.NotNil(t, streamChan)
|
||||
|
||||
var errChunk *schemas.BifrostStreamChunk
|
||||
for chunk := range streamChan {
|
||||
if chunk != nil && chunk.BifrostError != nil {
|
||||
errChunk = chunk
|
||||
break
|
||||
}
|
||||
}
|
||||
for range streamChan {
|
||||
}
|
||||
|
||||
require.NotNil(t, errChunk, "expected error chunk for %s", tc.excType)
|
||||
assert.False(t, errChunk.BifrostError.IsBifrostError,
|
||||
"%s must be IsBifrostError:false so retry gate can retry it", tc.excType)
|
||||
require.NotNil(t, errChunk.BifrostError.StatusCode,
|
||||
"%s must carry a StatusCode for the retryableStatusCodes gate", tc.excType)
|
||||
assert.Equal(t, tc.expectedStatus, *errChunk.BifrostError.StatusCode,
|
||||
"%s must map to HTTP %d", tc.excType, tc.expectedStatus)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestChatCompletionStream_NonRetryableException_IsTerminal verifies that
|
||||
// non-retryable exception types (e.g. validationException, accessDeniedException)
|
||||
// continue to use ProcessAndSendError (IsBifrostError:true) and are NOT retried.
|
||||
func TestChatCompletionStream_NonRetryableException_IsTerminal(t *testing.T) {
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/vnd.amazon.eventstream")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if f, ok := w.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
writeEventStreamException(t, w, "validationException", "input validation failed")
|
||||
if f, ok := w.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
provider := newTestProviderWithServer(t, ts)
|
||||
ctx := testBedrockCtx()
|
||||
key := testBedrockKey()
|
||||
|
||||
streamChan, bifrostErr := provider.ChatCompletionStream(ctx, noopPostHookRunner, nil, key, testChatRequest())
|
||||
require.Nil(t, bifrostErr, "expected EventStream exception to surface as a stream chunk")
|
||||
|
||||
require.NotNil(t, streamChan)
|
||||
|
||||
var errChunk *schemas.BifrostStreamChunk
|
||||
for chunk := range streamChan {
|
||||
if chunk != nil && chunk.BifrostError != nil {
|
||||
errChunk = chunk
|
||||
break
|
||||
}
|
||||
}
|
||||
for range streamChan {
|
||||
}
|
||||
|
||||
require.NotNil(t, errChunk, "expected error chunk for validationException")
|
||||
assert.True(t, errChunk.BifrostError.IsBifrostError,
|
||||
"non-retryable validationException must remain IsBifrostError:true")
|
||||
}
|
||||
|
||||
// testTextCompletionRequest returns a minimal BifrostTextCompletionRequest for streaming tests.
|
||||
func testTextCompletionRequest() *schemas.BifrostTextCompletionRequest {
|
||||
prompt := "hello"
|
||||
return &schemas.BifrostTextCompletionRequest{
|
||||
Model: "anthropic.claude-sonnet-4-5",
|
||||
Input: &schemas.TextCompletionInput{PromptStr: &prompt},
|
||||
}
|
||||
}
|
||||
|
||||
// testResponsesRequest returns a minimal BifrostResponsesRequest for streaming tests.
|
||||
func testResponsesRequest() *schemas.BifrostResponsesRequest {
|
||||
msgType := schemas.ResponsesMessageType("message")
|
||||
roleUser := schemas.ResponsesMessageRoleType("user")
|
||||
content := "hello"
|
||||
return &schemas.BifrostResponsesRequest{
|
||||
Model: "anthropic.claude-sonnet-4-5",
|
||||
Input: []schemas.ResponsesMessage{
|
||||
{
|
||||
Type: &msgType,
|
||||
Role: &roleUser,
|
||||
Content: &schemas.ResponsesMessageContent{ContentStr: &content},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// assertRetryableExceptionChunk is the shared assertion helper for all three
|
||||
// streaming-method retryable-exception tests.
|
||||
func assertRetryableExceptionChunk(t *testing.T, streamChan chan *schemas.BifrostStreamChunk, bifrostErr *schemas.BifrostError, excType string, expectedStatus int) {
|
||||
t.Helper()
|
||||
require.Nil(t, bifrostErr, "expected EventStream exception to surface as a stream chunk, not a pre-stream error")
|
||||
require.NotNil(t, streamChan)
|
||||
|
||||
var errChunk *schemas.BifrostStreamChunk
|
||||
for chunk := range streamChan {
|
||||
if chunk != nil && chunk.BifrostError != nil {
|
||||
errChunk = chunk
|
||||
break
|
||||
}
|
||||
}
|
||||
for range streamChan {
|
||||
}
|
||||
|
||||
require.NotNil(t, errChunk, "expected error chunk for %s", excType)
|
||||
assert.False(t, errChunk.BifrostError.IsBifrostError,
|
||||
"%s must be IsBifrostError:false so retry gate can retry it", excType)
|
||||
require.NotNil(t, errChunk.BifrostError.StatusCode,
|
||||
"%s must carry a StatusCode for the retryableStatusCodes gate", excType)
|
||||
assert.Equal(t, expectedStatus, *errChunk.BifrostError.StatusCode,
|
||||
"%s must map to HTTP %d", excType, expectedStatus)
|
||||
}
|
||||
|
||||
// TestTextCompletionStream_RetryableException_ChunkIsRetryable mirrors the
|
||||
// ChatCompletionStream test for the TextCompletionStream path, which has
|
||||
// slightly different payload-parsing logic (extra BedrockError JSON unmarshal).
|
||||
func TestTextCompletionStream_RetryableException_ChunkIsRetryable(t *testing.T) {
|
||||
tests := []struct {
|
||||
excType string
|
||||
expectedStatus int
|
||||
}{
|
||||
{"serviceUnavailableException", 503},
|
||||
{"throttlingException", 429},
|
||||
{"modelNotReadyException", 503},
|
||||
{"internalServerException", 500},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.excType, func(t *testing.T) {
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/vnd.amazon.eventstream")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if f, ok := w.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
writeEventStreamException(t, w, tc.excType, "service is unavailable, please retry")
|
||||
if f, ok := w.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
provider := newTestProviderWithServer(t, ts)
|
||||
streamChan, bifrostErr := provider.TextCompletionStream(testBedrockCtx(), noopPostHookRunner, nil, testBedrockKey(), testTextCompletionRequest())
|
||||
assertRetryableExceptionChunk(t, streamChan, bifrostErr, tc.excType, tc.expectedStatus)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestResponsesStream_RetryableException_ChunkIsRetryable mirrors the
|
||||
// ChatCompletionStream test for the ResponsesStream path.
|
||||
func TestResponsesStream_RetryableException_ChunkIsRetryable(t *testing.T) {
|
||||
tests := []struct {
|
||||
excType string
|
||||
expectedStatus int
|
||||
}{
|
||||
{"serviceUnavailableException", 503},
|
||||
{"throttlingException", 429},
|
||||
{"modelNotReadyException", 503},
|
||||
{"internalServerException", 500},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.excType, func(t *testing.T) {
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/vnd.amazon.eventstream")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if f, ok := w.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
writeEventStreamException(t, w, tc.excType, "service is unavailable, please retry")
|
||||
if f, ok := w.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
provider := newTestProviderWithServer(t, ts)
|
||||
streamChan, bifrostErr := provider.ResponsesStream(testBedrockCtx(), noopPostHookRunner, nil, testBedrockKey(), testResponsesRequest())
|
||||
assertRetryableExceptionChunk(t, streamChan, bifrostErr, tc.excType, tc.expectedStatus)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func generateTestCACert(t *testing.T) string {
|
||||
t.Helper()
|
||||
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
template := &x509.Certificate{
|
||||
SerialNumber: big.NewInt(1),
|
||||
Subject: pkix.Name{CommonName: "testca"},
|
||||
NotBefore: time.Now(),
|
||||
NotAfter: time.Now().Add(10 * 365 * 24 * time.Hour),
|
||||
IsCA: true,
|
||||
BasicConstraintsValid: true,
|
||||
}
|
||||
|
||||
certDER, err := x509.CreateCertificate(rand.Reader, template, template, &key.PublicKey, key)
|
||||
require.NoError(t, err)
|
||||
|
||||
certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER})
|
||||
return string(certPEM)
|
||||
}
|
||||
|
||||
func TestBedrockTransportHTTP2Config(t *testing.T) {
|
||||
config := &schemas.ProviderConfig{
|
||||
NetworkConfig: schemas.NetworkConfig{
|
||||
DefaultRequestTimeoutInSeconds: 30,
|
||||
MaxConnsPerHost: 5000,
|
||||
EnforceHTTP2: true,
|
||||
},
|
||||
}
|
||||
config.CheckAndSetDefaults()
|
||||
|
||||
provider, err := NewBedrockProvider(config, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, provider)
|
||||
|
||||
transport, ok := provider.client.Transport.(*http.Transport)
|
||||
require.True(t, ok, "transport should be *http.Transport")
|
||||
|
||||
assert.Equal(t, 5000, transport.MaxConnsPerHost)
|
||||
assert.Equal(t, schemas.DefaultMaxIdleConnsPerHost, transport.MaxIdleConnsPerHost)
|
||||
assert.Equal(t, schemas.DefaultMaxIdleConnsPerHost, transport.MaxIdleConns)
|
||||
assert.True(t, transport.ForceAttemptHTTP2)
|
||||
}
|
||||
|
||||
func TestBedrockTransportCustomMaxConns(t *testing.T) {
|
||||
config := &schemas.ProviderConfig{
|
||||
NetworkConfig: schemas.NetworkConfig{
|
||||
DefaultRequestTimeoutInSeconds: 30,
|
||||
MaxConnsPerHost: 50,
|
||||
},
|
||||
}
|
||||
config.CheckAndSetDefaults()
|
||||
|
||||
provider, err := NewBedrockProvider(config, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
transport, ok := provider.client.Transport.(*http.Transport)
|
||||
require.True(t, ok)
|
||||
|
||||
assert.Equal(t, 50, transport.MaxConnsPerHost)
|
||||
assert.Equal(t, schemas.DefaultMaxIdleConnsPerHost, transport.MaxIdleConnsPerHost)
|
||||
assert.Equal(t, schemas.DefaultMaxIdleConnsPerHost, transport.MaxIdleConns)
|
||||
}
|
||||
|
||||
func TestBedrockTransportDefaultMaxConns(t *testing.T) {
|
||||
config := &schemas.ProviderConfig{
|
||||
NetworkConfig: schemas.NetworkConfig{
|
||||
DefaultRequestTimeoutInSeconds: 30,
|
||||
// MaxConnsPerHost left as 0 — should default to 5000
|
||||
},
|
||||
}
|
||||
config.CheckAndSetDefaults()
|
||||
|
||||
assert.Equal(t, schemas.DefaultMaxConnsPerHost, config.NetworkConfig.MaxConnsPerHost)
|
||||
|
||||
provider, err := NewBedrockProvider(config, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
transport, ok := provider.client.Transport.(*http.Transport)
|
||||
require.True(t, ok)
|
||||
|
||||
assert.Equal(t, schemas.DefaultMaxConnsPerHost, transport.MaxConnsPerHost)
|
||||
assert.Equal(t, schemas.DefaultMaxIdleConnsPerHost, transport.MaxIdleConnsPerHost)
|
||||
assert.Equal(t, schemas.DefaultMaxIdleConnsPerHost, transport.MaxIdleConns)
|
||||
}
|
||||
|
||||
func TestBedrockTransportTLSInsecureSkipVerify(t *testing.T) {
|
||||
config := &schemas.ProviderConfig{
|
||||
NetworkConfig: schemas.NetworkConfig{
|
||||
DefaultRequestTimeoutInSeconds: 30,
|
||||
InsecureSkipVerify: true,
|
||||
EnforceHTTP2: true,
|
||||
},
|
||||
}
|
||||
config.CheckAndSetDefaults()
|
||||
|
||||
provider, err := NewBedrockProvider(config, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
transport, ok := provider.client.Transport.(*http.Transport)
|
||||
require.True(t, ok)
|
||||
require.NotNil(t, transport.TLSClientConfig)
|
||||
assert.True(t, transport.TLSClientConfig.InsecureSkipVerify)
|
||||
assert.Equal(t, uint16(tls.VersionTLS12), transport.TLSClientConfig.MinVersion)
|
||||
// ForceAttemptHTTP2 should still be true even with custom TLS config
|
||||
assert.True(t, transport.ForceAttemptHTTP2)
|
||||
}
|
||||
|
||||
func TestBedrockTransportTLSCACert(t *testing.T) {
|
||||
testCACert := generateTestCACert(t)
|
||||
|
||||
config := &schemas.ProviderConfig{
|
||||
NetworkConfig: schemas.NetworkConfig{
|
||||
DefaultRequestTimeoutInSeconds: 30,
|
||||
CACertPEM: testCACert,
|
||||
EnforceHTTP2: true,
|
||||
},
|
||||
}
|
||||
config.CheckAndSetDefaults()
|
||||
|
||||
provider, err := NewBedrockProvider(config, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
transport, ok := provider.client.Transport.(*http.Transport)
|
||||
require.True(t, ok)
|
||||
require.NotNil(t, transport.TLSClientConfig)
|
||||
assert.NotNil(t, transport.TLSClientConfig.RootCAs)
|
||||
assert.Equal(t, uint16(tls.VersionTLS12), transport.TLSClientConfig.MinVersion)
|
||||
assert.True(t, transport.ForceAttemptHTTP2)
|
||||
}
|
||||
|
||||
func TestBedrockTransportDefaultTLS(t *testing.T) {
|
||||
config := &schemas.ProviderConfig{
|
||||
NetworkConfig: schemas.NetworkConfig{
|
||||
DefaultRequestTimeoutInSeconds: 30,
|
||||
// No TLS settings — should use system defaults
|
||||
},
|
||||
}
|
||||
config.CheckAndSetDefaults()
|
||||
|
||||
provider, err := NewBedrockProvider(config, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
transport, ok := provider.client.Transport.(*http.Transport)
|
||||
require.True(t, ok)
|
||||
// No custom TLS config should be set
|
||||
assert.Nil(t, transport.TLSClientConfig)
|
||||
// EnforceHTTP2 not set — ForceAttemptHTTP2 should be false
|
||||
assert.False(t, transport.ForceAttemptHTTP2)
|
||||
}
|
||||
|
||||
func TestBedrockTransportEnforceHTTP2(t *testing.T) {
|
||||
config := &schemas.ProviderConfig{
|
||||
NetworkConfig: schemas.NetworkConfig{
|
||||
DefaultRequestTimeoutInSeconds: 30,
|
||||
EnforceHTTP2: true,
|
||||
},
|
||||
}
|
||||
config.CheckAndSetDefaults()
|
||||
|
||||
provider, err := NewBedrockProvider(config, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
transport, ok := provider.client.Transport.(*http.Transport)
|
||||
require.True(t, ok)
|
||||
assert.True(t, transport.ForceAttemptHTTP2)
|
||||
// TLSNextProto should NOT be set when HTTP/2 is enforced, allowing ALPN negotiation
|
||||
assert.Nil(t, transport.TLSNextProto)
|
||||
}
|
||||
|
||||
func TestBedrockTransportEnforceHTTP2Disabled(t *testing.T) {
|
||||
config := &schemas.ProviderConfig{
|
||||
NetworkConfig: schemas.NetworkConfig{
|
||||
DefaultRequestTimeoutInSeconds: 30,
|
||||
EnforceHTTP2: false,
|
||||
},
|
||||
}
|
||||
config.CheckAndSetDefaults()
|
||||
|
||||
provider, err := NewBedrockProvider(config, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
transport, ok := provider.client.Transport.(*http.Transport)
|
||||
require.True(t, ok)
|
||||
assert.False(t, transport.ForceAttemptHTTP2)
|
||||
// TLSNextProto must be set to empty map to truly disable HTTP/2 ALPN negotiation
|
||||
assert.NotNil(t, transport.TLSNextProto)
|
||||
assert.Empty(t, transport.TLSNextProto)
|
||||
}
|
||||
1130
core/providers/bedrock/types.go
Normal file
1130
core/providers/bedrock/types.go
Normal file
File diff suppressed because it is too large
Load Diff
1843
core/providers/bedrock/utils.go
Normal file
1843
core/providers/bedrock/utils.go
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user