638 lines
16 KiB
Go
638 lines
16 KiB
Go
package vectorstore
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/maximhq/bifrost/core/schemas"
|
|
"github.com/weaviate/weaviate-go-client/v5/weaviate"
|
|
"github.com/weaviate/weaviate-go-client/v5/weaviate/auth"
|
|
"github.com/weaviate/weaviate-go-client/v5/weaviate/filters"
|
|
"github.com/weaviate/weaviate-go-client/v5/weaviate/graphql"
|
|
"github.com/weaviate/weaviate-go-client/v5/weaviate/grpc"
|
|
"github.com/weaviate/weaviate/entities/models"
|
|
)
|
|
|
|
// Default values for Weaviate vector index configuration
|
|
const (
|
|
// Default class names (Weaviate prefers PascalCase)
|
|
DefaultClassName = "BifrostStore"
|
|
)
|
|
|
|
// WeaviateConfig represents the configuration for the Weaviate vector store.
|
|
type WeaviateConfig struct {
|
|
// Connection settings
|
|
Scheme string `json:"scheme"` // "http" or "https" - REQUIRED
|
|
Host *schemas.EnvVar `json:"host"` // "localhost:8080" - REQUIRED
|
|
GrpcConfig *WeaviateGrpcConfig `json:"grpc_config,omitempty"` // grpc config for weaviate (optional)
|
|
|
|
// Authentication settings (optional)
|
|
APIKey *schemas.EnvVar `json:"api_key,omitempty"` // API key for authentication
|
|
Headers map[string]string `json:"headers,omitempty"` // Additional headers
|
|
|
|
// Connection settings
|
|
Timeout time.Duration `json:"timeout,omitempty"` // Request timeout (optional)
|
|
}
|
|
|
|
type WeaviateGrpcConfig struct {
|
|
// Host is the host of the weaviate server (host:port).
|
|
// If host is without a port number then the 80 port for insecured and 443 port for secured connections will be used.
|
|
Host *schemas.EnvVar `json:"host"`
|
|
// Secured is a boolean flag indicating if the connection is secured
|
|
Secured bool `json:"secured"`
|
|
}
|
|
|
|
// WeaviateStore represents the Weaviate vector store.
|
|
type WeaviateStore struct {
|
|
client *weaviate.Client
|
|
config *WeaviateConfig
|
|
logger schemas.Logger
|
|
}
|
|
|
|
// Ping checks if the Weaviate server is reachable.
|
|
func (s *WeaviateStore) Ping(ctx context.Context) error {
|
|
_, err := s.client.Misc().MetaGetter().Do(ctx)
|
|
return err
|
|
}
|
|
|
|
// Add stores a new object (with or without embedding)
|
|
func (s *WeaviateStore) Add(ctx context.Context, className string, id string, embedding []float32, metadata map[string]interface{}) error {
|
|
if strings.TrimSpace(id) == "" {
|
|
return fmt.Errorf("id is required")
|
|
}
|
|
|
|
obj := &models.Object{
|
|
Class: className,
|
|
Properties: metadata,
|
|
}
|
|
|
|
var err error
|
|
if len(embedding) > 0 {
|
|
_, err = s.client.Data().Creator().
|
|
WithClassName(className).
|
|
WithID(id).
|
|
WithProperties(obj.Properties).
|
|
WithVector(embedding).
|
|
Do(ctx)
|
|
} else {
|
|
_, err = s.client.Data().Creator().
|
|
WithClassName(className).
|
|
WithID(id).
|
|
WithProperties(obj.Properties).
|
|
Do(ctx)
|
|
}
|
|
|
|
return err
|
|
}
|
|
|
|
// GetChunk returns the "metadata" for a single key
|
|
func (s *WeaviateStore) GetChunk(ctx context.Context, className string, id string) (SearchResult, error) {
|
|
obj, err := s.client.Data().ObjectsGetter().
|
|
WithClassName(className).
|
|
WithID(id).
|
|
Do(ctx)
|
|
if err != nil {
|
|
return SearchResult{}, err
|
|
}
|
|
if len(obj) == 0 {
|
|
return SearchResult{}, fmt.Errorf("not found: %s", id)
|
|
}
|
|
|
|
props, ok := obj[0].Properties.(map[string]interface{})
|
|
if !ok {
|
|
return SearchResult{}, fmt.Errorf("invalid properties")
|
|
}
|
|
|
|
return SearchResult{
|
|
ID: id,
|
|
Score: nil,
|
|
Properties: props,
|
|
}, nil
|
|
}
|
|
|
|
// GetChunks returns multiple objects by ID
|
|
func (s *WeaviateStore) GetChunks(ctx context.Context, className string, ids []string) ([]SearchResult, error) {
|
|
out := make([]SearchResult, 0, len(ids))
|
|
for _, id := range ids {
|
|
obj, err := s.client.Data().ObjectsGetter().
|
|
WithClassName(className).
|
|
WithID(id).
|
|
Do(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if len(obj) > 0 {
|
|
props, ok := obj[0].Properties.(map[string]interface{})
|
|
if !ok {
|
|
return nil, fmt.Errorf("invalid properties")
|
|
}
|
|
out = append(out, SearchResult{
|
|
ID: id,
|
|
Score: nil,
|
|
Properties: props,
|
|
})
|
|
}
|
|
}
|
|
return out, nil
|
|
}
|
|
|
|
// GetAll with filtering + pagination
|
|
func (s *WeaviateStore) GetAll(ctx context.Context, className string, queries []Query, selectFields []string, cursor *string, limit int64) ([]SearchResult, *string, error) {
|
|
where := buildWeaviateFilter(queries)
|
|
|
|
fields := []graphql.Field{
|
|
{Name: "_additional", Fields: []graphql.Field{
|
|
{Name: "id"},
|
|
}},
|
|
}
|
|
for _, field := range selectFields {
|
|
fields = append(fields, graphql.Field{Name: field})
|
|
}
|
|
|
|
search := s.client.GraphQL().Get().
|
|
WithClassName(className).
|
|
WithLimit(int(limit)).
|
|
WithFields(fields...)
|
|
|
|
if where != nil {
|
|
search = search.WithWhere(where)
|
|
}
|
|
if cursor != nil {
|
|
search = search.WithAfter(*cursor)
|
|
}
|
|
|
|
resp, err := search.Do(ctx)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
// Check for GraphQL errors
|
|
if len(resp.Errors) > 0 {
|
|
var errorMsgs []string
|
|
for _, err := range resp.Errors {
|
|
errorMsgs = append(errorMsgs, err.Message)
|
|
}
|
|
return nil, nil, fmt.Errorf("graphql errors: %v", errorMsgs)
|
|
}
|
|
|
|
data, ok := resp.Data["Get"].(map[string]interface{})
|
|
if !ok {
|
|
return nil, nil, fmt.Errorf("invalid graphql response: missing 'Get' key, got: %+v", resp.Data)
|
|
}
|
|
|
|
objsRaw, exists := data[className]
|
|
if !exists {
|
|
// No results for this class - this is normal, not an error
|
|
s.logger.Debug(fmt.Sprintf("No results found for class '%s', available classes: %+v", className, data))
|
|
return nil, nil, nil
|
|
}
|
|
|
|
objs, ok := objsRaw.([]interface{})
|
|
if !ok {
|
|
s.logger.Debug(fmt.Sprintf("Class '%s' exists but data is not an array: %+v", className, objsRaw))
|
|
return nil, nil, nil
|
|
}
|
|
|
|
results := make([]SearchResult, 0, len(objs))
|
|
var nextCursor *string
|
|
for _, o := range objs {
|
|
obj, ok := o.(map[string]interface{})
|
|
if !ok {
|
|
continue
|
|
}
|
|
|
|
// Convert to SearchResult format for consistency
|
|
searchResult := SearchResult{
|
|
Properties: obj,
|
|
}
|
|
|
|
if additional, ok := obj["_additional"].(map[string]interface{}); ok {
|
|
if id, ok := additional["id"].(string); ok {
|
|
searchResult.ID = id
|
|
nextCursor = &id
|
|
}
|
|
}
|
|
|
|
results = append(results, searchResult)
|
|
}
|
|
|
|
return results, nextCursor, nil
|
|
}
|
|
|
|
// GetNearest with explicit filters only
|
|
func (s *WeaviateStore) GetNearest(
|
|
ctx context.Context,
|
|
className string,
|
|
vector []float32,
|
|
queries []Query,
|
|
selectFields []string,
|
|
threshold float64,
|
|
limit int64,
|
|
) ([]SearchResult, error) {
|
|
where := buildWeaviateFilter(queries)
|
|
|
|
fields := []graphql.Field{
|
|
{Name: "_additional", Fields: []graphql.Field{
|
|
{Name: "id"},
|
|
{Name: "certainty"},
|
|
}},
|
|
}
|
|
|
|
for _, field := range selectFields {
|
|
fields = append(fields, graphql.Field{Name: field})
|
|
}
|
|
|
|
nearVector := s.client.GraphQL().NearVectorArgBuilder().
|
|
WithVector(vector).
|
|
WithCertainty(float32(threshold))
|
|
|
|
search := s.client.GraphQL().Get().
|
|
WithClassName(className).
|
|
WithNearVector(nearVector).
|
|
WithLimit(int(limit)).
|
|
WithFields(fields...)
|
|
|
|
if where != nil {
|
|
search = search.WithWhere(where)
|
|
}
|
|
|
|
resp, err := search.Do(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Check for GraphQL errors
|
|
if len(resp.Errors) > 0 {
|
|
var errorMsgs []string
|
|
for _, err := range resp.Errors {
|
|
errorMsgs = append(errorMsgs, err.Message)
|
|
}
|
|
return nil, fmt.Errorf("graphql errors: %v", errorMsgs)
|
|
}
|
|
|
|
data, ok := resp.Data["Get"].(map[string]interface{})
|
|
if !ok {
|
|
return nil, fmt.Errorf("invalid graphql response: missing 'Get' key, got: %+v", resp.Data)
|
|
}
|
|
|
|
objsRaw, exists := data[className]
|
|
if !exists {
|
|
// No results for this class - this is normal, not an error
|
|
s.logger.Debug(fmt.Sprintf("No results found for class '%s', available classes: %+v", className, data))
|
|
return nil, nil
|
|
}
|
|
|
|
objs, ok := objsRaw.([]interface{})
|
|
if !ok {
|
|
s.logger.Debug(fmt.Sprintf("Class '%s' exists but data is not an array: %+v", className, objsRaw))
|
|
return nil, nil
|
|
}
|
|
|
|
results := make([]SearchResult, 0, len(objs))
|
|
for _, o := range objs {
|
|
obj, ok := o.(map[string]interface{})
|
|
if !ok {
|
|
continue
|
|
}
|
|
|
|
additional, ok := obj["_additional"].(map[string]interface{})
|
|
if !ok {
|
|
continue
|
|
}
|
|
|
|
// Safely extract ID
|
|
idRaw, exists := additional["id"]
|
|
if !exists || idRaw == nil {
|
|
continue
|
|
}
|
|
id, ok := idRaw.(string)
|
|
if !ok {
|
|
continue
|
|
}
|
|
|
|
// Safely extract certainty/score with default value
|
|
var score float64
|
|
if certaintyRaw, exists := additional["certainty"]; exists && certaintyRaw != nil {
|
|
switch v := certaintyRaw.(type) {
|
|
case float64:
|
|
score = v
|
|
case float32:
|
|
score = float64(v)
|
|
case int:
|
|
score = float64(v)
|
|
case int64:
|
|
score = float64(v)
|
|
default:
|
|
score = 0.0 // Default score if type conversion fails
|
|
}
|
|
}
|
|
|
|
results = append(results, SearchResult{
|
|
ID: id,
|
|
Score: &score,
|
|
Properties: obj,
|
|
})
|
|
}
|
|
|
|
return results, nil
|
|
}
|
|
|
|
// Delete removes multiple objects by ID
|
|
func (s *WeaviateStore) Delete(ctx context.Context, className string, id string) error {
|
|
return s.client.Data().Deleter().
|
|
WithClassName(className).
|
|
WithID(id).
|
|
Do(ctx)
|
|
}
|
|
|
|
func (s *WeaviateStore) DeleteAll(ctx context.Context, className string, queries []Query) ([]DeleteResult, error) {
|
|
// Check if class exists first to avoid 500 errors from Weaviate
|
|
exists, err := s.client.Schema().ClassExistenceChecker().
|
|
WithClassName(className).
|
|
Do(ctx)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to check class existence: %w", err)
|
|
}
|
|
if !exists {
|
|
return []DeleteResult{}, nil // Class doesn't exist, nothing to delete
|
|
}
|
|
|
|
where := buildWeaviateFilter(queries)
|
|
|
|
res, err := s.client.Batch().ObjectsBatchDeleter().
|
|
WithClassName(className).
|
|
WithWhere(where).
|
|
Do(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// NOTE: Weaviate is returning an empty array for Results.Objects, even on successful deletes.
|
|
results := make([]DeleteResult, 0, len(res.Results.Objects))
|
|
|
|
for _, obj := range res.Results.Objects {
|
|
result := DeleteResult{
|
|
ID: obj.ID.String(),
|
|
}
|
|
|
|
if obj.Status != nil {
|
|
switch *obj.Status {
|
|
case "SUCCESS":
|
|
result.Status = DeleteStatusSuccess
|
|
case "FAILED":
|
|
result.Status = DeleteStatusError
|
|
|
|
if obj.Errors != nil {
|
|
var errorMsgs []string
|
|
for _, err := range obj.Errors.Error {
|
|
errorMsgs = append(errorMsgs, err.Message)
|
|
}
|
|
|
|
result.Error = strings.Join(errorMsgs, ", ")
|
|
}
|
|
}
|
|
}
|
|
|
|
results = append(results, result)
|
|
}
|
|
|
|
return results, nil
|
|
}
|
|
|
|
func (s *WeaviateStore) Close(ctx context.Context, className string) error {
|
|
// nothing to close
|
|
return nil
|
|
}
|
|
|
|
// RequiresVectors returns true because Weaviate's HNSW index
|
|
// requires vectors for proper object indexing and retrieval.
|
|
func (s *WeaviateStore) RequiresVectors() bool {
|
|
return true
|
|
}
|
|
|
|
// newWeaviateStore creates a new Weaviate vector store.
|
|
func newWeaviateStore(ctx context.Context, config *WeaviateConfig, logger schemas.Logger) (*WeaviateStore, error) {
|
|
// Validate required config
|
|
if config.Scheme == "" || (config.Host == nil || config.Host.GetValue() == "") {
|
|
return nil, fmt.Errorf("weaviate scheme and host are required")
|
|
}
|
|
// Build client configuration
|
|
cfg := weaviate.Config{
|
|
Scheme: config.Scheme,
|
|
Host: config.Host.GetValue(),
|
|
}
|
|
|
|
// Add authentication if provided
|
|
if config.APIKey != nil && config.APIKey.GetValue() != "" {
|
|
cfg.AuthConfig = auth.ApiKey{Value: config.APIKey.GetValue()}
|
|
}
|
|
|
|
// Add grpc config if provided
|
|
if config.GrpcConfig != nil {
|
|
if config.GrpcConfig.Host == nil || config.GrpcConfig.Host.GetValue() == "" {
|
|
return nil, fmt.Errorf("weaviate grpc host is required")
|
|
}
|
|
cfg.GrpcConfig = &grpc.Config{
|
|
Host: config.GrpcConfig.Host.GetValue(),
|
|
Secured: config.GrpcConfig.Secured,
|
|
}
|
|
}
|
|
|
|
// Add custom headers if provided
|
|
if len(config.Headers) > 0 {
|
|
cfg.Headers = config.Headers
|
|
}
|
|
|
|
// Create client
|
|
client, err := weaviate.NewClient(cfg)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create weaviate client: %w", err)
|
|
}
|
|
|
|
// Test connection with meta endpoint
|
|
testCtx := ctx
|
|
if config.Timeout > 0 {
|
|
var cancel context.CancelFunc
|
|
testCtx, cancel = context.WithTimeout(ctx, config.Timeout)
|
|
defer cancel()
|
|
}
|
|
|
|
_, err = client.Misc().MetaGetter().Do(testCtx)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to connect to weaviate: %w", err)
|
|
}
|
|
|
|
store := &WeaviateStore{
|
|
client: client,
|
|
config: config,
|
|
logger: logger,
|
|
}
|
|
|
|
return store, nil
|
|
}
|
|
|
|
func (s *WeaviateStore) CreateNamespace(ctx context.Context, className string, dimension int, properties map[string]VectorStoreProperties) error {
|
|
// Check if class exists
|
|
exists, err := s.client.Schema().ClassExistenceChecker().
|
|
WithClassName(className).
|
|
Do(ctx)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to check class existence: %w", err)
|
|
}
|
|
|
|
if exists {
|
|
return nil // Schema already exists
|
|
}
|
|
|
|
// Create properties
|
|
weaviateProperties := []*models.Property{}
|
|
for name, prop := range properties {
|
|
var dataType []string
|
|
switch prop.DataType {
|
|
case VectorStorePropertyTypeString:
|
|
dataType = []string{"string"}
|
|
case VectorStorePropertyTypeInteger:
|
|
dataType = []string{"int"}
|
|
case VectorStorePropertyTypeBoolean:
|
|
dataType = []string{"boolean"}
|
|
case VectorStorePropertyTypeStringArray:
|
|
dataType = []string{"string[]"}
|
|
}
|
|
|
|
weaviateProperties = append(weaviateProperties, &models.Property{
|
|
Name: name,
|
|
DataType: dataType,
|
|
Description: prop.Description,
|
|
})
|
|
}
|
|
|
|
// Create class schema with all fields we need
|
|
classSchema := &models.Class{
|
|
Class: className,
|
|
Properties: weaviateProperties,
|
|
VectorIndexType: "hnsw",
|
|
Vectorizer: "none", // We provide our own vectors
|
|
}
|
|
|
|
if dimension > 0 {
|
|
classSchema.VectorIndexConfig = map[string]interface{}{
|
|
"vectorDimensions": dimension,
|
|
}
|
|
}
|
|
|
|
err = s.client.Schema().ClassCreator().
|
|
WithClass(classSchema).
|
|
Do(ctx)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to create class schema: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *WeaviateStore) DeleteNamespace(ctx context.Context, className string) error {
|
|
exists, err := s.client.Schema().ClassExistenceChecker().
|
|
WithClassName(className).
|
|
Do(ctx)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to check class existence: %w", err)
|
|
}
|
|
if !exists {
|
|
return nil // Schema already does not exist
|
|
} else {
|
|
return s.client.Schema().ClassDeleter().
|
|
WithClassName(className).
|
|
Do(ctx)
|
|
}
|
|
}
|
|
|
|
// buildWeaviateFilter converts []Query → Weaviate WhereFilter
|
|
func buildWeaviateFilter(queries []Query) *filters.WhereBuilder {
|
|
if len(queries) == 0 {
|
|
return nil
|
|
}
|
|
|
|
var operands []*filters.WhereBuilder
|
|
for _, q := range queries {
|
|
// Convert string operator to filters operator
|
|
operator := convertOperator(q.Operator)
|
|
|
|
fieldPath := strings.Split(q.Field, ".")
|
|
|
|
whereClause := filters.Where().
|
|
WithPath(fieldPath).
|
|
WithOperator(operator)
|
|
|
|
// Special handling for IsNull and IsNotNull
|
|
switch q.Operator {
|
|
case QueryOperatorIsNull:
|
|
whereClause = whereClause.WithValueBoolean(true)
|
|
case QueryOperatorIsNotNull:
|
|
whereClause = whereClause.WithValueBoolean(false)
|
|
default:
|
|
// Set value based on type
|
|
switch v := q.Value.(type) {
|
|
case string:
|
|
whereClause = whereClause.WithValueString(v)
|
|
case int:
|
|
whereClause = whereClause.WithValueInt(int64(v))
|
|
case int64:
|
|
whereClause = whereClause.WithValueInt(v)
|
|
case float32:
|
|
whereClause = whereClause.WithValueNumber(float64(v))
|
|
case float64:
|
|
whereClause = whereClause.WithValueNumber(v)
|
|
case bool:
|
|
whereClause = whereClause.WithValueBoolean(v)
|
|
default:
|
|
// Fallback to string conversion
|
|
whereClause = whereClause.WithValueString(fmt.Sprintf("%v", v))
|
|
}
|
|
}
|
|
|
|
operands = append(operands, whereClause)
|
|
}
|
|
|
|
if len(operands) == 1 {
|
|
return operands[0]
|
|
}
|
|
|
|
// Create AND filter for multiple operands
|
|
return filters.Where().
|
|
WithOperator(filters.And).
|
|
WithOperands(operands)
|
|
}
|
|
|
|
// convertOperator converts string operator to filters operator
|
|
func convertOperator(op QueryOperator) filters.WhereOperator {
|
|
switch op {
|
|
case QueryOperatorEqual:
|
|
return filters.Equal
|
|
case QueryOperatorNotEqual:
|
|
return filters.NotEqual
|
|
case QueryOperatorLessThan:
|
|
return filters.LessThan
|
|
case QueryOperatorLessThanOrEqual:
|
|
return filters.LessThanEqual
|
|
case QueryOperatorGreaterThan:
|
|
return filters.GreaterThan
|
|
case QueryOperatorGreaterThanOrEqual:
|
|
return filters.GreaterThanEqual
|
|
case QueryOperatorLike:
|
|
return filters.Like
|
|
case QueryOperatorContainsAny:
|
|
return filters.ContainsAny
|
|
case QueryOperatorContainsAll:
|
|
return filters.ContainsAll
|
|
case QueryOperatorIsNull:
|
|
return filters.IsNull
|
|
case QueryOperatorIsNotNull: // IsNotNull is not supported by Weaviate, so we use IsNull and negate it.
|
|
return filters.IsNull
|
|
default:
|
|
// Default to Equal if unknown
|
|
return filters.Equal
|
|
}
|
|
}
|