first commit
This commit is contained in:
9
framework/vectorstore/errors.go
Normal file
9
framework/vectorstore/errors.go
Normal file
@@ -0,0 +1,9 @@
|
||||
package vectorstore
|
||||
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
ErrNotFound = errors.New("vectorstore: not found")
|
||||
ErrNotSupported = errors.New("vectorstore: operation not supported on this store")
|
||||
ErrQuerySyntax = errors.New("vectorstore: query syntax error")
|
||||
)
|
||||
649
framework/vectorstore/pinecone.go
Normal file
649
framework/vectorstore/pinecone.go
Normal file
@@ -0,0 +1,649 @@
|
||||
package vectorstore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/pinecone-io/go-pinecone/v5/pinecone"
|
||||
"google.golang.org/protobuf/types/known/structpb"
|
||||
)
|
||||
|
||||
// PineconeConfig represents the configuration for the Pinecone vector store.
|
||||
type PineconeConfig struct {
|
||||
APIKey schemas.EnvVar `json:"api_key"` // Pinecone API key - REQUIRED
|
||||
IndexHost schemas.EnvVar `json:"index_host"` // Index host URL from Pinecone console - REQUIRED
|
||||
}
|
||||
|
||||
// PineconeStore represents the Pinecone vector store.
|
||||
type PineconeStore struct {
|
||||
client *pinecone.Client
|
||||
indexConn *pinecone.IndexConnection
|
||||
config *PineconeConfig
|
||||
logger schemas.Logger
|
||||
mu sync.RWMutex // Protects namespaces and dimension
|
||||
namespaces map[string]*pinecone.IndexConnection
|
||||
dimension int // Store dimension for zero vector queries in GetAll
|
||||
}
|
||||
|
||||
// Ping checks if the Pinecone server is reachable.
|
||||
func (s *PineconeStore) Ping(ctx context.Context) error {
|
||||
_, err := s.indexConn.DescribeIndexStats(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
// CreateNamespace creates a new namespace in the Pinecone vector store.
|
||||
// Note: Pinecone namespaces are created implicitly when upserting vectors.
|
||||
// This method is a no-op but ensures the connection is valid.
|
||||
func (s *PineconeStore) CreateNamespace(ctx context.Context, namespace string, dimension int, properties map[string]VectorStoreProperties) error {
|
||||
// Store dimension for use in GetAll (zero vector queries)
|
||||
s.mu.Lock()
|
||||
s.dimension = dimension
|
||||
s.mu.Unlock()
|
||||
|
||||
// Pinecone namespaces are created automatically on first upsert.
|
||||
// We just verify the index connection is valid.
|
||||
_, err := s.indexConn.DescribeIndexStats(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to verify index connection: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteNamespace deletes a namespace from the Pinecone vector store.
|
||||
func (s *PineconeStore) DeleteNamespace(ctx context.Context, namespace string) error {
|
||||
idxConn, err := s.getNamespaceConnection(namespace)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return idxConn.DeleteAllVectorsInNamespace(ctx)
|
||||
}
|
||||
|
||||
// GetChunk retrieves a single vector from the Pinecone vector store.
|
||||
func (s *PineconeStore) GetChunk(ctx context.Context, namespace string, id string) (SearchResult, error) {
|
||||
if strings.TrimSpace(id) == "" {
|
||||
return SearchResult{}, fmt.Errorf("id is required")
|
||||
}
|
||||
|
||||
idxConn, err := s.getNamespaceConnection(namespace)
|
||||
if err != nil {
|
||||
return SearchResult{}, err
|
||||
}
|
||||
|
||||
res, err := idxConn.FetchVectors(ctx, []string{id})
|
||||
if err != nil {
|
||||
return SearchResult{}, fmt.Errorf("failed to fetch vector: %w", err)
|
||||
}
|
||||
|
||||
if len(res.Vectors) == 0 {
|
||||
return SearchResult{}, fmt.Errorf("not found: %s", id)
|
||||
}
|
||||
|
||||
vec, exists := res.Vectors[id]
|
||||
if !exists || vec == nil {
|
||||
return SearchResult{}, fmt.Errorf("not found: %s", id)
|
||||
}
|
||||
|
||||
return SearchResult{
|
||||
ID: id,
|
||||
Properties: metadataToMap(vec.Metadata),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetChunks retrieves multiple vectors from the Pinecone vector store.
|
||||
func (s *PineconeStore) GetChunks(ctx context.Context, namespace string, ids []string) ([]SearchResult, error) {
|
||||
if len(ids) == 0 {
|
||||
return []SearchResult{}, nil
|
||||
}
|
||||
|
||||
// Filter out empty IDs
|
||||
validIDs := make([]string, 0, len(ids))
|
||||
for _, id := range ids {
|
||||
if strings.TrimSpace(id) != "" {
|
||||
validIDs = append(validIDs, id)
|
||||
}
|
||||
}
|
||||
|
||||
if len(validIDs) == 0 {
|
||||
return []SearchResult{}, nil
|
||||
}
|
||||
|
||||
idxConn, err := s.getNamespaceConnection(namespace)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
res, err := idxConn.FetchVectors(ctx, validIDs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to fetch vectors: %w", err)
|
||||
}
|
||||
|
||||
results := make([]SearchResult, 0, len(res.Vectors))
|
||||
for id, vec := range res.Vectors {
|
||||
if vec != nil {
|
||||
results = append(results, SearchResult{
|
||||
ID: id,
|
||||
Properties: metadataToMap(vec.Metadata),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// GetAll retrieves all vectors with optional filtering and pagination.
|
||||
// Note: This implementation uses QueryByVectorValues with a zero vector instead of ListVectors
|
||||
// because ListVectors has severe eventual consistency issues on Pinecone Serverless/Starter indexes.
|
||||
// The metadata filtering is done server-side by Pinecone, providing much better consistency.
|
||||
func (s *PineconeStore) GetAll(ctx context.Context, namespace string, queries []Query, selectFields []string, cursor *string, limit int64) ([]SearchResult, *string, error) {
|
||||
idxConn, err := s.getNamespaceConnection(namespace)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
topK := uint32(limit)
|
||||
if limit <= 0 {
|
||||
topK = 100
|
||||
}
|
||||
|
||||
// Create zero vector for query - this allows us to use QueryByVectorValues
|
||||
// which has much better consistency than ListVectors
|
||||
s.mu.RLock()
|
||||
dim := s.dimension
|
||||
s.mu.RUnlock()
|
||||
if dim <= 0 {
|
||||
return nil, nil, fmt.Errorf("dimension not set: CreateNamespace must be called before GetAll to set the vector dimension")
|
||||
}
|
||||
zeroVector := make([]float32, dim)
|
||||
|
||||
queryReq := &pinecone.QueryByVectorValuesRequest{
|
||||
Vector: zeroVector,
|
||||
TopK: topK,
|
||||
IncludeValues: false,
|
||||
IncludeMetadata: true,
|
||||
}
|
||||
|
||||
// Build metadata filter from queries - filtering is done server-side
|
||||
if len(queries) > 0 {
|
||||
filter, filterErr := buildPineconeFilter(queries)
|
||||
if filterErr != nil {
|
||||
s.logger.Warn("failed to build pinecone filter, queries may not be applied: %v", filterErr)
|
||||
}
|
||||
if filter != nil {
|
||||
queryReq.MetadataFilter = filter
|
||||
}
|
||||
}
|
||||
|
||||
res, err := idxConn.QueryByVectorValues(ctx, queryReq)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to query vectors: %w", err)
|
||||
}
|
||||
|
||||
results := make([]SearchResult, 0, len(res.Matches))
|
||||
for _, match := range res.Matches {
|
||||
if match.Vector == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
props := metadataToMap(match.Vector.Metadata)
|
||||
filteredProps := filterPropertiesPinecone(props, selectFields)
|
||||
|
||||
results = append(results, SearchResult{
|
||||
ID: match.Vector.Id,
|
||||
Properties: filteredProps,
|
||||
})
|
||||
}
|
||||
|
||||
// Note: QueryByVectorValues doesn't support pagination tokens like ListVectors
|
||||
// For direct hash lookup (the main use case), we only need 1 result anyway
|
||||
return results, nil, nil
|
||||
}
|
||||
|
||||
// GetNearest retrieves the nearest vectors to a given vector.
|
||||
func (s *PineconeStore) GetNearest(ctx context.Context, namespace string, vector []float32, queries []Query, selectFields []string, threshold float64, limit int64) ([]SearchResult, error) {
|
||||
idxConn, err := s.getNamespaceConnection(namespace)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
topK := uint32(limit)
|
||||
if limit <= 0 {
|
||||
topK = 10
|
||||
}
|
||||
|
||||
queryReq := &pinecone.QueryByVectorValuesRequest{
|
||||
Vector: vector,
|
||||
TopK: topK,
|
||||
IncludeValues: false,
|
||||
IncludeMetadata: true,
|
||||
}
|
||||
|
||||
// Build metadata filter from queries
|
||||
if len(queries) > 0 {
|
||||
filter, err := buildPineconeFilter(queries)
|
||||
if err != nil {
|
||||
s.logger.Debug(fmt.Sprintf("failed to build pinecone filter: %v", err))
|
||||
} else if filter != nil {
|
||||
queryReq.MetadataFilter = filter
|
||||
}
|
||||
}
|
||||
|
||||
res, err := idxConn.QueryByVectorValues(ctx, queryReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query vectors: %w", err)
|
||||
}
|
||||
|
||||
results := make([]SearchResult, 0, len(res.Matches))
|
||||
for _, match := range res.Matches {
|
||||
if match.Vector == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
score := float64(match.Score)
|
||||
|
||||
// Apply threshold filter
|
||||
if score < threshold {
|
||||
continue
|
||||
}
|
||||
|
||||
props := metadataToMap(match.Vector.Metadata)
|
||||
filteredProps := filterPropertiesPinecone(props, selectFields)
|
||||
|
||||
results = append(results, SearchResult{
|
||||
ID: match.Vector.Id,
|
||||
Score: &score,
|
||||
Properties: filteredProps,
|
||||
})
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// convertMetadataForStructpb converts metadata map to be compatible with structpb.NewStruct.
|
||||
// Specifically, it converts []string to []interface{} since structpb doesn't handle []string directly.
|
||||
func convertMetadataForStructpb(metadata map[string]interface{}) map[string]interface{} {
|
||||
if metadata == nil {
|
||||
return nil
|
||||
}
|
||||
converted := make(map[string]interface{}, len(metadata))
|
||||
for k, v := range metadata {
|
||||
switch val := v.(type) {
|
||||
case []string:
|
||||
// Convert []string to []interface{}
|
||||
interfaceSlice := make([]interface{}, len(val))
|
||||
for i, s := range val {
|
||||
interfaceSlice[i] = s
|
||||
}
|
||||
converted[k] = interfaceSlice
|
||||
default:
|
||||
converted[k] = v
|
||||
}
|
||||
}
|
||||
return converted
|
||||
}
|
||||
|
||||
// Add stores a new vector in the Pinecone vector store.
|
||||
func (s *PineconeStore) Add(ctx context.Context, namespace string, id string, embedding []float32, metadata map[string]interface{}) error {
|
||||
if strings.TrimSpace(id) == "" {
|
||||
return fmt.Errorf("id is required")
|
||||
}
|
||||
|
||||
idxConn, err := s.getNamespaceConnection(namespace)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Convert metadata to structpb (handle []string -> []interface{} conversion)
|
||||
var pbMetadata *structpb.Struct
|
||||
if len(metadata) > 0 {
|
||||
convertedMetadata := convertMetadataForStructpb(metadata)
|
||||
pbMetadata, err = structpb.NewStruct(convertedMetadata)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to convert metadata: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
vec := &pinecone.Vector{
|
||||
Id: id,
|
||||
Metadata: pbMetadata,
|
||||
}
|
||||
|
||||
if len(embedding) > 0 {
|
||||
vec.Values = &embedding
|
||||
}
|
||||
|
||||
_, err = idxConn.UpsertVectors(ctx, []*pinecone.Vector{vec})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to upsert vector: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete removes a vector from the Pinecone vector store.
|
||||
func (s *PineconeStore) Delete(ctx context.Context, namespace string, id string) error {
|
||||
if strings.TrimSpace(id) == "" {
|
||||
return fmt.Errorf("id is required")
|
||||
}
|
||||
|
||||
idxConn, err := s.getNamespaceConnection(namespace)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return idxConn.DeleteVectorsById(ctx, []string{id})
|
||||
}
|
||||
|
||||
// DeleteAll removes multiple vectors matching the filter.
|
||||
func (s *PineconeStore) DeleteAll(ctx context.Context, namespace string, queries []Query) ([]DeleteResult, error) {
|
||||
idxConn, err := s.getNamespaceConnection(namespace)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// If we have queries, use filter-based deletion
|
||||
if len(queries) > 0 {
|
||||
filter, err := buildPineconeFilter(queries)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to build filter: %w", err)
|
||||
}
|
||||
|
||||
if filter != nil {
|
||||
err = idxConn.DeleteVectorsByFilter(ctx, filter)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to delete vectors by filter: %w", err)
|
||||
}
|
||||
// Pinecone doesn't return individual results for filter-based deletion
|
||||
return []DeleteResult{}, nil
|
||||
}
|
||||
}
|
||||
|
||||
// If no queries, list and delete all vectors in the namespace
|
||||
listRes, err := idxConn.ListVectors(ctx, &pinecone.ListVectorsRequest{})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list vectors: %w", err)
|
||||
}
|
||||
|
||||
if len(listRes.VectorIds) == 0 {
|
||||
return []DeleteResult{}, nil
|
||||
}
|
||||
|
||||
// Convert []*string to []string
|
||||
deleteIDs := make([]string, 0, len(listRes.VectorIds))
|
||||
for _, id := range listRes.VectorIds {
|
||||
if id != nil {
|
||||
deleteIDs = append(deleteIDs, *id)
|
||||
}
|
||||
}
|
||||
|
||||
results := make([]DeleteResult, len(deleteIDs))
|
||||
for i, id := range deleteIDs {
|
||||
results[i] = DeleteResult{
|
||||
ID: id,
|
||||
Status: DeleteStatusSuccess,
|
||||
}
|
||||
}
|
||||
|
||||
err = idxConn.DeleteVectorsById(ctx, deleteIDs)
|
||||
if err != nil {
|
||||
for i := range results {
|
||||
results[i].Status = DeleteStatusError
|
||||
results[i].Error = err.Error()
|
||||
}
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// Close closes the Pinecone client connection.
|
||||
// If namespace is non-empty, only that namespace connection is closed.
|
||||
// If namespace is empty, all connections (indexConn and all namespaces) are closed.
|
||||
func (s *PineconeStore) Close(ctx context.Context, namespace string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
// If a specific namespace is provided, close only that connection
|
||||
if namespace != "" {
|
||||
if conn, exists := s.namespaces[namespace]; exists && conn != nil {
|
||||
conn.Close()
|
||||
delete(s.namespaces, namespace)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
// Close all connections when namespace is empty
|
||||
var errs []error
|
||||
// Close the main index connection
|
||||
if s.indexConn != nil {
|
||||
s.indexConn.Close()
|
||||
s.indexConn = nil
|
||||
}
|
||||
// Close and remove all namespace connections
|
||||
for ns, conn := range s.namespaces {
|
||||
if conn != nil {
|
||||
conn.Close()
|
||||
}
|
||||
delete(s.namespaces, ns)
|
||||
}
|
||||
// Return aggregated errors if any occurred
|
||||
if len(errs) > 0 {
|
||||
return fmt.Errorf("errors closing connections: %v", errs)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RequiresVectors returns true because Pinecone is a dedicated vector database
|
||||
// that requires vectors for all entries with a specific dimension.
|
||||
func (s *PineconeStore) RequiresVectors() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// newPineconeStore creates a new Pinecone vector store.
|
||||
func newPineconeStore(ctx context.Context, config *PineconeConfig, logger schemas.Logger) (*PineconeStore, error) {
|
||||
if strings.TrimSpace(config.APIKey.GetValue()) == "" {
|
||||
return nil, fmt.Errorf("pinecone api_key is required")
|
||||
}
|
||||
if strings.TrimSpace(config.IndexHost.GetValue()) == "" {
|
||||
return nil, fmt.Errorf("pinecone index_host is required")
|
||||
}
|
||||
// Creating new client
|
||||
client, err := pinecone.NewClient(pinecone.NewClientParams{
|
||||
ApiKey: config.APIKey.GetValue(),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create pinecone client: %w", err)
|
||||
}
|
||||
// Prepare the host URL
|
||||
// For local connections (Pinecone Local), prefix with http:// to disable TLS
|
||||
// See: https://docs.pinecone.io/guides/operations/local-development
|
||||
host := config.IndexHost.GetValue()
|
||||
if !strings.HasPrefix(host, "http://") && !strings.HasPrefix(host, "https://") {
|
||||
// Check if this looks like a local connection
|
||||
if strings.HasPrefix(host, "localhost") || strings.HasPrefix(host, "127.0.0.1") {
|
||||
host = "http://" + host
|
||||
}
|
||||
}
|
||||
// Create index connection
|
||||
idxConn, err := client.Index(pinecone.NewIndexConnParams{
|
||||
Host: host,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create index connection: %w", err)
|
||||
}
|
||||
// Verify connection by getting index stats
|
||||
_, err = idxConn.DescribeIndexStats(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to connect to pinecone index: %w", err)
|
||||
}
|
||||
return &PineconeStore{
|
||||
client: client,
|
||||
indexConn: idxConn,
|
||||
config: config,
|
||||
logger: logger,
|
||||
namespaces: make(map[string]*pinecone.IndexConnection),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// getHostWithScheme returns the host with the appropriate scheme.
|
||||
// For local connections (localhost/127.0.0.1), it adds http:// to disable TLS.
|
||||
func (s *PineconeStore) getHostWithScheme() string {
|
||||
host := s.config.IndexHost.GetValue()
|
||||
if !strings.HasPrefix(host, "http://") && !strings.HasPrefix(host, "https://") {
|
||||
if strings.HasPrefix(host, "localhost") || strings.HasPrefix(host, "127.0.0.1") {
|
||||
return "http://" + host
|
||||
}
|
||||
}
|
||||
return host
|
||||
}
|
||||
|
||||
// getNamespaceConnection returns or creates a connection for the given namespace.
|
||||
func (s *PineconeStore) getNamespaceConnection(namespace string) (*pinecone.IndexConnection, error) {
|
||||
if namespace == "" {
|
||||
return s.indexConn, nil
|
||||
}
|
||||
// Check if we already have a connection for this namespace (optimistic read)
|
||||
s.mu.RLock()
|
||||
if conn, exists := s.namespaces[namespace]; exists {
|
||||
s.mu.RUnlock()
|
||||
return conn, nil
|
||||
}
|
||||
s.mu.RUnlock()
|
||||
// Acquire write lock to create new connection
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
// Double-check after acquiring write lock (another goroutine may have created it)
|
||||
if conn, exists := s.namespaces[namespace]; exists {
|
||||
return conn, nil
|
||||
}
|
||||
// Create a new connection for this namespace
|
||||
conn, err := s.client.Index(pinecone.NewIndexConnParams{
|
||||
Host: s.getHostWithScheme(),
|
||||
Namespace: namespace,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create namespace connection: %w", err)
|
||||
}
|
||||
s.namespaces[namespace] = conn
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// metadataToMap converts protobuf Struct to map[string]interface{}.
|
||||
func metadataToMap(metadata *structpb.Struct) map[string]interface{} {
|
||||
if metadata == nil {
|
||||
return make(map[string]interface{})
|
||||
}
|
||||
return metadata.AsMap()
|
||||
}
|
||||
|
||||
// filterPropertiesPinecone filters properties based on selected fields.
|
||||
func filterPropertiesPinecone(props map[string]interface{}, selectFields []string) map[string]interface{} {
|
||||
if len(selectFields) == 0 {
|
||||
return props
|
||||
}
|
||||
filtered := make(map[string]interface{}, len(selectFields))
|
||||
for _, field := range selectFields {
|
||||
if val, ok := props[field]; ok {
|
||||
filtered[field] = val
|
||||
}
|
||||
}
|
||||
return filtered
|
||||
}
|
||||
|
||||
// matchesQueries checks if properties match all query conditions.
|
||||
func matchesQueries(props map[string]interface{}, queries []Query) bool {
|
||||
if len(queries) == 0 {
|
||||
return true
|
||||
}
|
||||
for _, q := range queries {
|
||||
val, exists := props[q.Field]
|
||||
if !matchesQuery(val, exists, q) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// matchesQuery checks if a single value matches a query condition.
|
||||
func matchesQuery(val interface{}, exists bool, q Query) bool {
|
||||
switch q.Operator {
|
||||
case QueryOperatorIsNull:
|
||||
return !exists || val == nil
|
||||
case QueryOperatorIsNotNull:
|
||||
return exists && val != nil
|
||||
case QueryOperatorEqual:
|
||||
return exists && fmt.Sprintf("%v", val) == fmt.Sprintf("%v", q.Value)
|
||||
case QueryOperatorNotEqual:
|
||||
return !exists || fmt.Sprintf("%v", val) != fmt.Sprintf("%v", q.Value)
|
||||
default:
|
||||
// For complex operators, default to true (filter at query time)
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// buildPineconeFilter converts queries to Pinecone metadata filter.
|
||||
func buildPineconeFilter(queries []Query) (*structpb.Struct, error) {
|
||||
if len(queries) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
filterMap := make(map[string]interface{})
|
||||
|
||||
for _, q := range queries {
|
||||
condition := buildPineconeCondition(q)
|
||||
if condition != nil {
|
||||
filterMap[q.Field] = condition
|
||||
}
|
||||
}
|
||||
|
||||
if len(filterMap) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return structpb.NewStruct(filterMap)
|
||||
}
|
||||
|
||||
// buildPineconeCondition builds a single Pinecone filter condition.
|
||||
func buildPineconeCondition(q Query) interface{} {
|
||||
switch q.Operator {
|
||||
case QueryOperatorEqual:
|
||||
return map[string]interface{}{"$eq": q.Value}
|
||||
case QueryOperatorNotEqual:
|
||||
return map[string]interface{}{"$ne": q.Value}
|
||||
case QueryOperatorGreaterThan:
|
||||
return map[string]interface{}{"$gt": q.Value}
|
||||
case QueryOperatorGreaterThanOrEqual:
|
||||
return map[string]interface{}{"$gte": q.Value}
|
||||
case QueryOperatorLessThan:
|
||||
return map[string]interface{}{"$lt": q.Value}
|
||||
case QueryOperatorLessThanOrEqual:
|
||||
return map[string]interface{}{"$lte": q.Value}
|
||||
case QueryOperatorIsNull:
|
||||
return map[string]interface{}{"$eq": nil}
|
||||
case QueryOperatorIsNotNull:
|
||||
return map[string]interface{}{"$ne": nil}
|
||||
case QueryOperatorContainsAny:
|
||||
return map[string]interface{}{"$in": q.Value}
|
||||
case QueryOperatorContainsAll:
|
||||
// Build an $and array of equality checks so all values must match
|
||||
values, ok := q.Value.([]interface{})
|
||||
if !ok {
|
||||
// Try to convert []string to []interface{}
|
||||
if strValues, ok := q.Value.([]string); ok {
|
||||
values = make([]interface{}, len(strValues))
|
||||
for i, v := range strValues {
|
||||
values[i] = v
|
||||
}
|
||||
} else {
|
||||
// Fallback to single value equality
|
||||
return map[string]interface{}{"$eq": q.Value}
|
||||
}
|
||||
}
|
||||
andConditions := make([]interface{}, len(values))
|
||||
for i, v := range values {
|
||||
andConditions[i] = map[string]interface{}{"$eq": v}
|
||||
}
|
||||
return map[string]interface{}{"$and": andConditions}
|
||||
default:
|
||||
return map[string]interface{}{"$eq": q.Value}
|
||||
}
|
||||
}
|
||||
611
framework/vectorstore/pinecone_test.go
Normal file
611
framework/vectorstore/pinecone_test.go
Normal file
@@ -0,0 +1,611 @@
|
||||
package vectorstore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
bifrost "github.com/maximhq/bifrost/core"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const (
|
||||
PineconeTestTimeout = 30 * time.Second
|
||||
PineconeTestNamespace = "bifrost-test-namespace"
|
||||
PineconeTestDimension = 1536 // Matches text-embedding-3-small dimension
|
||||
PineconeTestDefaultAPIKey = "pclocal" // Pinecone Local doesn't validate API keys
|
||||
PineconeTestDefaultIndexHost = "localhost:5081" // Pinecone Local default port
|
||||
)
|
||||
|
||||
type PineconeTestSetup struct {
|
||||
Store *PineconeStore
|
||||
Logger schemas.Logger
|
||||
Config PineconeConfig
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
func NewPineconeTestSetup(t *testing.T) *PineconeTestSetup {
|
||||
apiKey := schemas.NewEnvVar(getEnvWithDefault("PINECONE_API_KEY", PineconeTestDefaultAPIKey))
|
||||
indexHost := schemas.NewEnvVar(getEnvWithDefault("PINECONE_INDEX_HOST", PineconeTestDefaultIndexHost))
|
||||
|
||||
config := PineconeConfig{
|
||||
APIKey: *apiKey,
|
||||
IndexHost: *indexHost,
|
||||
}
|
||||
|
||||
logger := bifrost.NewDefaultLogger(schemas.LogLevelInfo)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), PineconeTestTimeout)
|
||||
|
||||
store, err := newPineconeStore(ctx, &config, logger)
|
||||
if err != nil {
|
||||
cancel()
|
||||
t.Fatalf("Failed to create Pinecone store: %v", err)
|
||||
}
|
||||
|
||||
setup := &PineconeTestSetup{
|
||||
Store: store,
|
||||
Logger: logger,
|
||||
Config: config,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
return setup
|
||||
}
|
||||
|
||||
func (ts *PineconeTestSetup) Cleanup(t *testing.T) {
|
||||
defer ts.cancel()
|
||||
|
||||
if !testing.Short() {
|
||||
ts.cleanupTestData(t)
|
||||
}
|
||||
|
||||
if err := ts.Store.Close(ts.ctx, PineconeTestNamespace); err != nil {
|
||||
t.Logf("Warning: Failed to close store: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (ts *PineconeTestSetup) cleanupTestData(t *testing.T) {
|
||||
// Delete all vectors in the test namespace
|
||||
err := ts.Store.DeleteNamespace(ts.ctx, PineconeTestNamespace)
|
||||
if err != nil {
|
||||
t.Logf("Warning: Failed to cleanup test namespace: %v", err)
|
||||
}
|
||||
t.Logf("Cleaned up test namespace: %s", PineconeTestNamespace)
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// UNIT TESTS
|
||||
// ============================================================================
|
||||
|
||||
func TestPineconeConfig_Validation(t *testing.T) {
|
||||
logger := bifrost.NewDefaultLogger(schemas.LogLevelInfo)
|
||||
ctx := context.Background()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
config PineconeConfig
|
||||
expectError bool
|
||||
errorMsg string
|
||||
}{
|
||||
{
|
||||
name: "missing api key",
|
||||
config: PineconeConfig{
|
||||
IndexHost: *schemas.NewEnvVar("https://my-index.svc.environment.pinecone.io"),
|
||||
},
|
||||
expectError: true,
|
||||
errorMsg: "pinecone api_key is required",
|
||||
},
|
||||
{
|
||||
name: "missing index host",
|
||||
config: PineconeConfig{
|
||||
APIKey: *schemas.NewEnvVar("test-api-key"),
|
||||
},
|
||||
expectError: true,
|
||||
errorMsg: "pinecone index_host is required",
|
||||
},
|
||||
{
|
||||
name: "empty api key",
|
||||
config: PineconeConfig{
|
||||
APIKey: *schemas.NewEnvVar(""),
|
||||
IndexHost: *schemas.NewEnvVar("https://my-index.svc.environment.pinecone.io"),
|
||||
},
|
||||
expectError: true,
|
||||
errorMsg: "pinecone api_key is required",
|
||||
},
|
||||
{
|
||||
name: "empty index host",
|
||||
config: PineconeConfig{
|
||||
APIKey: *schemas.NewEnvVar("test-api-key"),
|
||||
IndexHost: *schemas.NewEnvVar(""),
|
||||
},
|
||||
expectError: true,
|
||||
errorMsg: "pinecone index_host is required",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
store, err := newPineconeStore(ctx, &tt.config, logger)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, store)
|
||||
if tt.errorMsg != "" {
|
||||
assert.Contains(t, err.Error(), tt.errorMsg)
|
||||
}
|
||||
} else {
|
||||
// Note: This will fail with connection error in unit tests
|
||||
// but should pass config validation
|
||||
if err != nil {
|
||||
assert.Contains(t, err.Error(), "failed to connect")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildPineconeFilter(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
queries []Query
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "empty queries",
|
||||
queries: []Query{},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "single string query",
|
||||
queries: []Query{
|
||||
{Field: "category", Operator: QueryOperatorEqual, Value: "tech"},
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "single numeric query",
|
||||
queries: []Query{
|
||||
{Field: "size", Operator: QueryOperatorGreaterThan, Value: 1000},
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "multiple queries",
|
||||
queries: []Query{
|
||||
{Field: "category", Operator: QueryOperatorEqual, Value: "tech"},
|
||||
{Field: "public", Operator: QueryOperatorEqual, Value: true},
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "not equal query",
|
||||
queries: []Query{
|
||||
{Field: "status", Operator: QueryOperatorNotEqual, Value: "deleted"},
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "range queries",
|
||||
queries: []Query{
|
||||
{Field: "count", Operator: QueryOperatorGreaterThanOrEqual, Value: 10},
|
||||
{Field: "score", Operator: QueryOperatorLessThanOrEqual, Value: 100},
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := buildPineconeFilter(tt.queries)
|
||||
assert.NoError(t, err)
|
||||
|
||||
if tt.expected {
|
||||
assert.NotNil(t, result)
|
||||
} else {
|
||||
assert.Nil(t, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildPineconeCondition(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
query Query
|
||||
expected map[string]interface{}
|
||||
}{
|
||||
{
|
||||
name: "equal operator",
|
||||
query: Query{Field: "category", Operator: QueryOperatorEqual, Value: "tech"},
|
||||
expected: map[string]interface{}{"$eq": "tech"},
|
||||
},
|
||||
{
|
||||
name: "not equal operator",
|
||||
query: Query{Field: "status", Operator: QueryOperatorNotEqual, Value: "deleted"},
|
||||
expected: map[string]interface{}{"$ne": "deleted"},
|
||||
},
|
||||
{
|
||||
name: "greater than operator",
|
||||
query: Query{Field: "count", Operator: QueryOperatorGreaterThan, Value: 10},
|
||||
expected: map[string]interface{}{"$gt": 10},
|
||||
},
|
||||
{
|
||||
name: "greater than or equal operator",
|
||||
query: Query{Field: "count", Operator: QueryOperatorGreaterThanOrEqual, Value: 10},
|
||||
expected: map[string]interface{}{"$gte": 10},
|
||||
},
|
||||
{
|
||||
name: "less than operator",
|
||||
query: Query{Field: "score", Operator: QueryOperatorLessThan, Value: 100},
|
||||
expected: map[string]interface{}{"$lt": 100},
|
||||
},
|
||||
{
|
||||
name: "less than or equal operator",
|
||||
query: Query{Field: "score", Operator: QueryOperatorLessThanOrEqual, Value: 100},
|
||||
expected: map[string]interface{}{"$lte": 100},
|
||||
},
|
||||
{
|
||||
name: "contains any operator",
|
||||
query: Query{Field: "tags", Operator: QueryOperatorContainsAny, Value: []string{"a", "b"}},
|
||||
expected: map[string]interface{}{"$in": []string{"a", "b"}},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := buildPineconeCondition(tt.query)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMatchesQueries(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
props map[string]interface{}
|
||||
queries []Query
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "empty queries matches all",
|
||||
props: map[string]interface{}{"type": "document"},
|
||||
queries: []Query{},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "equal match",
|
||||
props: map[string]interface{}{"type": "document"},
|
||||
queries: []Query{{Field: "type", Operator: QueryOperatorEqual, Value: "document"}},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "equal no match",
|
||||
props: map[string]interface{}{"type": "document"},
|
||||
queries: []Query{{Field: "type", Operator: QueryOperatorEqual, Value: "image"}},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "not equal match",
|
||||
props: map[string]interface{}{"type": "document"},
|
||||
queries: []Query{{Field: "type", Operator: QueryOperatorNotEqual, Value: "image"}},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "is null match",
|
||||
props: map[string]interface{}{"type": "document"},
|
||||
queries: []Query{{Field: "author", Operator: QueryOperatorIsNull, Value: nil}},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "is not null match",
|
||||
props: map[string]interface{}{"type": "document", "author": "alice"},
|
||||
queries: []Query{{Field: "author", Operator: QueryOperatorIsNotNull, Value: nil}},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "multiple queries all match",
|
||||
props: map[string]interface{}{"type": "document", "public": true},
|
||||
queries: []Query{
|
||||
{Field: "type", Operator: QueryOperatorEqual, Value: "document"},
|
||||
{Field: "public", Operator: QueryOperatorEqual, Value: true},
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "multiple queries one fails",
|
||||
props: map[string]interface{}{"type": "document", "public": false},
|
||||
queries: []Query{
|
||||
{Field: "type", Operator: QueryOperatorEqual, Value: "document"},
|
||||
{Field: "public", Operator: QueryOperatorEqual, Value: true},
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := matchesQueries(tt.props, tt.queries)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilterPropertiesPinecone(t *testing.T) {
|
||||
props := map[string]interface{}{
|
||||
"type": "document",
|
||||
"author": "alice",
|
||||
"size": 1024,
|
||||
"public": true,
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
selectFields []string
|
||||
expected map[string]interface{}
|
||||
}{
|
||||
{
|
||||
name: "empty select returns all",
|
||||
selectFields: []string{},
|
||||
expected: props,
|
||||
},
|
||||
{
|
||||
name: "select single field",
|
||||
selectFields: []string{"type"},
|
||||
expected: map[string]interface{}{"type": "document"},
|
||||
},
|
||||
{
|
||||
name: "select multiple fields",
|
||||
selectFields: []string{"type", "author"},
|
||||
expected: map[string]interface{}{"type": "document", "author": "alice"},
|
||||
},
|
||||
{
|
||||
name: "select non-existent field",
|
||||
selectFields: []string{"missing"},
|
||||
expected: map[string]interface{}{},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := filterPropertiesPinecone(props, tt.selectFields)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// INTEGRATION TESTS (require real Pinecone instance)
|
||||
// ============================================================================
|
||||
|
||||
func TestPineconeStore_Integration(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping integration tests in short mode")
|
||||
}
|
||||
|
||||
setup := NewPineconeTestSetup(t)
|
||||
defer setup.Cleanup(t)
|
||||
|
||||
// Test Ping
|
||||
err := setup.Store.Ping(setup.ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test Add and GetChunk
|
||||
key := generateUUID()
|
||||
embedding := generateTestEmbedding(PineconeTestDimension)
|
||||
metadata := map[string]interface{}{
|
||||
"type": "document",
|
||||
"author": "test",
|
||||
}
|
||||
|
||||
err = setup.Store.Add(setup.ctx, PineconeTestNamespace, key, embedding, metadata)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Wait for eventual consistency
|
||||
time.Sleep(2 * time.Second)
|
||||
|
||||
result, err := setup.Store.GetChunk(setup.ctx, PineconeTestNamespace, key)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, key, result.ID)
|
||||
assert.Equal(t, "document", result.Properties["type"])
|
||||
assert.Equal(t, "test", result.Properties["author"])
|
||||
|
||||
// Test GetChunks
|
||||
key2 := generateUUID()
|
||||
err = setup.Store.Add(setup.ctx, PineconeTestNamespace, key2, generateTestEmbedding(PineconeTestDimension), map[string]interface{}{"type": "image"})
|
||||
require.NoError(t, err)
|
||||
|
||||
time.Sleep(2 * time.Second)
|
||||
|
||||
results, err := setup.Store.GetChunks(setup.ctx, PineconeTestNamespace, []string{key, key2})
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, results, 2)
|
||||
}
|
||||
|
||||
func TestPineconeStore_VectorSearch(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping integration tests in short mode")
|
||||
}
|
||||
|
||||
setup := NewPineconeTestSetup(t)
|
||||
defer setup.Cleanup(t)
|
||||
|
||||
// Add test vectors
|
||||
emb := generateTestEmbedding(PineconeTestDimension)
|
||||
err := setup.Store.Add(setup.ctx, PineconeTestNamespace, generateUUID(), emb, map[string]interface{}{"type": "tech"})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = setup.Store.Add(setup.ctx, PineconeTestNamespace, generateUUID(), generateTestEmbedding(PineconeTestDimension), map[string]interface{}{"type": "sports"})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Wait for eventual consistency
|
||||
time.Sleep(3 * time.Second)
|
||||
|
||||
// Test vector similarity search
|
||||
results, err := setup.Store.GetNearest(setup.ctx, PineconeTestNamespace, emb, nil, []string{"type"}, 0.1, 10)
|
||||
require.NoError(t, err)
|
||||
assert.GreaterOrEqual(t, len(results), 1)
|
||||
|
||||
if len(results) > 0 {
|
||||
require.NotNil(t, results[0].Score)
|
||||
}
|
||||
|
||||
// Test with filter
|
||||
queries := []Query{{Field: "type", Operator: QueryOperatorEqual, Value: "tech"}}
|
||||
results, err = setup.Store.GetNearest(setup.ctx, PineconeTestNamespace, emb, queries, []string{"type"}, 0.1, 10)
|
||||
require.NoError(t, err)
|
||||
for _, result := range results {
|
||||
assert.Equal(t, "tech", result.Properties["type"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestPineconeStore_Delete(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping integration tests in short mode")
|
||||
}
|
||||
|
||||
setup := NewPineconeTestSetup(t)
|
||||
defer setup.Cleanup(t)
|
||||
|
||||
// Add a vector
|
||||
key := generateUUID()
|
||||
err := setup.Store.Add(setup.ctx, PineconeTestNamespace, key, generateTestEmbedding(PineconeTestDimension), map[string]interface{}{"type": "to-delete"})
|
||||
require.NoError(t, err)
|
||||
|
||||
time.Sleep(2 * time.Second)
|
||||
|
||||
// Verify it exists
|
||||
_, err = setup.Store.GetChunk(setup.ctx, PineconeTestNamespace, key)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Delete it
|
||||
err = setup.Store.Delete(setup.ctx, PineconeTestNamespace, key)
|
||||
require.NoError(t, err)
|
||||
|
||||
time.Sleep(2 * time.Second)
|
||||
|
||||
// Verify it's gone
|
||||
_, err = setup.Store.GetChunk(setup.ctx, PineconeTestNamespace, key)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "not found")
|
||||
}
|
||||
|
||||
func TestPineconeStore_ErrorHandling(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping integration tests in short mode")
|
||||
}
|
||||
|
||||
setup := NewPineconeTestSetup(t)
|
||||
defer setup.Cleanup(t)
|
||||
|
||||
// Test GetChunk with non-existent ID
|
||||
_, err := setup.Store.GetChunk(setup.ctx, PineconeTestNamespace, generateUUID())
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "not found")
|
||||
|
||||
// Test Add with empty ID
|
||||
err = setup.Store.Add(setup.ctx, PineconeTestNamespace, "", generateTestEmbedding(PineconeTestDimension), map[string]interface{}{"type": "test"})
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "id is required")
|
||||
|
||||
// Test Delete with empty ID
|
||||
err = setup.Store.Delete(setup.ctx, PineconeTestNamespace, "")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "id is required")
|
||||
}
|
||||
|
||||
func TestPineconeStore_SemanticCacheWorkflow(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping integration tests in short mode")
|
||||
}
|
||||
|
||||
setup := NewPineconeTestSetup(t)
|
||||
defer setup.Cleanup(t)
|
||||
|
||||
// Simulate a semantic cache workflow
|
||||
cacheEntries := []struct {
|
||||
key string
|
||||
embedding []float32
|
||||
metadata map[string]interface{}
|
||||
}{
|
||||
{
|
||||
generateUUID(),
|
||||
generateTestEmbedding(PineconeTestDimension),
|
||||
map[string]interface{}{
|
||||
"request_hash": "abc123",
|
||||
"user": "u1",
|
||||
"lang": "en",
|
||||
"response": "answer1",
|
||||
},
|
||||
},
|
||||
{
|
||||
generateUUID(),
|
||||
generateTestEmbedding(PineconeTestDimension),
|
||||
map[string]interface{}{
|
||||
"request_hash": "def456",
|
||||
"user": "u1",
|
||||
"lang": "es",
|
||||
"response": "answer2",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Add cache entries
|
||||
for _, entry := range cacheEntries {
|
||||
err := setup.Store.Add(setup.ctx, PineconeTestNamespace, entry.key, entry.embedding, entry.metadata)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
time.Sleep(3 * time.Second)
|
||||
|
||||
// Test semantic search with user filter
|
||||
userFilter := []Query{{Field: "user", Operator: QueryOperatorEqual, Value: "u1"}}
|
||||
results, err := setup.Store.GetNearest(setup.ctx, PineconeTestNamespace, cacheEntries[0].embedding, userFilter, []string{"request_hash", "user", "lang", "response"}, 0.1, 10)
|
||||
require.NoError(t, err)
|
||||
assert.GreaterOrEqual(t, len(results), 1)
|
||||
|
||||
// Verify user filter worked
|
||||
for _, result := range results {
|
||||
assert.Equal(t, "u1", result.Properties["user"])
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// INTERFACE COMPLIANCE TESTS
|
||||
// ============================================================================
|
||||
|
||||
func TestPineconeStore_InterfaceCompliance(t *testing.T) {
|
||||
var _ VectorStore = (*PineconeStore)(nil)
|
||||
}
|
||||
|
||||
func TestVectorStoreFactory_Pinecone(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping integration tests in short mode")
|
||||
}
|
||||
|
||||
apiKey := schemas.NewEnvVar(getEnvWithDefault("PINECONE_API_KEY", PineconeTestDefaultAPIKey))
|
||||
indexHost := schemas.NewEnvVar(getEnvWithDefault("PINECONE_INDEX_HOST", PineconeTestDefaultIndexHost))
|
||||
|
||||
logger := bifrost.NewDefaultLogger(schemas.LogLevelInfo)
|
||||
config := &Config{
|
||||
Enabled: true,
|
||||
Type: VectorStoreTypePinecone,
|
||||
Config: PineconeConfig{
|
||||
APIKey: *apiKey,
|
||||
IndexHost: *indexHost,
|
||||
},
|
||||
}
|
||||
|
||||
store, err := NewVectorStore(context.Background(), config, logger)
|
||||
if err != nil {
|
||||
t.Skipf("Could not create Pinecone store: %v", err)
|
||||
}
|
||||
defer store.Close(context.Background(), PineconeTestNamespace)
|
||||
|
||||
pineconeStore, ok := store.(*PineconeStore)
|
||||
assert.True(t, ok)
|
||||
assert.NotNil(t, pineconeStore)
|
||||
}
|
||||
609
framework/vectorstore/qdrant.go
Normal file
609
framework/vectorstore/qdrant.go
Normal file
@@ -0,0 +1,609 @@
|
||||
package vectorstore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/qdrant/go-client/qdrant"
|
||||
)
|
||||
|
||||
// QdrantConfig represents the configuration for the Qdrant vector store.
|
||||
type QdrantConfig struct {
|
||||
Host schemas.EnvVar `json:"host"` // Qdrant server host - REQUIRED
|
||||
Port schemas.EnvVar `json:"port"` // Qdrant server port (fallback to 6334 for gRPC)
|
||||
APIKey schemas.EnvVar `json:"api_key,omitempty"` // API key for authentication - Optional
|
||||
UseTLS schemas.EnvVar `json:"use_tls,omitempty"` // Use TLS for connection - Optional
|
||||
}
|
||||
|
||||
// QdrantStore represents the Qdrant vector store.
|
||||
type QdrantStore struct {
|
||||
client *qdrant.Client
|
||||
logger schemas.Logger
|
||||
}
|
||||
|
||||
// Ping checks if the Qdrant server is reachable.
|
||||
func (s *QdrantStore) Ping(ctx context.Context) error {
|
||||
_, err := s.client.HealthCheck(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
// CreateNamespace creates a new collection in the Qdrant vector store.
|
||||
func (s *QdrantStore) CreateNamespace(ctx context.Context, namespace string, dimension int, properties map[string]VectorStoreProperties) error {
|
||||
exists, err := s.client.CollectionExists(ctx, namespace)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check collection existence: %w", err)
|
||||
}
|
||||
|
||||
if !exists {
|
||||
err = s.client.CreateCollection(ctx, &qdrant.CreateCollection{
|
||||
CollectionName: namespace,
|
||||
VectorsConfig: qdrant.NewVectorsConfig(&qdrant.VectorParams{
|
||||
Size: uint64(dimension),
|
||||
Distance: qdrant.Distance_Cosine,
|
||||
}),
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create collection: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
for fieldName, prop := range properties {
|
||||
var fieldType qdrant.FieldType
|
||||
switch prop.DataType {
|
||||
case VectorStorePropertyTypeInteger:
|
||||
fieldType = qdrant.FieldType_FieldTypeInteger
|
||||
case VectorStorePropertyTypeBoolean:
|
||||
fieldType = qdrant.FieldType_FieldTypeBool
|
||||
default:
|
||||
fieldType = qdrant.FieldType_FieldTypeKeyword
|
||||
}
|
||||
|
||||
_, err = s.client.CreateFieldIndex(ctx, &qdrant.CreateFieldIndexCollection{
|
||||
CollectionName: namespace,
|
||||
FieldName: fieldName,
|
||||
FieldType: &fieldType,
|
||||
})
|
||||
if err != nil {
|
||||
s.logger.Debug(fmt.Sprintf("failed to create index for field %s: %v", fieldName, err))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteNamespace deletes a collection from the Qdrant vector store.
|
||||
func (s *QdrantStore) DeleteNamespace(ctx context.Context, namespace string) error {
|
||||
exists, err := s.client.CollectionExists(ctx, namespace)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check collection existence: %w", err)
|
||||
}
|
||||
if !exists {
|
||||
return nil
|
||||
}
|
||||
return s.client.DeleteCollection(ctx, namespace)
|
||||
}
|
||||
|
||||
// GetChunk retrieves a single point from the Qdrant vector store.
|
||||
func (s *QdrantStore) GetChunk(ctx context.Context, namespace string, id string) (SearchResult, error) {
|
||||
if strings.TrimSpace(id) == "" {
|
||||
return SearchResult{}, fmt.Errorf("id is required")
|
||||
}
|
||||
|
||||
pointID, err := parsePointID(id)
|
||||
if err != nil {
|
||||
return SearchResult{}, fmt.Errorf("invalid id format: %w", err)
|
||||
}
|
||||
|
||||
points, err := s.client.Get(ctx, &qdrant.GetPoints{
|
||||
CollectionName: namespace,
|
||||
Ids: []*qdrant.PointId{pointID},
|
||||
WithPayload: qdrant.NewWithPayload(true),
|
||||
})
|
||||
if err != nil {
|
||||
return SearchResult{}, fmt.Errorf("failed to get point: %w", err)
|
||||
}
|
||||
|
||||
if len(points) == 0 {
|
||||
return SearchResult{}, fmt.Errorf("not found: %s", id)
|
||||
}
|
||||
|
||||
return SearchResult{
|
||||
ID: id,
|
||||
Properties: payloadToMap(points[0].Payload),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetChunks retrieves multiple points from the Qdrant vector store.
|
||||
func (s *QdrantStore) GetChunks(ctx context.Context, namespace string, ids []string) ([]SearchResult, error) {
|
||||
if len(ids) == 0 {
|
||||
return []SearchResult{}, nil
|
||||
}
|
||||
|
||||
pointIDs := make([]*qdrant.PointId, 0, len(ids))
|
||||
for _, id := range ids {
|
||||
if strings.TrimSpace(id) == "" {
|
||||
continue
|
||||
}
|
||||
pointID, err := parsePointID(id)
|
||||
if err != nil {
|
||||
s.logger.Debug(fmt.Sprintf("skipping invalid id %s: %v", id, err))
|
||||
continue
|
||||
}
|
||||
pointIDs = append(pointIDs, pointID)
|
||||
}
|
||||
|
||||
if len(pointIDs) == 0 {
|
||||
return []SearchResult{}, nil
|
||||
}
|
||||
|
||||
points, err := s.client.Get(ctx, &qdrant.GetPoints{
|
||||
CollectionName: namespace,
|
||||
Ids: pointIDs,
|
||||
WithPayload: qdrant.NewWithPayload(true),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get points: %w", err)
|
||||
}
|
||||
|
||||
results := make([]SearchResult, 0, len(points))
|
||||
for _, point := range points {
|
||||
results = append(results, SearchResult{
|
||||
ID: pointIDToString(point.Id),
|
||||
Properties: payloadToMap(point.Payload),
|
||||
})
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// GetAll retrieves all points with optional filtering and pagination.
|
||||
func (s *QdrantStore) GetAll(ctx context.Context, namespace string, queries []Query, selectFields []string, cursor *string, limit int64) ([]SearchResult, *string, error) {
|
||||
filter := buildQdrantFilter(queries)
|
||||
|
||||
var offset *qdrant.PointId
|
||||
if cursor != nil && *cursor != "" {
|
||||
var err error
|
||||
offset, err = parsePointID(*cursor)
|
||||
if err != nil {
|
||||
s.logger.Debug(fmt.Sprintf("invalid cursor format: %v", err))
|
||||
}
|
||||
}
|
||||
|
||||
scrollLimit := uint32(limit)
|
||||
if limit <= 0 {
|
||||
scrollLimit = 100
|
||||
}
|
||||
|
||||
scrollResult, err := s.client.Scroll(ctx, &qdrant.ScrollPoints{
|
||||
CollectionName: namespace,
|
||||
Filter: filter,
|
||||
Limit: &scrollLimit,
|
||||
Offset: offset,
|
||||
WithPayload: qdrant.NewWithPayload(true),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to scroll points: %w", err)
|
||||
}
|
||||
|
||||
results := make([]SearchResult, 0, len(scrollResult))
|
||||
var lastID string
|
||||
|
||||
for _, point := range scrollResult {
|
||||
lastID = pointIDToString(point.Id)
|
||||
results = append(results, SearchResult{
|
||||
ID: lastID,
|
||||
Properties: filterProperties(payloadToMap(point.Payload), selectFields),
|
||||
})
|
||||
}
|
||||
|
||||
if len(scrollResult) >= int(scrollLimit) {
|
||||
return results, &lastID, nil
|
||||
}
|
||||
return results, nil, nil
|
||||
}
|
||||
|
||||
// GetNearest retrieves the nearest points to a vector.
|
||||
func (s *QdrantStore) GetNearest(ctx context.Context, namespace string, vector []float32, queries []Query, selectFields []string, threshold float64, limit int64) ([]SearchResult, error) {
|
||||
filter := buildQdrantFilter(queries)
|
||||
|
||||
searchLimit := uint64(limit)
|
||||
if limit <= 0 {
|
||||
searchLimit = 10
|
||||
}
|
||||
|
||||
searchResult, err := s.client.Query(ctx, &qdrant.QueryPoints{
|
||||
CollectionName: namespace,
|
||||
Query: qdrant.NewQuery(vector...),
|
||||
Filter: filter,
|
||||
Limit: &searchLimit,
|
||||
WithPayload: qdrant.NewWithPayload(true),
|
||||
ScoreThreshold: qdrant.PtrOf(float32(threshold)),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to search points: %w", err)
|
||||
}
|
||||
|
||||
results := make([]SearchResult, 0, len(searchResult))
|
||||
for _, point := range searchResult {
|
||||
score := float64(point.Score)
|
||||
results = append(results, SearchResult{
|
||||
ID: pointIDToString(point.Id),
|
||||
Score: &score,
|
||||
Properties: filterProperties(payloadToMap(point.Payload), selectFields),
|
||||
})
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// Add stores a new point in the Qdrant vector store.
|
||||
func (s *QdrantStore) Add(ctx context.Context, namespace string, id string, embedding []float32, metadata map[string]interface{}) error {
|
||||
if strings.TrimSpace(id) == "" {
|
||||
return fmt.Errorf("id is required")
|
||||
}
|
||||
|
||||
pointID, err := parsePointID(id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid id format (must be UUID): %w", err)
|
||||
}
|
||||
|
||||
point := &qdrant.PointStruct{
|
||||
Id: pointID,
|
||||
Payload: mapToPayload(metadata),
|
||||
}
|
||||
if len(embedding) > 0 {
|
||||
point.Vectors = qdrant.NewVectors(embedding...)
|
||||
}
|
||||
|
||||
_, err = s.client.Upsert(ctx, &qdrant.UpsertPoints{
|
||||
CollectionName: namespace,
|
||||
Points: []*qdrant.PointStruct{point},
|
||||
Wait: qdrant.PtrOf(true),
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to upsert point: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete removes a point from the Qdrant vector store.
|
||||
func (s *QdrantStore) Delete(ctx context.Context, namespace string, id string) error {
|
||||
if strings.TrimSpace(id) == "" {
|
||||
return fmt.Errorf("id is required")
|
||||
}
|
||||
|
||||
pointID, err := parsePointID(id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid id format: %w", err)
|
||||
}
|
||||
|
||||
_, err = s.client.Delete(ctx, &qdrant.DeletePoints{
|
||||
CollectionName: namespace,
|
||||
Points: qdrant.NewPointsSelector(pointID),
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
// DeleteAll removes multiple points matching the filter.
|
||||
func (s *QdrantStore) DeleteAll(ctx context.Context, namespace string, queries []Query) ([]DeleteResult, error) {
|
||||
filter := buildQdrantFilter(queries)
|
||||
|
||||
scrollResult, err := s.client.Scroll(ctx, &qdrant.ScrollPoints{
|
||||
CollectionName: namespace,
|
||||
Filter: filter,
|
||||
WithPayload: qdrant.NewWithPayload(false),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to scroll points: %w", err)
|
||||
}
|
||||
|
||||
if len(scrollResult) == 0 {
|
||||
return []DeleteResult{}, nil
|
||||
}
|
||||
|
||||
results := make([]DeleteResult, len(scrollResult))
|
||||
for i, point := range scrollResult {
|
||||
results[i] = DeleteResult{
|
||||
ID: pointIDToString(point.Id),
|
||||
Status: DeleteStatusSuccess,
|
||||
}
|
||||
}
|
||||
|
||||
var deleteErr error
|
||||
if filter != nil {
|
||||
_, deleteErr = s.client.Delete(ctx, &qdrant.DeletePoints{
|
||||
CollectionName: namespace,
|
||||
Points: qdrant.NewPointsSelectorFilter(filter),
|
||||
})
|
||||
} else {
|
||||
pointIDs := make([]*qdrant.PointId, len(scrollResult))
|
||||
for i, point := range scrollResult {
|
||||
pointIDs[i] = point.Id
|
||||
}
|
||||
_, deleteErr = s.client.Delete(ctx, &qdrant.DeletePoints{
|
||||
CollectionName: namespace,
|
||||
Points: qdrant.NewPointsSelectorIDs(pointIDs),
|
||||
})
|
||||
}
|
||||
|
||||
if deleteErr != nil {
|
||||
for i := range results {
|
||||
results[i].Status = DeleteStatusError
|
||||
results[i].Error = deleteErr.Error()
|
||||
}
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// Close closes the Qdrant client connection.
|
||||
func (s *QdrantStore) Close(ctx context.Context, namespace string) error {
|
||||
return s.client.Close()
|
||||
}
|
||||
|
||||
// RequiresVectors returns true because Qdrant is a dedicated vector database
|
||||
// that requires vectors for all points/entries.
|
||||
func (s *QdrantStore) RequiresVectors() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// newQdrantStore creates a new Qdrant vector store.
|
||||
func newQdrantStore(ctx context.Context, config *QdrantConfig, logger schemas.Logger) (*QdrantStore, error) {
|
||||
if strings.TrimSpace(config.Host.GetValue()) == "" {
|
||||
return nil, fmt.Errorf("qdrant host is required")
|
||||
}
|
||||
client, err := qdrant.NewClient(&qdrant.Config{
|
||||
Host: config.Host.GetValue(),
|
||||
Port: config.Port.CoerceInt(6334),
|
||||
APIKey: config.APIKey.GetValue(),
|
||||
UseTLS: config.UseTLS.CoerceBool(false),
|
||||
SkipCompatibilityCheck: true,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create qdrant client: %w", err)
|
||||
}
|
||||
|
||||
_, err = client.HealthCheck(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to connect to qdrant: %w", err)
|
||||
}
|
||||
|
||||
return &QdrantStore{
|
||||
client: client,
|
||||
logger: logger,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func parsePointID(id string) (*qdrant.PointId, error) {
|
||||
if _, err := uuid.Parse(id); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return qdrant.NewID(id), nil
|
||||
}
|
||||
|
||||
func pointIDToString(id *qdrant.PointId) string {
|
||||
if id == nil {
|
||||
return ""
|
||||
}
|
||||
switch v := id.PointIdOptions.(type) {
|
||||
case *qdrant.PointId_Uuid:
|
||||
return v.Uuid
|
||||
case *qdrant.PointId_Num:
|
||||
return fmt.Sprintf("%d", v.Num)
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func payloadToMap(payload map[string]*qdrant.Value) map[string]interface{} {
|
||||
if payload == nil {
|
||||
return make(map[string]interface{})
|
||||
}
|
||||
|
||||
result := make(map[string]interface{}, len(payload))
|
||||
for k, v := range payload {
|
||||
result[k] = valueToInterface(v)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func valueToInterface(v *qdrant.Value) interface{} {
|
||||
if v == nil {
|
||||
return nil
|
||||
}
|
||||
switch val := v.Kind.(type) {
|
||||
case *qdrant.Value_StringValue:
|
||||
return val.StringValue
|
||||
case *qdrant.Value_IntegerValue:
|
||||
return val.IntegerValue
|
||||
case *qdrant.Value_DoubleValue:
|
||||
return val.DoubleValue
|
||||
case *qdrant.Value_BoolValue:
|
||||
return val.BoolValue
|
||||
case *qdrant.Value_ListValue:
|
||||
list := make([]interface{}, len(val.ListValue.Values))
|
||||
for i, item := range val.ListValue.Values {
|
||||
list[i] = valueToInterface(item)
|
||||
}
|
||||
return list
|
||||
case *qdrant.Value_StructValue:
|
||||
return payloadToMap(val.StructValue.Fields)
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func mapToPayload(m map[string]interface{}) map[string]*qdrant.Value {
|
||||
if m == nil {
|
||||
return make(map[string]*qdrant.Value)
|
||||
}
|
||||
// Convert []string to []interface{} since Qdrant's NewValueMap doesn't handle []string directly
|
||||
converted := make(map[string]interface{}, len(m))
|
||||
for k, v := range m {
|
||||
switch val := v.(type) {
|
||||
case []string:
|
||||
// Convert []string to []interface{}
|
||||
interfaceSlice := make([]interface{}, len(val))
|
||||
for i, s := range val {
|
||||
interfaceSlice[i] = s
|
||||
}
|
||||
converted[k] = interfaceSlice
|
||||
default:
|
||||
converted[k] = v
|
||||
}
|
||||
}
|
||||
return qdrant.NewValueMap(converted)
|
||||
}
|
||||
|
||||
func filterProperties(props map[string]interface{}, selectFields []string) map[string]interface{} {
|
||||
if len(selectFields) == 0 {
|
||||
return props
|
||||
}
|
||||
filtered := make(map[string]interface{}, len(selectFields))
|
||||
for _, field := range selectFields {
|
||||
if val, ok := props[field]; ok {
|
||||
filtered[field] = val
|
||||
}
|
||||
}
|
||||
return filtered
|
||||
}
|
||||
|
||||
func buildQdrantFilter(queries []Query) *qdrant.Filter {
|
||||
if len(queries) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var conditions []*qdrant.Condition
|
||||
for _, q := range queries {
|
||||
condition := buildQdrantCondition(q)
|
||||
if condition != nil {
|
||||
conditions = append(conditions, condition)
|
||||
}
|
||||
}
|
||||
|
||||
if len(conditions) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &qdrant.Filter{
|
||||
Must: conditions,
|
||||
}
|
||||
}
|
||||
|
||||
func buildQdrantCondition(q Query) *qdrant.Condition {
|
||||
field := q.Field
|
||||
|
||||
switch q.Operator {
|
||||
case QueryOperatorEqual:
|
||||
return buildMatchCondition(field, q.Value)
|
||||
case QueryOperatorNotEqual:
|
||||
matchCond := buildMatchCondition(field, q.Value)
|
||||
if matchCond == nil {
|
||||
return nil
|
||||
}
|
||||
return qdrant.NewFilterAsCondition(&qdrant.Filter{
|
||||
MustNot: []*qdrant.Condition{matchCond},
|
||||
})
|
||||
case QueryOperatorGreaterThan:
|
||||
return buildRangeCondition(field, q.Value, "gt")
|
||||
case QueryOperatorGreaterThanOrEqual:
|
||||
return buildRangeCondition(field, q.Value, "gte")
|
||||
case QueryOperatorLessThan:
|
||||
return buildRangeCondition(field, q.Value, "lt")
|
||||
case QueryOperatorLessThanOrEqual:
|
||||
return buildRangeCondition(field, q.Value, "lte")
|
||||
case QueryOperatorIsNull:
|
||||
return qdrant.NewIsNull(field)
|
||||
case QueryOperatorIsNotNull:
|
||||
return qdrant.NewFilterAsCondition(&qdrant.Filter{
|
||||
MustNot: []*qdrant.Condition{qdrant.NewIsNull(field)},
|
||||
})
|
||||
case QueryOperatorContainsAny:
|
||||
switch v := q.Value.(type) {
|
||||
case []string:
|
||||
return qdrant.NewMatchKeywords(field, v...)
|
||||
case []int:
|
||||
int64s := make([]int64, len(v))
|
||||
for i, val := range v {
|
||||
int64s[i] = int64(val)
|
||||
}
|
||||
return qdrant.NewMatchInts(field, int64s...)
|
||||
case []int64:
|
||||
return qdrant.NewMatchInts(field, v...)
|
||||
}
|
||||
return buildMatchCondition(field, q.Value)
|
||||
case QueryOperatorContainsAll:
|
||||
if values, ok := q.Value.([]interface{}); ok {
|
||||
var mustConditions []*qdrant.Condition
|
||||
for _, v := range values {
|
||||
cond := buildMatchCondition(field, v)
|
||||
if cond != nil {
|
||||
mustConditions = append(mustConditions, cond)
|
||||
}
|
||||
}
|
||||
if len(mustConditions) > 0 {
|
||||
return qdrant.NewFilterAsCondition(&qdrant.Filter{
|
||||
Must: mustConditions,
|
||||
})
|
||||
}
|
||||
}
|
||||
return buildMatchCondition(field, q.Value)
|
||||
case QueryOperatorLike:
|
||||
if str, ok := q.Value.(string); ok {
|
||||
return qdrant.NewMatchText(field, str)
|
||||
}
|
||||
return nil
|
||||
default:
|
||||
return buildMatchCondition(field, q.Value)
|
||||
}
|
||||
}
|
||||
|
||||
func buildMatchCondition(field string, value interface{}) *qdrant.Condition {
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
return qdrant.NewMatchKeyword(field, v)
|
||||
case int:
|
||||
return qdrant.NewMatchInt(field, int64(v))
|
||||
case int32:
|
||||
return qdrant.NewMatchInt(field, int64(v))
|
||||
case int64:
|
||||
return qdrant.NewMatchInt(field, v)
|
||||
case bool:
|
||||
return qdrant.NewMatchBool(field, v)
|
||||
default:
|
||||
return qdrant.NewMatchKeyword(field, fmt.Sprintf("%v", v))
|
||||
}
|
||||
}
|
||||
|
||||
func buildRangeCondition(field string, value interface{}, op string) *qdrant.Condition {
|
||||
var floatVal float64
|
||||
switch v := value.(type) {
|
||||
case int:
|
||||
floatVal = float64(v)
|
||||
case int32:
|
||||
floatVal = float64(v)
|
||||
case int64:
|
||||
floatVal = float64(v)
|
||||
case float32:
|
||||
floatVal = float64(v)
|
||||
case float64:
|
||||
floatVal = v
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
|
||||
r := &qdrant.Range{}
|
||||
switch op {
|
||||
case "gt":
|
||||
r.Gt = qdrant.PtrOf(floatVal)
|
||||
case "gte":
|
||||
r.Gte = qdrant.PtrOf(floatVal)
|
||||
case "lt":
|
||||
r.Lt = qdrant.PtrOf(floatVal)
|
||||
case "lte":
|
||||
r.Lte = qdrant.PtrOf(floatVal)
|
||||
}
|
||||
return qdrant.NewRange(field, r)
|
||||
}
|
||||
506
framework/vectorstore/qdrant_test.go
Normal file
506
framework/vectorstore/qdrant_test.go
Normal file
@@ -0,0 +1,506 @@
|
||||
package vectorstore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
bifrost "github.com/maximhq/bifrost/core"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const (
|
||||
QdrantTestTimeout = 30 * time.Second
|
||||
QdrantTestCollection = "bifrost-test-collection"
|
||||
QdrantTestDefaultHost = "localhost"
|
||||
QdrantTestDefaultPort = "6334"
|
||||
QdrantTestDimension = 384
|
||||
)
|
||||
|
||||
type QdrantTestSetup struct {
|
||||
Store *QdrantStore
|
||||
Logger schemas.Logger
|
||||
Config QdrantConfig
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
func NewQdrantTestSetup(t *testing.T) *QdrantTestSetup {
|
||||
host := schemas.NewEnvVar(getEnvWithDefault("QDRANT_HOST", QdrantTestDefaultHost))
|
||||
port := schemas.NewEnvVar(getEnvWithDefault("QDRANT_PORT", QdrantTestDefaultPort))
|
||||
apiKey := schemas.NewEnvVar(os.Getenv("QDRANT_API_KEY"))
|
||||
useTLS := schemas.NewEnvVar(os.Getenv("QDRANT_USE_TLS"))
|
||||
|
||||
config := QdrantConfig{
|
||||
Host: *host,
|
||||
Port: *port,
|
||||
APIKey: *apiKey,
|
||||
UseTLS: *useTLS,
|
||||
}
|
||||
|
||||
logger := bifrost.NewDefaultLogger(schemas.LogLevelInfo)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), QdrantTestTimeout)
|
||||
|
||||
store, err := newQdrantStore(ctx, &config, logger)
|
||||
if err != nil {
|
||||
cancel()
|
||||
t.Fatalf("Failed to create Qdrant store: %v", err)
|
||||
}
|
||||
|
||||
setup := &QdrantTestSetup{
|
||||
Store: store,
|
||||
Logger: logger,
|
||||
Config: config,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
setup.ensureCollectionExists(t)
|
||||
|
||||
return setup
|
||||
}
|
||||
|
||||
func (ts *QdrantTestSetup) Cleanup(t *testing.T) {
|
||||
defer ts.cancel()
|
||||
|
||||
if !testing.Short() {
|
||||
ts.cleanupTestData(t)
|
||||
}
|
||||
|
||||
if err := ts.Store.Close(ts.ctx, QdrantTestCollection); err != nil {
|
||||
t.Logf("Warning: Failed to close store: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (ts *QdrantTestSetup) ensureCollectionExists(t *testing.T) {
|
||||
properties := map[string]VectorStoreProperties{
|
||||
"key": {
|
||||
DataType: VectorStorePropertyTypeString,
|
||||
},
|
||||
"type": {
|
||||
DataType: VectorStorePropertyTypeString,
|
||||
},
|
||||
"test_type": {
|
||||
DataType: VectorStorePropertyTypeString,
|
||||
},
|
||||
"size": {
|
||||
DataType: VectorStorePropertyTypeInteger,
|
||||
},
|
||||
"public": {
|
||||
DataType: VectorStorePropertyTypeBoolean,
|
||||
},
|
||||
"author": {
|
||||
DataType: VectorStorePropertyTypeString,
|
||||
},
|
||||
"request_hash": {
|
||||
DataType: VectorStorePropertyTypeString,
|
||||
},
|
||||
"user": {
|
||||
DataType: VectorStorePropertyTypeString,
|
||||
},
|
||||
"lang": {
|
||||
DataType: VectorStorePropertyTypeString,
|
||||
},
|
||||
"category": {
|
||||
DataType: VectorStorePropertyTypeString,
|
||||
},
|
||||
"content": {
|
||||
DataType: VectorStorePropertyTypeString,
|
||||
},
|
||||
"response": {
|
||||
DataType: VectorStorePropertyTypeString,
|
||||
},
|
||||
}
|
||||
|
||||
err := ts.Store.CreateNamespace(ts.ctx, QdrantTestCollection, QdrantTestDimension, properties)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create collection %q: %v", QdrantTestCollection, err)
|
||||
}
|
||||
t.Logf("Created test collection: %s", QdrantTestCollection)
|
||||
}
|
||||
|
||||
func (ts *QdrantTestSetup) cleanupTestData(t *testing.T) {
|
||||
allTestKeys, _, err := ts.Store.GetAll(ts.ctx, QdrantTestCollection, []Query{}, []string{}, nil, 1000)
|
||||
if err != nil {
|
||||
t.Logf("Warning: Failed to get all test keys: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
for _, key := range allTestKeys {
|
||||
err := ts.Store.Delete(ts.ctx, QdrantTestCollection, key.ID)
|
||||
if err != nil {
|
||||
t.Logf("Warning: Failed to delete test key %s: %v", key.ID, err)
|
||||
}
|
||||
}
|
||||
|
||||
t.Logf("Cleaned up test collection: %s", QdrantTestCollection)
|
||||
}
|
||||
|
||||
func TestQdrantConfig_Validation(t *testing.T) {
|
||||
logger := bifrost.NewDefaultLogger(schemas.LogLevelInfo)
|
||||
ctx := context.Background()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
config QdrantConfig
|
||||
expectError bool
|
||||
errorMsg string
|
||||
}{
|
||||
{
|
||||
name: "valid config",
|
||||
config: QdrantConfig{
|
||||
Host: *schemas.NewEnvVar("localhost"),
|
||||
Port: *schemas.NewEnvVar("6334"),
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "missing host",
|
||||
config: QdrantConfig{
|
||||
Port: *schemas.NewEnvVar("6334"),
|
||||
},
|
||||
expectError: true,
|
||||
errorMsg: "qdrant host is required",
|
||||
},
|
||||
{
|
||||
name: "missing port uses default",
|
||||
config: QdrantConfig{
|
||||
Host: *schemas.NewEnvVar("localhost"),
|
||||
},
|
||||
expectError: false, // Port defaults to 6334 via CoerceInt fallback
|
||||
},
|
||||
{
|
||||
name: "with api key",
|
||||
config: QdrantConfig{
|
||||
Host: *schemas.NewEnvVar("cluster.qdrant.io"),
|
||||
Port: *schemas.NewEnvVar("6334"),
|
||||
APIKey: *schemas.NewEnvVar("test-key"),
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "with tls",
|
||||
config: QdrantConfig{
|
||||
Host: *schemas.NewEnvVar("localhost"),
|
||||
Port: *schemas.NewEnvVar("6334"),
|
||||
UseTLS: *schemas.NewEnvVar("true"),
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
store, err := newQdrantStore(ctx, &tt.config, logger)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, store)
|
||||
if tt.errorMsg != "" {
|
||||
assert.Contains(t, err.Error(), tt.errorMsg)
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
assert.Contains(t, err.Error(), "failed to connect")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParsePointID(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
id string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "valid UUID",
|
||||
id: "550e8400-e29b-41d4-a716-446655440000",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "invalid UUID",
|
||||
id: "not-a-uuid",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "empty string",
|
||||
id: "",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "numeric string",
|
||||
id: "12345",
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
pointID, err := parsePointID(tt.id)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, pointID)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, pointID)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildQdrantFilter(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
queries []Query
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "empty queries",
|
||||
queries: []Query{},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "single string query",
|
||||
queries: []Query{
|
||||
{Field: "category", Operator: QueryOperatorEqual, Value: "tech"},
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "single numeric query",
|
||||
queries: []Query{
|
||||
{Field: "size", Operator: QueryOperatorGreaterThan, Value: 1000},
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "multiple queries (AND)",
|
||||
queries: []Query{
|
||||
{Field: "category", Operator: QueryOperatorEqual, Value: "tech"},
|
||||
{Field: "public", Operator: QueryOperatorEqual, Value: true},
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "null checks",
|
||||
queries: []Query{
|
||||
{Field: "author", Operator: QueryOperatorIsNull, Value: nil},
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "not null checks",
|
||||
queries: []Query{
|
||||
{Field: "author", Operator: QueryOperatorIsNotNull, Value: nil},
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := buildQdrantFilter(tt.queries)
|
||||
|
||||
if tt.expected {
|
||||
assert.NotNil(t, result)
|
||||
} else {
|
||||
assert.Nil(t, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestQdrantStore_Integration(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping integration tests in short mode")
|
||||
}
|
||||
|
||||
setup := NewQdrantTestSetup(t)
|
||||
defer setup.Cleanup(t)
|
||||
|
||||
err := setup.Store.Ping(setup.ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
key := generateUUID()
|
||||
err = setup.Store.Add(setup.ctx, QdrantTestCollection, key, generateTestEmbedding(QdrantTestDimension), map[string]interface{}{"type": "document"})
|
||||
require.NoError(t, err)
|
||||
|
||||
result, err := setup.Store.GetChunk(setup.ctx, QdrantTestCollection, key)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "document", result.Properties["type"])
|
||||
|
||||
keys := []string{generateUUID(), generateUUID()}
|
||||
for i, k := range keys {
|
||||
err = setup.Store.Add(setup.ctx, QdrantTestCollection, k, generateTestEmbedding(QdrantTestDimension), map[string]interface{}{"type": i})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
results, err := setup.Store.GetChunks(setup.ctx, QdrantTestCollection, keys)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, results, 2)
|
||||
}
|
||||
|
||||
func TestQdrantStore_Filtering(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping integration tests in short mode")
|
||||
}
|
||||
|
||||
setup := NewQdrantTestSetup(t)
|
||||
defer setup.Cleanup(t)
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
metadata := map[string]interface{}{"type": "pdf", "public": true}
|
||||
if i == 1 {
|
||||
metadata["type"] = "docx"
|
||||
metadata["public"] = false
|
||||
}
|
||||
err := setup.Store.Add(setup.ctx, QdrantTestCollection, generateUUID(), generateTestEmbedding(QdrantTestDimension), metadata)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
queries := []Query{{Field: "type", Operator: QueryOperatorEqual, Value: "pdf"}}
|
||||
results, _, err := setup.Store.GetAll(setup.ctx, QdrantTestCollection, queries, []string{"type"}, nil, 10)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, results, 2)
|
||||
|
||||
multiQuery := []Query{
|
||||
{Field: "type", Operator: QueryOperatorEqual, Value: "pdf"},
|
||||
{Field: "public", Operator: QueryOperatorEqual, Value: true},
|
||||
}
|
||||
results, _, err = setup.Store.GetAll(setup.ctx, QdrantTestCollection, multiQuery, []string{"type"}, nil, 10)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, results, 2)
|
||||
}
|
||||
|
||||
func TestQdrantStore_VectorSearch(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping integration tests in short mode")
|
||||
}
|
||||
|
||||
setup := NewQdrantTestSetup(t)
|
||||
defer setup.Cleanup(t)
|
||||
|
||||
emb := generateTestEmbedding(QdrantTestDimension)
|
||||
err := setup.Store.Add(setup.ctx, QdrantTestCollection, generateUUID(), emb, map[string]interface{}{"type": "tech"})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = setup.Store.Add(setup.ctx, QdrantTestCollection, generateUUID(), generateTestEmbedding(QdrantTestDimension), map[string]interface{}{"type": "sports"})
|
||||
require.NoError(t, err)
|
||||
|
||||
results, err := setup.Store.GetNearest(setup.ctx, QdrantTestCollection, emb, nil, []string{"type"}, 0.1, 10)
|
||||
require.NoError(t, err)
|
||||
assert.GreaterOrEqual(t, len(results), 1)
|
||||
require.NotNil(t, results[0].Score)
|
||||
|
||||
queries := []Query{{Field: "type", Operator: QueryOperatorEqual, Value: "tech"}}
|
||||
results, err = setup.Store.GetNearest(setup.ctx, QdrantTestCollection, emb, queries, []string{"type"}, 0.1, 10)
|
||||
require.NoError(t, err)
|
||||
for _, result := range results {
|
||||
assert.Equal(t, "tech", result.Properties["type"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestQdrantStore_InterfaceCompliance(t *testing.T) {
|
||||
var _ VectorStore = (*QdrantStore)(nil)
|
||||
}
|
||||
|
||||
func TestVectorStoreFactory_Qdrant(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping integration tests in short mode")
|
||||
}
|
||||
|
||||
logger := bifrost.NewDefaultLogger(schemas.LogLevelInfo)
|
||||
|
||||
host := schemas.NewEnvVar(getEnvWithDefault("QDRANT_HOST", QdrantTestDefaultHost))
|
||||
port := schemas.NewEnvVar(getEnvWithDefault("QDRANT_PORT", QdrantTestDefaultPort))
|
||||
apiKey := schemas.NewEnvVar(os.Getenv("QDRANT_API_KEY"))
|
||||
|
||||
config := &Config{
|
||||
Enabled: true,
|
||||
Type: VectorStoreTypeQdrant,
|
||||
Config: QdrantConfig{
|
||||
Host: *host,
|
||||
Port: *port,
|
||||
APIKey: *apiKey,
|
||||
},
|
||||
}
|
||||
|
||||
store, err := NewVectorStore(context.Background(), config, logger)
|
||||
if err != nil {
|
||||
t.Skipf("Could not create Qdrant store: %v", err)
|
||||
}
|
||||
defer store.Close(context.Background(), QdrantTestCollection)
|
||||
|
||||
qdrantStore, ok := store.(*QdrantStore)
|
||||
assert.True(t, ok)
|
||||
assert.NotNil(t, qdrantStore)
|
||||
}
|
||||
|
||||
func TestQdrantStore_DimensionHandling(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping integration tests in short mode")
|
||||
}
|
||||
|
||||
setup := NewQdrantTestSetup(t)
|
||||
defer setup.Cleanup(t)
|
||||
|
||||
testCollection := "TestDim"
|
||||
props := map[string]VectorStoreProperties{"type": {DataType: VectorStorePropertyTypeString}}
|
||||
|
||||
err := setup.Store.CreateNamespace(setup.ctx, testCollection, 512, props)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = setup.Store.Add(setup.ctx, testCollection, generateUUID(), generateTestEmbedding(512), map[string]interface{}{"type": "test"})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = setup.Store.DeleteNamespace(setup.ctx, testCollection)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = setup.Store.CreateNamespace(setup.ctx, testCollection, QdrantTestDimension, props)
|
||||
require.NoError(t, err)
|
||||
|
||||
emb := generateTestEmbedding(QdrantTestDimension)
|
||||
err = setup.Store.Add(setup.ctx, testCollection, generateUUID(), emb, map[string]interface{}{"type": "test"})
|
||||
require.NoError(t, err)
|
||||
|
||||
results, err := setup.Store.GetNearest(setup.ctx, testCollection, emb, nil, []string{"type"}, 0.8, 10)
|
||||
require.NoError(t, err)
|
||||
assert.GreaterOrEqual(t, len(results), 1)
|
||||
|
||||
setup.Store.DeleteNamespace(setup.ctx, testCollection)
|
||||
}
|
||||
|
||||
func TestQdrantStore_ErrorHandling(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping integration tests in short mode")
|
||||
}
|
||||
|
||||
setup := NewQdrantTestSetup(t)
|
||||
defer setup.Cleanup(t)
|
||||
|
||||
_, err := setup.Store.GetChunk(setup.ctx, QdrantTestCollection, generateUUID())
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "not found")
|
||||
|
||||
err = setup.Store.Add(setup.ctx, QdrantTestCollection, "", generateTestEmbedding(QdrantTestDimension), map[string]interface{}{"type": "test"})
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "id is required")
|
||||
|
||||
err = setup.Store.Add(setup.ctx, QdrantTestCollection, "not-a-uuid", generateTestEmbedding(QdrantTestDimension), map[string]interface{}{"type": "test"})
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid id format")
|
||||
|
||||
err = setup.Store.Delete(setup.ctx, QdrantTestCollection, "")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "id is required")
|
||||
|
||||
err = setup.Store.Delete(setup.ctx, QdrantTestCollection, "not-a-uuid")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid id format")
|
||||
}
|
||||
1753
framework/vectorstore/redis.go
Normal file
1753
framework/vectorstore/redis.go
Normal file
File diff suppressed because it is too large
Load Diff
1817
framework/vectorstore/redis_test.go
Normal file
1817
framework/vectorstore/redis_test.go
Normal file
File diff suppressed because it is too large
Load Diff
238
framework/vectorstore/store.go
Normal file
238
framework/vectorstore/store.go
Normal file
@@ -0,0 +1,238 @@
|
||||
// Package vectorstore provides a generic interface for vector stores.
|
||||
package vectorstore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
type VectorStoreType string
|
||||
|
||||
const (
|
||||
VectorStoreTypeWeaviate VectorStoreType = "weaviate"
|
||||
VectorStoreTypeRedis VectorStoreType = "redis"
|
||||
VectorStoreTypeQdrant VectorStoreType = "qdrant"
|
||||
VectorStoreTypePinecone VectorStoreType = "pinecone"
|
||||
)
|
||||
|
||||
// Query represents a query to the vector store.
|
||||
type Query struct {
|
||||
Field string
|
||||
Operator QueryOperator
|
||||
Value interface{}
|
||||
}
|
||||
|
||||
type QueryOperator string
|
||||
|
||||
const (
|
||||
QueryOperatorEqual QueryOperator = "Equal"
|
||||
QueryOperatorNotEqual QueryOperator = "NotEqual"
|
||||
QueryOperatorGreaterThan QueryOperator = "GreaterThan"
|
||||
QueryOperatorLessThan QueryOperator = "LessThan"
|
||||
QueryOperatorGreaterThanOrEqual QueryOperator = "GreaterThanOrEqual"
|
||||
QueryOperatorLessThanOrEqual QueryOperator = "LessThanOrEqual"
|
||||
QueryOperatorLike QueryOperator = "Like"
|
||||
QueryOperatorContainsAny QueryOperator = "ContainsAny"
|
||||
QueryOperatorContainsAll QueryOperator = "ContainsAll"
|
||||
QueryOperatorIsNull QueryOperator = "IsNull"
|
||||
QueryOperatorIsNotNull QueryOperator = "IsNotNull"
|
||||
)
|
||||
|
||||
// SearchResult represents a search result with metadata.
|
||||
type SearchResult struct {
|
||||
ID string
|
||||
Score *float64
|
||||
Properties map[string]interface{}
|
||||
}
|
||||
|
||||
// DeleteResult represents the result of a delete operation.
|
||||
type DeleteResult struct {
|
||||
ID string
|
||||
Status DeleteStatus
|
||||
Error string
|
||||
}
|
||||
|
||||
type DeleteStatus string
|
||||
|
||||
const (
|
||||
DeleteStatusSuccess DeleteStatus = "success"
|
||||
DeleteStatusError DeleteStatus = "error"
|
||||
)
|
||||
|
||||
type VectorStoreProperties struct {
|
||||
DataType VectorStorePropertyType `json:"data_type"`
|
||||
Description string `json:"description"`
|
||||
}
|
||||
|
||||
type VectorStorePropertyType string
|
||||
|
||||
const (
|
||||
VectorStorePropertyTypeString VectorStorePropertyType = "string"
|
||||
VectorStorePropertyTypeInteger VectorStorePropertyType = "integer"
|
||||
VectorStorePropertyTypeBoolean VectorStorePropertyType = "boolean"
|
||||
VectorStorePropertyTypeStringArray VectorStorePropertyType = "string[]"
|
||||
)
|
||||
|
||||
type disableScanFallbackContextKey struct{}
|
||||
|
||||
// VectorStore represents the interface for the vector store.
|
||||
type VectorStore interface {
|
||||
// Health check
|
||||
Ping(ctx context.Context) error
|
||||
// CreateNamespace creates a new namespace in the vector store.
|
||||
CreateNamespace(ctx context.Context, namespace string, dimension int, properties map[string]VectorStoreProperties) error
|
||||
// DeleteNamespace deletes a namespace from the vector store.
|
||||
DeleteNamespace(ctx context.Context, namespace string) error
|
||||
// GetChunk retrieves a single vector from the vector store.
|
||||
GetChunk(ctx context.Context, namespace string, id string) (SearchResult, error)
|
||||
// GetChunks retrieves multiple vectors from the vector store.
|
||||
GetChunks(ctx context.Context, namespace string, ids []string) ([]SearchResult, error)
|
||||
// GetAll retrieves all vectors from the vector store.
|
||||
GetAll(ctx context.Context, namespace string, queries []Query, selectFields []string, cursor *string, limit int64) ([]SearchResult, *string, error)
|
||||
// GetNearest retrieves the nearest vectors from the vector store.
|
||||
GetNearest(ctx context.Context, namespace string, vector []float32, queries []Query, selectFields []string, threshold float64, limit int64) ([]SearchResult, error)
|
||||
// RequiresVectors returns true if the vector store requires vectors for all entries.
|
||||
// Dedicated vector databases like Qdrant and Pinecone require vectors, while
|
||||
// more flexible stores like Weaviate and Redis can store metadata-only entries.
|
||||
RequiresVectors() bool
|
||||
// Add stores a new vector in the vector store.
|
||||
Add(ctx context.Context, namespace string, id string, embedding []float32, metadata map[string]interface{}) error
|
||||
// Delete removes a vector from the vector store.
|
||||
Delete(ctx context.Context, namespace string, id string) error
|
||||
// DeleteAll deletes all vectors from the vector store.
|
||||
DeleteAll(ctx context.Context, namespace string, queries []Query) ([]DeleteResult, error)
|
||||
// Close closes the vector store.
|
||||
Close(ctx context.Context, namespace string) error
|
||||
}
|
||||
|
||||
// WithDisableScanFallback returns a derived context that tells vector stores not
|
||||
// to fall back to full scans when indexed search fails.
|
||||
func WithDisableScanFallback(ctx context.Context) context.Context {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
return context.WithValue(ctx, disableScanFallbackContextKey{}, true)
|
||||
}
|
||||
|
||||
// IsScanFallbackDisabled reports whether scan fallback has been disabled for
|
||||
// the current vector store operation.
|
||||
func IsScanFallbackDisabled(ctx context.Context) bool {
|
||||
if ctx == nil {
|
||||
return false
|
||||
}
|
||||
disabled, _ := ctx.Value(disableScanFallbackContextKey{}).(bool)
|
||||
return disabled
|
||||
}
|
||||
|
||||
// Config represents the configuration for the vector store.
|
||||
type Config struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
Type VectorStoreType `json:"type"`
|
||||
Config any `json:"config"`
|
||||
}
|
||||
|
||||
// UnmarshalJSON unmarshals the config from JSON.
|
||||
func (c *Config) UnmarshalJSON(data []byte) error {
|
||||
// First, unmarshal into a temporary struct to get the basic fields
|
||||
type TempConfig struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
Type string `json:"type"`
|
||||
Config json.RawMessage `json:"config"` // Keep as raw JSON
|
||||
}
|
||||
|
||||
var temp TempConfig
|
||||
if err := json.Unmarshal(data, &temp); err != nil {
|
||||
return fmt.Errorf("failed to unmarshal config: %w", err)
|
||||
}
|
||||
|
||||
// Set basic fields
|
||||
c.Enabled = temp.Enabled
|
||||
c.Type = VectorStoreType(temp.Type)
|
||||
|
||||
// Parse the config field based on type
|
||||
switch c.Type {
|
||||
case VectorStoreTypeWeaviate:
|
||||
var weaviateConfig WeaviateConfig
|
||||
if err := json.Unmarshal(temp.Config, &weaviateConfig); err != nil {
|
||||
return fmt.Errorf("failed to unmarshal weaviate config: %w", err)
|
||||
}
|
||||
c.Config = weaviateConfig
|
||||
case VectorStoreTypeRedis:
|
||||
var redisConfig RedisConfig
|
||||
if err := json.Unmarshal(temp.Config, &redisConfig); err != nil {
|
||||
return fmt.Errorf("failed to unmarshal redis config: %w", err)
|
||||
}
|
||||
// Process env. values for sensitive fields
|
||||
c.Config = redisConfig
|
||||
case VectorStoreTypeQdrant:
|
||||
var qdrantConfig QdrantConfig
|
||||
if err := json.Unmarshal(temp.Config, &qdrantConfig); err != nil {
|
||||
return fmt.Errorf("failed to unmarshal qdrant config: %w", err)
|
||||
}
|
||||
c.Config = qdrantConfig
|
||||
case VectorStoreTypePinecone:
|
||||
var pineconeConfig PineconeConfig
|
||||
if err := json.Unmarshal(temp.Config, &pineconeConfig); err != nil {
|
||||
return fmt.Errorf("failed to unmarshal pinecone config: %w", err)
|
||||
}
|
||||
c.Config = pineconeConfig
|
||||
default:
|
||||
return fmt.Errorf("unknown vector store type: %s", temp.Type)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// NewVectorStore returns a new vector store based on the configuration.
|
||||
func NewVectorStore(ctx context.Context, config *Config, logger schemas.Logger) (VectorStore, error) {
|
||||
if config == nil {
|
||||
return nil, fmt.Errorf("config cannot be nil")
|
||||
}
|
||||
|
||||
if !config.Enabled {
|
||||
return nil, fmt.Errorf("vector store is disabled")
|
||||
}
|
||||
|
||||
switch config.Type {
|
||||
case VectorStoreTypeWeaviate:
|
||||
if config.Config == nil {
|
||||
return nil, fmt.Errorf("weaviate config is required")
|
||||
}
|
||||
weaviateConfig, ok := config.Config.(WeaviateConfig)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid weaviate config")
|
||||
}
|
||||
return newWeaviateStore(ctx, &weaviateConfig, logger)
|
||||
case VectorStoreTypeRedis:
|
||||
if config.Config == nil {
|
||||
return nil, fmt.Errorf("redis config is required")
|
||||
}
|
||||
redisConfig, ok := config.Config.(RedisConfig)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid redis config")
|
||||
}
|
||||
return newRedisStore(ctx, redisConfig, logger)
|
||||
case VectorStoreTypeQdrant:
|
||||
if config.Config == nil {
|
||||
return nil, fmt.Errorf("qdrant config is required")
|
||||
}
|
||||
qdrantConfig, ok := config.Config.(QdrantConfig)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid qdrant config")
|
||||
}
|
||||
return newQdrantStore(ctx, &qdrantConfig, logger)
|
||||
case VectorStoreTypePinecone:
|
||||
if config.Config == nil {
|
||||
return nil, fmt.Errorf("pinecone config is required")
|
||||
}
|
||||
pineconeConfig, ok := config.Config.(PineconeConfig)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid pinecone config")
|
||||
}
|
||||
return newPineconeStore(ctx, &pineconeConfig, logger)
|
||||
}
|
||||
return nil, fmt.Errorf("invalid vector store type: %s", config.Type)
|
||||
}
|
||||
46
framework/vectorstore/test_utils.go
Normal file
46
framework/vectorstore/test_utils.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package vectorstore
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"os"
|
||||
"strconv"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// Helper functions
|
||||
func getEnvWithDefault(key, defaultValue string) string {
|
||||
if value := os.Getenv(key); value != "" {
|
||||
return value
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
func getEnvWithDefaultInt(key string, defaultValue int) (int, error) {
|
||||
if value := os.Getenv(key); value != "" {
|
||||
return strconv.Atoi(value)
|
||||
}
|
||||
return defaultValue, nil
|
||||
}
|
||||
|
||||
func generateUUID() string {
|
||||
return uuid.New().String()
|
||||
}
|
||||
|
||||
func generateTestEmbedding(dim int) []float32 {
|
||||
embedding := make([]float32, dim)
|
||||
for i := range embedding {
|
||||
embedding[i] = rand.Float32()*2 - 1 // Random values between -1 and 1
|
||||
}
|
||||
return embedding
|
||||
}
|
||||
|
||||
func generateSimilarEmbedding(original []float32, similarity float32) []float32 {
|
||||
similar := make([]float32, len(original))
|
||||
for i := range similar {
|
||||
// Add small random noise to create similar but not identical embedding
|
||||
noise := (rand.Float32()*2 - 1) * (1 - similarity) * 0.1
|
||||
similar[i] = original[i] + noise
|
||||
}
|
||||
return similar
|
||||
}
|
||||
15
framework/vectorstore/utils.go
Normal file
15
framework/vectorstore/utils.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package vectorstore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
// withTimeout adds a timeout to the context if it is set.
|
||||
func withTimeout(ctx context.Context, timeout time.Duration) (context.Context, context.CancelFunc) {
|
||||
if timeout > 0 {
|
||||
return context.WithTimeout(ctx, timeout)
|
||||
}
|
||||
// No-op cancel to simplify call sites.
|
||||
return ctx, func() {}
|
||||
}
|
||||
637
framework/vectorstore/weaviate.go
Normal file
637
framework/vectorstore/weaviate.go
Normal file
@@ -0,0 +1,637 @@
|
||||
package vectorstore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/weaviate/weaviate-go-client/v5/weaviate"
|
||||
"github.com/weaviate/weaviate-go-client/v5/weaviate/auth"
|
||||
"github.com/weaviate/weaviate-go-client/v5/weaviate/filters"
|
||||
"github.com/weaviate/weaviate-go-client/v5/weaviate/graphql"
|
||||
"github.com/weaviate/weaviate-go-client/v5/weaviate/grpc"
|
||||
"github.com/weaviate/weaviate/entities/models"
|
||||
)
|
||||
|
||||
// Default values for Weaviate vector index configuration
|
||||
const (
|
||||
// Default class names (Weaviate prefers PascalCase)
|
||||
DefaultClassName = "BifrostStore"
|
||||
)
|
||||
|
||||
// WeaviateConfig represents the configuration for the Weaviate vector store.
|
||||
type WeaviateConfig struct {
|
||||
// Connection settings
|
||||
Scheme string `json:"scheme"` // "http" or "https" - REQUIRED
|
||||
Host *schemas.EnvVar `json:"host"` // "localhost:8080" - REQUIRED
|
||||
GrpcConfig *WeaviateGrpcConfig `json:"grpc_config,omitempty"` // grpc config for weaviate (optional)
|
||||
|
||||
// Authentication settings (optional)
|
||||
APIKey *schemas.EnvVar `json:"api_key,omitempty"` // API key for authentication
|
||||
Headers map[string]string `json:"headers,omitempty"` // Additional headers
|
||||
|
||||
// Connection settings
|
||||
Timeout time.Duration `json:"timeout,omitempty"` // Request timeout (optional)
|
||||
}
|
||||
|
||||
type WeaviateGrpcConfig struct {
|
||||
// Host is the host of the weaviate server (host:port).
|
||||
// If host is without a port number then the 80 port for insecured and 443 port for secured connections will be used.
|
||||
Host *schemas.EnvVar `json:"host"`
|
||||
// Secured is a boolean flag indicating if the connection is secured
|
||||
Secured bool `json:"secured"`
|
||||
}
|
||||
|
||||
// WeaviateStore represents the Weaviate vector store.
|
||||
type WeaviateStore struct {
|
||||
client *weaviate.Client
|
||||
config *WeaviateConfig
|
||||
logger schemas.Logger
|
||||
}
|
||||
|
||||
// Ping checks if the Weaviate server is reachable.
|
||||
func (s *WeaviateStore) Ping(ctx context.Context) error {
|
||||
_, err := s.client.Misc().MetaGetter().Do(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
// Add stores a new object (with or without embedding)
|
||||
func (s *WeaviateStore) Add(ctx context.Context, className string, id string, embedding []float32, metadata map[string]interface{}) error {
|
||||
if strings.TrimSpace(id) == "" {
|
||||
return fmt.Errorf("id is required")
|
||||
}
|
||||
|
||||
obj := &models.Object{
|
||||
Class: className,
|
||||
Properties: metadata,
|
||||
}
|
||||
|
||||
var err error
|
||||
if len(embedding) > 0 {
|
||||
_, err = s.client.Data().Creator().
|
||||
WithClassName(className).
|
||||
WithID(id).
|
||||
WithProperties(obj.Properties).
|
||||
WithVector(embedding).
|
||||
Do(ctx)
|
||||
} else {
|
||||
_, err = s.client.Data().Creator().
|
||||
WithClassName(className).
|
||||
WithID(id).
|
||||
WithProperties(obj.Properties).
|
||||
Do(ctx)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// GetChunk returns the "metadata" for a single key
|
||||
func (s *WeaviateStore) GetChunk(ctx context.Context, className string, id string) (SearchResult, error) {
|
||||
obj, err := s.client.Data().ObjectsGetter().
|
||||
WithClassName(className).
|
||||
WithID(id).
|
||||
Do(ctx)
|
||||
if err != nil {
|
||||
return SearchResult{}, err
|
||||
}
|
||||
if len(obj) == 0 {
|
||||
return SearchResult{}, fmt.Errorf("not found: %s", id)
|
||||
}
|
||||
|
||||
props, ok := obj[0].Properties.(map[string]interface{})
|
||||
if !ok {
|
||||
return SearchResult{}, fmt.Errorf("invalid properties")
|
||||
}
|
||||
|
||||
return SearchResult{
|
||||
ID: id,
|
||||
Score: nil,
|
||||
Properties: props,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetChunks returns multiple objects by ID
|
||||
func (s *WeaviateStore) GetChunks(ctx context.Context, className string, ids []string) ([]SearchResult, error) {
|
||||
out := make([]SearchResult, 0, len(ids))
|
||||
for _, id := range ids {
|
||||
obj, err := s.client.Data().ObjectsGetter().
|
||||
WithClassName(className).
|
||||
WithID(id).
|
||||
Do(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(obj) > 0 {
|
||||
props, ok := obj[0].Properties.(map[string]interface{})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid properties")
|
||||
}
|
||||
out = append(out, SearchResult{
|
||||
ID: id,
|
||||
Score: nil,
|
||||
Properties: props,
|
||||
})
|
||||
}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// GetAll with filtering + pagination
|
||||
func (s *WeaviateStore) GetAll(ctx context.Context, className string, queries []Query, selectFields []string, cursor *string, limit int64) ([]SearchResult, *string, error) {
|
||||
where := buildWeaviateFilter(queries)
|
||||
|
||||
fields := []graphql.Field{
|
||||
{Name: "_additional", Fields: []graphql.Field{
|
||||
{Name: "id"},
|
||||
}},
|
||||
}
|
||||
for _, field := range selectFields {
|
||||
fields = append(fields, graphql.Field{Name: field})
|
||||
}
|
||||
|
||||
search := s.client.GraphQL().Get().
|
||||
WithClassName(className).
|
||||
WithLimit(int(limit)).
|
||||
WithFields(fields...)
|
||||
|
||||
if where != nil {
|
||||
search = search.WithWhere(where)
|
||||
}
|
||||
if cursor != nil {
|
||||
search = search.WithAfter(*cursor)
|
||||
}
|
||||
|
||||
resp, err := search.Do(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// Check for GraphQL errors
|
||||
if len(resp.Errors) > 0 {
|
||||
var errorMsgs []string
|
||||
for _, err := range resp.Errors {
|
||||
errorMsgs = append(errorMsgs, err.Message)
|
||||
}
|
||||
return nil, nil, fmt.Errorf("graphql errors: %v", errorMsgs)
|
||||
}
|
||||
|
||||
data, ok := resp.Data["Get"].(map[string]interface{})
|
||||
if !ok {
|
||||
return nil, nil, fmt.Errorf("invalid graphql response: missing 'Get' key, got: %+v", resp.Data)
|
||||
}
|
||||
|
||||
objsRaw, exists := data[className]
|
||||
if !exists {
|
||||
// No results for this class - this is normal, not an error
|
||||
s.logger.Debug(fmt.Sprintf("No results found for class '%s', available classes: %+v", className, data))
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
objs, ok := objsRaw.([]interface{})
|
||||
if !ok {
|
||||
s.logger.Debug(fmt.Sprintf("Class '%s' exists but data is not an array: %+v", className, objsRaw))
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
results := make([]SearchResult, 0, len(objs))
|
||||
var nextCursor *string
|
||||
for _, o := range objs {
|
||||
obj, ok := o.(map[string]interface{})
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
// Convert to SearchResult format for consistency
|
||||
searchResult := SearchResult{
|
||||
Properties: obj,
|
||||
}
|
||||
|
||||
if additional, ok := obj["_additional"].(map[string]interface{}); ok {
|
||||
if id, ok := additional["id"].(string); ok {
|
||||
searchResult.ID = id
|
||||
nextCursor = &id
|
||||
}
|
||||
}
|
||||
|
||||
results = append(results, searchResult)
|
||||
}
|
||||
|
||||
return results, nextCursor, nil
|
||||
}
|
||||
|
||||
// GetNearest with explicit filters only
|
||||
func (s *WeaviateStore) GetNearest(
|
||||
ctx context.Context,
|
||||
className string,
|
||||
vector []float32,
|
||||
queries []Query,
|
||||
selectFields []string,
|
||||
threshold float64,
|
||||
limit int64,
|
||||
) ([]SearchResult, error) {
|
||||
where := buildWeaviateFilter(queries)
|
||||
|
||||
fields := []graphql.Field{
|
||||
{Name: "_additional", Fields: []graphql.Field{
|
||||
{Name: "id"},
|
||||
{Name: "certainty"},
|
||||
}},
|
||||
}
|
||||
|
||||
for _, field := range selectFields {
|
||||
fields = append(fields, graphql.Field{Name: field})
|
||||
}
|
||||
|
||||
nearVector := s.client.GraphQL().NearVectorArgBuilder().
|
||||
WithVector(vector).
|
||||
WithCertainty(float32(threshold))
|
||||
|
||||
search := s.client.GraphQL().Get().
|
||||
WithClassName(className).
|
||||
WithNearVector(nearVector).
|
||||
WithLimit(int(limit)).
|
||||
WithFields(fields...)
|
||||
|
||||
if where != nil {
|
||||
search = search.WithWhere(where)
|
||||
}
|
||||
|
||||
resp, err := search.Do(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Check for GraphQL errors
|
||||
if len(resp.Errors) > 0 {
|
||||
var errorMsgs []string
|
||||
for _, err := range resp.Errors {
|
||||
errorMsgs = append(errorMsgs, err.Message)
|
||||
}
|
||||
return nil, fmt.Errorf("graphql errors: %v", errorMsgs)
|
||||
}
|
||||
|
||||
data, ok := resp.Data["Get"].(map[string]interface{})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid graphql response: missing 'Get' key, got: %+v", resp.Data)
|
||||
}
|
||||
|
||||
objsRaw, exists := data[className]
|
||||
if !exists {
|
||||
// No results for this class - this is normal, not an error
|
||||
s.logger.Debug(fmt.Sprintf("No results found for class '%s', available classes: %+v", className, data))
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
objs, ok := objsRaw.([]interface{})
|
||||
if !ok {
|
||||
s.logger.Debug(fmt.Sprintf("Class '%s' exists but data is not an array: %+v", className, objsRaw))
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
results := make([]SearchResult, 0, len(objs))
|
||||
for _, o := range objs {
|
||||
obj, ok := o.(map[string]interface{})
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
additional, ok := obj["_additional"].(map[string]interface{})
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
// Safely extract ID
|
||||
idRaw, exists := additional["id"]
|
||||
if !exists || idRaw == nil {
|
||||
continue
|
||||
}
|
||||
id, ok := idRaw.(string)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
// Safely extract certainty/score with default value
|
||||
var score float64
|
||||
if certaintyRaw, exists := additional["certainty"]; exists && certaintyRaw != nil {
|
||||
switch v := certaintyRaw.(type) {
|
||||
case float64:
|
||||
score = v
|
||||
case float32:
|
||||
score = float64(v)
|
||||
case int:
|
||||
score = float64(v)
|
||||
case int64:
|
||||
score = float64(v)
|
||||
default:
|
||||
score = 0.0 // Default score if type conversion fails
|
||||
}
|
||||
}
|
||||
|
||||
results = append(results, SearchResult{
|
||||
ID: id,
|
||||
Score: &score,
|
||||
Properties: obj,
|
||||
})
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// Delete removes multiple objects by ID
|
||||
func (s *WeaviateStore) Delete(ctx context.Context, className string, id string) error {
|
||||
return s.client.Data().Deleter().
|
||||
WithClassName(className).
|
||||
WithID(id).
|
||||
Do(ctx)
|
||||
}
|
||||
|
||||
func (s *WeaviateStore) DeleteAll(ctx context.Context, className string, queries []Query) ([]DeleteResult, error) {
|
||||
// Check if class exists first to avoid 500 errors from Weaviate
|
||||
exists, err := s.client.Schema().ClassExistenceChecker().
|
||||
WithClassName(className).
|
||||
Do(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to check class existence: %w", err)
|
||||
}
|
||||
if !exists {
|
||||
return []DeleteResult{}, nil // Class doesn't exist, nothing to delete
|
||||
}
|
||||
|
||||
where := buildWeaviateFilter(queries)
|
||||
|
||||
res, err := s.client.Batch().ObjectsBatchDeleter().
|
||||
WithClassName(className).
|
||||
WithWhere(where).
|
||||
Do(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// NOTE: Weaviate is returning an empty array for Results.Objects, even on successful deletes.
|
||||
results := make([]DeleteResult, 0, len(res.Results.Objects))
|
||||
|
||||
for _, obj := range res.Results.Objects {
|
||||
result := DeleteResult{
|
||||
ID: obj.ID.String(),
|
||||
}
|
||||
|
||||
if obj.Status != nil {
|
||||
switch *obj.Status {
|
||||
case "SUCCESS":
|
||||
result.Status = DeleteStatusSuccess
|
||||
case "FAILED":
|
||||
result.Status = DeleteStatusError
|
||||
|
||||
if obj.Errors != nil {
|
||||
var errorMsgs []string
|
||||
for _, err := range obj.Errors.Error {
|
||||
errorMsgs = append(errorMsgs, err.Message)
|
||||
}
|
||||
|
||||
result.Error = strings.Join(errorMsgs, ", ")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
results = append(results, result)
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func (s *WeaviateStore) Close(ctx context.Context, className string) error {
|
||||
// nothing to close
|
||||
return nil
|
||||
}
|
||||
|
||||
// RequiresVectors returns true because Weaviate's HNSW index
|
||||
// requires vectors for proper object indexing and retrieval.
|
||||
func (s *WeaviateStore) RequiresVectors() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// newWeaviateStore creates a new Weaviate vector store.
|
||||
func newWeaviateStore(ctx context.Context, config *WeaviateConfig, logger schemas.Logger) (*WeaviateStore, error) {
|
||||
// Validate required config
|
||||
if config.Scheme == "" || (config.Host == nil || config.Host.GetValue() == "") {
|
||||
return nil, fmt.Errorf("weaviate scheme and host are required")
|
||||
}
|
||||
// Build client configuration
|
||||
cfg := weaviate.Config{
|
||||
Scheme: config.Scheme,
|
||||
Host: config.Host.GetValue(),
|
||||
}
|
||||
|
||||
// Add authentication if provided
|
||||
if config.APIKey != nil && config.APIKey.GetValue() != "" {
|
||||
cfg.AuthConfig = auth.ApiKey{Value: config.APIKey.GetValue()}
|
||||
}
|
||||
|
||||
// Add grpc config if provided
|
||||
if config.GrpcConfig != nil {
|
||||
if config.GrpcConfig.Host == nil || config.GrpcConfig.Host.GetValue() == "" {
|
||||
return nil, fmt.Errorf("weaviate grpc host is required")
|
||||
}
|
||||
cfg.GrpcConfig = &grpc.Config{
|
||||
Host: config.GrpcConfig.Host.GetValue(),
|
||||
Secured: config.GrpcConfig.Secured,
|
||||
}
|
||||
}
|
||||
|
||||
// Add custom headers if provided
|
||||
if len(config.Headers) > 0 {
|
||||
cfg.Headers = config.Headers
|
||||
}
|
||||
|
||||
// Create client
|
||||
client, err := weaviate.NewClient(cfg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create weaviate client: %w", err)
|
||||
}
|
||||
|
||||
// Test connection with meta endpoint
|
||||
testCtx := ctx
|
||||
if config.Timeout > 0 {
|
||||
var cancel context.CancelFunc
|
||||
testCtx, cancel = context.WithTimeout(ctx, config.Timeout)
|
||||
defer cancel()
|
||||
}
|
||||
|
||||
_, err = client.Misc().MetaGetter().Do(testCtx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to connect to weaviate: %w", err)
|
||||
}
|
||||
|
||||
store := &WeaviateStore{
|
||||
client: client,
|
||||
config: config,
|
||||
logger: logger,
|
||||
}
|
||||
|
||||
return store, nil
|
||||
}
|
||||
|
||||
func (s *WeaviateStore) CreateNamespace(ctx context.Context, className string, dimension int, properties map[string]VectorStoreProperties) error {
|
||||
// Check if class exists
|
||||
exists, err := s.client.Schema().ClassExistenceChecker().
|
||||
WithClassName(className).
|
||||
Do(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check class existence: %w", err)
|
||||
}
|
||||
|
||||
if exists {
|
||||
return nil // Schema already exists
|
||||
}
|
||||
|
||||
// Create properties
|
||||
weaviateProperties := []*models.Property{}
|
||||
for name, prop := range properties {
|
||||
var dataType []string
|
||||
switch prop.DataType {
|
||||
case VectorStorePropertyTypeString:
|
||||
dataType = []string{"string"}
|
||||
case VectorStorePropertyTypeInteger:
|
||||
dataType = []string{"int"}
|
||||
case VectorStorePropertyTypeBoolean:
|
||||
dataType = []string{"boolean"}
|
||||
case VectorStorePropertyTypeStringArray:
|
||||
dataType = []string{"string[]"}
|
||||
}
|
||||
|
||||
weaviateProperties = append(weaviateProperties, &models.Property{
|
||||
Name: name,
|
||||
DataType: dataType,
|
||||
Description: prop.Description,
|
||||
})
|
||||
}
|
||||
|
||||
// Create class schema with all fields we need
|
||||
classSchema := &models.Class{
|
||||
Class: className,
|
||||
Properties: weaviateProperties,
|
||||
VectorIndexType: "hnsw",
|
||||
Vectorizer: "none", // We provide our own vectors
|
||||
}
|
||||
|
||||
if dimension > 0 {
|
||||
classSchema.VectorIndexConfig = map[string]interface{}{
|
||||
"vectorDimensions": dimension,
|
||||
}
|
||||
}
|
||||
|
||||
err = s.client.Schema().ClassCreator().
|
||||
WithClass(classSchema).
|
||||
Do(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create class schema: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *WeaviateStore) DeleteNamespace(ctx context.Context, className string) error {
|
||||
exists, err := s.client.Schema().ClassExistenceChecker().
|
||||
WithClassName(className).
|
||||
Do(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check class existence: %w", err)
|
||||
}
|
||||
if !exists {
|
||||
return nil // Schema already does not exist
|
||||
} else {
|
||||
return s.client.Schema().ClassDeleter().
|
||||
WithClassName(className).
|
||||
Do(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
// buildWeaviateFilter converts []Query → Weaviate WhereFilter
|
||||
func buildWeaviateFilter(queries []Query) *filters.WhereBuilder {
|
||||
if len(queries) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var operands []*filters.WhereBuilder
|
||||
for _, q := range queries {
|
||||
// Convert string operator to filters operator
|
||||
operator := convertOperator(q.Operator)
|
||||
|
||||
fieldPath := strings.Split(q.Field, ".")
|
||||
|
||||
whereClause := filters.Where().
|
||||
WithPath(fieldPath).
|
||||
WithOperator(operator)
|
||||
|
||||
// Special handling for IsNull and IsNotNull
|
||||
switch q.Operator {
|
||||
case QueryOperatorIsNull:
|
||||
whereClause = whereClause.WithValueBoolean(true)
|
||||
case QueryOperatorIsNotNull:
|
||||
whereClause = whereClause.WithValueBoolean(false)
|
||||
default:
|
||||
// Set value based on type
|
||||
switch v := q.Value.(type) {
|
||||
case string:
|
||||
whereClause = whereClause.WithValueString(v)
|
||||
case int:
|
||||
whereClause = whereClause.WithValueInt(int64(v))
|
||||
case int64:
|
||||
whereClause = whereClause.WithValueInt(v)
|
||||
case float32:
|
||||
whereClause = whereClause.WithValueNumber(float64(v))
|
||||
case float64:
|
||||
whereClause = whereClause.WithValueNumber(v)
|
||||
case bool:
|
||||
whereClause = whereClause.WithValueBoolean(v)
|
||||
default:
|
||||
// Fallback to string conversion
|
||||
whereClause = whereClause.WithValueString(fmt.Sprintf("%v", v))
|
||||
}
|
||||
}
|
||||
|
||||
operands = append(operands, whereClause)
|
||||
}
|
||||
|
||||
if len(operands) == 1 {
|
||||
return operands[0]
|
||||
}
|
||||
|
||||
// Create AND filter for multiple operands
|
||||
return filters.Where().
|
||||
WithOperator(filters.And).
|
||||
WithOperands(operands)
|
||||
}
|
||||
|
||||
// convertOperator converts string operator to filters operator
|
||||
func convertOperator(op QueryOperator) filters.WhereOperator {
|
||||
switch op {
|
||||
case QueryOperatorEqual:
|
||||
return filters.Equal
|
||||
case QueryOperatorNotEqual:
|
||||
return filters.NotEqual
|
||||
case QueryOperatorLessThan:
|
||||
return filters.LessThan
|
||||
case QueryOperatorLessThanOrEqual:
|
||||
return filters.LessThanEqual
|
||||
case QueryOperatorGreaterThan:
|
||||
return filters.GreaterThan
|
||||
case QueryOperatorGreaterThanOrEqual:
|
||||
return filters.GreaterThanEqual
|
||||
case QueryOperatorLike:
|
||||
return filters.Like
|
||||
case QueryOperatorContainsAny:
|
||||
return filters.ContainsAny
|
||||
case QueryOperatorContainsAll:
|
||||
return filters.ContainsAll
|
||||
case QueryOperatorIsNull:
|
||||
return filters.IsNull
|
||||
case QueryOperatorIsNotNull: // IsNotNull is not supported by Weaviate, so we use IsNull and negate it.
|
||||
return filters.IsNull
|
||||
default:
|
||||
// Default to Equal if unknown
|
||||
return filters.Equal
|
||||
}
|
||||
}
|
||||
812
framework/vectorstore/weaviate_test.go
Normal file
812
framework/vectorstore/weaviate_test.go
Normal file
@@ -0,0 +1,812 @@
|
||||
package vectorstore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
bifrost "github.com/maximhq/bifrost/core"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/weaviate/weaviate-go-client/v5/weaviate/filters"
|
||||
"github.com/weaviate/weaviate/entities/models"
|
||||
)
|
||||
|
||||
// Test constants
|
||||
const (
|
||||
TestTimeout = 30 * time.Second
|
||||
TestClassName = "TestWeaviate"
|
||||
TestEmbeddingDim = 384
|
||||
DefaultTestScheme = "http"
|
||||
DefaultTestHost = "localhost:9000"
|
||||
DefaultTestTimeout = 10 * time.Second
|
||||
)
|
||||
|
||||
// TestSetup provides common test infrastructure
|
||||
type TestSetup struct {
|
||||
Store *WeaviateStore
|
||||
Logger schemas.Logger
|
||||
Config WeaviateConfig
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
// NewTestSetup creates a test setup with environment-driven configuration
|
||||
func NewTestSetup(t *testing.T) *TestSetup {
|
||||
// Get configuration from environment variables
|
||||
scheme := getEnvWithDefault("WEAVIATE_SCHEME", DefaultTestScheme)
|
||||
host := schemas.NewEnvVar(getEnvWithDefault("WEAVIATE_HOST", DefaultTestHost))
|
||||
|
||||
timeoutStr := getEnvWithDefault("WEAVIATE_TIMEOUT", "10s")
|
||||
timeout, err := time.ParseDuration(timeoutStr)
|
||||
if err != nil {
|
||||
timeout = DefaultTestTimeout
|
||||
}
|
||||
|
||||
config := WeaviateConfig{
|
||||
Scheme: scheme,
|
||||
Host: host,
|
||||
APIKey: schemas.NewEnvVar("env.WEAVIATE_API_KEY"),
|
||||
Timeout: timeout,
|
||||
}
|
||||
|
||||
logger := bifrost.NewDefaultLogger(schemas.LogLevelInfo)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), TestTimeout)
|
||||
|
||||
store, err := newWeaviateStore(ctx, &config, logger)
|
||||
if err != nil {
|
||||
cancel()
|
||||
t.Fatalf("Failed to create Weaviate store: %v", err)
|
||||
}
|
||||
|
||||
setup := &TestSetup{
|
||||
Store: store,
|
||||
Logger: logger,
|
||||
Config: config,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
// Ensure class exists for integration tests
|
||||
if !testing.Short() {
|
||||
setup.ensureClassExists(t)
|
||||
}
|
||||
|
||||
return setup
|
||||
}
|
||||
|
||||
// Cleanup cleans up test resources
|
||||
func (ts *TestSetup) Cleanup(t *testing.T) {
|
||||
defer ts.cancel()
|
||||
|
||||
if !testing.Short() {
|
||||
// Clean up test data
|
||||
ts.cleanupTestData(t)
|
||||
}
|
||||
|
||||
if err := ts.Store.Close(ts.ctx, TestClassName); err != nil {
|
||||
t.Logf("Warning: Failed to close store: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ensureClassExists creates the test class in Weaviate
|
||||
func (ts *TestSetup) ensureClassExists(t *testing.T) {
|
||||
// Try to get class schema first
|
||||
exists, err := ts.Store.client.Schema().ClassGetter().
|
||||
WithClassName(TestClassName).
|
||||
Do(ts.ctx)
|
||||
|
||||
if err == nil && exists != nil {
|
||||
t.Logf("Class %s already exists", TestClassName)
|
||||
return
|
||||
}
|
||||
|
||||
// Create class with minimal schema - let Weaviate auto-create properties
|
||||
class := &models.Class{
|
||||
Class: TestClassName,
|
||||
Properties: []*models.Property{
|
||||
{
|
||||
Name: "key",
|
||||
DataType: []string{"text"},
|
||||
},
|
||||
{
|
||||
Name: "test_type",
|
||||
DataType: []string{"text"},
|
||||
},
|
||||
{
|
||||
Name: "size",
|
||||
DataType: []string{"int"},
|
||||
},
|
||||
{
|
||||
Name: "public",
|
||||
DataType: []string{"boolean"},
|
||||
},
|
||||
},
|
||||
VectorIndexConfig: map[string]interface{}{
|
||||
"distance": "cosine",
|
||||
},
|
||||
}
|
||||
|
||||
err = ts.Store.client.Schema().ClassCreator().
|
||||
WithClass(class).
|
||||
Do(ts.ctx)
|
||||
|
||||
if err != nil {
|
||||
t.Logf("Warning: Failed to create test class %s: %v", TestClassName, err)
|
||||
t.Logf("This might be due to auto-schema creation. Continuing...")
|
||||
} else {
|
||||
t.Logf("Created test class: %s", TestClassName)
|
||||
}
|
||||
}
|
||||
|
||||
// cleanupTestData removes all test objects from the class
|
||||
func (ts *TestSetup) cleanupTestData(t *testing.T) {
|
||||
// Delete all objects in the test class
|
||||
allTestKeys, _, err := ts.Store.GetAll(ts.ctx, TestClassName, []Query{}, []string{}, nil, 1000)
|
||||
if err != nil {
|
||||
t.Logf("Warning: Failed to get all test keys: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
for _, key := range allTestKeys {
|
||||
err := ts.Store.Delete(ts.ctx, TestClassName, key.ID)
|
||||
if err != nil {
|
||||
t.Logf("Warning: Failed to delete test key %s: %v", key.ID, err)
|
||||
}
|
||||
}
|
||||
|
||||
t.Logf("Cleaned up test class: %s", TestClassName)
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// UNIT TESTS
|
||||
// ============================================================================
|
||||
|
||||
func TestWeaviateConfig_Validation(t *testing.T) {
|
||||
logger := bifrost.NewDefaultLogger(schemas.LogLevelInfo)
|
||||
ctx := context.Background()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
config WeaviateConfig
|
||||
expectError bool
|
||||
errorMsg string
|
||||
}{
|
||||
{
|
||||
name: "valid config",
|
||||
config: WeaviateConfig{
|
||||
Scheme: "http",
|
||||
Host: schemas.NewEnvVar("localhost:8080"),
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "missing scheme",
|
||||
config: WeaviateConfig{
|
||||
Host: schemas.NewEnvVar("localhost:8080"),
|
||||
},
|
||||
expectError: true,
|
||||
errorMsg: "scheme and host are required",
|
||||
},
|
||||
{
|
||||
name: "missing host",
|
||||
config: WeaviateConfig{
|
||||
Scheme: "http",
|
||||
},
|
||||
expectError: true,
|
||||
errorMsg: "scheme and host are required",
|
||||
},
|
||||
{
|
||||
name: "with api key",
|
||||
config: WeaviateConfig{
|
||||
Scheme: "https",
|
||||
Host: schemas.NewEnvVar("cluster.weaviate.network"),
|
||||
APIKey: schemas.NewEnvVar("test-key"),
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "with custom headers",
|
||||
config: WeaviateConfig{
|
||||
Scheme: "http",
|
||||
Host: schemas.NewEnvVar("localhost:8080"),
|
||||
Headers: map[string]string{
|
||||
"Custom-Header": "value",
|
||||
},
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
store, err := newWeaviateStore(ctx, &tt.config, logger)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, store)
|
||||
if tt.errorMsg != "" {
|
||||
assert.Contains(t, err.Error(), tt.errorMsg)
|
||||
}
|
||||
} else {
|
||||
// Note: This will fail with connection error in unit tests
|
||||
// but should pass config validation
|
||||
assert.Nil(t, store) // Expected due to no real Weaviate instance
|
||||
assert.Error(t, err) // Connection error expected
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultClassName(t *testing.T) {
|
||||
config := WeaviateConfig{
|
||||
Scheme: "http",
|
||||
Host: schemas.NewEnvVar("localhost:8080"),
|
||||
}
|
||||
|
||||
// This will fail to connect but should set default class name
|
||||
logger := bifrost.NewDefaultLogger(schemas.LogLevelInfo)
|
||||
_, err := newWeaviateStore(context.Background(), &config, logger)
|
||||
|
||||
// Should fail with connection error, but we can't test the default class name
|
||||
// without mocking the client, which would be more complex
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestBuildWeaviateFilter(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
queries []Query
|
||||
expected *filters.WhereBuilder // We'll test the structure, not exact equality
|
||||
isNil bool
|
||||
}{
|
||||
{
|
||||
name: "empty queries",
|
||||
queries: []Query{},
|
||||
expected: nil,
|
||||
isNil: true,
|
||||
},
|
||||
{
|
||||
name: "single string query",
|
||||
queries: []Query{
|
||||
{Field: "category", Operator: QueryOperatorEqual, Value: "tech"},
|
||||
},
|
||||
isNil: false,
|
||||
},
|
||||
{
|
||||
name: "single numeric query",
|
||||
queries: []Query{
|
||||
{Field: "size", Operator: QueryOperatorGreaterThan, Value: 1000},
|
||||
},
|
||||
isNil: false,
|
||||
},
|
||||
{
|
||||
name: "multiple queries (AND)",
|
||||
queries: []Query{
|
||||
{Field: "category", Operator: QueryOperatorEqual, Value: "tech"},
|
||||
{Field: "public", Operator: QueryOperatorEqual, Value: true},
|
||||
},
|
||||
isNil: false,
|
||||
},
|
||||
{
|
||||
name: "mixed types",
|
||||
queries: []Query{
|
||||
{Field: "name", Operator: QueryOperatorLike, Value: "test%"},
|
||||
{Field: "count", Operator: QueryOperatorLessThan, Value: int64(100)},
|
||||
{Field: "active", Operator: QueryOperatorEqual, Value: true},
|
||||
{Field: "score", Operator: QueryOperatorGreaterThanOrEqual, Value: 95.5},
|
||||
},
|
||||
isNil: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := buildWeaviateFilter(tt.queries)
|
||||
|
||||
if tt.isNil {
|
||||
assert.Nil(t, result)
|
||||
} else {
|
||||
assert.NotNil(t, result)
|
||||
// We can't easily test the internal structure without reflection
|
||||
// or implementing String() methods, but we verify it's not nil
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertOperator(t *testing.T) {
|
||||
tests := []struct {
|
||||
input QueryOperator
|
||||
expected filters.WhereOperator
|
||||
}{
|
||||
{QueryOperatorEqual, filters.Equal},
|
||||
{QueryOperatorNotEqual, filters.NotEqual},
|
||||
{QueryOperatorLessThan, filters.LessThan},
|
||||
{QueryOperatorLessThanOrEqual, filters.LessThanEqual},
|
||||
{QueryOperatorGreaterThan, filters.GreaterThan},
|
||||
{QueryOperatorGreaterThanOrEqual, filters.GreaterThanEqual},
|
||||
{QueryOperatorLike, filters.Like},
|
||||
{QueryOperatorContainsAny, filters.ContainsAny},
|
||||
{QueryOperatorContainsAll, filters.ContainsAll},
|
||||
{QueryOperatorIsNull, filters.IsNull},
|
||||
{QueryOperatorIsNotNull, filters.IsNull},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(string(tt.input), func(t *testing.T) {
|
||||
result := convertOperator(tt.input)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// INTEGRATION TESTS (require real Weaviate instance)
|
||||
// ============================================================================
|
||||
|
||||
func TestWeaviateStore_Integration(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping integration tests in short mode")
|
||||
}
|
||||
|
||||
setup := NewTestSetup(t)
|
||||
defer setup.Cleanup(t)
|
||||
|
||||
t.Run("Add and GetChunk", func(t *testing.T) {
|
||||
testKey := generateUUID()
|
||||
embedding := generateTestEmbedding(TestEmbeddingDim)
|
||||
metadata := map[string]interface{}{
|
||||
"type": "document",
|
||||
"size": 1024,
|
||||
"public": true,
|
||||
}
|
||||
|
||||
// Add object
|
||||
err := setup.Store.Add(setup.ctx, TestClassName, testKey, embedding, metadata)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Small delay to ensure consistency
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Get single chunk
|
||||
result, err := setup.Store.GetChunk(setup.ctx, TestClassName, testKey)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, result)
|
||||
assert.Equal(t, "document", result.Properties["type"]) // Should contain metadata
|
||||
})
|
||||
|
||||
t.Run("Add without embedding", func(t *testing.T) {
|
||||
testKey := generateUUID()
|
||||
metadata := map[string]interface{}{
|
||||
"type": "metadata-only",
|
||||
}
|
||||
|
||||
// Add object without embedding
|
||||
err := setup.Store.Add(setup.ctx, TestClassName, testKey, nil, metadata)
|
||||
require.NoError(t, err)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Retrieve it
|
||||
result, err := setup.Store.GetChunk(setup.ctx, TestClassName, testKey)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "metadata-only", result.Properties["type"])
|
||||
})
|
||||
}
|
||||
|
||||
func TestWeaviateStore_FilteringScenarios(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping integration tests in short mode")
|
||||
}
|
||||
|
||||
setup := NewTestSetup(t)
|
||||
defer setup.Cleanup(t)
|
||||
|
||||
// Setup test data for filtering scenarios
|
||||
testData := []struct {
|
||||
key string
|
||||
metadata map[string]interface{}
|
||||
}{
|
||||
{
|
||||
generateUUID(),
|
||||
map[string]interface{}{
|
||||
"type": "pdf",
|
||||
"size": 1024,
|
||||
"public": true,
|
||||
"author": "alice",
|
||||
},
|
||||
},
|
||||
{
|
||||
generateUUID(),
|
||||
map[string]interface{}{
|
||||
"type": "docx",
|
||||
"size": 2048,
|
||||
"public": false,
|
||||
"author": "bob",
|
||||
},
|
||||
},
|
||||
{
|
||||
generateUUID(),
|
||||
map[string]interface{}{
|
||||
"type": "pdf",
|
||||
"size": 512,
|
||||
"public": true,
|
||||
"author": "alice",
|
||||
},
|
||||
},
|
||||
{
|
||||
generateUUID(),
|
||||
map[string]interface{}{
|
||||
"type": "txt",
|
||||
"size": 256,
|
||||
"public": true,
|
||||
"author": "charlie",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
filterFields := []string{"type", "size", "public", "author"}
|
||||
|
||||
// Add all test data
|
||||
for _, item := range testData {
|
||||
embedding := generateTestEmbedding(TestEmbeddingDim)
|
||||
err := setup.Store.Add(setup.ctx, TestClassName, item.key, embedding, item.metadata)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
time.Sleep(500 * time.Millisecond) // Wait for consistency
|
||||
|
||||
t.Run("Filter by numeric comparison", func(t *testing.T) {
|
||||
queries := []Query{
|
||||
{Field: "size", Operator: "GreaterThan", Value: 1000},
|
||||
}
|
||||
|
||||
results, _, err := setup.Store.GetAll(setup.ctx, TestClassName, queries, filterFields, nil, 10)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, results, 2) // doc1 (1024) and doc2 (2048)
|
||||
})
|
||||
|
||||
t.Run("Filter by boolean", func(t *testing.T) {
|
||||
queries := []Query{
|
||||
{Field: "public", Operator: "Equal", Value: true},
|
||||
}
|
||||
|
||||
results, _, err := setup.Store.GetAll(setup.ctx, TestClassName, queries, filterFields, nil, 10)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, results, 3) // doc1, doc3, doc4
|
||||
})
|
||||
|
||||
t.Run("Multiple filters (AND)", func(t *testing.T) {
|
||||
queries := []Query{
|
||||
{Field: "type", Operator: "Equal", Value: "pdf"},
|
||||
{Field: "public", Operator: "Equal", Value: true},
|
||||
}
|
||||
|
||||
results, _, err := setup.Store.GetAll(setup.ctx, TestClassName, queries, filterFields, nil, 10)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, results, 2) // doc1 and doc3
|
||||
})
|
||||
|
||||
t.Run("Complex multi-condition filter", func(t *testing.T) {
|
||||
queries := []Query{
|
||||
{Field: "author", Operator: "Equal", Value: "alice"},
|
||||
{Field: "size", Operator: "LessThan", Value: 2000},
|
||||
{Field: "public", Operator: "Equal", Value: true},
|
||||
}
|
||||
|
||||
results, _, err := setup.Store.GetAll(setup.ctx, TestClassName, queries, filterFields, nil, 10)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, results, 2) // doc1 and doc3 (both by alice, < 2000 size, public)
|
||||
})
|
||||
|
||||
t.Run("Pagination test", func(t *testing.T) {
|
||||
// Test with limit of 2
|
||||
results, cursor, err := setup.Store.GetAll(setup.ctx, TestClassName, nil, filterFields, nil, 2)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, results, 2)
|
||||
|
||||
if cursor != nil {
|
||||
// Get next page
|
||||
nextResults, _, err := setup.Store.GetAll(setup.ctx, TestClassName, nil, filterFields, cursor, 2)
|
||||
require.NoError(t, err)
|
||||
assert.LessOrEqual(t, len(nextResults), 2)
|
||||
t.Logf("First page: %d results, Next page: %d results", len(results), len(nextResults))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestWeaviateStore_CompleteUseCases(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping integration tests in short mode")
|
||||
}
|
||||
|
||||
setup := NewTestSetup(t)
|
||||
defer setup.Cleanup(t)
|
||||
|
||||
t.Run("Document Storage & Retrieval Scenario", func(t *testing.T) {
|
||||
// Add documents with different types
|
||||
documents := []struct {
|
||||
key string
|
||||
embedding []float32
|
||||
metadata map[string]interface{}
|
||||
}{
|
||||
{
|
||||
generateUUID(),
|
||||
generateTestEmbedding(TestEmbeddingDim),
|
||||
map[string]interface{}{"type": "pdf", "size": 1024, "public": true},
|
||||
},
|
||||
{
|
||||
generateUUID(),
|
||||
generateTestEmbedding(TestEmbeddingDim),
|
||||
map[string]interface{}{"type": "docx", "size": 2048, "public": false},
|
||||
},
|
||||
{
|
||||
generateUUID(),
|
||||
generateTestEmbedding(TestEmbeddingDim),
|
||||
map[string]interface{}{"type": "pdf", "size": 512, "public": true},
|
||||
},
|
||||
}
|
||||
|
||||
filterFields := []string{"type", "size", "public", "author"}
|
||||
|
||||
for _, doc := range documents {
|
||||
err := setup.Store.Add(setup.ctx, TestClassName, doc.key, doc.embedding, doc.metadata)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
time.Sleep(300 * time.Millisecond)
|
||||
|
||||
// Test various retrieval patterns
|
||||
|
||||
// Get PDF documents
|
||||
pdfQuery := []Query{{Field: "type", Operator: "Equal", Value: "pdf"}}
|
||||
results, _, err := setup.Store.GetAll(setup.ctx, TestClassName, pdfQuery, filterFields, nil, 10)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, results, 2) // doc1, doc3
|
||||
|
||||
// Get large documents (size > 1000)
|
||||
sizeQuery := []Query{{Field: "size", Operator: "GreaterThan", Value: 1000}}
|
||||
results, _, err = setup.Store.GetAll(setup.ctx, TestClassName, sizeQuery, filterFields, nil, 10)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, results, 2) // doc1, doc2
|
||||
|
||||
// Get public PDFs
|
||||
combinedQuery := []Query{
|
||||
{Field: "public", Operator: "Equal", Value: true},
|
||||
{Field: "type", Operator: "Equal", Value: "pdf"},
|
||||
}
|
||||
results, _, err = setup.Store.GetAll(setup.ctx, TestClassName, combinedQuery, filterFields, nil, 10)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, results, 2) // doc1, doc3
|
||||
|
||||
// Vector similarity search
|
||||
queryEmbedding := documents[0].embedding // Similar to doc1
|
||||
vectorResults, err := setup.Store.GetNearest(setup.ctx, TestClassName, queryEmbedding, nil, filterFields, 0.8, 10)
|
||||
require.NoError(t, err)
|
||||
assert.GreaterOrEqual(t, len(vectorResults), 1)
|
||||
})
|
||||
|
||||
t.Run("User Content Management Scenario", func(t *testing.T) {
|
||||
// Add user content with metadata
|
||||
userContent := []struct {
|
||||
key string
|
||||
embedding []float32
|
||||
metadata map[string]interface{}
|
||||
}{
|
||||
{
|
||||
generateUUID(),
|
||||
generateTestEmbedding(TestEmbeddingDim),
|
||||
map[string]interface{}{"user": "alice", "lang": "en", "category": "tech"},
|
||||
},
|
||||
{
|
||||
generateUUID(),
|
||||
generateTestEmbedding(TestEmbeddingDim),
|
||||
map[string]interface{}{"user": "bob", "lang": "es", "category": "tech"},
|
||||
},
|
||||
{
|
||||
generateUUID(),
|
||||
generateTestEmbedding(TestEmbeddingDim),
|
||||
map[string]interface{}{"user": "alice", "lang": "en", "category": "sports"},
|
||||
},
|
||||
}
|
||||
|
||||
filterFields := []string{"user", "lang", "category"}
|
||||
|
||||
for _, content := range userContent {
|
||||
err := setup.Store.Add(setup.ctx, TestClassName, content.key, content.embedding, content.metadata)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
time.Sleep(300 * time.Millisecond)
|
||||
|
||||
// Test user-specific filtering
|
||||
aliceQuery := []Query{{Field: "user", Operator: "Equal", Value: "alice"}}
|
||||
results, _, err := setup.Store.GetAll(setup.ctx, TestClassName, aliceQuery, filterFields, nil, 10)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, results, 2) // Alice's content
|
||||
|
||||
// English tech content
|
||||
techEnQuery := []Query{
|
||||
{Field: "lang", Operator: "Equal", Value: "en"},
|
||||
{Field: "category", Operator: "Equal", Value: "tech"},
|
||||
}
|
||||
results, _, err = setup.Store.GetAll(setup.ctx, TestClassName, techEnQuery, filterFields, nil, 10)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, results, 1) // user1_content
|
||||
|
||||
// Alice's similar content (semantic search with user filter)
|
||||
aliceFilter := []Query{{Field: "user", Operator: "Equal", Value: "alice"}}
|
||||
queryEmbedding := userContent[0].embedding
|
||||
vectorResults, err := setup.Store.GetNearest(setup.ctx, TestClassName, queryEmbedding, aliceFilter, filterFields, 0.1, 10)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, vectorResults, 2) // Both of Alice's content
|
||||
})
|
||||
|
||||
t.Run("Semantic Cache-like Workflow", func(t *testing.T) {
|
||||
// Add request-response pairs with parameters
|
||||
cacheEntries := []struct {
|
||||
key string
|
||||
embedding []float32
|
||||
metadata map[string]interface{}
|
||||
}{
|
||||
{
|
||||
generateUUID(),
|
||||
generateTestEmbedding(TestEmbeddingDim),
|
||||
map[string]interface{}{
|
||||
"request_hash": "abc123",
|
||||
"user": "u1",
|
||||
"lang": "en",
|
||||
"response": "answer1",
|
||||
},
|
||||
},
|
||||
{
|
||||
generateUUID(),
|
||||
generateTestEmbedding(TestEmbeddingDim),
|
||||
map[string]interface{}{
|
||||
"request_hash": "def456",
|
||||
"user": "u1",
|
||||
"lang": "es",
|
||||
"response": "answer2",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
filterFields := []string{"request_hash", "user", "lang", "response"}
|
||||
|
||||
for _, entry := range cacheEntries {
|
||||
err := setup.Store.Add(setup.ctx, TestClassName, entry.key, entry.embedding, entry.metadata)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
time.Sleep(300 * time.Millisecond)
|
||||
|
||||
// Test hash-based direct retrieval (exact match)
|
||||
hashQuery := []Query{{Field: "request_hash", Operator: "Equal", Value: "abc123"}}
|
||||
results, _, err := setup.Store.GetAll(setup.ctx, TestClassName, hashQuery, filterFields, nil, 10)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, results, 1)
|
||||
|
||||
// Test semantic search with user and language filters
|
||||
userLangFilter := []Query{
|
||||
{Field: "user", Operator: "Equal", Value: "u1"},
|
||||
{Field: "lang", Operator: "Equal", Value: "en"},
|
||||
}
|
||||
similarEmbedding := generateSimilarEmbedding(cacheEntries[0].embedding, 0.9)
|
||||
vectorResults, err := setup.Store.GetNearest(setup.ctx, TestClassName, similarEmbedding, userLangFilter, filterFields, 0.7, 10)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, vectorResults, 1) // Should find English content for u1
|
||||
})
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// INTERFACE COMPLIANCE TESTS
|
||||
// ============================================================================
|
||||
|
||||
func TestWeaviateStore_InterfaceCompliance(t *testing.T) {
|
||||
// Verify that WeaviateStore implements VectorStore interface
|
||||
var _ VectorStore = (*WeaviateStore)(nil)
|
||||
}
|
||||
|
||||
func TestVectorStoreFactory_Weaviate(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping integration tests in short mode")
|
||||
}
|
||||
|
||||
logger := bifrost.NewDefaultLogger(schemas.LogLevelInfo)
|
||||
config := &Config{
|
||||
Enabled: true,
|
||||
Type: VectorStoreTypeWeaviate,
|
||||
Config: WeaviateConfig{
|
||||
Scheme: getEnvWithDefault("WEAVIATE_SCHEME", DefaultTestScheme),
|
||||
Host: schemas.NewEnvVar(getEnvWithDefault("WEAVIATE_HOST", DefaultTestHost)),
|
||||
APIKey: schemas.NewEnvVar("env.WEAVIATE_API_KEY"),
|
||||
},
|
||||
}
|
||||
|
||||
store, err := NewVectorStore(context.Background(), config, logger)
|
||||
if err != nil {
|
||||
t.Skipf("Could not create Weaviate store: %v", err)
|
||||
}
|
||||
defer store.Close(context.Background(), TestClassName)
|
||||
|
||||
// Verify it's actually a WeaviateStore
|
||||
weaviateStore, ok := store.(*WeaviateStore)
|
||||
assert.True(t, ok)
|
||||
assert.NotNil(t, weaviateStore)
|
||||
}
|
||||
|
||||
func TestWeaviateStore_NamespaceDimensionHandling(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping integration tests in short mode")
|
||||
}
|
||||
|
||||
setup := NewTestSetup(t)
|
||||
defer setup.Cleanup(t)
|
||||
|
||||
testClassName := "TestDimensionHandling"
|
||||
|
||||
t.Run("Recreate class with different dimension should not crash", func(t *testing.T) {
|
||||
properties := map[string]VectorStoreProperties{
|
||||
"type": {DataType: VectorStorePropertyTypeString},
|
||||
"test": {DataType: VectorStorePropertyTypeString},
|
||||
}
|
||||
|
||||
// Step 1: Create class with dimension 512
|
||||
err := setup.Store.CreateNamespace(setup.ctx, testClassName, 512, properties)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Add a document with 512-dimensional embedding
|
||||
testKey512 := generateUUID()
|
||||
embedding512 := generateTestEmbedding(512)
|
||||
metadata := map[string]interface{}{
|
||||
"type": "test_doc",
|
||||
"test": "dimension_512",
|
||||
}
|
||||
|
||||
err = setup.Store.Add(setup.ctx, testClassName, testKey512, embedding512, metadata)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify it was added
|
||||
result, err := setup.Store.GetChunk(setup.ctx, testClassName, testKey512)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "dimension_512", result.Properties["test"])
|
||||
|
||||
// Step 2: Delete the class/namespace
|
||||
err = setup.Store.DeleteNamespace(setup.ctx, testClassName)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Step 3: Create class with same name but different dimension - should not crash
|
||||
err = setup.Store.CreateNamespace(setup.ctx, testClassName, 1024, properties)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Add a document with 1024-dimensional embedding
|
||||
testKey1024 := generateUUID()
|
||||
embedding1024 := generateTestEmbedding(1024)
|
||||
metadata1024 := map[string]interface{}{
|
||||
"type": "test_doc",
|
||||
"test": "dimension_1024",
|
||||
}
|
||||
|
||||
err = setup.Store.Add(setup.ctx, testClassName, testKey1024, embedding1024, metadata1024)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify new document exists
|
||||
result, err = setup.Store.GetChunk(setup.ctx, testClassName, testKey1024)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "dimension_1024", result.Properties["test"])
|
||||
|
||||
// Verify vector search works with new dimension
|
||||
vectorResults, err := setup.Store.GetNearest(setup.ctx, testClassName, embedding1024, nil, []string{"type", "test"}, 0.8, 10)
|
||||
require.NoError(t, err)
|
||||
assert.GreaterOrEqual(t, len(vectorResults), 1)
|
||||
assert.NotNil(t, vectorResults[0].Score)
|
||||
|
||||
// Cleanup
|
||||
err = setup.Store.DeleteNamespace(setup.ctx, testClassName)
|
||||
if err != nil {
|
||||
t.Logf("Warning: Failed to cleanup class: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
Reference in New Issue
Block a user