first commit

This commit is contained in:
Beyhan Oğur
2026-04-26 21:52:23 +03:00
commit 880f412e2c
2662 changed files with 866266 additions and 0 deletions

View File

@@ -0,0 +1,9 @@
package vectorstore
import "errors"
var (
ErrNotFound = errors.New("vectorstore: not found")
ErrNotSupported = errors.New("vectorstore: operation not supported on this store")
ErrQuerySyntax = errors.New("vectorstore: query syntax error")
)

View File

@@ -0,0 +1,649 @@
package vectorstore
import (
"context"
"fmt"
"strings"
"sync"
"github.com/maximhq/bifrost/core/schemas"
"github.com/pinecone-io/go-pinecone/v5/pinecone"
"google.golang.org/protobuf/types/known/structpb"
)
// PineconeConfig represents the configuration for the Pinecone vector store.
type PineconeConfig struct {
APIKey schemas.EnvVar `json:"api_key"` // Pinecone API key - REQUIRED
IndexHost schemas.EnvVar `json:"index_host"` // Index host URL from Pinecone console - REQUIRED
}
// PineconeStore represents the Pinecone vector store.
type PineconeStore struct {
client *pinecone.Client
indexConn *pinecone.IndexConnection
config *PineconeConfig
logger schemas.Logger
mu sync.RWMutex // Protects namespaces and dimension
namespaces map[string]*pinecone.IndexConnection
dimension int // Store dimension for zero vector queries in GetAll
}
// Ping checks if the Pinecone server is reachable.
func (s *PineconeStore) Ping(ctx context.Context) error {
_, err := s.indexConn.DescribeIndexStats(ctx)
return err
}
// CreateNamespace creates a new namespace in the Pinecone vector store.
// Note: Pinecone namespaces are created implicitly when upserting vectors.
// This method is a no-op but ensures the connection is valid.
func (s *PineconeStore) CreateNamespace(ctx context.Context, namespace string, dimension int, properties map[string]VectorStoreProperties) error {
// Store dimension for use in GetAll (zero vector queries)
s.mu.Lock()
s.dimension = dimension
s.mu.Unlock()
// Pinecone namespaces are created automatically on first upsert.
// We just verify the index connection is valid.
_, err := s.indexConn.DescribeIndexStats(ctx)
if err != nil {
return fmt.Errorf("failed to verify index connection: %w", err)
}
return nil
}
// DeleteNamespace deletes a namespace from the Pinecone vector store.
func (s *PineconeStore) DeleteNamespace(ctx context.Context, namespace string) error {
idxConn, err := s.getNamespaceConnection(namespace)
if err != nil {
return err
}
return idxConn.DeleteAllVectorsInNamespace(ctx)
}
// GetChunk retrieves a single vector from the Pinecone vector store.
func (s *PineconeStore) GetChunk(ctx context.Context, namespace string, id string) (SearchResult, error) {
if strings.TrimSpace(id) == "" {
return SearchResult{}, fmt.Errorf("id is required")
}
idxConn, err := s.getNamespaceConnection(namespace)
if err != nil {
return SearchResult{}, err
}
res, err := idxConn.FetchVectors(ctx, []string{id})
if err != nil {
return SearchResult{}, fmt.Errorf("failed to fetch vector: %w", err)
}
if len(res.Vectors) == 0 {
return SearchResult{}, fmt.Errorf("not found: %s", id)
}
vec, exists := res.Vectors[id]
if !exists || vec == nil {
return SearchResult{}, fmt.Errorf("not found: %s", id)
}
return SearchResult{
ID: id,
Properties: metadataToMap(vec.Metadata),
}, nil
}
// GetChunks retrieves multiple vectors from the Pinecone vector store.
func (s *PineconeStore) GetChunks(ctx context.Context, namespace string, ids []string) ([]SearchResult, error) {
if len(ids) == 0 {
return []SearchResult{}, nil
}
// Filter out empty IDs
validIDs := make([]string, 0, len(ids))
for _, id := range ids {
if strings.TrimSpace(id) != "" {
validIDs = append(validIDs, id)
}
}
if len(validIDs) == 0 {
return []SearchResult{}, nil
}
idxConn, err := s.getNamespaceConnection(namespace)
if err != nil {
return nil, err
}
res, err := idxConn.FetchVectors(ctx, validIDs)
if err != nil {
return nil, fmt.Errorf("failed to fetch vectors: %w", err)
}
results := make([]SearchResult, 0, len(res.Vectors))
for id, vec := range res.Vectors {
if vec != nil {
results = append(results, SearchResult{
ID: id,
Properties: metadataToMap(vec.Metadata),
})
}
}
return results, nil
}
// GetAll retrieves all vectors with optional filtering and pagination.
// Note: This implementation uses QueryByVectorValues with a zero vector instead of ListVectors
// because ListVectors has severe eventual consistency issues on Pinecone Serverless/Starter indexes.
// The metadata filtering is done server-side by Pinecone, providing much better consistency.
func (s *PineconeStore) GetAll(ctx context.Context, namespace string, queries []Query, selectFields []string, cursor *string, limit int64) ([]SearchResult, *string, error) {
idxConn, err := s.getNamespaceConnection(namespace)
if err != nil {
return nil, nil, err
}
topK := uint32(limit)
if limit <= 0 {
topK = 100
}
// Create zero vector for query - this allows us to use QueryByVectorValues
// which has much better consistency than ListVectors
s.mu.RLock()
dim := s.dimension
s.mu.RUnlock()
if dim <= 0 {
return nil, nil, fmt.Errorf("dimension not set: CreateNamespace must be called before GetAll to set the vector dimension")
}
zeroVector := make([]float32, dim)
queryReq := &pinecone.QueryByVectorValuesRequest{
Vector: zeroVector,
TopK: topK,
IncludeValues: false,
IncludeMetadata: true,
}
// Build metadata filter from queries - filtering is done server-side
if len(queries) > 0 {
filter, filterErr := buildPineconeFilter(queries)
if filterErr != nil {
s.logger.Warn("failed to build pinecone filter, queries may not be applied: %v", filterErr)
}
if filter != nil {
queryReq.MetadataFilter = filter
}
}
res, err := idxConn.QueryByVectorValues(ctx, queryReq)
if err != nil {
return nil, nil, fmt.Errorf("failed to query vectors: %w", err)
}
results := make([]SearchResult, 0, len(res.Matches))
for _, match := range res.Matches {
if match.Vector == nil {
continue
}
props := metadataToMap(match.Vector.Metadata)
filteredProps := filterPropertiesPinecone(props, selectFields)
results = append(results, SearchResult{
ID: match.Vector.Id,
Properties: filteredProps,
})
}
// Note: QueryByVectorValues doesn't support pagination tokens like ListVectors
// For direct hash lookup (the main use case), we only need 1 result anyway
return results, nil, nil
}
// GetNearest retrieves the nearest vectors to a given vector.
func (s *PineconeStore) GetNearest(ctx context.Context, namespace string, vector []float32, queries []Query, selectFields []string, threshold float64, limit int64) ([]SearchResult, error) {
idxConn, err := s.getNamespaceConnection(namespace)
if err != nil {
return nil, err
}
topK := uint32(limit)
if limit <= 0 {
topK = 10
}
queryReq := &pinecone.QueryByVectorValuesRequest{
Vector: vector,
TopK: topK,
IncludeValues: false,
IncludeMetadata: true,
}
// Build metadata filter from queries
if len(queries) > 0 {
filter, err := buildPineconeFilter(queries)
if err != nil {
s.logger.Debug(fmt.Sprintf("failed to build pinecone filter: %v", err))
} else if filter != nil {
queryReq.MetadataFilter = filter
}
}
res, err := idxConn.QueryByVectorValues(ctx, queryReq)
if err != nil {
return nil, fmt.Errorf("failed to query vectors: %w", err)
}
results := make([]SearchResult, 0, len(res.Matches))
for _, match := range res.Matches {
if match.Vector == nil {
continue
}
score := float64(match.Score)
// Apply threshold filter
if score < threshold {
continue
}
props := metadataToMap(match.Vector.Metadata)
filteredProps := filterPropertiesPinecone(props, selectFields)
results = append(results, SearchResult{
ID: match.Vector.Id,
Score: &score,
Properties: filteredProps,
})
}
return results, nil
}
// convertMetadataForStructpb converts metadata map to be compatible with structpb.NewStruct.
// Specifically, it converts []string to []interface{} since structpb doesn't handle []string directly.
func convertMetadataForStructpb(metadata map[string]interface{}) map[string]interface{} {
if metadata == nil {
return nil
}
converted := make(map[string]interface{}, len(metadata))
for k, v := range metadata {
switch val := v.(type) {
case []string:
// Convert []string to []interface{}
interfaceSlice := make([]interface{}, len(val))
for i, s := range val {
interfaceSlice[i] = s
}
converted[k] = interfaceSlice
default:
converted[k] = v
}
}
return converted
}
// Add stores a new vector in the Pinecone vector store.
func (s *PineconeStore) Add(ctx context.Context, namespace string, id string, embedding []float32, metadata map[string]interface{}) error {
if strings.TrimSpace(id) == "" {
return fmt.Errorf("id is required")
}
idxConn, err := s.getNamespaceConnection(namespace)
if err != nil {
return err
}
// Convert metadata to structpb (handle []string -> []interface{} conversion)
var pbMetadata *structpb.Struct
if len(metadata) > 0 {
convertedMetadata := convertMetadataForStructpb(metadata)
pbMetadata, err = structpb.NewStruct(convertedMetadata)
if err != nil {
return fmt.Errorf("failed to convert metadata: %w", err)
}
}
vec := &pinecone.Vector{
Id: id,
Metadata: pbMetadata,
}
if len(embedding) > 0 {
vec.Values = &embedding
}
_, err = idxConn.UpsertVectors(ctx, []*pinecone.Vector{vec})
if err != nil {
return fmt.Errorf("failed to upsert vector: %w", err)
}
return nil
}
// Delete removes a vector from the Pinecone vector store.
func (s *PineconeStore) Delete(ctx context.Context, namespace string, id string) error {
if strings.TrimSpace(id) == "" {
return fmt.Errorf("id is required")
}
idxConn, err := s.getNamespaceConnection(namespace)
if err != nil {
return err
}
return idxConn.DeleteVectorsById(ctx, []string{id})
}
// DeleteAll removes multiple vectors matching the filter.
func (s *PineconeStore) DeleteAll(ctx context.Context, namespace string, queries []Query) ([]DeleteResult, error) {
idxConn, err := s.getNamespaceConnection(namespace)
if err != nil {
return nil, err
}
// If we have queries, use filter-based deletion
if len(queries) > 0 {
filter, err := buildPineconeFilter(queries)
if err != nil {
return nil, fmt.Errorf("failed to build filter: %w", err)
}
if filter != nil {
err = idxConn.DeleteVectorsByFilter(ctx, filter)
if err != nil {
return nil, fmt.Errorf("failed to delete vectors by filter: %w", err)
}
// Pinecone doesn't return individual results for filter-based deletion
return []DeleteResult{}, nil
}
}
// If no queries, list and delete all vectors in the namespace
listRes, err := idxConn.ListVectors(ctx, &pinecone.ListVectorsRequest{})
if err != nil {
return nil, fmt.Errorf("failed to list vectors: %w", err)
}
if len(listRes.VectorIds) == 0 {
return []DeleteResult{}, nil
}
// Convert []*string to []string
deleteIDs := make([]string, 0, len(listRes.VectorIds))
for _, id := range listRes.VectorIds {
if id != nil {
deleteIDs = append(deleteIDs, *id)
}
}
results := make([]DeleteResult, len(deleteIDs))
for i, id := range deleteIDs {
results[i] = DeleteResult{
ID: id,
Status: DeleteStatusSuccess,
}
}
err = idxConn.DeleteVectorsById(ctx, deleteIDs)
if err != nil {
for i := range results {
results[i].Status = DeleteStatusError
results[i].Error = err.Error()
}
}
return results, nil
}
// Close closes the Pinecone client connection.
// If namespace is non-empty, only that namespace connection is closed.
// If namespace is empty, all connections (indexConn and all namespaces) are closed.
func (s *PineconeStore) Close(ctx context.Context, namespace string) error {
s.mu.Lock()
defer s.mu.Unlock()
// If a specific namespace is provided, close only that connection
if namespace != "" {
if conn, exists := s.namespaces[namespace]; exists && conn != nil {
conn.Close()
delete(s.namespaces, namespace)
}
return nil
}
// Close all connections when namespace is empty
var errs []error
// Close the main index connection
if s.indexConn != nil {
s.indexConn.Close()
s.indexConn = nil
}
// Close and remove all namespace connections
for ns, conn := range s.namespaces {
if conn != nil {
conn.Close()
}
delete(s.namespaces, ns)
}
// Return aggregated errors if any occurred
if len(errs) > 0 {
return fmt.Errorf("errors closing connections: %v", errs)
}
return nil
}
// RequiresVectors returns true because Pinecone is a dedicated vector database
// that requires vectors for all entries with a specific dimension.
func (s *PineconeStore) RequiresVectors() bool {
return true
}
// newPineconeStore creates a new Pinecone vector store.
func newPineconeStore(ctx context.Context, config *PineconeConfig, logger schemas.Logger) (*PineconeStore, error) {
if strings.TrimSpace(config.APIKey.GetValue()) == "" {
return nil, fmt.Errorf("pinecone api_key is required")
}
if strings.TrimSpace(config.IndexHost.GetValue()) == "" {
return nil, fmt.Errorf("pinecone index_host is required")
}
// Creating new client
client, err := pinecone.NewClient(pinecone.NewClientParams{
ApiKey: config.APIKey.GetValue(),
})
if err != nil {
return nil, fmt.Errorf("failed to create pinecone client: %w", err)
}
// Prepare the host URL
// For local connections (Pinecone Local), prefix with http:// to disable TLS
// See: https://docs.pinecone.io/guides/operations/local-development
host := config.IndexHost.GetValue()
if !strings.HasPrefix(host, "http://") && !strings.HasPrefix(host, "https://") {
// Check if this looks like a local connection
if strings.HasPrefix(host, "localhost") || strings.HasPrefix(host, "127.0.0.1") {
host = "http://" + host
}
}
// Create index connection
idxConn, err := client.Index(pinecone.NewIndexConnParams{
Host: host,
})
if err != nil {
return nil, fmt.Errorf("failed to create index connection: %w", err)
}
// Verify connection by getting index stats
_, err = idxConn.DescribeIndexStats(ctx)
if err != nil {
return nil, fmt.Errorf("failed to connect to pinecone index: %w", err)
}
return &PineconeStore{
client: client,
indexConn: idxConn,
config: config,
logger: logger,
namespaces: make(map[string]*pinecone.IndexConnection),
}, nil
}
// getHostWithScheme returns the host with the appropriate scheme.
// For local connections (localhost/127.0.0.1), it adds http:// to disable TLS.
func (s *PineconeStore) getHostWithScheme() string {
host := s.config.IndexHost.GetValue()
if !strings.HasPrefix(host, "http://") && !strings.HasPrefix(host, "https://") {
if strings.HasPrefix(host, "localhost") || strings.HasPrefix(host, "127.0.0.1") {
return "http://" + host
}
}
return host
}
// getNamespaceConnection returns or creates a connection for the given namespace.
func (s *PineconeStore) getNamespaceConnection(namespace string) (*pinecone.IndexConnection, error) {
if namespace == "" {
return s.indexConn, nil
}
// Check if we already have a connection for this namespace (optimistic read)
s.mu.RLock()
if conn, exists := s.namespaces[namespace]; exists {
s.mu.RUnlock()
return conn, nil
}
s.mu.RUnlock()
// Acquire write lock to create new connection
s.mu.Lock()
defer s.mu.Unlock()
// Double-check after acquiring write lock (another goroutine may have created it)
if conn, exists := s.namespaces[namespace]; exists {
return conn, nil
}
// Create a new connection for this namespace
conn, err := s.client.Index(pinecone.NewIndexConnParams{
Host: s.getHostWithScheme(),
Namespace: namespace,
})
if err != nil {
return nil, fmt.Errorf("failed to create namespace connection: %w", err)
}
s.namespaces[namespace] = conn
return conn, nil
}
// metadataToMap converts protobuf Struct to map[string]interface{}.
func metadataToMap(metadata *structpb.Struct) map[string]interface{} {
if metadata == nil {
return make(map[string]interface{})
}
return metadata.AsMap()
}
// filterPropertiesPinecone filters properties based on selected fields.
func filterPropertiesPinecone(props map[string]interface{}, selectFields []string) map[string]interface{} {
if len(selectFields) == 0 {
return props
}
filtered := make(map[string]interface{}, len(selectFields))
for _, field := range selectFields {
if val, ok := props[field]; ok {
filtered[field] = val
}
}
return filtered
}
// matchesQueries checks if properties match all query conditions.
func matchesQueries(props map[string]interface{}, queries []Query) bool {
if len(queries) == 0 {
return true
}
for _, q := range queries {
val, exists := props[q.Field]
if !matchesQuery(val, exists, q) {
return false
}
}
return true
}
// matchesQuery checks if a single value matches a query condition.
func matchesQuery(val interface{}, exists bool, q Query) bool {
switch q.Operator {
case QueryOperatorIsNull:
return !exists || val == nil
case QueryOperatorIsNotNull:
return exists && val != nil
case QueryOperatorEqual:
return exists && fmt.Sprintf("%v", val) == fmt.Sprintf("%v", q.Value)
case QueryOperatorNotEqual:
return !exists || fmt.Sprintf("%v", val) != fmt.Sprintf("%v", q.Value)
default:
// For complex operators, default to true (filter at query time)
return true
}
}
// buildPineconeFilter converts queries to Pinecone metadata filter.
func buildPineconeFilter(queries []Query) (*structpb.Struct, error) {
if len(queries) == 0 {
return nil, nil
}
filterMap := make(map[string]interface{})
for _, q := range queries {
condition := buildPineconeCondition(q)
if condition != nil {
filterMap[q.Field] = condition
}
}
if len(filterMap) == 0 {
return nil, nil
}
return structpb.NewStruct(filterMap)
}
// buildPineconeCondition builds a single Pinecone filter condition.
func buildPineconeCondition(q Query) interface{} {
switch q.Operator {
case QueryOperatorEqual:
return map[string]interface{}{"$eq": q.Value}
case QueryOperatorNotEqual:
return map[string]interface{}{"$ne": q.Value}
case QueryOperatorGreaterThan:
return map[string]interface{}{"$gt": q.Value}
case QueryOperatorGreaterThanOrEqual:
return map[string]interface{}{"$gte": q.Value}
case QueryOperatorLessThan:
return map[string]interface{}{"$lt": q.Value}
case QueryOperatorLessThanOrEqual:
return map[string]interface{}{"$lte": q.Value}
case QueryOperatorIsNull:
return map[string]interface{}{"$eq": nil}
case QueryOperatorIsNotNull:
return map[string]interface{}{"$ne": nil}
case QueryOperatorContainsAny:
return map[string]interface{}{"$in": q.Value}
case QueryOperatorContainsAll:
// Build an $and array of equality checks so all values must match
values, ok := q.Value.([]interface{})
if !ok {
// Try to convert []string to []interface{}
if strValues, ok := q.Value.([]string); ok {
values = make([]interface{}, len(strValues))
for i, v := range strValues {
values[i] = v
}
} else {
// Fallback to single value equality
return map[string]interface{}{"$eq": q.Value}
}
}
andConditions := make([]interface{}, len(values))
for i, v := range values {
andConditions[i] = map[string]interface{}{"$eq": v}
}
return map[string]interface{}{"$and": andConditions}
default:
return map[string]interface{}{"$eq": q.Value}
}
}

View File

@@ -0,0 +1,611 @@
package vectorstore
import (
"context"
"testing"
"time"
bifrost "github.com/maximhq/bifrost/core"
"github.com/maximhq/bifrost/core/schemas"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
const (
PineconeTestTimeout = 30 * time.Second
PineconeTestNamespace = "bifrost-test-namespace"
PineconeTestDimension = 1536 // Matches text-embedding-3-small dimension
PineconeTestDefaultAPIKey = "pclocal" // Pinecone Local doesn't validate API keys
PineconeTestDefaultIndexHost = "localhost:5081" // Pinecone Local default port
)
type PineconeTestSetup struct {
Store *PineconeStore
Logger schemas.Logger
Config PineconeConfig
ctx context.Context
cancel context.CancelFunc
}
func NewPineconeTestSetup(t *testing.T) *PineconeTestSetup {
apiKey := schemas.NewEnvVar(getEnvWithDefault("PINECONE_API_KEY", PineconeTestDefaultAPIKey))
indexHost := schemas.NewEnvVar(getEnvWithDefault("PINECONE_INDEX_HOST", PineconeTestDefaultIndexHost))
config := PineconeConfig{
APIKey: *apiKey,
IndexHost: *indexHost,
}
logger := bifrost.NewDefaultLogger(schemas.LogLevelInfo)
ctx, cancel := context.WithTimeout(context.Background(), PineconeTestTimeout)
store, err := newPineconeStore(ctx, &config, logger)
if err != nil {
cancel()
t.Fatalf("Failed to create Pinecone store: %v", err)
}
setup := &PineconeTestSetup{
Store: store,
Logger: logger,
Config: config,
ctx: ctx,
cancel: cancel,
}
return setup
}
func (ts *PineconeTestSetup) Cleanup(t *testing.T) {
defer ts.cancel()
if !testing.Short() {
ts.cleanupTestData(t)
}
if err := ts.Store.Close(ts.ctx, PineconeTestNamespace); err != nil {
t.Logf("Warning: Failed to close store: %v", err)
}
}
func (ts *PineconeTestSetup) cleanupTestData(t *testing.T) {
// Delete all vectors in the test namespace
err := ts.Store.DeleteNamespace(ts.ctx, PineconeTestNamespace)
if err != nil {
t.Logf("Warning: Failed to cleanup test namespace: %v", err)
}
t.Logf("Cleaned up test namespace: %s", PineconeTestNamespace)
}
// ============================================================================
// UNIT TESTS
// ============================================================================
func TestPineconeConfig_Validation(t *testing.T) {
logger := bifrost.NewDefaultLogger(schemas.LogLevelInfo)
ctx := context.Background()
tests := []struct {
name string
config PineconeConfig
expectError bool
errorMsg string
}{
{
name: "missing api key",
config: PineconeConfig{
IndexHost: *schemas.NewEnvVar("https://my-index.svc.environment.pinecone.io"),
},
expectError: true,
errorMsg: "pinecone api_key is required",
},
{
name: "missing index host",
config: PineconeConfig{
APIKey: *schemas.NewEnvVar("test-api-key"),
},
expectError: true,
errorMsg: "pinecone index_host is required",
},
{
name: "empty api key",
config: PineconeConfig{
APIKey: *schemas.NewEnvVar(""),
IndexHost: *schemas.NewEnvVar("https://my-index.svc.environment.pinecone.io"),
},
expectError: true,
errorMsg: "pinecone api_key is required",
},
{
name: "empty index host",
config: PineconeConfig{
APIKey: *schemas.NewEnvVar("test-api-key"),
IndexHost: *schemas.NewEnvVar(""),
},
expectError: true,
errorMsg: "pinecone index_host is required",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
store, err := newPineconeStore(ctx, &tt.config, logger)
if tt.expectError {
assert.Error(t, err)
assert.Nil(t, store)
if tt.errorMsg != "" {
assert.Contains(t, err.Error(), tt.errorMsg)
}
} else {
// Note: This will fail with connection error in unit tests
// but should pass config validation
if err != nil {
assert.Contains(t, err.Error(), "failed to connect")
}
}
})
}
}
func TestBuildPineconeFilter(t *testing.T) {
tests := []struct {
name string
queries []Query
expected bool
}{
{
name: "empty queries",
queries: []Query{},
expected: false,
},
{
name: "single string query",
queries: []Query{
{Field: "category", Operator: QueryOperatorEqual, Value: "tech"},
},
expected: true,
},
{
name: "single numeric query",
queries: []Query{
{Field: "size", Operator: QueryOperatorGreaterThan, Value: 1000},
},
expected: true,
},
{
name: "multiple queries",
queries: []Query{
{Field: "category", Operator: QueryOperatorEqual, Value: "tech"},
{Field: "public", Operator: QueryOperatorEqual, Value: true},
},
expected: true,
},
{
name: "not equal query",
queries: []Query{
{Field: "status", Operator: QueryOperatorNotEqual, Value: "deleted"},
},
expected: true,
},
{
name: "range queries",
queries: []Query{
{Field: "count", Operator: QueryOperatorGreaterThanOrEqual, Value: 10},
{Field: "score", Operator: QueryOperatorLessThanOrEqual, Value: 100},
},
expected: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := buildPineconeFilter(tt.queries)
assert.NoError(t, err)
if tt.expected {
assert.NotNil(t, result)
} else {
assert.Nil(t, result)
}
})
}
}
func TestBuildPineconeCondition(t *testing.T) {
tests := []struct {
name string
query Query
expected map[string]interface{}
}{
{
name: "equal operator",
query: Query{Field: "category", Operator: QueryOperatorEqual, Value: "tech"},
expected: map[string]interface{}{"$eq": "tech"},
},
{
name: "not equal operator",
query: Query{Field: "status", Operator: QueryOperatorNotEqual, Value: "deleted"},
expected: map[string]interface{}{"$ne": "deleted"},
},
{
name: "greater than operator",
query: Query{Field: "count", Operator: QueryOperatorGreaterThan, Value: 10},
expected: map[string]interface{}{"$gt": 10},
},
{
name: "greater than or equal operator",
query: Query{Field: "count", Operator: QueryOperatorGreaterThanOrEqual, Value: 10},
expected: map[string]interface{}{"$gte": 10},
},
{
name: "less than operator",
query: Query{Field: "score", Operator: QueryOperatorLessThan, Value: 100},
expected: map[string]interface{}{"$lt": 100},
},
{
name: "less than or equal operator",
query: Query{Field: "score", Operator: QueryOperatorLessThanOrEqual, Value: 100},
expected: map[string]interface{}{"$lte": 100},
},
{
name: "contains any operator",
query: Query{Field: "tags", Operator: QueryOperatorContainsAny, Value: []string{"a", "b"}},
expected: map[string]interface{}{"$in": []string{"a", "b"}},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := buildPineconeCondition(tt.query)
assert.Equal(t, tt.expected, result)
})
}
}
func TestMatchesQueries(t *testing.T) {
tests := []struct {
name string
props map[string]interface{}
queries []Query
expected bool
}{
{
name: "empty queries matches all",
props: map[string]interface{}{"type": "document"},
queries: []Query{},
expected: true,
},
{
name: "equal match",
props: map[string]interface{}{"type": "document"},
queries: []Query{{Field: "type", Operator: QueryOperatorEqual, Value: "document"}},
expected: true,
},
{
name: "equal no match",
props: map[string]interface{}{"type": "document"},
queries: []Query{{Field: "type", Operator: QueryOperatorEqual, Value: "image"}},
expected: false,
},
{
name: "not equal match",
props: map[string]interface{}{"type": "document"},
queries: []Query{{Field: "type", Operator: QueryOperatorNotEqual, Value: "image"}},
expected: true,
},
{
name: "is null match",
props: map[string]interface{}{"type": "document"},
queries: []Query{{Field: "author", Operator: QueryOperatorIsNull, Value: nil}},
expected: true,
},
{
name: "is not null match",
props: map[string]interface{}{"type": "document", "author": "alice"},
queries: []Query{{Field: "author", Operator: QueryOperatorIsNotNull, Value: nil}},
expected: true,
},
{
name: "multiple queries all match",
props: map[string]interface{}{"type": "document", "public": true},
queries: []Query{
{Field: "type", Operator: QueryOperatorEqual, Value: "document"},
{Field: "public", Operator: QueryOperatorEqual, Value: true},
},
expected: true,
},
{
name: "multiple queries one fails",
props: map[string]interface{}{"type": "document", "public": false},
queries: []Query{
{Field: "type", Operator: QueryOperatorEqual, Value: "document"},
{Field: "public", Operator: QueryOperatorEqual, Value: true},
},
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := matchesQueries(tt.props, tt.queries)
assert.Equal(t, tt.expected, result)
})
}
}
func TestFilterPropertiesPinecone(t *testing.T) {
props := map[string]interface{}{
"type": "document",
"author": "alice",
"size": 1024,
"public": true,
}
tests := []struct {
name string
selectFields []string
expected map[string]interface{}
}{
{
name: "empty select returns all",
selectFields: []string{},
expected: props,
},
{
name: "select single field",
selectFields: []string{"type"},
expected: map[string]interface{}{"type": "document"},
},
{
name: "select multiple fields",
selectFields: []string{"type", "author"},
expected: map[string]interface{}{"type": "document", "author": "alice"},
},
{
name: "select non-existent field",
selectFields: []string{"missing"},
expected: map[string]interface{}{},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := filterPropertiesPinecone(props, tt.selectFields)
assert.Equal(t, tt.expected, result)
})
}
}
// ============================================================================
// INTEGRATION TESTS (require real Pinecone instance)
// ============================================================================
func TestPineconeStore_Integration(t *testing.T) {
if testing.Short() {
t.Skip("Skipping integration tests in short mode")
}
setup := NewPineconeTestSetup(t)
defer setup.Cleanup(t)
// Test Ping
err := setup.Store.Ping(setup.ctx)
require.NoError(t, err)
// Test Add and GetChunk
key := generateUUID()
embedding := generateTestEmbedding(PineconeTestDimension)
metadata := map[string]interface{}{
"type": "document",
"author": "test",
}
err = setup.Store.Add(setup.ctx, PineconeTestNamespace, key, embedding, metadata)
require.NoError(t, err)
// Wait for eventual consistency
time.Sleep(2 * time.Second)
result, err := setup.Store.GetChunk(setup.ctx, PineconeTestNamespace, key)
require.NoError(t, err)
assert.Equal(t, key, result.ID)
assert.Equal(t, "document", result.Properties["type"])
assert.Equal(t, "test", result.Properties["author"])
// Test GetChunks
key2 := generateUUID()
err = setup.Store.Add(setup.ctx, PineconeTestNamespace, key2, generateTestEmbedding(PineconeTestDimension), map[string]interface{}{"type": "image"})
require.NoError(t, err)
time.Sleep(2 * time.Second)
results, err := setup.Store.GetChunks(setup.ctx, PineconeTestNamespace, []string{key, key2})
require.NoError(t, err)
assert.Len(t, results, 2)
}
func TestPineconeStore_VectorSearch(t *testing.T) {
if testing.Short() {
t.Skip("Skipping integration tests in short mode")
}
setup := NewPineconeTestSetup(t)
defer setup.Cleanup(t)
// Add test vectors
emb := generateTestEmbedding(PineconeTestDimension)
err := setup.Store.Add(setup.ctx, PineconeTestNamespace, generateUUID(), emb, map[string]interface{}{"type": "tech"})
require.NoError(t, err)
err = setup.Store.Add(setup.ctx, PineconeTestNamespace, generateUUID(), generateTestEmbedding(PineconeTestDimension), map[string]interface{}{"type": "sports"})
require.NoError(t, err)
// Wait for eventual consistency
time.Sleep(3 * time.Second)
// Test vector similarity search
results, err := setup.Store.GetNearest(setup.ctx, PineconeTestNamespace, emb, nil, []string{"type"}, 0.1, 10)
require.NoError(t, err)
assert.GreaterOrEqual(t, len(results), 1)
if len(results) > 0 {
require.NotNil(t, results[0].Score)
}
// Test with filter
queries := []Query{{Field: "type", Operator: QueryOperatorEqual, Value: "tech"}}
results, err = setup.Store.GetNearest(setup.ctx, PineconeTestNamespace, emb, queries, []string{"type"}, 0.1, 10)
require.NoError(t, err)
for _, result := range results {
assert.Equal(t, "tech", result.Properties["type"])
}
}
func TestPineconeStore_Delete(t *testing.T) {
if testing.Short() {
t.Skip("Skipping integration tests in short mode")
}
setup := NewPineconeTestSetup(t)
defer setup.Cleanup(t)
// Add a vector
key := generateUUID()
err := setup.Store.Add(setup.ctx, PineconeTestNamespace, key, generateTestEmbedding(PineconeTestDimension), map[string]interface{}{"type": "to-delete"})
require.NoError(t, err)
time.Sleep(2 * time.Second)
// Verify it exists
_, err = setup.Store.GetChunk(setup.ctx, PineconeTestNamespace, key)
require.NoError(t, err)
// Delete it
err = setup.Store.Delete(setup.ctx, PineconeTestNamespace, key)
require.NoError(t, err)
time.Sleep(2 * time.Second)
// Verify it's gone
_, err = setup.Store.GetChunk(setup.ctx, PineconeTestNamespace, key)
assert.Error(t, err)
assert.Contains(t, err.Error(), "not found")
}
func TestPineconeStore_ErrorHandling(t *testing.T) {
if testing.Short() {
t.Skip("Skipping integration tests in short mode")
}
setup := NewPineconeTestSetup(t)
defer setup.Cleanup(t)
// Test GetChunk with non-existent ID
_, err := setup.Store.GetChunk(setup.ctx, PineconeTestNamespace, generateUUID())
assert.Error(t, err)
assert.Contains(t, err.Error(), "not found")
// Test Add with empty ID
err = setup.Store.Add(setup.ctx, PineconeTestNamespace, "", generateTestEmbedding(PineconeTestDimension), map[string]interface{}{"type": "test"})
assert.Error(t, err)
assert.Contains(t, err.Error(), "id is required")
// Test Delete with empty ID
err = setup.Store.Delete(setup.ctx, PineconeTestNamespace, "")
assert.Error(t, err)
assert.Contains(t, err.Error(), "id is required")
}
func TestPineconeStore_SemanticCacheWorkflow(t *testing.T) {
if testing.Short() {
t.Skip("Skipping integration tests in short mode")
}
setup := NewPineconeTestSetup(t)
defer setup.Cleanup(t)
// Simulate a semantic cache workflow
cacheEntries := []struct {
key string
embedding []float32
metadata map[string]interface{}
}{
{
generateUUID(),
generateTestEmbedding(PineconeTestDimension),
map[string]interface{}{
"request_hash": "abc123",
"user": "u1",
"lang": "en",
"response": "answer1",
},
},
{
generateUUID(),
generateTestEmbedding(PineconeTestDimension),
map[string]interface{}{
"request_hash": "def456",
"user": "u1",
"lang": "es",
"response": "answer2",
},
},
}
// Add cache entries
for _, entry := range cacheEntries {
err := setup.Store.Add(setup.ctx, PineconeTestNamespace, entry.key, entry.embedding, entry.metadata)
require.NoError(t, err)
}
time.Sleep(3 * time.Second)
// Test semantic search with user filter
userFilter := []Query{{Field: "user", Operator: QueryOperatorEqual, Value: "u1"}}
results, err := setup.Store.GetNearest(setup.ctx, PineconeTestNamespace, cacheEntries[0].embedding, userFilter, []string{"request_hash", "user", "lang", "response"}, 0.1, 10)
require.NoError(t, err)
assert.GreaterOrEqual(t, len(results), 1)
// Verify user filter worked
for _, result := range results {
assert.Equal(t, "u1", result.Properties["user"])
}
}
// ============================================================================
// INTERFACE COMPLIANCE TESTS
// ============================================================================
func TestPineconeStore_InterfaceCompliance(t *testing.T) {
var _ VectorStore = (*PineconeStore)(nil)
}
func TestVectorStoreFactory_Pinecone(t *testing.T) {
if testing.Short() {
t.Skip("Skipping integration tests in short mode")
}
apiKey := schemas.NewEnvVar(getEnvWithDefault("PINECONE_API_KEY", PineconeTestDefaultAPIKey))
indexHost := schemas.NewEnvVar(getEnvWithDefault("PINECONE_INDEX_HOST", PineconeTestDefaultIndexHost))
logger := bifrost.NewDefaultLogger(schemas.LogLevelInfo)
config := &Config{
Enabled: true,
Type: VectorStoreTypePinecone,
Config: PineconeConfig{
APIKey: *apiKey,
IndexHost: *indexHost,
},
}
store, err := NewVectorStore(context.Background(), config, logger)
if err != nil {
t.Skipf("Could not create Pinecone store: %v", err)
}
defer store.Close(context.Background(), PineconeTestNamespace)
pineconeStore, ok := store.(*PineconeStore)
assert.True(t, ok)
assert.NotNil(t, pineconeStore)
}

View File

@@ -0,0 +1,609 @@
package vectorstore
import (
"context"
"fmt"
"strings"
"github.com/google/uuid"
"github.com/maximhq/bifrost/core/schemas"
"github.com/qdrant/go-client/qdrant"
)
// QdrantConfig represents the configuration for the Qdrant vector store.
type QdrantConfig struct {
Host schemas.EnvVar `json:"host"` // Qdrant server host - REQUIRED
Port schemas.EnvVar `json:"port"` // Qdrant server port (fallback to 6334 for gRPC)
APIKey schemas.EnvVar `json:"api_key,omitempty"` // API key for authentication - Optional
UseTLS schemas.EnvVar `json:"use_tls,omitempty"` // Use TLS for connection - Optional
}
// QdrantStore represents the Qdrant vector store.
type QdrantStore struct {
client *qdrant.Client
logger schemas.Logger
}
// Ping checks if the Qdrant server is reachable.
func (s *QdrantStore) Ping(ctx context.Context) error {
_, err := s.client.HealthCheck(ctx)
return err
}
// CreateNamespace creates a new collection in the Qdrant vector store.
func (s *QdrantStore) CreateNamespace(ctx context.Context, namespace string, dimension int, properties map[string]VectorStoreProperties) error {
exists, err := s.client.CollectionExists(ctx, namespace)
if err != nil {
return fmt.Errorf("failed to check collection existence: %w", err)
}
if !exists {
err = s.client.CreateCollection(ctx, &qdrant.CreateCollection{
CollectionName: namespace,
VectorsConfig: qdrant.NewVectorsConfig(&qdrant.VectorParams{
Size: uint64(dimension),
Distance: qdrant.Distance_Cosine,
}),
})
if err != nil {
return fmt.Errorf("failed to create collection: %w", err)
}
}
for fieldName, prop := range properties {
var fieldType qdrant.FieldType
switch prop.DataType {
case VectorStorePropertyTypeInteger:
fieldType = qdrant.FieldType_FieldTypeInteger
case VectorStorePropertyTypeBoolean:
fieldType = qdrant.FieldType_FieldTypeBool
default:
fieldType = qdrant.FieldType_FieldTypeKeyword
}
_, err = s.client.CreateFieldIndex(ctx, &qdrant.CreateFieldIndexCollection{
CollectionName: namespace,
FieldName: fieldName,
FieldType: &fieldType,
})
if err != nil {
s.logger.Debug(fmt.Sprintf("failed to create index for field %s: %v", fieldName, err))
}
}
return nil
}
// DeleteNamespace deletes a collection from the Qdrant vector store.
func (s *QdrantStore) DeleteNamespace(ctx context.Context, namespace string) error {
exists, err := s.client.CollectionExists(ctx, namespace)
if err != nil {
return fmt.Errorf("failed to check collection existence: %w", err)
}
if !exists {
return nil
}
return s.client.DeleteCollection(ctx, namespace)
}
// GetChunk retrieves a single point from the Qdrant vector store.
func (s *QdrantStore) GetChunk(ctx context.Context, namespace string, id string) (SearchResult, error) {
if strings.TrimSpace(id) == "" {
return SearchResult{}, fmt.Errorf("id is required")
}
pointID, err := parsePointID(id)
if err != nil {
return SearchResult{}, fmt.Errorf("invalid id format: %w", err)
}
points, err := s.client.Get(ctx, &qdrant.GetPoints{
CollectionName: namespace,
Ids: []*qdrant.PointId{pointID},
WithPayload: qdrant.NewWithPayload(true),
})
if err != nil {
return SearchResult{}, fmt.Errorf("failed to get point: %w", err)
}
if len(points) == 0 {
return SearchResult{}, fmt.Errorf("not found: %s", id)
}
return SearchResult{
ID: id,
Properties: payloadToMap(points[0].Payload),
}, nil
}
// GetChunks retrieves multiple points from the Qdrant vector store.
func (s *QdrantStore) GetChunks(ctx context.Context, namespace string, ids []string) ([]SearchResult, error) {
if len(ids) == 0 {
return []SearchResult{}, nil
}
pointIDs := make([]*qdrant.PointId, 0, len(ids))
for _, id := range ids {
if strings.TrimSpace(id) == "" {
continue
}
pointID, err := parsePointID(id)
if err != nil {
s.logger.Debug(fmt.Sprintf("skipping invalid id %s: %v", id, err))
continue
}
pointIDs = append(pointIDs, pointID)
}
if len(pointIDs) == 0 {
return []SearchResult{}, nil
}
points, err := s.client.Get(ctx, &qdrant.GetPoints{
CollectionName: namespace,
Ids: pointIDs,
WithPayload: qdrant.NewWithPayload(true),
})
if err != nil {
return nil, fmt.Errorf("failed to get points: %w", err)
}
results := make([]SearchResult, 0, len(points))
for _, point := range points {
results = append(results, SearchResult{
ID: pointIDToString(point.Id),
Properties: payloadToMap(point.Payload),
})
}
return results, nil
}
// GetAll retrieves all points with optional filtering and pagination.
func (s *QdrantStore) GetAll(ctx context.Context, namespace string, queries []Query, selectFields []string, cursor *string, limit int64) ([]SearchResult, *string, error) {
filter := buildQdrantFilter(queries)
var offset *qdrant.PointId
if cursor != nil && *cursor != "" {
var err error
offset, err = parsePointID(*cursor)
if err != nil {
s.logger.Debug(fmt.Sprintf("invalid cursor format: %v", err))
}
}
scrollLimit := uint32(limit)
if limit <= 0 {
scrollLimit = 100
}
scrollResult, err := s.client.Scroll(ctx, &qdrant.ScrollPoints{
CollectionName: namespace,
Filter: filter,
Limit: &scrollLimit,
Offset: offset,
WithPayload: qdrant.NewWithPayload(true),
})
if err != nil {
return nil, nil, fmt.Errorf("failed to scroll points: %w", err)
}
results := make([]SearchResult, 0, len(scrollResult))
var lastID string
for _, point := range scrollResult {
lastID = pointIDToString(point.Id)
results = append(results, SearchResult{
ID: lastID,
Properties: filterProperties(payloadToMap(point.Payload), selectFields),
})
}
if len(scrollResult) >= int(scrollLimit) {
return results, &lastID, nil
}
return results, nil, nil
}
// GetNearest retrieves the nearest points to a vector.
func (s *QdrantStore) GetNearest(ctx context.Context, namespace string, vector []float32, queries []Query, selectFields []string, threshold float64, limit int64) ([]SearchResult, error) {
filter := buildQdrantFilter(queries)
searchLimit := uint64(limit)
if limit <= 0 {
searchLimit = 10
}
searchResult, err := s.client.Query(ctx, &qdrant.QueryPoints{
CollectionName: namespace,
Query: qdrant.NewQuery(vector...),
Filter: filter,
Limit: &searchLimit,
WithPayload: qdrant.NewWithPayload(true),
ScoreThreshold: qdrant.PtrOf(float32(threshold)),
})
if err != nil {
return nil, fmt.Errorf("failed to search points: %w", err)
}
results := make([]SearchResult, 0, len(searchResult))
for _, point := range searchResult {
score := float64(point.Score)
results = append(results, SearchResult{
ID: pointIDToString(point.Id),
Score: &score,
Properties: filterProperties(payloadToMap(point.Payload), selectFields),
})
}
return results, nil
}
// Add stores a new point in the Qdrant vector store.
func (s *QdrantStore) Add(ctx context.Context, namespace string, id string, embedding []float32, metadata map[string]interface{}) error {
if strings.TrimSpace(id) == "" {
return fmt.Errorf("id is required")
}
pointID, err := parsePointID(id)
if err != nil {
return fmt.Errorf("invalid id format (must be UUID): %w", err)
}
point := &qdrant.PointStruct{
Id: pointID,
Payload: mapToPayload(metadata),
}
if len(embedding) > 0 {
point.Vectors = qdrant.NewVectors(embedding...)
}
_, err = s.client.Upsert(ctx, &qdrant.UpsertPoints{
CollectionName: namespace,
Points: []*qdrant.PointStruct{point},
Wait: qdrant.PtrOf(true),
})
if err != nil {
return fmt.Errorf("failed to upsert point: %w", err)
}
return nil
}
// Delete removes a point from the Qdrant vector store.
func (s *QdrantStore) Delete(ctx context.Context, namespace string, id string) error {
if strings.TrimSpace(id) == "" {
return fmt.Errorf("id is required")
}
pointID, err := parsePointID(id)
if err != nil {
return fmt.Errorf("invalid id format: %w", err)
}
_, err = s.client.Delete(ctx, &qdrant.DeletePoints{
CollectionName: namespace,
Points: qdrant.NewPointsSelector(pointID),
})
return err
}
// DeleteAll removes multiple points matching the filter.
func (s *QdrantStore) DeleteAll(ctx context.Context, namespace string, queries []Query) ([]DeleteResult, error) {
filter := buildQdrantFilter(queries)
scrollResult, err := s.client.Scroll(ctx, &qdrant.ScrollPoints{
CollectionName: namespace,
Filter: filter,
WithPayload: qdrant.NewWithPayload(false),
})
if err != nil {
return nil, fmt.Errorf("failed to scroll points: %w", err)
}
if len(scrollResult) == 0 {
return []DeleteResult{}, nil
}
results := make([]DeleteResult, len(scrollResult))
for i, point := range scrollResult {
results[i] = DeleteResult{
ID: pointIDToString(point.Id),
Status: DeleteStatusSuccess,
}
}
var deleteErr error
if filter != nil {
_, deleteErr = s.client.Delete(ctx, &qdrant.DeletePoints{
CollectionName: namespace,
Points: qdrant.NewPointsSelectorFilter(filter),
})
} else {
pointIDs := make([]*qdrant.PointId, len(scrollResult))
for i, point := range scrollResult {
pointIDs[i] = point.Id
}
_, deleteErr = s.client.Delete(ctx, &qdrant.DeletePoints{
CollectionName: namespace,
Points: qdrant.NewPointsSelectorIDs(pointIDs),
})
}
if deleteErr != nil {
for i := range results {
results[i].Status = DeleteStatusError
results[i].Error = deleteErr.Error()
}
}
return results, nil
}
// Close closes the Qdrant client connection.
func (s *QdrantStore) Close(ctx context.Context, namespace string) error {
return s.client.Close()
}
// RequiresVectors returns true because Qdrant is a dedicated vector database
// that requires vectors for all points/entries.
func (s *QdrantStore) RequiresVectors() bool {
return true
}
// newQdrantStore creates a new Qdrant vector store.
func newQdrantStore(ctx context.Context, config *QdrantConfig, logger schemas.Logger) (*QdrantStore, error) {
if strings.TrimSpace(config.Host.GetValue()) == "" {
return nil, fmt.Errorf("qdrant host is required")
}
client, err := qdrant.NewClient(&qdrant.Config{
Host: config.Host.GetValue(),
Port: config.Port.CoerceInt(6334),
APIKey: config.APIKey.GetValue(),
UseTLS: config.UseTLS.CoerceBool(false),
SkipCompatibilityCheck: true,
})
if err != nil {
return nil, fmt.Errorf("failed to create qdrant client: %w", err)
}
_, err = client.HealthCheck(ctx)
if err != nil {
return nil, fmt.Errorf("failed to connect to qdrant: %w", err)
}
return &QdrantStore{
client: client,
logger: logger,
}, nil
}
func parsePointID(id string) (*qdrant.PointId, error) {
if _, err := uuid.Parse(id); err != nil {
return nil, err
}
return qdrant.NewID(id), nil
}
func pointIDToString(id *qdrant.PointId) string {
if id == nil {
return ""
}
switch v := id.PointIdOptions.(type) {
case *qdrant.PointId_Uuid:
return v.Uuid
case *qdrant.PointId_Num:
return fmt.Sprintf("%d", v.Num)
default:
return ""
}
}
func payloadToMap(payload map[string]*qdrant.Value) map[string]interface{} {
if payload == nil {
return make(map[string]interface{})
}
result := make(map[string]interface{}, len(payload))
for k, v := range payload {
result[k] = valueToInterface(v)
}
return result
}
func valueToInterface(v *qdrant.Value) interface{} {
if v == nil {
return nil
}
switch val := v.Kind.(type) {
case *qdrant.Value_StringValue:
return val.StringValue
case *qdrant.Value_IntegerValue:
return val.IntegerValue
case *qdrant.Value_DoubleValue:
return val.DoubleValue
case *qdrant.Value_BoolValue:
return val.BoolValue
case *qdrant.Value_ListValue:
list := make([]interface{}, len(val.ListValue.Values))
for i, item := range val.ListValue.Values {
list[i] = valueToInterface(item)
}
return list
case *qdrant.Value_StructValue:
return payloadToMap(val.StructValue.Fields)
default:
return nil
}
}
func mapToPayload(m map[string]interface{}) map[string]*qdrant.Value {
if m == nil {
return make(map[string]*qdrant.Value)
}
// Convert []string to []interface{} since Qdrant's NewValueMap doesn't handle []string directly
converted := make(map[string]interface{}, len(m))
for k, v := range m {
switch val := v.(type) {
case []string:
// Convert []string to []interface{}
interfaceSlice := make([]interface{}, len(val))
for i, s := range val {
interfaceSlice[i] = s
}
converted[k] = interfaceSlice
default:
converted[k] = v
}
}
return qdrant.NewValueMap(converted)
}
func filterProperties(props map[string]interface{}, selectFields []string) map[string]interface{} {
if len(selectFields) == 0 {
return props
}
filtered := make(map[string]interface{}, len(selectFields))
for _, field := range selectFields {
if val, ok := props[field]; ok {
filtered[field] = val
}
}
return filtered
}
func buildQdrantFilter(queries []Query) *qdrant.Filter {
if len(queries) == 0 {
return nil
}
var conditions []*qdrant.Condition
for _, q := range queries {
condition := buildQdrantCondition(q)
if condition != nil {
conditions = append(conditions, condition)
}
}
if len(conditions) == 0 {
return nil
}
return &qdrant.Filter{
Must: conditions,
}
}
func buildQdrantCondition(q Query) *qdrant.Condition {
field := q.Field
switch q.Operator {
case QueryOperatorEqual:
return buildMatchCondition(field, q.Value)
case QueryOperatorNotEqual:
matchCond := buildMatchCondition(field, q.Value)
if matchCond == nil {
return nil
}
return qdrant.NewFilterAsCondition(&qdrant.Filter{
MustNot: []*qdrant.Condition{matchCond},
})
case QueryOperatorGreaterThan:
return buildRangeCondition(field, q.Value, "gt")
case QueryOperatorGreaterThanOrEqual:
return buildRangeCondition(field, q.Value, "gte")
case QueryOperatorLessThan:
return buildRangeCondition(field, q.Value, "lt")
case QueryOperatorLessThanOrEqual:
return buildRangeCondition(field, q.Value, "lte")
case QueryOperatorIsNull:
return qdrant.NewIsNull(field)
case QueryOperatorIsNotNull:
return qdrant.NewFilterAsCondition(&qdrant.Filter{
MustNot: []*qdrant.Condition{qdrant.NewIsNull(field)},
})
case QueryOperatorContainsAny:
switch v := q.Value.(type) {
case []string:
return qdrant.NewMatchKeywords(field, v...)
case []int:
int64s := make([]int64, len(v))
for i, val := range v {
int64s[i] = int64(val)
}
return qdrant.NewMatchInts(field, int64s...)
case []int64:
return qdrant.NewMatchInts(field, v...)
}
return buildMatchCondition(field, q.Value)
case QueryOperatorContainsAll:
if values, ok := q.Value.([]interface{}); ok {
var mustConditions []*qdrant.Condition
for _, v := range values {
cond := buildMatchCondition(field, v)
if cond != nil {
mustConditions = append(mustConditions, cond)
}
}
if len(mustConditions) > 0 {
return qdrant.NewFilterAsCondition(&qdrant.Filter{
Must: mustConditions,
})
}
}
return buildMatchCondition(field, q.Value)
case QueryOperatorLike:
if str, ok := q.Value.(string); ok {
return qdrant.NewMatchText(field, str)
}
return nil
default:
return buildMatchCondition(field, q.Value)
}
}
func buildMatchCondition(field string, value interface{}) *qdrant.Condition {
switch v := value.(type) {
case string:
return qdrant.NewMatchKeyword(field, v)
case int:
return qdrant.NewMatchInt(field, int64(v))
case int32:
return qdrant.NewMatchInt(field, int64(v))
case int64:
return qdrant.NewMatchInt(field, v)
case bool:
return qdrant.NewMatchBool(field, v)
default:
return qdrant.NewMatchKeyword(field, fmt.Sprintf("%v", v))
}
}
func buildRangeCondition(field string, value interface{}, op string) *qdrant.Condition {
var floatVal float64
switch v := value.(type) {
case int:
floatVal = float64(v)
case int32:
floatVal = float64(v)
case int64:
floatVal = float64(v)
case float32:
floatVal = float64(v)
case float64:
floatVal = v
default:
return nil
}
r := &qdrant.Range{}
switch op {
case "gt":
r.Gt = qdrant.PtrOf(floatVal)
case "gte":
r.Gte = qdrant.PtrOf(floatVal)
case "lt":
r.Lt = qdrant.PtrOf(floatVal)
case "lte":
r.Lte = qdrant.PtrOf(floatVal)
}
return qdrant.NewRange(field, r)
}

View File

@@ -0,0 +1,506 @@
package vectorstore
import (
"context"
"os"
"testing"
"time"
bifrost "github.com/maximhq/bifrost/core"
"github.com/maximhq/bifrost/core/schemas"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
const (
QdrantTestTimeout = 30 * time.Second
QdrantTestCollection = "bifrost-test-collection"
QdrantTestDefaultHost = "localhost"
QdrantTestDefaultPort = "6334"
QdrantTestDimension = 384
)
type QdrantTestSetup struct {
Store *QdrantStore
Logger schemas.Logger
Config QdrantConfig
ctx context.Context
cancel context.CancelFunc
}
func NewQdrantTestSetup(t *testing.T) *QdrantTestSetup {
host := schemas.NewEnvVar(getEnvWithDefault("QDRANT_HOST", QdrantTestDefaultHost))
port := schemas.NewEnvVar(getEnvWithDefault("QDRANT_PORT", QdrantTestDefaultPort))
apiKey := schemas.NewEnvVar(os.Getenv("QDRANT_API_KEY"))
useTLS := schemas.NewEnvVar(os.Getenv("QDRANT_USE_TLS"))
config := QdrantConfig{
Host: *host,
Port: *port,
APIKey: *apiKey,
UseTLS: *useTLS,
}
logger := bifrost.NewDefaultLogger(schemas.LogLevelInfo)
ctx, cancel := context.WithTimeout(context.Background(), QdrantTestTimeout)
store, err := newQdrantStore(ctx, &config, logger)
if err != nil {
cancel()
t.Fatalf("Failed to create Qdrant store: %v", err)
}
setup := &QdrantTestSetup{
Store: store,
Logger: logger,
Config: config,
ctx: ctx,
cancel: cancel,
}
setup.ensureCollectionExists(t)
return setup
}
func (ts *QdrantTestSetup) Cleanup(t *testing.T) {
defer ts.cancel()
if !testing.Short() {
ts.cleanupTestData(t)
}
if err := ts.Store.Close(ts.ctx, QdrantTestCollection); err != nil {
t.Logf("Warning: Failed to close store: %v", err)
}
}
func (ts *QdrantTestSetup) ensureCollectionExists(t *testing.T) {
properties := map[string]VectorStoreProperties{
"key": {
DataType: VectorStorePropertyTypeString,
},
"type": {
DataType: VectorStorePropertyTypeString,
},
"test_type": {
DataType: VectorStorePropertyTypeString,
},
"size": {
DataType: VectorStorePropertyTypeInteger,
},
"public": {
DataType: VectorStorePropertyTypeBoolean,
},
"author": {
DataType: VectorStorePropertyTypeString,
},
"request_hash": {
DataType: VectorStorePropertyTypeString,
},
"user": {
DataType: VectorStorePropertyTypeString,
},
"lang": {
DataType: VectorStorePropertyTypeString,
},
"category": {
DataType: VectorStorePropertyTypeString,
},
"content": {
DataType: VectorStorePropertyTypeString,
},
"response": {
DataType: VectorStorePropertyTypeString,
},
}
err := ts.Store.CreateNamespace(ts.ctx, QdrantTestCollection, QdrantTestDimension, properties)
if err != nil {
t.Fatalf("Failed to create collection %q: %v", QdrantTestCollection, err)
}
t.Logf("Created test collection: %s", QdrantTestCollection)
}
func (ts *QdrantTestSetup) cleanupTestData(t *testing.T) {
allTestKeys, _, err := ts.Store.GetAll(ts.ctx, QdrantTestCollection, []Query{}, []string{}, nil, 1000)
if err != nil {
t.Logf("Warning: Failed to get all test keys: %v", err)
return
}
for _, key := range allTestKeys {
err := ts.Store.Delete(ts.ctx, QdrantTestCollection, key.ID)
if err != nil {
t.Logf("Warning: Failed to delete test key %s: %v", key.ID, err)
}
}
t.Logf("Cleaned up test collection: %s", QdrantTestCollection)
}
func TestQdrantConfig_Validation(t *testing.T) {
logger := bifrost.NewDefaultLogger(schemas.LogLevelInfo)
ctx := context.Background()
tests := []struct {
name string
config QdrantConfig
expectError bool
errorMsg string
}{
{
name: "valid config",
config: QdrantConfig{
Host: *schemas.NewEnvVar("localhost"),
Port: *schemas.NewEnvVar("6334"),
},
expectError: false,
},
{
name: "missing host",
config: QdrantConfig{
Port: *schemas.NewEnvVar("6334"),
},
expectError: true,
errorMsg: "qdrant host is required",
},
{
name: "missing port uses default",
config: QdrantConfig{
Host: *schemas.NewEnvVar("localhost"),
},
expectError: false, // Port defaults to 6334 via CoerceInt fallback
},
{
name: "with api key",
config: QdrantConfig{
Host: *schemas.NewEnvVar("cluster.qdrant.io"),
Port: *schemas.NewEnvVar("6334"),
APIKey: *schemas.NewEnvVar("test-key"),
},
expectError: false,
},
{
name: "with tls",
config: QdrantConfig{
Host: *schemas.NewEnvVar("localhost"),
Port: *schemas.NewEnvVar("6334"),
UseTLS: *schemas.NewEnvVar("true"),
},
expectError: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
store, err := newQdrantStore(ctx, &tt.config, logger)
if tt.expectError {
assert.Error(t, err)
assert.Nil(t, store)
if tt.errorMsg != "" {
assert.Contains(t, err.Error(), tt.errorMsg)
}
} else {
if err != nil {
assert.Contains(t, err.Error(), "failed to connect")
}
}
})
}
}
func TestParsePointID(t *testing.T) {
tests := []struct {
name string
id string
expectError bool
}{
{
name: "valid UUID",
id: "550e8400-e29b-41d4-a716-446655440000",
expectError: false,
},
{
name: "invalid UUID",
id: "not-a-uuid",
expectError: true,
},
{
name: "empty string",
id: "",
expectError: true,
},
{
name: "numeric string",
id: "12345",
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
pointID, err := parsePointID(tt.id)
if tt.expectError {
assert.Error(t, err)
assert.Nil(t, pointID)
} else {
assert.NoError(t, err)
assert.NotNil(t, pointID)
}
})
}
}
func TestBuildQdrantFilter(t *testing.T) {
tests := []struct {
name string
queries []Query
expected bool
}{
{
name: "empty queries",
queries: []Query{},
expected: false,
},
{
name: "single string query",
queries: []Query{
{Field: "category", Operator: QueryOperatorEqual, Value: "tech"},
},
expected: true,
},
{
name: "single numeric query",
queries: []Query{
{Field: "size", Operator: QueryOperatorGreaterThan, Value: 1000},
},
expected: true,
},
{
name: "multiple queries (AND)",
queries: []Query{
{Field: "category", Operator: QueryOperatorEqual, Value: "tech"},
{Field: "public", Operator: QueryOperatorEqual, Value: true},
},
expected: true,
},
{
name: "null checks",
queries: []Query{
{Field: "author", Operator: QueryOperatorIsNull, Value: nil},
},
expected: true,
},
{
name: "not null checks",
queries: []Query{
{Field: "author", Operator: QueryOperatorIsNotNull, Value: nil},
},
expected: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := buildQdrantFilter(tt.queries)
if tt.expected {
assert.NotNil(t, result)
} else {
assert.Nil(t, result)
}
})
}
}
func TestQdrantStore_Integration(t *testing.T) {
if testing.Short() {
t.Skip("Skipping integration tests in short mode")
}
setup := NewQdrantTestSetup(t)
defer setup.Cleanup(t)
err := setup.Store.Ping(setup.ctx)
require.NoError(t, err)
key := generateUUID()
err = setup.Store.Add(setup.ctx, QdrantTestCollection, key, generateTestEmbedding(QdrantTestDimension), map[string]interface{}{"type": "document"})
require.NoError(t, err)
result, err := setup.Store.GetChunk(setup.ctx, QdrantTestCollection, key)
require.NoError(t, err)
assert.Equal(t, "document", result.Properties["type"])
keys := []string{generateUUID(), generateUUID()}
for i, k := range keys {
err = setup.Store.Add(setup.ctx, QdrantTestCollection, k, generateTestEmbedding(QdrantTestDimension), map[string]interface{}{"type": i})
require.NoError(t, err)
}
results, err := setup.Store.GetChunks(setup.ctx, QdrantTestCollection, keys)
require.NoError(t, err)
assert.Len(t, results, 2)
}
func TestQdrantStore_Filtering(t *testing.T) {
if testing.Short() {
t.Skip("Skipping integration tests in short mode")
}
setup := NewQdrantTestSetup(t)
defer setup.Cleanup(t)
for i := 0; i < 3; i++ {
metadata := map[string]interface{}{"type": "pdf", "public": true}
if i == 1 {
metadata["type"] = "docx"
metadata["public"] = false
}
err := setup.Store.Add(setup.ctx, QdrantTestCollection, generateUUID(), generateTestEmbedding(QdrantTestDimension), metadata)
require.NoError(t, err)
}
queries := []Query{{Field: "type", Operator: QueryOperatorEqual, Value: "pdf"}}
results, _, err := setup.Store.GetAll(setup.ctx, QdrantTestCollection, queries, []string{"type"}, nil, 10)
require.NoError(t, err)
assert.Len(t, results, 2)
multiQuery := []Query{
{Field: "type", Operator: QueryOperatorEqual, Value: "pdf"},
{Field: "public", Operator: QueryOperatorEqual, Value: true},
}
results, _, err = setup.Store.GetAll(setup.ctx, QdrantTestCollection, multiQuery, []string{"type"}, nil, 10)
require.NoError(t, err)
assert.Len(t, results, 2)
}
func TestQdrantStore_VectorSearch(t *testing.T) {
if testing.Short() {
t.Skip("Skipping integration tests in short mode")
}
setup := NewQdrantTestSetup(t)
defer setup.Cleanup(t)
emb := generateTestEmbedding(QdrantTestDimension)
err := setup.Store.Add(setup.ctx, QdrantTestCollection, generateUUID(), emb, map[string]interface{}{"type": "tech"})
require.NoError(t, err)
err = setup.Store.Add(setup.ctx, QdrantTestCollection, generateUUID(), generateTestEmbedding(QdrantTestDimension), map[string]interface{}{"type": "sports"})
require.NoError(t, err)
results, err := setup.Store.GetNearest(setup.ctx, QdrantTestCollection, emb, nil, []string{"type"}, 0.1, 10)
require.NoError(t, err)
assert.GreaterOrEqual(t, len(results), 1)
require.NotNil(t, results[0].Score)
queries := []Query{{Field: "type", Operator: QueryOperatorEqual, Value: "tech"}}
results, err = setup.Store.GetNearest(setup.ctx, QdrantTestCollection, emb, queries, []string{"type"}, 0.1, 10)
require.NoError(t, err)
for _, result := range results {
assert.Equal(t, "tech", result.Properties["type"])
}
}
func TestQdrantStore_InterfaceCompliance(t *testing.T) {
var _ VectorStore = (*QdrantStore)(nil)
}
func TestVectorStoreFactory_Qdrant(t *testing.T) {
if testing.Short() {
t.Skip("Skipping integration tests in short mode")
}
logger := bifrost.NewDefaultLogger(schemas.LogLevelInfo)
host := schemas.NewEnvVar(getEnvWithDefault("QDRANT_HOST", QdrantTestDefaultHost))
port := schemas.NewEnvVar(getEnvWithDefault("QDRANT_PORT", QdrantTestDefaultPort))
apiKey := schemas.NewEnvVar(os.Getenv("QDRANT_API_KEY"))
config := &Config{
Enabled: true,
Type: VectorStoreTypeQdrant,
Config: QdrantConfig{
Host: *host,
Port: *port,
APIKey: *apiKey,
},
}
store, err := NewVectorStore(context.Background(), config, logger)
if err != nil {
t.Skipf("Could not create Qdrant store: %v", err)
}
defer store.Close(context.Background(), QdrantTestCollection)
qdrantStore, ok := store.(*QdrantStore)
assert.True(t, ok)
assert.NotNil(t, qdrantStore)
}
func TestQdrantStore_DimensionHandling(t *testing.T) {
if testing.Short() {
t.Skip("Skipping integration tests in short mode")
}
setup := NewQdrantTestSetup(t)
defer setup.Cleanup(t)
testCollection := "TestDim"
props := map[string]VectorStoreProperties{"type": {DataType: VectorStorePropertyTypeString}}
err := setup.Store.CreateNamespace(setup.ctx, testCollection, 512, props)
require.NoError(t, err)
err = setup.Store.Add(setup.ctx, testCollection, generateUUID(), generateTestEmbedding(512), map[string]interface{}{"type": "test"})
require.NoError(t, err)
err = setup.Store.DeleteNamespace(setup.ctx, testCollection)
require.NoError(t, err)
err = setup.Store.CreateNamespace(setup.ctx, testCollection, QdrantTestDimension, props)
require.NoError(t, err)
emb := generateTestEmbedding(QdrantTestDimension)
err = setup.Store.Add(setup.ctx, testCollection, generateUUID(), emb, map[string]interface{}{"type": "test"})
require.NoError(t, err)
results, err := setup.Store.GetNearest(setup.ctx, testCollection, emb, nil, []string{"type"}, 0.8, 10)
require.NoError(t, err)
assert.GreaterOrEqual(t, len(results), 1)
setup.Store.DeleteNamespace(setup.ctx, testCollection)
}
func TestQdrantStore_ErrorHandling(t *testing.T) {
if testing.Short() {
t.Skip("Skipping integration tests in short mode")
}
setup := NewQdrantTestSetup(t)
defer setup.Cleanup(t)
_, err := setup.Store.GetChunk(setup.ctx, QdrantTestCollection, generateUUID())
assert.Error(t, err)
assert.Contains(t, err.Error(), "not found")
err = setup.Store.Add(setup.ctx, QdrantTestCollection, "", generateTestEmbedding(QdrantTestDimension), map[string]interface{}{"type": "test"})
assert.Error(t, err)
assert.Contains(t, err.Error(), "id is required")
err = setup.Store.Add(setup.ctx, QdrantTestCollection, "not-a-uuid", generateTestEmbedding(QdrantTestDimension), map[string]interface{}{"type": "test"})
assert.Error(t, err)
assert.Contains(t, err.Error(), "invalid id format")
err = setup.Store.Delete(setup.ctx, QdrantTestCollection, "")
assert.Error(t, err)
assert.Contains(t, err.Error(), "id is required")
err = setup.Store.Delete(setup.ctx, QdrantTestCollection, "not-a-uuid")
assert.Error(t, err)
assert.Contains(t, err.Error(), "invalid id format")
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,238 @@
// Package vectorstore provides a generic interface for vector stores.
package vectorstore
import (
"context"
"encoding/json"
"fmt"
"github.com/maximhq/bifrost/core/schemas"
)
type VectorStoreType string
const (
VectorStoreTypeWeaviate VectorStoreType = "weaviate"
VectorStoreTypeRedis VectorStoreType = "redis"
VectorStoreTypeQdrant VectorStoreType = "qdrant"
VectorStoreTypePinecone VectorStoreType = "pinecone"
)
// Query represents a query to the vector store.
type Query struct {
Field string
Operator QueryOperator
Value interface{}
}
type QueryOperator string
const (
QueryOperatorEqual QueryOperator = "Equal"
QueryOperatorNotEqual QueryOperator = "NotEqual"
QueryOperatorGreaterThan QueryOperator = "GreaterThan"
QueryOperatorLessThan QueryOperator = "LessThan"
QueryOperatorGreaterThanOrEqual QueryOperator = "GreaterThanOrEqual"
QueryOperatorLessThanOrEqual QueryOperator = "LessThanOrEqual"
QueryOperatorLike QueryOperator = "Like"
QueryOperatorContainsAny QueryOperator = "ContainsAny"
QueryOperatorContainsAll QueryOperator = "ContainsAll"
QueryOperatorIsNull QueryOperator = "IsNull"
QueryOperatorIsNotNull QueryOperator = "IsNotNull"
)
// SearchResult represents a search result with metadata.
type SearchResult struct {
ID string
Score *float64
Properties map[string]interface{}
}
// DeleteResult represents the result of a delete operation.
type DeleteResult struct {
ID string
Status DeleteStatus
Error string
}
type DeleteStatus string
const (
DeleteStatusSuccess DeleteStatus = "success"
DeleteStatusError DeleteStatus = "error"
)
type VectorStoreProperties struct {
DataType VectorStorePropertyType `json:"data_type"`
Description string `json:"description"`
}
type VectorStorePropertyType string
const (
VectorStorePropertyTypeString VectorStorePropertyType = "string"
VectorStorePropertyTypeInteger VectorStorePropertyType = "integer"
VectorStorePropertyTypeBoolean VectorStorePropertyType = "boolean"
VectorStorePropertyTypeStringArray VectorStorePropertyType = "string[]"
)
type disableScanFallbackContextKey struct{}
// VectorStore represents the interface for the vector store.
type VectorStore interface {
// Health check
Ping(ctx context.Context) error
// CreateNamespace creates a new namespace in the vector store.
CreateNamespace(ctx context.Context, namespace string, dimension int, properties map[string]VectorStoreProperties) error
// DeleteNamespace deletes a namespace from the vector store.
DeleteNamespace(ctx context.Context, namespace string) error
// GetChunk retrieves a single vector from the vector store.
GetChunk(ctx context.Context, namespace string, id string) (SearchResult, error)
// GetChunks retrieves multiple vectors from the vector store.
GetChunks(ctx context.Context, namespace string, ids []string) ([]SearchResult, error)
// GetAll retrieves all vectors from the vector store.
GetAll(ctx context.Context, namespace string, queries []Query, selectFields []string, cursor *string, limit int64) ([]SearchResult, *string, error)
// GetNearest retrieves the nearest vectors from the vector store.
GetNearest(ctx context.Context, namespace string, vector []float32, queries []Query, selectFields []string, threshold float64, limit int64) ([]SearchResult, error)
// RequiresVectors returns true if the vector store requires vectors for all entries.
// Dedicated vector databases like Qdrant and Pinecone require vectors, while
// more flexible stores like Weaviate and Redis can store metadata-only entries.
RequiresVectors() bool
// Add stores a new vector in the vector store.
Add(ctx context.Context, namespace string, id string, embedding []float32, metadata map[string]interface{}) error
// Delete removes a vector from the vector store.
Delete(ctx context.Context, namespace string, id string) error
// DeleteAll deletes all vectors from the vector store.
DeleteAll(ctx context.Context, namespace string, queries []Query) ([]DeleteResult, error)
// Close closes the vector store.
Close(ctx context.Context, namespace string) error
}
// WithDisableScanFallback returns a derived context that tells vector stores not
// to fall back to full scans when indexed search fails.
func WithDisableScanFallback(ctx context.Context) context.Context {
if ctx == nil {
ctx = context.Background()
}
return context.WithValue(ctx, disableScanFallbackContextKey{}, true)
}
// IsScanFallbackDisabled reports whether scan fallback has been disabled for
// the current vector store operation.
func IsScanFallbackDisabled(ctx context.Context) bool {
if ctx == nil {
return false
}
disabled, _ := ctx.Value(disableScanFallbackContextKey{}).(bool)
return disabled
}
// Config represents the configuration for the vector store.
type Config struct {
Enabled bool `json:"enabled"`
Type VectorStoreType `json:"type"`
Config any `json:"config"`
}
// UnmarshalJSON unmarshals the config from JSON.
func (c *Config) UnmarshalJSON(data []byte) error {
// First, unmarshal into a temporary struct to get the basic fields
type TempConfig struct {
Enabled bool `json:"enabled"`
Type string `json:"type"`
Config json.RawMessage `json:"config"` // Keep as raw JSON
}
var temp TempConfig
if err := json.Unmarshal(data, &temp); err != nil {
return fmt.Errorf("failed to unmarshal config: %w", err)
}
// Set basic fields
c.Enabled = temp.Enabled
c.Type = VectorStoreType(temp.Type)
// Parse the config field based on type
switch c.Type {
case VectorStoreTypeWeaviate:
var weaviateConfig WeaviateConfig
if err := json.Unmarshal(temp.Config, &weaviateConfig); err != nil {
return fmt.Errorf("failed to unmarshal weaviate config: %w", err)
}
c.Config = weaviateConfig
case VectorStoreTypeRedis:
var redisConfig RedisConfig
if err := json.Unmarshal(temp.Config, &redisConfig); err != nil {
return fmt.Errorf("failed to unmarshal redis config: %w", err)
}
// Process env. values for sensitive fields
c.Config = redisConfig
case VectorStoreTypeQdrant:
var qdrantConfig QdrantConfig
if err := json.Unmarshal(temp.Config, &qdrantConfig); err != nil {
return fmt.Errorf("failed to unmarshal qdrant config: %w", err)
}
c.Config = qdrantConfig
case VectorStoreTypePinecone:
var pineconeConfig PineconeConfig
if err := json.Unmarshal(temp.Config, &pineconeConfig); err != nil {
return fmt.Errorf("failed to unmarshal pinecone config: %w", err)
}
c.Config = pineconeConfig
default:
return fmt.Errorf("unknown vector store type: %s", temp.Type)
}
return nil
}
// NewVectorStore returns a new vector store based on the configuration.
func NewVectorStore(ctx context.Context, config *Config, logger schemas.Logger) (VectorStore, error) {
if config == nil {
return nil, fmt.Errorf("config cannot be nil")
}
if !config.Enabled {
return nil, fmt.Errorf("vector store is disabled")
}
switch config.Type {
case VectorStoreTypeWeaviate:
if config.Config == nil {
return nil, fmt.Errorf("weaviate config is required")
}
weaviateConfig, ok := config.Config.(WeaviateConfig)
if !ok {
return nil, fmt.Errorf("invalid weaviate config")
}
return newWeaviateStore(ctx, &weaviateConfig, logger)
case VectorStoreTypeRedis:
if config.Config == nil {
return nil, fmt.Errorf("redis config is required")
}
redisConfig, ok := config.Config.(RedisConfig)
if !ok {
return nil, fmt.Errorf("invalid redis config")
}
return newRedisStore(ctx, redisConfig, logger)
case VectorStoreTypeQdrant:
if config.Config == nil {
return nil, fmt.Errorf("qdrant config is required")
}
qdrantConfig, ok := config.Config.(QdrantConfig)
if !ok {
return nil, fmt.Errorf("invalid qdrant config")
}
return newQdrantStore(ctx, &qdrantConfig, logger)
case VectorStoreTypePinecone:
if config.Config == nil {
return nil, fmt.Errorf("pinecone config is required")
}
pineconeConfig, ok := config.Config.(PineconeConfig)
if !ok {
return nil, fmt.Errorf("invalid pinecone config")
}
return newPineconeStore(ctx, &pineconeConfig, logger)
}
return nil, fmt.Errorf("invalid vector store type: %s", config.Type)
}

View File

@@ -0,0 +1,46 @@
package vectorstore
import (
"math/rand"
"os"
"strconv"
"github.com/google/uuid"
)
// Helper functions
func getEnvWithDefault(key, defaultValue string) string {
if value := os.Getenv(key); value != "" {
return value
}
return defaultValue
}
func getEnvWithDefaultInt(key string, defaultValue int) (int, error) {
if value := os.Getenv(key); value != "" {
return strconv.Atoi(value)
}
return defaultValue, nil
}
func generateUUID() string {
return uuid.New().String()
}
func generateTestEmbedding(dim int) []float32 {
embedding := make([]float32, dim)
for i := range embedding {
embedding[i] = rand.Float32()*2 - 1 // Random values between -1 and 1
}
return embedding
}
func generateSimilarEmbedding(original []float32, similarity float32) []float32 {
similar := make([]float32, len(original))
for i := range similar {
// Add small random noise to create similar but not identical embedding
noise := (rand.Float32()*2 - 1) * (1 - similarity) * 0.1
similar[i] = original[i] + noise
}
return similar
}

View File

@@ -0,0 +1,15 @@
package vectorstore
import (
"context"
"time"
)
// withTimeout adds a timeout to the context if it is set.
func withTimeout(ctx context.Context, timeout time.Duration) (context.Context, context.CancelFunc) {
if timeout > 0 {
return context.WithTimeout(ctx, timeout)
}
// No-op cancel to simplify call sites.
return ctx, func() {}
}

View File

@@ -0,0 +1,637 @@
package vectorstore
import (
"context"
"fmt"
"strings"
"time"
"github.com/maximhq/bifrost/core/schemas"
"github.com/weaviate/weaviate-go-client/v5/weaviate"
"github.com/weaviate/weaviate-go-client/v5/weaviate/auth"
"github.com/weaviate/weaviate-go-client/v5/weaviate/filters"
"github.com/weaviate/weaviate-go-client/v5/weaviate/graphql"
"github.com/weaviate/weaviate-go-client/v5/weaviate/grpc"
"github.com/weaviate/weaviate/entities/models"
)
// Default values for Weaviate vector index configuration
const (
// Default class names (Weaviate prefers PascalCase)
DefaultClassName = "BifrostStore"
)
// WeaviateConfig represents the configuration for the Weaviate vector store.
type WeaviateConfig struct {
// Connection settings
Scheme string `json:"scheme"` // "http" or "https" - REQUIRED
Host *schemas.EnvVar `json:"host"` // "localhost:8080" - REQUIRED
GrpcConfig *WeaviateGrpcConfig `json:"grpc_config,omitempty"` // grpc config for weaviate (optional)
// Authentication settings (optional)
APIKey *schemas.EnvVar `json:"api_key,omitempty"` // API key for authentication
Headers map[string]string `json:"headers,omitempty"` // Additional headers
// Connection settings
Timeout time.Duration `json:"timeout,omitempty"` // Request timeout (optional)
}
type WeaviateGrpcConfig struct {
// Host is the host of the weaviate server (host:port).
// If host is without a port number then the 80 port for insecured and 443 port for secured connections will be used.
Host *schemas.EnvVar `json:"host"`
// Secured is a boolean flag indicating if the connection is secured
Secured bool `json:"secured"`
}
// WeaviateStore represents the Weaviate vector store.
type WeaviateStore struct {
client *weaviate.Client
config *WeaviateConfig
logger schemas.Logger
}
// Ping checks if the Weaviate server is reachable.
func (s *WeaviateStore) Ping(ctx context.Context) error {
_, err := s.client.Misc().MetaGetter().Do(ctx)
return err
}
// Add stores a new object (with or without embedding)
func (s *WeaviateStore) Add(ctx context.Context, className string, id string, embedding []float32, metadata map[string]interface{}) error {
if strings.TrimSpace(id) == "" {
return fmt.Errorf("id is required")
}
obj := &models.Object{
Class: className,
Properties: metadata,
}
var err error
if len(embedding) > 0 {
_, err = s.client.Data().Creator().
WithClassName(className).
WithID(id).
WithProperties(obj.Properties).
WithVector(embedding).
Do(ctx)
} else {
_, err = s.client.Data().Creator().
WithClassName(className).
WithID(id).
WithProperties(obj.Properties).
Do(ctx)
}
return err
}
// GetChunk returns the "metadata" for a single key
func (s *WeaviateStore) GetChunk(ctx context.Context, className string, id string) (SearchResult, error) {
obj, err := s.client.Data().ObjectsGetter().
WithClassName(className).
WithID(id).
Do(ctx)
if err != nil {
return SearchResult{}, err
}
if len(obj) == 0 {
return SearchResult{}, fmt.Errorf("not found: %s", id)
}
props, ok := obj[0].Properties.(map[string]interface{})
if !ok {
return SearchResult{}, fmt.Errorf("invalid properties")
}
return SearchResult{
ID: id,
Score: nil,
Properties: props,
}, nil
}
// GetChunks returns multiple objects by ID
func (s *WeaviateStore) GetChunks(ctx context.Context, className string, ids []string) ([]SearchResult, error) {
out := make([]SearchResult, 0, len(ids))
for _, id := range ids {
obj, err := s.client.Data().ObjectsGetter().
WithClassName(className).
WithID(id).
Do(ctx)
if err != nil {
return nil, err
}
if len(obj) > 0 {
props, ok := obj[0].Properties.(map[string]interface{})
if !ok {
return nil, fmt.Errorf("invalid properties")
}
out = append(out, SearchResult{
ID: id,
Score: nil,
Properties: props,
})
}
}
return out, nil
}
// GetAll with filtering + pagination
func (s *WeaviateStore) GetAll(ctx context.Context, className string, queries []Query, selectFields []string, cursor *string, limit int64) ([]SearchResult, *string, error) {
where := buildWeaviateFilter(queries)
fields := []graphql.Field{
{Name: "_additional", Fields: []graphql.Field{
{Name: "id"},
}},
}
for _, field := range selectFields {
fields = append(fields, graphql.Field{Name: field})
}
search := s.client.GraphQL().Get().
WithClassName(className).
WithLimit(int(limit)).
WithFields(fields...)
if where != nil {
search = search.WithWhere(where)
}
if cursor != nil {
search = search.WithAfter(*cursor)
}
resp, err := search.Do(ctx)
if err != nil {
return nil, nil, err
}
// Check for GraphQL errors
if len(resp.Errors) > 0 {
var errorMsgs []string
for _, err := range resp.Errors {
errorMsgs = append(errorMsgs, err.Message)
}
return nil, nil, fmt.Errorf("graphql errors: %v", errorMsgs)
}
data, ok := resp.Data["Get"].(map[string]interface{})
if !ok {
return nil, nil, fmt.Errorf("invalid graphql response: missing 'Get' key, got: %+v", resp.Data)
}
objsRaw, exists := data[className]
if !exists {
// No results for this class - this is normal, not an error
s.logger.Debug(fmt.Sprintf("No results found for class '%s', available classes: %+v", className, data))
return nil, nil, nil
}
objs, ok := objsRaw.([]interface{})
if !ok {
s.logger.Debug(fmt.Sprintf("Class '%s' exists but data is not an array: %+v", className, objsRaw))
return nil, nil, nil
}
results := make([]SearchResult, 0, len(objs))
var nextCursor *string
for _, o := range objs {
obj, ok := o.(map[string]interface{})
if !ok {
continue
}
// Convert to SearchResult format for consistency
searchResult := SearchResult{
Properties: obj,
}
if additional, ok := obj["_additional"].(map[string]interface{}); ok {
if id, ok := additional["id"].(string); ok {
searchResult.ID = id
nextCursor = &id
}
}
results = append(results, searchResult)
}
return results, nextCursor, nil
}
// GetNearest with explicit filters only
func (s *WeaviateStore) GetNearest(
ctx context.Context,
className string,
vector []float32,
queries []Query,
selectFields []string,
threshold float64,
limit int64,
) ([]SearchResult, error) {
where := buildWeaviateFilter(queries)
fields := []graphql.Field{
{Name: "_additional", Fields: []graphql.Field{
{Name: "id"},
{Name: "certainty"},
}},
}
for _, field := range selectFields {
fields = append(fields, graphql.Field{Name: field})
}
nearVector := s.client.GraphQL().NearVectorArgBuilder().
WithVector(vector).
WithCertainty(float32(threshold))
search := s.client.GraphQL().Get().
WithClassName(className).
WithNearVector(nearVector).
WithLimit(int(limit)).
WithFields(fields...)
if where != nil {
search = search.WithWhere(where)
}
resp, err := search.Do(ctx)
if err != nil {
return nil, err
}
// Check for GraphQL errors
if len(resp.Errors) > 0 {
var errorMsgs []string
for _, err := range resp.Errors {
errorMsgs = append(errorMsgs, err.Message)
}
return nil, fmt.Errorf("graphql errors: %v", errorMsgs)
}
data, ok := resp.Data["Get"].(map[string]interface{})
if !ok {
return nil, fmt.Errorf("invalid graphql response: missing 'Get' key, got: %+v", resp.Data)
}
objsRaw, exists := data[className]
if !exists {
// No results for this class - this is normal, not an error
s.logger.Debug(fmt.Sprintf("No results found for class '%s', available classes: %+v", className, data))
return nil, nil
}
objs, ok := objsRaw.([]interface{})
if !ok {
s.logger.Debug(fmt.Sprintf("Class '%s' exists but data is not an array: %+v", className, objsRaw))
return nil, nil
}
results := make([]SearchResult, 0, len(objs))
for _, o := range objs {
obj, ok := o.(map[string]interface{})
if !ok {
continue
}
additional, ok := obj["_additional"].(map[string]interface{})
if !ok {
continue
}
// Safely extract ID
idRaw, exists := additional["id"]
if !exists || idRaw == nil {
continue
}
id, ok := idRaw.(string)
if !ok {
continue
}
// Safely extract certainty/score with default value
var score float64
if certaintyRaw, exists := additional["certainty"]; exists && certaintyRaw != nil {
switch v := certaintyRaw.(type) {
case float64:
score = v
case float32:
score = float64(v)
case int:
score = float64(v)
case int64:
score = float64(v)
default:
score = 0.0 // Default score if type conversion fails
}
}
results = append(results, SearchResult{
ID: id,
Score: &score,
Properties: obj,
})
}
return results, nil
}
// Delete removes multiple objects by ID
func (s *WeaviateStore) Delete(ctx context.Context, className string, id string) error {
return s.client.Data().Deleter().
WithClassName(className).
WithID(id).
Do(ctx)
}
func (s *WeaviateStore) DeleteAll(ctx context.Context, className string, queries []Query) ([]DeleteResult, error) {
// Check if class exists first to avoid 500 errors from Weaviate
exists, err := s.client.Schema().ClassExistenceChecker().
WithClassName(className).
Do(ctx)
if err != nil {
return nil, fmt.Errorf("failed to check class existence: %w", err)
}
if !exists {
return []DeleteResult{}, nil // Class doesn't exist, nothing to delete
}
where := buildWeaviateFilter(queries)
res, err := s.client.Batch().ObjectsBatchDeleter().
WithClassName(className).
WithWhere(where).
Do(ctx)
if err != nil {
return nil, err
}
// NOTE: Weaviate is returning an empty array for Results.Objects, even on successful deletes.
results := make([]DeleteResult, 0, len(res.Results.Objects))
for _, obj := range res.Results.Objects {
result := DeleteResult{
ID: obj.ID.String(),
}
if obj.Status != nil {
switch *obj.Status {
case "SUCCESS":
result.Status = DeleteStatusSuccess
case "FAILED":
result.Status = DeleteStatusError
if obj.Errors != nil {
var errorMsgs []string
for _, err := range obj.Errors.Error {
errorMsgs = append(errorMsgs, err.Message)
}
result.Error = strings.Join(errorMsgs, ", ")
}
}
}
results = append(results, result)
}
return results, nil
}
func (s *WeaviateStore) Close(ctx context.Context, className string) error {
// nothing to close
return nil
}
// RequiresVectors returns true because Weaviate's HNSW index
// requires vectors for proper object indexing and retrieval.
func (s *WeaviateStore) RequiresVectors() bool {
return true
}
// newWeaviateStore creates a new Weaviate vector store.
func newWeaviateStore(ctx context.Context, config *WeaviateConfig, logger schemas.Logger) (*WeaviateStore, error) {
// Validate required config
if config.Scheme == "" || (config.Host == nil || config.Host.GetValue() == "") {
return nil, fmt.Errorf("weaviate scheme and host are required")
}
// Build client configuration
cfg := weaviate.Config{
Scheme: config.Scheme,
Host: config.Host.GetValue(),
}
// Add authentication if provided
if config.APIKey != nil && config.APIKey.GetValue() != "" {
cfg.AuthConfig = auth.ApiKey{Value: config.APIKey.GetValue()}
}
// Add grpc config if provided
if config.GrpcConfig != nil {
if config.GrpcConfig.Host == nil || config.GrpcConfig.Host.GetValue() == "" {
return nil, fmt.Errorf("weaviate grpc host is required")
}
cfg.GrpcConfig = &grpc.Config{
Host: config.GrpcConfig.Host.GetValue(),
Secured: config.GrpcConfig.Secured,
}
}
// Add custom headers if provided
if len(config.Headers) > 0 {
cfg.Headers = config.Headers
}
// Create client
client, err := weaviate.NewClient(cfg)
if err != nil {
return nil, fmt.Errorf("failed to create weaviate client: %w", err)
}
// Test connection with meta endpoint
testCtx := ctx
if config.Timeout > 0 {
var cancel context.CancelFunc
testCtx, cancel = context.WithTimeout(ctx, config.Timeout)
defer cancel()
}
_, err = client.Misc().MetaGetter().Do(testCtx)
if err != nil {
return nil, fmt.Errorf("failed to connect to weaviate: %w", err)
}
store := &WeaviateStore{
client: client,
config: config,
logger: logger,
}
return store, nil
}
func (s *WeaviateStore) CreateNamespace(ctx context.Context, className string, dimension int, properties map[string]VectorStoreProperties) error {
// Check if class exists
exists, err := s.client.Schema().ClassExistenceChecker().
WithClassName(className).
Do(ctx)
if err != nil {
return fmt.Errorf("failed to check class existence: %w", err)
}
if exists {
return nil // Schema already exists
}
// Create properties
weaviateProperties := []*models.Property{}
for name, prop := range properties {
var dataType []string
switch prop.DataType {
case VectorStorePropertyTypeString:
dataType = []string{"string"}
case VectorStorePropertyTypeInteger:
dataType = []string{"int"}
case VectorStorePropertyTypeBoolean:
dataType = []string{"boolean"}
case VectorStorePropertyTypeStringArray:
dataType = []string{"string[]"}
}
weaviateProperties = append(weaviateProperties, &models.Property{
Name: name,
DataType: dataType,
Description: prop.Description,
})
}
// Create class schema with all fields we need
classSchema := &models.Class{
Class: className,
Properties: weaviateProperties,
VectorIndexType: "hnsw",
Vectorizer: "none", // We provide our own vectors
}
if dimension > 0 {
classSchema.VectorIndexConfig = map[string]interface{}{
"vectorDimensions": dimension,
}
}
err = s.client.Schema().ClassCreator().
WithClass(classSchema).
Do(ctx)
if err != nil {
return fmt.Errorf("failed to create class schema: %w", err)
}
return nil
}
func (s *WeaviateStore) DeleteNamespace(ctx context.Context, className string) error {
exists, err := s.client.Schema().ClassExistenceChecker().
WithClassName(className).
Do(ctx)
if err != nil {
return fmt.Errorf("failed to check class existence: %w", err)
}
if !exists {
return nil // Schema already does not exist
} else {
return s.client.Schema().ClassDeleter().
WithClassName(className).
Do(ctx)
}
}
// buildWeaviateFilter converts []Query → Weaviate WhereFilter
func buildWeaviateFilter(queries []Query) *filters.WhereBuilder {
if len(queries) == 0 {
return nil
}
var operands []*filters.WhereBuilder
for _, q := range queries {
// Convert string operator to filters operator
operator := convertOperator(q.Operator)
fieldPath := strings.Split(q.Field, ".")
whereClause := filters.Where().
WithPath(fieldPath).
WithOperator(operator)
// Special handling for IsNull and IsNotNull
switch q.Operator {
case QueryOperatorIsNull:
whereClause = whereClause.WithValueBoolean(true)
case QueryOperatorIsNotNull:
whereClause = whereClause.WithValueBoolean(false)
default:
// Set value based on type
switch v := q.Value.(type) {
case string:
whereClause = whereClause.WithValueString(v)
case int:
whereClause = whereClause.WithValueInt(int64(v))
case int64:
whereClause = whereClause.WithValueInt(v)
case float32:
whereClause = whereClause.WithValueNumber(float64(v))
case float64:
whereClause = whereClause.WithValueNumber(v)
case bool:
whereClause = whereClause.WithValueBoolean(v)
default:
// Fallback to string conversion
whereClause = whereClause.WithValueString(fmt.Sprintf("%v", v))
}
}
operands = append(operands, whereClause)
}
if len(operands) == 1 {
return operands[0]
}
// Create AND filter for multiple operands
return filters.Where().
WithOperator(filters.And).
WithOperands(operands)
}
// convertOperator converts string operator to filters operator
func convertOperator(op QueryOperator) filters.WhereOperator {
switch op {
case QueryOperatorEqual:
return filters.Equal
case QueryOperatorNotEqual:
return filters.NotEqual
case QueryOperatorLessThan:
return filters.LessThan
case QueryOperatorLessThanOrEqual:
return filters.LessThanEqual
case QueryOperatorGreaterThan:
return filters.GreaterThan
case QueryOperatorGreaterThanOrEqual:
return filters.GreaterThanEqual
case QueryOperatorLike:
return filters.Like
case QueryOperatorContainsAny:
return filters.ContainsAny
case QueryOperatorContainsAll:
return filters.ContainsAll
case QueryOperatorIsNull:
return filters.IsNull
case QueryOperatorIsNotNull: // IsNotNull is not supported by Weaviate, so we use IsNull and negate it.
return filters.IsNull
default:
// Default to Equal if unknown
return filters.Equal
}
}

View File

@@ -0,0 +1,812 @@
package vectorstore
import (
"context"
"testing"
"time"
bifrost "github.com/maximhq/bifrost/core"
"github.com/maximhq/bifrost/core/schemas"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/weaviate/weaviate-go-client/v5/weaviate/filters"
"github.com/weaviate/weaviate/entities/models"
)
// Test constants
const (
TestTimeout = 30 * time.Second
TestClassName = "TestWeaviate"
TestEmbeddingDim = 384
DefaultTestScheme = "http"
DefaultTestHost = "localhost:9000"
DefaultTestTimeout = 10 * time.Second
)
// TestSetup provides common test infrastructure
type TestSetup struct {
Store *WeaviateStore
Logger schemas.Logger
Config WeaviateConfig
ctx context.Context
cancel context.CancelFunc
}
// NewTestSetup creates a test setup with environment-driven configuration
func NewTestSetup(t *testing.T) *TestSetup {
// Get configuration from environment variables
scheme := getEnvWithDefault("WEAVIATE_SCHEME", DefaultTestScheme)
host := schemas.NewEnvVar(getEnvWithDefault("WEAVIATE_HOST", DefaultTestHost))
timeoutStr := getEnvWithDefault("WEAVIATE_TIMEOUT", "10s")
timeout, err := time.ParseDuration(timeoutStr)
if err != nil {
timeout = DefaultTestTimeout
}
config := WeaviateConfig{
Scheme: scheme,
Host: host,
APIKey: schemas.NewEnvVar("env.WEAVIATE_API_KEY"),
Timeout: timeout,
}
logger := bifrost.NewDefaultLogger(schemas.LogLevelInfo)
ctx, cancel := context.WithTimeout(context.Background(), TestTimeout)
store, err := newWeaviateStore(ctx, &config, logger)
if err != nil {
cancel()
t.Fatalf("Failed to create Weaviate store: %v", err)
}
setup := &TestSetup{
Store: store,
Logger: logger,
Config: config,
ctx: ctx,
cancel: cancel,
}
// Ensure class exists for integration tests
if !testing.Short() {
setup.ensureClassExists(t)
}
return setup
}
// Cleanup cleans up test resources
func (ts *TestSetup) Cleanup(t *testing.T) {
defer ts.cancel()
if !testing.Short() {
// Clean up test data
ts.cleanupTestData(t)
}
if err := ts.Store.Close(ts.ctx, TestClassName); err != nil {
t.Logf("Warning: Failed to close store: %v", err)
}
}
// ensureClassExists creates the test class in Weaviate
func (ts *TestSetup) ensureClassExists(t *testing.T) {
// Try to get class schema first
exists, err := ts.Store.client.Schema().ClassGetter().
WithClassName(TestClassName).
Do(ts.ctx)
if err == nil && exists != nil {
t.Logf("Class %s already exists", TestClassName)
return
}
// Create class with minimal schema - let Weaviate auto-create properties
class := &models.Class{
Class: TestClassName,
Properties: []*models.Property{
{
Name: "key",
DataType: []string{"text"},
},
{
Name: "test_type",
DataType: []string{"text"},
},
{
Name: "size",
DataType: []string{"int"},
},
{
Name: "public",
DataType: []string{"boolean"},
},
},
VectorIndexConfig: map[string]interface{}{
"distance": "cosine",
},
}
err = ts.Store.client.Schema().ClassCreator().
WithClass(class).
Do(ts.ctx)
if err != nil {
t.Logf("Warning: Failed to create test class %s: %v", TestClassName, err)
t.Logf("This might be due to auto-schema creation. Continuing...")
} else {
t.Logf("Created test class: %s", TestClassName)
}
}
// cleanupTestData removes all test objects from the class
func (ts *TestSetup) cleanupTestData(t *testing.T) {
// Delete all objects in the test class
allTestKeys, _, err := ts.Store.GetAll(ts.ctx, TestClassName, []Query{}, []string{}, nil, 1000)
if err != nil {
t.Logf("Warning: Failed to get all test keys: %v", err)
return
}
for _, key := range allTestKeys {
err := ts.Store.Delete(ts.ctx, TestClassName, key.ID)
if err != nil {
t.Logf("Warning: Failed to delete test key %s: %v", key.ID, err)
}
}
t.Logf("Cleaned up test class: %s", TestClassName)
}
// ============================================================================
// UNIT TESTS
// ============================================================================
func TestWeaviateConfig_Validation(t *testing.T) {
logger := bifrost.NewDefaultLogger(schemas.LogLevelInfo)
ctx := context.Background()
tests := []struct {
name string
config WeaviateConfig
expectError bool
errorMsg string
}{
{
name: "valid config",
config: WeaviateConfig{
Scheme: "http",
Host: schemas.NewEnvVar("localhost:8080"),
},
expectError: false,
},
{
name: "missing scheme",
config: WeaviateConfig{
Host: schemas.NewEnvVar("localhost:8080"),
},
expectError: true,
errorMsg: "scheme and host are required",
},
{
name: "missing host",
config: WeaviateConfig{
Scheme: "http",
},
expectError: true,
errorMsg: "scheme and host are required",
},
{
name: "with api key",
config: WeaviateConfig{
Scheme: "https",
Host: schemas.NewEnvVar("cluster.weaviate.network"),
APIKey: schemas.NewEnvVar("test-key"),
},
expectError: false,
},
{
name: "with custom headers",
config: WeaviateConfig{
Scheme: "http",
Host: schemas.NewEnvVar("localhost:8080"),
Headers: map[string]string{
"Custom-Header": "value",
},
},
expectError: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
store, err := newWeaviateStore(ctx, &tt.config, logger)
if tt.expectError {
assert.Error(t, err)
assert.Nil(t, store)
if tt.errorMsg != "" {
assert.Contains(t, err.Error(), tt.errorMsg)
}
} else {
// Note: This will fail with connection error in unit tests
// but should pass config validation
assert.Nil(t, store) // Expected due to no real Weaviate instance
assert.Error(t, err) // Connection error expected
}
})
}
}
func TestDefaultClassName(t *testing.T) {
config := WeaviateConfig{
Scheme: "http",
Host: schemas.NewEnvVar("localhost:8080"),
}
// This will fail to connect but should set default class name
logger := bifrost.NewDefaultLogger(schemas.LogLevelInfo)
_, err := newWeaviateStore(context.Background(), &config, logger)
// Should fail with connection error, but we can't test the default class name
// without mocking the client, which would be more complex
assert.Error(t, err)
}
func TestBuildWeaviateFilter(t *testing.T) {
tests := []struct {
name string
queries []Query
expected *filters.WhereBuilder // We'll test the structure, not exact equality
isNil bool
}{
{
name: "empty queries",
queries: []Query{},
expected: nil,
isNil: true,
},
{
name: "single string query",
queries: []Query{
{Field: "category", Operator: QueryOperatorEqual, Value: "tech"},
},
isNil: false,
},
{
name: "single numeric query",
queries: []Query{
{Field: "size", Operator: QueryOperatorGreaterThan, Value: 1000},
},
isNil: false,
},
{
name: "multiple queries (AND)",
queries: []Query{
{Field: "category", Operator: QueryOperatorEqual, Value: "tech"},
{Field: "public", Operator: QueryOperatorEqual, Value: true},
},
isNil: false,
},
{
name: "mixed types",
queries: []Query{
{Field: "name", Operator: QueryOperatorLike, Value: "test%"},
{Field: "count", Operator: QueryOperatorLessThan, Value: int64(100)},
{Field: "active", Operator: QueryOperatorEqual, Value: true},
{Field: "score", Operator: QueryOperatorGreaterThanOrEqual, Value: 95.5},
},
isNil: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := buildWeaviateFilter(tt.queries)
if tt.isNil {
assert.Nil(t, result)
} else {
assert.NotNil(t, result)
// We can't easily test the internal structure without reflection
// or implementing String() methods, but we verify it's not nil
}
})
}
}
func TestConvertOperator(t *testing.T) {
tests := []struct {
input QueryOperator
expected filters.WhereOperator
}{
{QueryOperatorEqual, filters.Equal},
{QueryOperatorNotEqual, filters.NotEqual},
{QueryOperatorLessThan, filters.LessThan},
{QueryOperatorLessThanOrEqual, filters.LessThanEqual},
{QueryOperatorGreaterThan, filters.GreaterThan},
{QueryOperatorGreaterThanOrEqual, filters.GreaterThanEqual},
{QueryOperatorLike, filters.Like},
{QueryOperatorContainsAny, filters.ContainsAny},
{QueryOperatorContainsAll, filters.ContainsAll},
{QueryOperatorIsNull, filters.IsNull},
{QueryOperatorIsNotNull, filters.IsNull},
}
for _, tt := range tests {
t.Run(string(tt.input), func(t *testing.T) {
result := convertOperator(tt.input)
assert.Equal(t, tt.expected, result)
})
}
}
// ============================================================================
// INTEGRATION TESTS (require real Weaviate instance)
// ============================================================================
func TestWeaviateStore_Integration(t *testing.T) {
if testing.Short() {
t.Skip("Skipping integration tests in short mode")
}
setup := NewTestSetup(t)
defer setup.Cleanup(t)
t.Run("Add and GetChunk", func(t *testing.T) {
testKey := generateUUID()
embedding := generateTestEmbedding(TestEmbeddingDim)
metadata := map[string]interface{}{
"type": "document",
"size": 1024,
"public": true,
}
// Add object
err := setup.Store.Add(setup.ctx, TestClassName, testKey, embedding, metadata)
require.NoError(t, err)
// Small delay to ensure consistency
time.Sleep(100 * time.Millisecond)
// Get single chunk
result, err := setup.Store.GetChunk(setup.ctx, TestClassName, testKey)
require.NoError(t, err)
assert.NotEmpty(t, result)
assert.Equal(t, "document", result.Properties["type"]) // Should contain metadata
})
t.Run("Add without embedding", func(t *testing.T) {
testKey := generateUUID()
metadata := map[string]interface{}{
"type": "metadata-only",
}
// Add object without embedding
err := setup.Store.Add(setup.ctx, TestClassName, testKey, nil, metadata)
require.NoError(t, err)
time.Sleep(100 * time.Millisecond)
// Retrieve it
result, err := setup.Store.GetChunk(setup.ctx, TestClassName, testKey)
require.NoError(t, err)
assert.Equal(t, "metadata-only", result.Properties["type"])
})
}
func TestWeaviateStore_FilteringScenarios(t *testing.T) {
if testing.Short() {
t.Skip("Skipping integration tests in short mode")
}
setup := NewTestSetup(t)
defer setup.Cleanup(t)
// Setup test data for filtering scenarios
testData := []struct {
key string
metadata map[string]interface{}
}{
{
generateUUID(),
map[string]interface{}{
"type": "pdf",
"size": 1024,
"public": true,
"author": "alice",
},
},
{
generateUUID(),
map[string]interface{}{
"type": "docx",
"size": 2048,
"public": false,
"author": "bob",
},
},
{
generateUUID(),
map[string]interface{}{
"type": "pdf",
"size": 512,
"public": true,
"author": "alice",
},
},
{
generateUUID(),
map[string]interface{}{
"type": "txt",
"size": 256,
"public": true,
"author": "charlie",
},
},
}
filterFields := []string{"type", "size", "public", "author"}
// Add all test data
for _, item := range testData {
embedding := generateTestEmbedding(TestEmbeddingDim)
err := setup.Store.Add(setup.ctx, TestClassName, item.key, embedding, item.metadata)
require.NoError(t, err)
}
time.Sleep(500 * time.Millisecond) // Wait for consistency
t.Run("Filter by numeric comparison", func(t *testing.T) {
queries := []Query{
{Field: "size", Operator: "GreaterThan", Value: 1000},
}
results, _, err := setup.Store.GetAll(setup.ctx, TestClassName, queries, filterFields, nil, 10)
require.NoError(t, err)
assert.Len(t, results, 2) // doc1 (1024) and doc2 (2048)
})
t.Run("Filter by boolean", func(t *testing.T) {
queries := []Query{
{Field: "public", Operator: "Equal", Value: true},
}
results, _, err := setup.Store.GetAll(setup.ctx, TestClassName, queries, filterFields, nil, 10)
require.NoError(t, err)
assert.Len(t, results, 3) // doc1, doc3, doc4
})
t.Run("Multiple filters (AND)", func(t *testing.T) {
queries := []Query{
{Field: "type", Operator: "Equal", Value: "pdf"},
{Field: "public", Operator: "Equal", Value: true},
}
results, _, err := setup.Store.GetAll(setup.ctx, TestClassName, queries, filterFields, nil, 10)
require.NoError(t, err)
assert.Len(t, results, 2) // doc1 and doc3
})
t.Run("Complex multi-condition filter", func(t *testing.T) {
queries := []Query{
{Field: "author", Operator: "Equal", Value: "alice"},
{Field: "size", Operator: "LessThan", Value: 2000},
{Field: "public", Operator: "Equal", Value: true},
}
results, _, err := setup.Store.GetAll(setup.ctx, TestClassName, queries, filterFields, nil, 10)
require.NoError(t, err)
assert.Len(t, results, 2) // doc1 and doc3 (both by alice, < 2000 size, public)
})
t.Run("Pagination test", func(t *testing.T) {
// Test with limit of 2
results, cursor, err := setup.Store.GetAll(setup.ctx, TestClassName, nil, filterFields, nil, 2)
require.NoError(t, err)
assert.Len(t, results, 2)
if cursor != nil {
// Get next page
nextResults, _, err := setup.Store.GetAll(setup.ctx, TestClassName, nil, filterFields, cursor, 2)
require.NoError(t, err)
assert.LessOrEqual(t, len(nextResults), 2)
t.Logf("First page: %d results, Next page: %d results", len(results), len(nextResults))
}
})
}
func TestWeaviateStore_CompleteUseCases(t *testing.T) {
if testing.Short() {
t.Skip("Skipping integration tests in short mode")
}
setup := NewTestSetup(t)
defer setup.Cleanup(t)
t.Run("Document Storage & Retrieval Scenario", func(t *testing.T) {
// Add documents with different types
documents := []struct {
key string
embedding []float32
metadata map[string]interface{}
}{
{
generateUUID(),
generateTestEmbedding(TestEmbeddingDim),
map[string]interface{}{"type": "pdf", "size": 1024, "public": true},
},
{
generateUUID(),
generateTestEmbedding(TestEmbeddingDim),
map[string]interface{}{"type": "docx", "size": 2048, "public": false},
},
{
generateUUID(),
generateTestEmbedding(TestEmbeddingDim),
map[string]interface{}{"type": "pdf", "size": 512, "public": true},
},
}
filterFields := []string{"type", "size", "public", "author"}
for _, doc := range documents {
err := setup.Store.Add(setup.ctx, TestClassName, doc.key, doc.embedding, doc.metadata)
require.NoError(t, err)
}
time.Sleep(300 * time.Millisecond)
// Test various retrieval patterns
// Get PDF documents
pdfQuery := []Query{{Field: "type", Operator: "Equal", Value: "pdf"}}
results, _, err := setup.Store.GetAll(setup.ctx, TestClassName, pdfQuery, filterFields, nil, 10)
require.NoError(t, err)
assert.Len(t, results, 2) // doc1, doc3
// Get large documents (size > 1000)
sizeQuery := []Query{{Field: "size", Operator: "GreaterThan", Value: 1000}}
results, _, err = setup.Store.GetAll(setup.ctx, TestClassName, sizeQuery, filterFields, nil, 10)
require.NoError(t, err)
assert.Len(t, results, 2) // doc1, doc2
// Get public PDFs
combinedQuery := []Query{
{Field: "public", Operator: "Equal", Value: true},
{Field: "type", Operator: "Equal", Value: "pdf"},
}
results, _, err = setup.Store.GetAll(setup.ctx, TestClassName, combinedQuery, filterFields, nil, 10)
require.NoError(t, err)
assert.Len(t, results, 2) // doc1, doc3
// Vector similarity search
queryEmbedding := documents[0].embedding // Similar to doc1
vectorResults, err := setup.Store.GetNearest(setup.ctx, TestClassName, queryEmbedding, nil, filterFields, 0.8, 10)
require.NoError(t, err)
assert.GreaterOrEqual(t, len(vectorResults), 1)
})
t.Run("User Content Management Scenario", func(t *testing.T) {
// Add user content with metadata
userContent := []struct {
key string
embedding []float32
metadata map[string]interface{}
}{
{
generateUUID(),
generateTestEmbedding(TestEmbeddingDim),
map[string]interface{}{"user": "alice", "lang": "en", "category": "tech"},
},
{
generateUUID(),
generateTestEmbedding(TestEmbeddingDim),
map[string]interface{}{"user": "bob", "lang": "es", "category": "tech"},
},
{
generateUUID(),
generateTestEmbedding(TestEmbeddingDim),
map[string]interface{}{"user": "alice", "lang": "en", "category": "sports"},
},
}
filterFields := []string{"user", "lang", "category"}
for _, content := range userContent {
err := setup.Store.Add(setup.ctx, TestClassName, content.key, content.embedding, content.metadata)
require.NoError(t, err)
}
time.Sleep(300 * time.Millisecond)
// Test user-specific filtering
aliceQuery := []Query{{Field: "user", Operator: "Equal", Value: "alice"}}
results, _, err := setup.Store.GetAll(setup.ctx, TestClassName, aliceQuery, filterFields, nil, 10)
require.NoError(t, err)
assert.Len(t, results, 2) // Alice's content
// English tech content
techEnQuery := []Query{
{Field: "lang", Operator: "Equal", Value: "en"},
{Field: "category", Operator: "Equal", Value: "tech"},
}
results, _, err = setup.Store.GetAll(setup.ctx, TestClassName, techEnQuery, filterFields, nil, 10)
require.NoError(t, err)
assert.Len(t, results, 1) // user1_content
// Alice's similar content (semantic search with user filter)
aliceFilter := []Query{{Field: "user", Operator: "Equal", Value: "alice"}}
queryEmbedding := userContent[0].embedding
vectorResults, err := setup.Store.GetNearest(setup.ctx, TestClassName, queryEmbedding, aliceFilter, filterFields, 0.1, 10)
require.NoError(t, err)
assert.Len(t, vectorResults, 2) // Both of Alice's content
})
t.Run("Semantic Cache-like Workflow", func(t *testing.T) {
// Add request-response pairs with parameters
cacheEntries := []struct {
key string
embedding []float32
metadata map[string]interface{}
}{
{
generateUUID(),
generateTestEmbedding(TestEmbeddingDim),
map[string]interface{}{
"request_hash": "abc123",
"user": "u1",
"lang": "en",
"response": "answer1",
},
},
{
generateUUID(),
generateTestEmbedding(TestEmbeddingDim),
map[string]interface{}{
"request_hash": "def456",
"user": "u1",
"lang": "es",
"response": "answer2",
},
},
}
filterFields := []string{"request_hash", "user", "lang", "response"}
for _, entry := range cacheEntries {
err := setup.Store.Add(setup.ctx, TestClassName, entry.key, entry.embedding, entry.metadata)
require.NoError(t, err)
}
time.Sleep(300 * time.Millisecond)
// Test hash-based direct retrieval (exact match)
hashQuery := []Query{{Field: "request_hash", Operator: "Equal", Value: "abc123"}}
results, _, err := setup.Store.GetAll(setup.ctx, TestClassName, hashQuery, filterFields, nil, 10)
require.NoError(t, err)
assert.Len(t, results, 1)
// Test semantic search with user and language filters
userLangFilter := []Query{
{Field: "user", Operator: "Equal", Value: "u1"},
{Field: "lang", Operator: "Equal", Value: "en"},
}
similarEmbedding := generateSimilarEmbedding(cacheEntries[0].embedding, 0.9)
vectorResults, err := setup.Store.GetNearest(setup.ctx, TestClassName, similarEmbedding, userLangFilter, filterFields, 0.7, 10)
require.NoError(t, err)
assert.Len(t, vectorResults, 1) // Should find English content for u1
})
}
// ============================================================================
// INTERFACE COMPLIANCE TESTS
// ============================================================================
func TestWeaviateStore_InterfaceCompliance(t *testing.T) {
// Verify that WeaviateStore implements VectorStore interface
var _ VectorStore = (*WeaviateStore)(nil)
}
func TestVectorStoreFactory_Weaviate(t *testing.T) {
if testing.Short() {
t.Skip("Skipping integration tests in short mode")
}
logger := bifrost.NewDefaultLogger(schemas.LogLevelInfo)
config := &Config{
Enabled: true,
Type: VectorStoreTypeWeaviate,
Config: WeaviateConfig{
Scheme: getEnvWithDefault("WEAVIATE_SCHEME", DefaultTestScheme),
Host: schemas.NewEnvVar(getEnvWithDefault("WEAVIATE_HOST", DefaultTestHost)),
APIKey: schemas.NewEnvVar("env.WEAVIATE_API_KEY"),
},
}
store, err := NewVectorStore(context.Background(), config, logger)
if err != nil {
t.Skipf("Could not create Weaviate store: %v", err)
}
defer store.Close(context.Background(), TestClassName)
// Verify it's actually a WeaviateStore
weaviateStore, ok := store.(*WeaviateStore)
assert.True(t, ok)
assert.NotNil(t, weaviateStore)
}
func TestWeaviateStore_NamespaceDimensionHandling(t *testing.T) {
if testing.Short() {
t.Skip("Skipping integration tests in short mode")
}
setup := NewTestSetup(t)
defer setup.Cleanup(t)
testClassName := "TestDimensionHandling"
t.Run("Recreate class with different dimension should not crash", func(t *testing.T) {
properties := map[string]VectorStoreProperties{
"type": {DataType: VectorStorePropertyTypeString},
"test": {DataType: VectorStorePropertyTypeString},
}
// Step 1: Create class with dimension 512
err := setup.Store.CreateNamespace(setup.ctx, testClassName, 512, properties)
require.NoError(t, err)
// Add a document with 512-dimensional embedding
testKey512 := generateUUID()
embedding512 := generateTestEmbedding(512)
metadata := map[string]interface{}{
"type": "test_doc",
"test": "dimension_512",
}
err = setup.Store.Add(setup.ctx, testClassName, testKey512, embedding512, metadata)
require.NoError(t, err)
// Verify it was added
result, err := setup.Store.GetChunk(setup.ctx, testClassName, testKey512)
require.NoError(t, err)
assert.Equal(t, "dimension_512", result.Properties["test"])
// Step 2: Delete the class/namespace
err = setup.Store.DeleteNamespace(setup.ctx, testClassName)
require.NoError(t, err)
// Step 3: Create class with same name but different dimension - should not crash
err = setup.Store.CreateNamespace(setup.ctx, testClassName, 1024, properties)
require.NoError(t, err)
// Add a document with 1024-dimensional embedding
testKey1024 := generateUUID()
embedding1024 := generateTestEmbedding(1024)
metadata1024 := map[string]interface{}{
"type": "test_doc",
"test": "dimension_1024",
}
err = setup.Store.Add(setup.ctx, testClassName, testKey1024, embedding1024, metadata1024)
require.NoError(t, err)
// Verify new document exists
result, err = setup.Store.GetChunk(setup.ctx, testClassName, testKey1024)
require.NoError(t, err)
assert.Equal(t, "dimension_1024", result.Properties["test"])
// Verify vector search works with new dimension
vectorResults, err := setup.Store.GetNearest(setup.ctx, testClassName, embedding1024, nil, []string{"type", "test"}, 0.8, 10)
require.NoError(t, err)
assert.GreaterOrEqual(t, len(vectorResults), 1)
assert.NotNil(t, vectorResults[0].Score)
// Cleanup
err = setup.Store.DeleteNamespace(setup.ctx, testClassName)
if err != nil {
t.Logf("Warning: Failed to cleanup class: %v", err)
}
})
}