first commit
This commit is contained in:
637
framework/vectorstore/weaviate.go
Normal file
637
framework/vectorstore/weaviate.go
Normal file
@@ -0,0 +1,637 @@
|
||||
package vectorstore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/weaviate/weaviate-go-client/v5/weaviate"
|
||||
"github.com/weaviate/weaviate-go-client/v5/weaviate/auth"
|
||||
"github.com/weaviate/weaviate-go-client/v5/weaviate/filters"
|
||||
"github.com/weaviate/weaviate-go-client/v5/weaviate/graphql"
|
||||
"github.com/weaviate/weaviate-go-client/v5/weaviate/grpc"
|
||||
"github.com/weaviate/weaviate/entities/models"
|
||||
)
|
||||
|
||||
// Default values for Weaviate vector index configuration
|
||||
const (
|
||||
// Default class names (Weaviate prefers PascalCase)
|
||||
DefaultClassName = "BifrostStore"
|
||||
)
|
||||
|
||||
// WeaviateConfig represents the configuration for the Weaviate vector store.
|
||||
type WeaviateConfig struct {
|
||||
// Connection settings
|
||||
Scheme string `json:"scheme"` // "http" or "https" - REQUIRED
|
||||
Host *schemas.EnvVar `json:"host"` // "localhost:8080" - REQUIRED
|
||||
GrpcConfig *WeaviateGrpcConfig `json:"grpc_config,omitempty"` // grpc config for weaviate (optional)
|
||||
|
||||
// Authentication settings (optional)
|
||||
APIKey *schemas.EnvVar `json:"api_key,omitempty"` // API key for authentication
|
||||
Headers map[string]string `json:"headers,omitempty"` // Additional headers
|
||||
|
||||
// Connection settings
|
||||
Timeout time.Duration `json:"timeout,omitempty"` // Request timeout (optional)
|
||||
}
|
||||
|
||||
type WeaviateGrpcConfig struct {
|
||||
// Host is the host of the weaviate server (host:port).
|
||||
// If host is without a port number then the 80 port for insecured and 443 port for secured connections will be used.
|
||||
Host *schemas.EnvVar `json:"host"`
|
||||
// Secured is a boolean flag indicating if the connection is secured
|
||||
Secured bool `json:"secured"`
|
||||
}
|
||||
|
||||
// WeaviateStore represents the Weaviate vector store.
|
||||
type WeaviateStore struct {
|
||||
client *weaviate.Client
|
||||
config *WeaviateConfig
|
||||
logger schemas.Logger
|
||||
}
|
||||
|
||||
// Ping checks if the Weaviate server is reachable.
|
||||
func (s *WeaviateStore) Ping(ctx context.Context) error {
|
||||
_, err := s.client.Misc().MetaGetter().Do(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
// Add stores a new object (with or without embedding)
|
||||
func (s *WeaviateStore) Add(ctx context.Context, className string, id string, embedding []float32, metadata map[string]interface{}) error {
|
||||
if strings.TrimSpace(id) == "" {
|
||||
return fmt.Errorf("id is required")
|
||||
}
|
||||
|
||||
obj := &models.Object{
|
||||
Class: className,
|
||||
Properties: metadata,
|
||||
}
|
||||
|
||||
var err error
|
||||
if len(embedding) > 0 {
|
||||
_, err = s.client.Data().Creator().
|
||||
WithClassName(className).
|
||||
WithID(id).
|
||||
WithProperties(obj.Properties).
|
||||
WithVector(embedding).
|
||||
Do(ctx)
|
||||
} else {
|
||||
_, err = s.client.Data().Creator().
|
||||
WithClassName(className).
|
||||
WithID(id).
|
||||
WithProperties(obj.Properties).
|
||||
Do(ctx)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// GetChunk returns the "metadata" for a single key
|
||||
func (s *WeaviateStore) GetChunk(ctx context.Context, className string, id string) (SearchResult, error) {
|
||||
obj, err := s.client.Data().ObjectsGetter().
|
||||
WithClassName(className).
|
||||
WithID(id).
|
||||
Do(ctx)
|
||||
if err != nil {
|
||||
return SearchResult{}, err
|
||||
}
|
||||
if len(obj) == 0 {
|
||||
return SearchResult{}, fmt.Errorf("not found: %s", id)
|
||||
}
|
||||
|
||||
props, ok := obj[0].Properties.(map[string]interface{})
|
||||
if !ok {
|
||||
return SearchResult{}, fmt.Errorf("invalid properties")
|
||||
}
|
||||
|
||||
return SearchResult{
|
||||
ID: id,
|
||||
Score: nil,
|
||||
Properties: props,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetChunks returns multiple objects by ID
|
||||
func (s *WeaviateStore) GetChunks(ctx context.Context, className string, ids []string) ([]SearchResult, error) {
|
||||
out := make([]SearchResult, 0, len(ids))
|
||||
for _, id := range ids {
|
||||
obj, err := s.client.Data().ObjectsGetter().
|
||||
WithClassName(className).
|
||||
WithID(id).
|
||||
Do(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(obj) > 0 {
|
||||
props, ok := obj[0].Properties.(map[string]interface{})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid properties")
|
||||
}
|
||||
out = append(out, SearchResult{
|
||||
ID: id,
|
||||
Score: nil,
|
||||
Properties: props,
|
||||
})
|
||||
}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// GetAll with filtering + pagination
|
||||
func (s *WeaviateStore) GetAll(ctx context.Context, className string, queries []Query, selectFields []string, cursor *string, limit int64) ([]SearchResult, *string, error) {
|
||||
where := buildWeaviateFilter(queries)
|
||||
|
||||
fields := []graphql.Field{
|
||||
{Name: "_additional", Fields: []graphql.Field{
|
||||
{Name: "id"},
|
||||
}},
|
||||
}
|
||||
for _, field := range selectFields {
|
||||
fields = append(fields, graphql.Field{Name: field})
|
||||
}
|
||||
|
||||
search := s.client.GraphQL().Get().
|
||||
WithClassName(className).
|
||||
WithLimit(int(limit)).
|
||||
WithFields(fields...)
|
||||
|
||||
if where != nil {
|
||||
search = search.WithWhere(where)
|
||||
}
|
||||
if cursor != nil {
|
||||
search = search.WithAfter(*cursor)
|
||||
}
|
||||
|
||||
resp, err := search.Do(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// Check for GraphQL errors
|
||||
if len(resp.Errors) > 0 {
|
||||
var errorMsgs []string
|
||||
for _, err := range resp.Errors {
|
||||
errorMsgs = append(errorMsgs, err.Message)
|
||||
}
|
||||
return nil, nil, fmt.Errorf("graphql errors: %v", errorMsgs)
|
||||
}
|
||||
|
||||
data, ok := resp.Data["Get"].(map[string]interface{})
|
||||
if !ok {
|
||||
return nil, nil, fmt.Errorf("invalid graphql response: missing 'Get' key, got: %+v", resp.Data)
|
||||
}
|
||||
|
||||
objsRaw, exists := data[className]
|
||||
if !exists {
|
||||
// No results for this class - this is normal, not an error
|
||||
s.logger.Debug(fmt.Sprintf("No results found for class '%s', available classes: %+v", className, data))
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
objs, ok := objsRaw.([]interface{})
|
||||
if !ok {
|
||||
s.logger.Debug(fmt.Sprintf("Class '%s' exists but data is not an array: %+v", className, objsRaw))
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
results := make([]SearchResult, 0, len(objs))
|
||||
var nextCursor *string
|
||||
for _, o := range objs {
|
||||
obj, ok := o.(map[string]interface{})
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
// Convert to SearchResult format for consistency
|
||||
searchResult := SearchResult{
|
||||
Properties: obj,
|
||||
}
|
||||
|
||||
if additional, ok := obj["_additional"].(map[string]interface{}); ok {
|
||||
if id, ok := additional["id"].(string); ok {
|
||||
searchResult.ID = id
|
||||
nextCursor = &id
|
||||
}
|
||||
}
|
||||
|
||||
results = append(results, searchResult)
|
||||
}
|
||||
|
||||
return results, nextCursor, nil
|
||||
}
|
||||
|
||||
// GetNearest with explicit filters only
|
||||
func (s *WeaviateStore) GetNearest(
|
||||
ctx context.Context,
|
||||
className string,
|
||||
vector []float32,
|
||||
queries []Query,
|
||||
selectFields []string,
|
||||
threshold float64,
|
||||
limit int64,
|
||||
) ([]SearchResult, error) {
|
||||
where := buildWeaviateFilter(queries)
|
||||
|
||||
fields := []graphql.Field{
|
||||
{Name: "_additional", Fields: []graphql.Field{
|
||||
{Name: "id"},
|
||||
{Name: "certainty"},
|
||||
}},
|
||||
}
|
||||
|
||||
for _, field := range selectFields {
|
||||
fields = append(fields, graphql.Field{Name: field})
|
||||
}
|
||||
|
||||
nearVector := s.client.GraphQL().NearVectorArgBuilder().
|
||||
WithVector(vector).
|
||||
WithCertainty(float32(threshold))
|
||||
|
||||
search := s.client.GraphQL().Get().
|
||||
WithClassName(className).
|
||||
WithNearVector(nearVector).
|
||||
WithLimit(int(limit)).
|
||||
WithFields(fields...)
|
||||
|
||||
if where != nil {
|
||||
search = search.WithWhere(where)
|
||||
}
|
||||
|
||||
resp, err := search.Do(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Check for GraphQL errors
|
||||
if len(resp.Errors) > 0 {
|
||||
var errorMsgs []string
|
||||
for _, err := range resp.Errors {
|
||||
errorMsgs = append(errorMsgs, err.Message)
|
||||
}
|
||||
return nil, fmt.Errorf("graphql errors: %v", errorMsgs)
|
||||
}
|
||||
|
||||
data, ok := resp.Data["Get"].(map[string]interface{})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid graphql response: missing 'Get' key, got: %+v", resp.Data)
|
||||
}
|
||||
|
||||
objsRaw, exists := data[className]
|
||||
if !exists {
|
||||
// No results for this class - this is normal, not an error
|
||||
s.logger.Debug(fmt.Sprintf("No results found for class '%s', available classes: %+v", className, data))
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
objs, ok := objsRaw.([]interface{})
|
||||
if !ok {
|
||||
s.logger.Debug(fmt.Sprintf("Class '%s' exists but data is not an array: %+v", className, objsRaw))
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
results := make([]SearchResult, 0, len(objs))
|
||||
for _, o := range objs {
|
||||
obj, ok := o.(map[string]interface{})
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
additional, ok := obj["_additional"].(map[string]interface{})
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
// Safely extract ID
|
||||
idRaw, exists := additional["id"]
|
||||
if !exists || idRaw == nil {
|
||||
continue
|
||||
}
|
||||
id, ok := idRaw.(string)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
// Safely extract certainty/score with default value
|
||||
var score float64
|
||||
if certaintyRaw, exists := additional["certainty"]; exists && certaintyRaw != nil {
|
||||
switch v := certaintyRaw.(type) {
|
||||
case float64:
|
||||
score = v
|
||||
case float32:
|
||||
score = float64(v)
|
||||
case int:
|
||||
score = float64(v)
|
||||
case int64:
|
||||
score = float64(v)
|
||||
default:
|
||||
score = 0.0 // Default score if type conversion fails
|
||||
}
|
||||
}
|
||||
|
||||
results = append(results, SearchResult{
|
||||
ID: id,
|
||||
Score: &score,
|
||||
Properties: obj,
|
||||
})
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// Delete removes multiple objects by ID
|
||||
func (s *WeaviateStore) Delete(ctx context.Context, className string, id string) error {
|
||||
return s.client.Data().Deleter().
|
||||
WithClassName(className).
|
||||
WithID(id).
|
||||
Do(ctx)
|
||||
}
|
||||
|
||||
func (s *WeaviateStore) DeleteAll(ctx context.Context, className string, queries []Query) ([]DeleteResult, error) {
|
||||
// Check if class exists first to avoid 500 errors from Weaviate
|
||||
exists, err := s.client.Schema().ClassExistenceChecker().
|
||||
WithClassName(className).
|
||||
Do(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to check class existence: %w", err)
|
||||
}
|
||||
if !exists {
|
||||
return []DeleteResult{}, nil // Class doesn't exist, nothing to delete
|
||||
}
|
||||
|
||||
where := buildWeaviateFilter(queries)
|
||||
|
||||
res, err := s.client.Batch().ObjectsBatchDeleter().
|
||||
WithClassName(className).
|
||||
WithWhere(where).
|
||||
Do(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// NOTE: Weaviate is returning an empty array for Results.Objects, even on successful deletes.
|
||||
results := make([]DeleteResult, 0, len(res.Results.Objects))
|
||||
|
||||
for _, obj := range res.Results.Objects {
|
||||
result := DeleteResult{
|
||||
ID: obj.ID.String(),
|
||||
}
|
||||
|
||||
if obj.Status != nil {
|
||||
switch *obj.Status {
|
||||
case "SUCCESS":
|
||||
result.Status = DeleteStatusSuccess
|
||||
case "FAILED":
|
||||
result.Status = DeleteStatusError
|
||||
|
||||
if obj.Errors != nil {
|
||||
var errorMsgs []string
|
||||
for _, err := range obj.Errors.Error {
|
||||
errorMsgs = append(errorMsgs, err.Message)
|
||||
}
|
||||
|
||||
result.Error = strings.Join(errorMsgs, ", ")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
results = append(results, result)
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func (s *WeaviateStore) Close(ctx context.Context, className string) error {
|
||||
// nothing to close
|
||||
return nil
|
||||
}
|
||||
|
||||
// RequiresVectors returns true because Weaviate's HNSW index
|
||||
// requires vectors for proper object indexing and retrieval.
|
||||
func (s *WeaviateStore) RequiresVectors() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// newWeaviateStore creates a new Weaviate vector store.
|
||||
func newWeaviateStore(ctx context.Context, config *WeaviateConfig, logger schemas.Logger) (*WeaviateStore, error) {
|
||||
// Validate required config
|
||||
if config.Scheme == "" || (config.Host == nil || config.Host.GetValue() == "") {
|
||||
return nil, fmt.Errorf("weaviate scheme and host are required")
|
||||
}
|
||||
// Build client configuration
|
||||
cfg := weaviate.Config{
|
||||
Scheme: config.Scheme,
|
||||
Host: config.Host.GetValue(),
|
||||
}
|
||||
|
||||
// Add authentication if provided
|
||||
if config.APIKey != nil && config.APIKey.GetValue() != "" {
|
||||
cfg.AuthConfig = auth.ApiKey{Value: config.APIKey.GetValue()}
|
||||
}
|
||||
|
||||
// Add grpc config if provided
|
||||
if config.GrpcConfig != nil {
|
||||
if config.GrpcConfig.Host == nil || config.GrpcConfig.Host.GetValue() == "" {
|
||||
return nil, fmt.Errorf("weaviate grpc host is required")
|
||||
}
|
||||
cfg.GrpcConfig = &grpc.Config{
|
||||
Host: config.GrpcConfig.Host.GetValue(),
|
||||
Secured: config.GrpcConfig.Secured,
|
||||
}
|
||||
}
|
||||
|
||||
// Add custom headers if provided
|
||||
if len(config.Headers) > 0 {
|
||||
cfg.Headers = config.Headers
|
||||
}
|
||||
|
||||
// Create client
|
||||
client, err := weaviate.NewClient(cfg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create weaviate client: %w", err)
|
||||
}
|
||||
|
||||
// Test connection with meta endpoint
|
||||
testCtx := ctx
|
||||
if config.Timeout > 0 {
|
||||
var cancel context.CancelFunc
|
||||
testCtx, cancel = context.WithTimeout(ctx, config.Timeout)
|
||||
defer cancel()
|
||||
}
|
||||
|
||||
_, err = client.Misc().MetaGetter().Do(testCtx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to connect to weaviate: %w", err)
|
||||
}
|
||||
|
||||
store := &WeaviateStore{
|
||||
client: client,
|
||||
config: config,
|
||||
logger: logger,
|
||||
}
|
||||
|
||||
return store, nil
|
||||
}
|
||||
|
||||
func (s *WeaviateStore) CreateNamespace(ctx context.Context, className string, dimension int, properties map[string]VectorStoreProperties) error {
|
||||
// Check if class exists
|
||||
exists, err := s.client.Schema().ClassExistenceChecker().
|
||||
WithClassName(className).
|
||||
Do(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check class existence: %w", err)
|
||||
}
|
||||
|
||||
if exists {
|
||||
return nil // Schema already exists
|
||||
}
|
||||
|
||||
// Create properties
|
||||
weaviateProperties := []*models.Property{}
|
||||
for name, prop := range properties {
|
||||
var dataType []string
|
||||
switch prop.DataType {
|
||||
case VectorStorePropertyTypeString:
|
||||
dataType = []string{"string"}
|
||||
case VectorStorePropertyTypeInteger:
|
||||
dataType = []string{"int"}
|
||||
case VectorStorePropertyTypeBoolean:
|
||||
dataType = []string{"boolean"}
|
||||
case VectorStorePropertyTypeStringArray:
|
||||
dataType = []string{"string[]"}
|
||||
}
|
||||
|
||||
weaviateProperties = append(weaviateProperties, &models.Property{
|
||||
Name: name,
|
||||
DataType: dataType,
|
||||
Description: prop.Description,
|
||||
})
|
||||
}
|
||||
|
||||
// Create class schema with all fields we need
|
||||
classSchema := &models.Class{
|
||||
Class: className,
|
||||
Properties: weaviateProperties,
|
||||
VectorIndexType: "hnsw",
|
||||
Vectorizer: "none", // We provide our own vectors
|
||||
}
|
||||
|
||||
if dimension > 0 {
|
||||
classSchema.VectorIndexConfig = map[string]interface{}{
|
||||
"vectorDimensions": dimension,
|
||||
}
|
||||
}
|
||||
|
||||
err = s.client.Schema().ClassCreator().
|
||||
WithClass(classSchema).
|
||||
Do(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create class schema: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *WeaviateStore) DeleteNamespace(ctx context.Context, className string) error {
|
||||
exists, err := s.client.Schema().ClassExistenceChecker().
|
||||
WithClassName(className).
|
||||
Do(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check class existence: %w", err)
|
||||
}
|
||||
if !exists {
|
||||
return nil // Schema already does not exist
|
||||
} else {
|
||||
return s.client.Schema().ClassDeleter().
|
||||
WithClassName(className).
|
||||
Do(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
// buildWeaviateFilter converts []Query → Weaviate WhereFilter
|
||||
func buildWeaviateFilter(queries []Query) *filters.WhereBuilder {
|
||||
if len(queries) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var operands []*filters.WhereBuilder
|
||||
for _, q := range queries {
|
||||
// Convert string operator to filters operator
|
||||
operator := convertOperator(q.Operator)
|
||||
|
||||
fieldPath := strings.Split(q.Field, ".")
|
||||
|
||||
whereClause := filters.Where().
|
||||
WithPath(fieldPath).
|
||||
WithOperator(operator)
|
||||
|
||||
// Special handling for IsNull and IsNotNull
|
||||
switch q.Operator {
|
||||
case QueryOperatorIsNull:
|
||||
whereClause = whereClause.WithValueBoolean(true)
|
||||
case QueryOperatorIsNotNull:
|
||||
whereClause = whereClause.WithValueBoolean(false)
|
||||
default:
|
||||
// Set value based on type
|
||||
switch v := q.Value.(type) {
|
||||
case string:
|
||||
whereClause = whereClause.WithValueString(v)
|
||||
case int:
|
||||
whereClause = whereClause.WithValueInt(int64(v))
|
||||
case int64:
|
||||
whereClause = whereClause.WithValueInt(v)
|
||||
case float32:
|
||||
whereClause = whereClause.WithValueNumber(float64(v))
|
||||
case float64:
|
||||
whereClause = whereClause.WithValueNumber(v)
|
||||
case bool:
|
||||
whereClause = whereClause.WithValueBoolean(v)
|
||||
default:
|
||||
// Fallback to string conversion
|
||||
whereClause = whereClause.WithValueString(fmt.Sprintf("%v", v))
|
||||
}
|
||||
}
|
||||
|
||||
operands = append(operands, whereClause)
|
||||
}
|
||||
|
||||
if len(operands) == 1 {
|
||||
return operands[0]
|
||||
}
|
||||
|
||||
// Create AND filter for multiple operands
|
||||
return filters.Where().
|
||||
WithOperator(filters.And).
|
||||
WithOperands(operands)
|
||||
}
|
||||
|
||||
// convertOperator converts string operator to filters operator
|
||||
func convertOperator(op QueryOperator) filters.WhereOperator {
|
||||
switch op {
|
||||
case QueryOperatorEqual:
|
||||
return filters.Equal
|
||||
case QueryOperatorNotEqual:
|
||||
return filters.NotEqual
|
||||
case QueryOperatorLessThan:
|
||||
return filters.LessThan
|
||||
case QueryOperatorLessThanOrEqual:
|
||||
return filters.LessThanEqual
|
||||
case QueryOperatorGreaterThan:
|
||||
return filters.GreaterThan
|
||||
case QueryOperatorGreaterThanOrEqual:
|
||||
return filters.GreaterThanEqual
|
||||
case QueryOperatorLike:
|
||||
return filters.Like
|
||||
case QueryOperatorContainsAny:
|
||||
return filters.ContainsAny
|
||||
case QueryOperatorContainsAll:
|
||||
return filters.ContainsAll
|
||||
case QueryOperatorIsNull:
|
||||
return filters.IsNull
|
||||
case QueryOperatorIsNotNull: // IsNotNull is not supported by Weaviate, so we use IsNull and negate it.
|
||||
return filters.IsNull
|
||||
default:
|
||||
// Default to Equal if unknown
|
||||
return filters.Equal
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user