first commit
This commit is contained in:
190
core/providers/cohere/embedding.go
Normal file
190
core/providers/cohere/embedding.go
Normal file
@@ -0,0 +1,190 @@
|
||||
package cohere
|
||||
|
||||
import (
|
||||
"github.com/maximhq/bifrost/core/providers/utils"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// ToCohereEmbeddingRequest converts a Bifrost embedding request to Cohere format
|
||||
func ToCohereEmbeddingRequest(bifrostReq *schemas.BifrostEmbeddingRequest) *CohereEmbeddingRequest {
|
||||
if bifrostReq == nil || bifrostReq.Input == nil || (bifrostReq.Input.Text == nil && bifrostReq.Input.Texts == nil) {
|
||||
return nil
|
||||
}
|
||||
|
||||
embeddingInput := bifrostReq.Input
|
||||
cohereReq := &CohereEmbeddingRequest{
|
||||
Model: bifrostReq.Model,
|
||||
}
|
||||
|
||||
texts := []string{}
|
||||
if embeddingInput.Text != nil {
|
||||
texts = append(texts, *embeddingInput.Text)
|
||||
} else {
|
||||
texts = embeddingInput.Texts
|
||||
}
|
||||
|
||||
// Convert texts from Bifrost format
|
||||
if len(texts) > 0 {
|
||||
cohereReq.Texts = texts
|
||||
}
|
||||
|
||||
// Set default input type if not specified in extra params
|
||||
cohereReq.InputType = "search_document" // Default value
|
||||
|
||||
if bifrostReq.Params != nil {
|
||||
cohereReq.OutputDimension = bifrostReq.Params.Dimensions
|
||||
cohereReq.ExtraParams = bifrostReq.Params.ExtraParams
|
||||
if bifrostReq.Params.ExtraParams != nil {
|
||||
if maxTokens, ok := schemas.SafeExtractIntPointer(bifrostReq.Params.ExtraParams["max_tokens"]); ok {
|
||||
delete(cohereReq.ExtraParams, "max_tokens")
|
||||
cohereReq.MaxTokens = maxTokens
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Handle extra params
|
||||
if bifrostReq.Params != nil && bifrostReq.Params.ExtraParams != nil {
|
||||
// Input type
|
||||
if inputType, ok := schemas.SafeExtractString(bifrostReq.Params.ExtraParams["input_type"]); ok {
|
||||
delete(cohereReq.ExtraParams, "input_type")
|
||||
cohereReq.InputType = inputType
|
||||
}
|
||||
|
||||
// Embedding types
|
||||
if embeddingTypes, ok := schemas.SafeExtractStringSlice(bifrostReq.Params.ExtraParams["embedding_types"]); ok {
|
||||
if len(embeddingTypes) > 0 {
|
||||
delete(cohereReq.ExtraParams, "embedding_types")
|
||||
cohereReq.EmbeddingTypes = embeddingTypes
|
||||
}
|
||||
}
|
||||
|
||||
// Truncate
|
||||
if truncate, ok := schemas.SafeExtractStringPointer(bifrostReq.Params.ExtraParams["truncate"]); ok {
|
||||
delete(cohereReq.ExtraParams, "truncate")
|
||||
cohereReq.Truncate = truncate
|
||||
}
|
||||
}
|
||||
|
||||
return cohereReq
|
||||
}
|
||||
|
||||
// ToBifrostEmbeddingRequest converts a Cohere embedding request to Bifrost format
|
||||
func (req *CohereEmbeddingRequest) ToBifrostEmbeddingRequest(ctx *schemas.BifrostContext) *schemas.BifrostEmbeddingRequest {
|
||||
if req == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
provider, model := schemas.ParseModelString(req.Model, utils.CheckAndSetDefaultProvider(ctx, schemas.Cohere))
|
||||
|
||||
bifrostReq := &schemas.BifrostEmbeddingRequest{
|
||||
Provider: provider,
|
||||
Model: model,
|
||||
Input: &schemas.EmbeddingInput{},
|
||||
Params: &schemas.EmbeddingParameters{},
|
||||
}
|
||||
|
||||
// Convert texts
|
||||
if len(req.Texts) > 0 {
|
||||
if len(req.Texts) == 1 {
|
||||
bifrostReq.Input.Text = &req.Texts[0]
|
||||
} else {
|
||||
bifrostReq.Input.Texts = req.Texts
|
||||
}
|
||||
}
|
||||
|
||||
// Convert parameters
|
||||
if req.OutputDimension != nil {
|
||||
bifrostReq.Params.Dimensions = req.OutputDimension
|
||||
}
|
||||
|
||||
// Convert extra params
|
||||
extraParams := make(map[string]interface{})
|
||||
if req.InputType != "" {
|
||||
extraParams["input_type"] = req.InputType
|
||||
}
|
||||
if req.EmbeddingTypes != nil {
|
||||
extraParams["embedding_types"] = req.EmbeddingTypes
|
||||
}
|
||||
if req.Truncate != nil {
|
||||
extraParams["truncate"] = *req.Truncate
|
||||
}
|
||||
if req.MaxTokens != nil {
|
||||
extraParams["max_tokens"] = *req.MaxTokens
|
||||
}
|
||||
if len(extraParams) > 0 {
|
||||
bifrostReq.Params.ExtraParams = extraParams
|
||||
}
|
||||
|
||||
return bifrostReq
|
||||
}
|
||||
|
||||
// ToBifrostEmbeddingResponse converts a Cohere embedding response to Bifrost format
|
||||
func (response *CohereEmbeddingResponse) ToBifrostEmbeddingResponse() *schemas.BifrostEmbeddingResponse {
|
||||
if response == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
bifrostResponse := &schemas.BifrostEmbeddingResponse{
|
||||
Object: "list",
|
||||
}
|
||||
|
||||
// Convert embeddings data
|
||||
if response.Embeddings != nil {
|
||||
var bifrostEmbeddings []schemas.EmbeddingData
|
||||
|
||||
// Handle different embedding types - prioritize float embeddings
|
||||
if response.Embeddings.Float != nil {
|
||||
for i, embedding := range response.Embeddings.Float {
|
||||
bifrostEmbedding := schemas.EmbeddingData{
|
||||
Object: "embedding",
|
||||
Index: i,
|
||||
Embedding: schemas.EmbeddingStruct{
|
||||
EmbeddingArray: embedding,
|
||||
},
|
||||
}
|
||||
bifrostEmbeddings = append(bifrostEmbeddings, bifrostEmbedding)
|
||||
}
|
||||
} else if response.Embeddings.Base64 != nil {
|
||||
// Handle base64 embeddings as strings
|
||||
for i, embedding := range response.Embeddings.Base64 {
|
||||
bifrostEmbedding := schemas.EmbeddingData{
|
||||
Object: "embedding",
|
||||
Index: i,
|
||||
Embedding: schemas.EmbeddingStruct{
|
||||
EmbeddingStr: &embedding,
|
||||
},
|
||||
}
|
||||
bifrostEmbeddings = append(bifrostEmbeddings, bifrostEmbedding)
|
||||
}
|
||||
}
|
||||
// Note: Int8, Uint8, Binary, Ubinary types would need special handling
|
||||
// depending on how Bifrost wants to represent them
|
||||
|
||||
bifrostResponse.Data = bifrostEmbeddings
|
||||
}
|
||||
|
||||
// Convert usage information
|
||||
if response.Meta != nil {
|
||||
if response.Meta.Tokens != nil {
|
||||
bifrostResponse.Usage = &schemas.BifrostLLMUsage{}
|
||||
if response.Meta.Tokens.InputTokens != nil {
|
||||
bifrostResponse.Usage.PromptTokens = int(*response.Meta.Tokens.InputTokens)
|
||||
}
|
||||
if response.Meta.Tokens.OutputTokens != nil {
|
||||
bifrostResponse.Usage.CompletionTokens = int(*response.Meta.Tokens.OutputTokens)
|
||||
}
|
||||
bifrostResponse.Usage.TotalTokens = bifrostResponse.Usage.PromptTokens + bifrostResponse.Usage.CompletionTokens
|
||||
} else if response.Meta.BilledUnits != nil {
|
||||
bifrostResponse.Usage = &schemas.BifrostLLMUsage{}
|
||||
if response.Meta.BilledUnits.InputTokens != nil {
|
||||
bifrostResponse.Usage.PromptTokens = int(*response.Meta.BilledUnits.InputTokens)
|
||||
}
|
||||
if response.Meta.BilledUnits.OutputTokens != nil {
|
||||
bifrostResponse.Usage.CompletionTokens = int(*response.Meta.BilledUnits.OutputTokens)
|
||||
}
|
||||
bifrostResponse.Usage.TotalTokens = bifrostResponse.Usage.PromptTokens + bifrostResponse.Usage.CompletionTokens
|
||||
}
|
||||
}
|
||||
|
||||
return bifrostResponse
|
||||
}
|
||||
Reference in New Issue
Block a user