Files
bifrost/framework/vectorstore/weaviate.go
Beyhan Oğur 880f412e2c first commit
2026-04-26 21:52:23 +03:00

638 lines
16 KiB
Go

package vectorstore
import (
"context"
"fmt"
"strings"
"time"
"github.com/maximhq/bifrost/core/schemas"
"github.com/weaviate/weaviate-go-client/v5/weaviate"
"github.com/weaviate/weaviate-go-client/v5/weaviate/auth"
"github.com/weaviate/weaviate-go-client/v5/weaviate/filters"
"github.com/weaviate/weaviate-go-client/v5/weaviate/graphql"
"github.com/weaviate/weaviate-go-client/v5/weaviate/grpc"
"github.com/weaviate/weaviate/entities/models"
)
// Default values for Weaviate vector index configuration
const (
// Default class names (Weaviate prefers PascalCase)
DefaultClassName = "BifrostStore"
)
// WeaviateConfig represents the configuration for the Weaviate vector store.
type WeaviateConfig struct {
// Connection settings
Scheme string `json:"scheme"` // "http" or "https" - REQUIRED
Host *schemas.EnvVar `json:"host"` // "localhost:8080" - REQUIRED
GrpcConfig *WeaviateGrpcConfig `json:"grpc_config,omitempty"` // grpc config for weaviate (optional)
// Authentication settings (optional)
APIKey *schemas.EnvVar `json:"api_key,omitempty"` // API key for authentication
Headers map[string]string `json:"headers,omitempty"` // Additional headers
// Connection settings
Timeout time.Duration `json:"timeout,omitempty"` // Request timeout (optional)
}
type WeaviateGrpcConfig struct {
// Host is the host of the weaviate server (host:port).
// If host is without a port number then the 80 port for insecured and 443 port for secured connections will be used.
Host *schemas.EnvVar `json:"host"`
// Secured is a boolean flag indicating if the connection is secured
Secured bool `json:"secured"`
}
// WeaviateStore represents the Weaviate vector store.
type WeaviateStore struct {
client *weaviate.Client
config *WeaviateConfig
logger schemas.Logger
}
// Ping checks if the Weaviate server is reachable.
func (s *WeaviateStore) Ping(ctx context.Context) error {
_, err := s.client.Misc().MetaGetter().Do(ctx)
return err
}
// Add stores a new object (with or without embedding)
func (s *WeaviateStore) Add(ctx context.Context, className string, id string, embedding []float32, metadata map[string]interface{}) error {
if strings.TrimSpace(id) == "" {
return fmt.Errorf("id is required")
}
obj := &models.Object{
Class: className,
Properties: metadata,
}
var err error
if len(embedding) > 0 {
_, err = s.client.Data().Creator().
WithClassName(className).
WithID(id).
WithProperties(obj.Properties).
WithVector(embedding).
Do(ctx)
} else {
_, err = s.client.Data().Creator().
WithClassName(className).
WithID(id).
WithProperties(obj.Properties).
Do(ctx)
}
return err
}
// GetChunk returns the "metadata" for a single key
func (s *WeaviateStore) GetChunk(ctx context.Context, className string, id string) (SearchResult, error) {
obj, err := s.client.Data().ObjectsGetter().
WithClassName(className).
WithID(id).
Do(ctx)
if err != nil {
return SearchResult{}, err
}
if len(obj) == 0 {
return SearchResult{}, fmt.Errorf("not found: %s", id)
}
props, ok := obj[0].Properties.(map[string]interface{})
if !ok {
return SearchResult{}, fmt.Errorf("invalid properties")
}
return SearchResult{
ID: id,
Score: nil,
Properties: props,
}, nil
}
// GetChunks returns multiple objects by ID
func (s *WeaviateStore) GetChunks(ctx context.Context, className string, ids []string) ([]SearchResult, error) {
out := make([]SearchResult, 0, len(ids))
for _, id := range ids {
obj, err := s.client.Data().ObjectsGetter().
WithClassName(className).
WithID(id).
Do(ctx)
if err != nil {
return nil, err
}
if len(obj) > 0 {
props, ok := obj[0].Properties.(map[string]interface{})
if !ok {
return nil, fmt.Errorf("invalid properties")
}
out = append(out, SearchResult{
ID: id,
Score: nil,
Properties: props,
})
}
}
return out, nil
}
// GetAll with filtering + pagination
func (s *WeaviateStore) GetAll(ctx context.Context, className string, queries []Query, selectFields []string, cursor *string, limit int64) ([]SearchResult, *string, error) {
where := buildWeaviateFilter(queries)
fields := []graphql.Field{
{Name: "_additional", Fields: []graphql.Field{
{Name: "id"},
}},
}
for _, field := range selectFields {
fields = append(fields, graphql.Field{Name: field})
}
search := s.client.GraphQL().Get().
WithClassName(className).
WithLimit(int(limit)).
WithFields(fields...)
if where != nil {
search = search.WithWhere(where)
}
if cursor != nil {
search = search.WithAfter(*cursor)
}
resp, err := search.Do(ctx)
if err != nil {
return nil, nil, err
}
// Check for GraphQL errors
if len(resp.Errors) > 0 {
var errorMsgs []string
for _, err := range resp.Errors {
errorMsgs = append(errorMsgs, err.Message)
}
return nil, nil, fmt.Errorf("graphql errors: %v", errorMsgs)
}
data, ok := resp.Data["Get"].(map[string]interface{})
if !ok {
return nil, nil, fmt.Errorf("invalid graphql response: missing 'Get' key, got: %+v", resp.Data)
}
objsRaw, exists := data[className]
if !exists {
// No results for this class - this is normal, not an error
s.logger.Debug(fmt.Sprintf("No results found for class '%s', available classes: %+v", className, data))
return nil, nil, nil
}
objs, ok := objsRaw.([]interface{})
if !ok {
s.logger.Debug(fmt.Sprintf("Class '%s' exists but data is not an array: %+v", className, objsRaw))
return nil, nil, nil
}
results := make([]SearchResult, 0, len(objs))
var nextCursor *string
for _, o := range objs {
obj, ok := o.(map[string]interface{})
if !ok {
continue
}
// Convert to SearchResult format for consistency
searchResult := SearchResult{
Properties: obj,
}
if additional, ok := obj["_additional"].(map[string]interface{}); ok {
if id, ok := additional["id"].(string); ok {
searchResult.ID = id
nextCursor = &id
}
}
results = append(results, searchResult)
}
return results, nextCursor, nil
}
// GetNearest with explicit filters only
func (s *WeaviateStore) GetNearest(
ctx context.Context,
className string,
vector []float32,
queries []Query,
selectFields []string,
threshold float64,
limit int64,
) ([]SearchResult, error) {
where := buildWeaviateFilter(queries)
fields := []graphql.Field{
{Name: "_additional", Fields: []graphql.Field{
{Name: "id"},
{Name: "certainty"},
}},
}
for _, field := range selectFields {
fields = append(fields, graphql.Field{Name: field})
}
nearVector := s.client.GraphQL().NearVectorArgBuilder().
WithVector(vector).
WithCertainty(float32(threshold))
search := s.client.GraphQL().Get().
WithClassName(className).
WithNearVector(nearVector).
WithLimit(int(limit)).
WithFields(fields...)
if where != nil {
search = search.WithWhere(where)
}
resp, err := search.Do(ctx)
if err != nil {
return nil, err
}
// Check for GraphQL errors
if len(resp.Errors) > 0 {
var errorMsgs []string
for _, err := range resp.Errors {
errorMsgs = append(errorMsgs, err.Message)
}
return nil, fmt.Errorf("graphql errors: %v", errorMsgs)
}
data, ok := resp.Data["Get"].(map[string]interface{})
if !ok {
return nil, fmt.Errorf("invalid graphql response: missing 'Get' key, got: %+v", resp.Data)
}
objsRaw, exists := data[className]
if !exists {
// No results for this class - this is normal, not an error
s.logger.Debug(fmt.Sprintf("No results found for class '%s', available classes: %+v", className, data))
return nil, nil
}
objs, ok := objsRaw.([]interface{})
if !ok {
s.logger.Debug(fmt.Sprintf("Class '%s' exists but data is not an array: %+v", className, objsRaw))
return nil, nil
}
results := make([]SearchResult, 0, len(objs))
for _, o := range objs {
obj, ok := o.(map[string]interface{})
if !ok {
continue
}
additional, ok := obj["_additional"].(map[string]interface{})
if !ok {
continue
}
// Safely extract ID
idRaw, exists := additional["id"]
if !exists || idRaw == nil {
continue
}
id, ok := idRaw.(string)
if !ok {
continue
}
// Safely extract certainty/score with default value
var score float64
if certaintyRaw, exists := additional["certainty"]; exists && certaintyRaw != nil {
switch v := certaintyRaw.(type) {
case float64:
score = v
case float32:
score = float64(v)
case int:
score = float64(v)
case int64:
score = float64(v)
default:
score = 0.0 // Default score if type conversion fails
}
}
results = append(results, SearchResult{
ID: id,
Score: &score,
Properties: obj,
})
}
return results, nil
}
// Delete removes multiple objects by ID
func (s *WeaviateStore) Delete(ctx context.Context, className string, id string) error {
return s.client.Data().Deleter().
WithClassName(className).
WithID(id).
Do(ctx)
}
func (s *WeaviateStore) DeleteAll(ctx context.Context, className string, queries []Query) ([]DeleteResult, error) {
// Check if class exists first to avoid 500 errors from Weaviate
exists, err := s.client.Schema().ClassExistenceChecker().
WithClassName(className).
Do(ctx)
if err != nil {
return nil, fmt.Errorf("failed to check class existence: %w", err)
}
if !exists {
return []DeleteResult{}, nil // Class doesn't exist, nothing to delete
}
where := buildWeaviateFilter(queries)
res, err := s.client.Batch().ObjectsBatchDeleter().
WithClassName(className).
WithWhere(where).
Do(ctx)
if err != nil {
return nil, err
}
// NOTE: Weaviate is returning an empty array for Results.Objects, even on successful deletes.
results := make([]DeleteResult, 0, len(res.Results.Objects))
for _, obj := range res.Results.Objects {
result := DeleteResult{
ID: obj.ID.String(),
}
if obj.Status != nil {
switch *obj.Status {
case "SUCCESS":
result.Status = DeleteStatusSuccess
case "FAILED":
result.Status = DeleteStatusError
if obj.Errors != nil {
var errorMsgs []string
for _, err := range obj.Errors.Error {
errorMsgs = append(errorMsgs, err.Message)
}
result.Error = strings.Join(errorMsgs, ", ")
}
}
}
results = append(results, result)
}
return results, nil
}
func (s *WeaviateStore) Close(ctx context.Context, className string) error {
// nothing to close
return nil
}
// RequiresVectors returns true because Weaviate's HNSW index
// requires vectors for proper object indexing and retrieval.
func (s *WeaviateStore) RequiresVectors() bool {
return true
}
// newWeaviateStore creates a new Weaviate vector store.
func newWeaviateStore(ctx context.Context, config *WeaviateConfig, logger schemas.Logger) (*WeaviateStore, error) {
// Validate required config
if config.Scheme == "" || (config.Host == nil || config.Host.GetValue() == "") {
return nil, fmt.Errorf("weaviate scheme and host are required")
}
// Build client configuration
cfg := weaviate.Config{
Scheme: config.Scheme,
Host: config.Host.GetValue(),
}
// Add authentication if provided
if config.APIKey != nil && config.APIKey.GetValue() != "" {
cfg.AuthConfig = auth.ApiKey{Value: config.APIKey.GetValue()}
}
// Add grpc config if provided
if config.GrpcConfig != nil {
if config.GrpcConfig.Host == nil || config.GrpcConfig.Host.GetValue() == "" {
return nil, fmt.Errorf("weaviate grpc host is required")
}
cfg.GrpcConfig = &grpc.Config{
Host: config.GrpcConfig.Host.GetValue(),
Secured: config.GrpcConfig.Secured,
}
}
// Add custom headers if provided
if len(config.Headers) > 0 {
cfg.Headers = config.Headers
}
// Create client
client, err := weaviate.NewClient(cfg)
if err != nil {
return nil, fmt.Errorf("failed to create weaviate client: %w", err)
}
// Test connection with meta endpoint
testCtx := ctx
if config.Timeout > 0 {
var cancel context.CancelFunc
testCtx, cancel = context.WithTimeout(ctx, config.Timeout)
defer cancel()
}
_, err = client.Misc().MetaGetter().Do(testCtx)
if err != nil {
return nil, fmt.Errorf("failed to connect to weaviate: %w", err)
}
store := &WeaviateStore{
client: client,
config: config,
logger: logger,
}
return store, nil
}
func (s *WeaviateStore) CreateNamespace(ctx context.Context, className string, dimension int, properties map[string]VectorStoreProperties) error {
// Check if class exists
exists, err := s.client.Schema().ClassExistenceChecker().
WithClassName(className).
Do(ctx)
if err != nil {
return fmt.Errorf("failed to check class existence: %w", err)
}
if exists {
return nil // Schema already exists
}
// Create properties
weaviateProperties := []*models.Property{}
for name, prop := range properties {
var dataType []string
switch prop.DataType {
case VectorStorePropertyTypeString:
dataType = []string{"string"}
case VectorStorePropertyTypeInteger:
dataType = []string{"int"}
case VectorStorePropertyTypeBoolean:
dataType = []string{"boolean"}
case VectorStorePropertyTypeStringArray:
dataType = []string{"string[]"}
}
weaviateProperties = append(weaviateProperties, &models.Property{
Name: name,
DataType: dataType,
Description: prop.Description,
})
}
// Create class schema with all fields we need
classSchema := &models.Class{
Class: className,
Properties: weaviateProperties,
VectorIndexType: "hnsw",
Vectorizer: "none", // We provide our own vectors
}
if dimension > 0 {
classSchema.VectorIndexConfig = map[string]interface{}{
"vectorDimensions": dimension,
}
}
err = s.client.Schema().ClassCreator().
WithClass(classSchema).
Do(ctx)
if err != nil {
return fmt.Errorf("failed to create class schema: %w", err)
}
return nil
}
func (s *WeaviateStore) DeleteNamespace(ctx context.Context, className string) error {
exists, err := s.client.Schema().ClassExistenceChecker().
WithClassName(className).
Do(ctx)
if err != nil {
return fmt.Errorf("failed to check class existence: %w", err)
}
if !exists {
return nil // Schema already does not exist
} else {
return s.client.Schema().ClassDeleter().
WithClassName(className).
Do(ctx)
}
}
// buildWeaviateFilter converts []Query → Weaviate WhereFilter
func buildWeaviateFilter(queries []Query) *filters.WhereBuilder {
if len(queries) == 0 {
return nil
}
var operands []*filters.WhereBuilder
for _, q := range queries {
// Convert string operator to filters operator
operator := convertOperator(q.Operator)
fieldPath := strings.Split(q.Field, ".")
whereClause := filters.Where().
WithPath(fieldPath).
WithOperator(operator)
// Special handling for IsNull and IsNotNull
switch q.Operator {
case QueryOperatorIsNull:
whereClause = whereClause.WithValueBoolean(true)
case QueryOperatorIsNotNull:
whereClause = whereClause.WithValueBoolean(false)
default:
// Set value based on type
switch v := q.Value.(type) {
case string:
whereClause = whereClause.WithValueString(v)
case int:
whereClause = whereClause.WithValueInt(int64(v))
case int64:
whereClause = whereClause.WithValueInt(v)
case float32:
whereClause = whereClause.WithValueNumber(float64(v))
case float64:
whereClause = whereClause.WithValueNumber(v)
case bool:
whereClause = whereClause.WithValueBoolean(v)
default:
// Fallback to string conversion
whereClause = whereClause.WithValueString(fmt.Sprintf("%v", v))
}
}
operands = append(operands, whereClause)
}
if len(operands) == 1 {
return operands[0]
}
// Create AND filter for multiple operands
return filters.Where().
WithOperator(filters.And).
WithOperands(operands)
}
// convertOperator converts string operator to filters operator
func convertOperator(op QueryOperator) filters.WhereOperator {
switch op {
case QueryOperatorEqual:
return filters.Equal
case QueryOperatorNotEqual:
return filters.NotEqual
case QueryOperatorLessThan:
return filters.LessThan
case QueryOperatorLessThanOrEqual:
return filters.LessThanEqual
case QueryOperatorGreaterThan:
return filters.GreaterThan
case QueryOperatorGreaterThanOrEqual:
return filters.GreaterThanEqual
case QueryOperatorLike:
return filters.Like
case QueryOperatorContainsAny:
return filters.ContainsAny
case QueryOperatorContainsAll:
return filters.ContainsAll
case QueryOperatorIsNull:
return filters.IsNull
case QueryOperatorIsNotNull: // IsNotNull is not supported by Weaviate, so we use IsNull and negate it.
return filters.IsNull
default:
// Default to Equal if unknown
return filters.Equal
}
}