1754 lines
48 KiB
Go
1754 lines
48 KiB
Go
package vectorstore
|
|
|
|
import (
|
|
"context"
|
|
"crypto/tls"
|
|
"crypto/x509"
|
|
"encoding/binary"
|
|
"encoding/json"
|
|
"fmt"
|
|
"math"
|
|
"sort"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/maximhq/bifrost/core/schemas"
|
|
"github.com/redis/go-redis/v9"
|
|
)
|
|
|
|
const (
|
|
// BatchLimit is the default limit used for pagination and batch operations
|
|
BatchLimit = 100
|
|
// RedisMaxSearchResults is the maximum number of results Redis Search returns in a single query.
|
|
// This is the default MAXSEARCHRESULTS configuration in Redis Search.
|
|
RedisMaxSearchResults = 10000
|
|
)
|
|
|
|
type RedisConfig struct {
|
|
// Connection settings
|
|
Addr *schemas.EnvVar `json:"addr"` // Redis server address (host:port) - REQUIRED
|
|
Username *schemas.EnvVar `json:"username,omitempty"` // Username for Redis AUTH (optional)
|
|
Password *schemas.EnvVar `json:"password,omitempty"` // Password for Redis AUTH (optional)
|
|
DB *schemas.EnvVar `json:"db,omitempty"` // Redis database number (default: 0)
|
|
|
|
// TLS settings
|
|
UseTLS *schemas.EnvVar `json:"use_tls,omitempty"` // Enable TLS for connection (default: false)
|
|
InsecureSkipVerify *schemas.EnvVar `json:"insecure_skip_verify,omitempty"` // Skip TLS cert verification (default: false)
|
|
CACertPEM *schemas.EnvVar `json:"ca_cert_pem,omitempty"` // PEM-encoded CA certificate to trust for Redis/Valkey TLS
|
|
|
|
// Cluster mode
|
|
ClusterMode *schemas.EnvVar `json:"cluster_mode,omitempty"` // Use Redis Cluster client (default: false)
|
|
|
|
// Connection pool and timeout settings (passed directly to Redis client)
|
|
PoolSize int `json:"pool_size,omitempty"` // Maximum number of socket connections (optional)
|
|
MaxActiveConns int `json:"max_active_conns,omitempty"` // Maximum number of active connections (optional)
|
|
MinIdleConns int `json:"min_idle_conns,omitempty"` // Minimum number of idle connections (optional)
|
|
MaxIdleConns int `json:"max_idle_conns,omitempty"` // Maximum number of idle connections (optional)
|
|
ConnMaxLifetime time.Duration `json:"conn_max_lifetime,omitempty"` // Connection maximum lifetime (optional)
|
|
ConnMaxIdleTime time.Duration `json:"conn_max_idle_time,omitempty"` // Connection maximum idle time (optional)
|
|
DialTimeout time.Duration `json:"dial_timeout,omitempty"` // Timeout for socket connection (optional)
|
|
ReadTimeout time.Duration `json:"read_timeout,omitempty"` // Timeout for socket reads (optional)
|
|
WriteTimeout time.Duration `json:"write_timeout,omitempty"` // Timeout for socket writes (optional)
|
|
ContextTimeout time.Duration `json:"context_timeout,omitempty"` // Timeout for Redis operations (optional)
|
|
}
|
|
|
|
// RedisStore represents the Redis vector store.
|
|
type RedisStore struct {
|
|
client redis.UniversalClient
|
|
config RedisConfig
|
|
logger schemas.Logger
|
|
|
|
namespaceFieldTypesMu sync.RWMutex
|
|
namespaceFieldTypes map[string]map[string]VectorStorePropertyType
|
|
}
|
|
|
|
// Ping checks if the Redis server is reachable.
|
|
func (s *RedisStore) Ping(ctx context.Context) error {
|
|
return s.client.Ping(ctx).Err()
|
|
}
|
|
|
|
// CreateNamespace creates a new namespace in the Redis vector store.
|
|
func (s *RedisStore) CreateNamespace(ctx context.Context, namespace string, dimension int, properties map[string]VectorStoreProperties) error {
|
|
ctx, cancel := withTimeout(ctx, s.config.ContextTimeout)
|
|
defer cancel()
|
|
|
|
// Check if index already exists
|
|
infoResult := s.client.Do(ctx, "FT.INFO", namespace)
|
|
if infoResult.Err() == nil {
|
|
s.cacheNamespaceFieldTypes(namespace, properties)
|
|
return nil // Index already exists
|
|
}
|
|
if err := infoResult.Err(); err != nil && strings.Contains(strings.ToLower(err.Error()), "unknown command") {
|
|
return fmt.Errorf("search module not available: please use Redis Stack or a Valkey bundle with search support (FT.* commands required). original error: %w", err)
|
|
}
|
|
|
|
// Extract metadata field names from properties
|
|
var metadataFields []string
|
|
for fieldName := range properties {
|
|
metadataFields = append(metadataFields, fieldName)
|
|
}
|
|
|
|
// Create index with VECTOR field + metadata fields
|
|
keyPrefix := fmt.Sprintf("%s:", namespace)
|
|
|
|
if dimension <= 0 {
|
|
return fmt.Errorf("redis vector index %q: dimension must be > 0 (got %d)", namespace, dimension)
|
|
}
|
|
|
|
args := []interface{}{
|
|
"FT.CREATE", namespace,
|
|
"ON", "HASH",
|
|
"PREFIX", "1", keyPrefix,
|
|
"SCHEMA",
|
|
// Native vector field with HNSW algorithm
|
|
"embedding", "VECTOR", "HNSW", "6",
|
|
"TYPE", "FLOAT32",
|
|
"DIM", dimension,
|
|
"DISTANCE_METRIC", "COSINE",
|
|
}
|
|
|
|
// Add all metadata fields as TEXT with exact matching
|
|
// All values are converted to strings for consistent searching
|
|
for _, field := range metadataFields {
|
|
// Detect field type from VectorStoreProperties
|
|
prop := properties[field]
|
|
switch prop.DataType {
|
|
case VectorStorePropertyTypeInteger:
|
|
args = append(args, field, "NUMERIC")
|
|
default:
|
|
args = append(args, field, "TAG")
|
|
}
|
|
}
|
|
|
|
// Create the index
|
|
if err := s.client.Do(ctx, args...).Err(); err != nil {
|
|
return fmt.Errorf("failed to create semantic vector index %s: %w", namespace, err)
|
|
}
|
|
|
|
s.cacheNamespaceFieldTypes(namespace, properties)
|
|
return nil
|
|
}
|
|
|
|
// GetChunk retrieves a chunk from the Redis vector store.
|
|
func (s *RedisStore) GetChunk(ctx context.Context, namespace string, id string) (SearchResult, error) {
|
|
ctx, cancel := withTimeout(ctx, s.config.ContextTimeout)
|
|
defer cancel()
|
|
|
|
if strings.TrimSpace(id) == "" {
|
|
return SearchResult{}, fmt.Errorf("id is required")
|
|
}
|
|
|
|
// Create key with namespace
|
|
key := buildKey(namespace, id)
|
|
|
|
// Get all fields from the hash
|
|
result := s.client.HGetAll(ctx, key)
|
|
if result.Err() != nil {
|
|
return SearchResult{}, fmt.Errorf("failed to get chunk: %w", result.Err())
|
|
}
|
|
|
|
fields := result.Val()
|
|
if len(fields) == 0 {
|
|
return SearchResult{}, fmt.Errorf("chunk not found: %s", id)
|
|
}
|
|
|
|
// Build SearchResult
|
|
searchResult := SearchResult{
|
|
ID: id,
|
|
Properties: make(map[string]interface{}),
|
|
}
|
|
|
|
// Parse fields
|
|
for k, v := range fields {
|
|
searchResult.Properties[k] = v
|
|
}
|
|
|
|
return searchResult, nil
|
|
}
|
|
|
|
// GetChunks retrieves multiple chunks from the Redis vector store.
|
|
func (s *RedisStore) GetChunks(ctx context.Context, namespace string, ids []string) ([]SearchResult, error) {
|
|
ctx, cancel := withTimeout(ctx, s.config.ContextTimeout)
|
|
defer cancel()
|
|
|
|
if len(ids) == 0 {
|
|
return []SearchResult{}, nil
|
|
}
|
|
|
|
// Create keys with namespace
|
|
keys := make([]string, len(ids))
|
|
for i, id := range ids {
|
|
if strings.TrimSpace(id) == "" {
|
|
return nil, fmt.Errorf("id cannot be empty at index %d", i)
|
|
}
|
|
keys[i] = buildKey(namespace, id)
|
|
}
|
|
|
|
// Use pipeline for efficient batch retrieval
|
|
pipe := s.client.Pipeline()
|
|
cmds := make([]*redis.MapStringStringCmd, len(keys))
|
|
|
|
for i, key := range keys {
|
|
cmds[i] = pipe.HGetAll(ctx, key)
|
|
}
|
|
|
|
// Execute pipeline
|
|
_, err := pipe.Exec(ctx)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to execute pipeline: %w", err)
|
|
}
|
|
|
|
// Process results
|
|
var results []SearchResult
|
|
for i, cmd := range cmds {
|
|
if cmd.Err() != nil {
|
|
// Log error but continue with other results
|
|
s.logger.Debug(fmt.Sprintf("failed to get chunk %s: %v", ids[i], cmd.Err()))
|
|
continue
|
|
}
|
|
|
|
fields := cmd.Val()
|
|
if len(fields) == 0 {
|
|
// Chunk not found, skip
|
|
continue
|
|
}
|
|
|
|
// Build SearchResult
|
|
searchResult := SearchResult{
|
|
ID: ids[i],
|
|
Properties: make(map[string]interface{}),
|
|
}
|
|
|
|
// Parse fields
|
|
for k, v := range fields {
|
|
searchResult.Properties[k] = v
|
|
}
|
|
|
|
results = append(results, searchResult)
|
|
}
|
|
|
|
return results, nil
|
|
}
|
|
|
|
// GetAll retrieves all chunks from the Redis vector store.
|
|
func (s *RedisStore) GetAll(ctx context.Context, namespace string, queries []Query, selectFields []string, cursor *string, limit int64) ([]SearchResult, *string, error) {
|
|
ctx, cancel := withTimeout(ctx, s.config.ContextTimeout)
|
|
defer cancel()
|
|
|
|
// Set default limit if not provided
|
|
if limit < 0 {
|
|
limit = BatchLimit
|
|
}
|
|
|
|
// Build Redis query from the provided queries
|
|
redisQuery := buildRedisQuery(queries, s.getNamespaceFieldTypes(namespace))
|
|
|
|
// When limit=0 (get all), use internal pagination to avoid exceeding Redis MAXSEARCHRESULTS
|
|
if limit == 0 {
|
|
return s.getAllWithPagination(ctx, namespace, redisQuery, queries, selectFields)
|
|
}
|
|
|
|
// For explicit limit, cap to Redis maximum and use single query with cursor support
|
|
searchLimit := limit
|
|
if searchLimit > RedisMaxSearchResults {
|
|
searchLimit = RedisMaxSearchResults
|
|
}
|
|
|
|
// Add OFFSET for pagination if cursor is provided
|
|
offset, err := parseOffsetCursor(cursor)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
results, err := s.executeSearch(ctx, namespace, redisQuery, queries, selectFields, offset, int(searchLimit))
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
// Implement cursor-based pagination using OFFSET
|
|
var nextCursor *string = nil
|
|
if cursor != nil && *cursor != "" {
|
|
if len(results) == int(limit) && limit > 0 {
|
|
offset, err := strconv.ParseInt(*cursor, 10, 64)
|
|
if err == nil {
|
|
nextOffset := offset + limit
|
|
nextCursorStr := strconv.FormatInt(nextOffset, 10)
|
|
nextCursor = &nextCursorStr
|
|
}
|
|
}
|
|
} else if len(results) == int(limit) && limit > 0 {
|
|
nextCursorStr := strconv.FormatInt(limit, 10)
|
|
nextCursor = &nextCursorStr
|
|
}
|
|
|
|
return results, nextCursor, nil
|
|
}
|
|
|
|
// getAllWithPagination fetches all matching results using internal pagination to avoid
|
|
// exceeding Redis Search's MAXSEARCHRESULTS limit (default 10,000).
|
|
func (s *RedisStore) getAllWithPagination(ctx context.Context, namespace string, redisQuery string, queries []Query, selectFields []string) ([]SearchResult, *string, error) {
|
|
var allResults []SearchResult
|
|
offset := 0
|
|
|
|
for {
|
|
pageResults, err := s.executeSearch(ctx, namespace, redisQuery, queries, selectFields, offset, RedisMaxSearchResults)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
if len(pageResults) == 0 {
|
|
break
|
|
}
|
|
|
|
allResults = append(allResults, pageResults...)
|
|
|
|
if len(pageResults) < RedisMaxSearchResults {
|
|
break
|
|
}
|
|
offset += len(pageResults)
|
|
}
|
|
|
|
return allResults, nil, nil
|
|
}
|
|
|
|
// executeSearch performs a single FT.SEARCH query with the given offset and limit.
|
|
func (s *RedisStore) executeSearch(ctx context.Context, namespace string, redisQuery string, queries []Query, selectFields []string, offset int, searchLimit int) ([]SearchResult, error) {
|
|
args := []interface{}{
|
|
"FT.SEARCH", namespace,
|
|
redisQuery,
|
|
}
|
|
|
|
if len(selectFields) > 0 {
|
|
args = append(args, "RETURN", len(selectFields))
|
|
for _, field := range selectFields {
|
|
args = append(args, field)
|
|
}
|
|
}
|
|
|
|
args = append(args, "LIMIT", offset, searchLimit, "DIALECT", "2")
|
|
|
|
result := s.client.Do(ctx, args...)
|
|
if result.Err() != nil {
|
|
errMsg := strings.ToLower(result.Err().Error())
|
|
if isQuerySyntaxError(errMsg) {
|
|
s.logger.Debug(fmt.Sprintf("FT.SEARCH DIALECT fallback triggered for namespace %s: %s", namespace, result.Err()))
|
|
compatArgs := make([]interface{}, 0, len(args)-2)
|
|
for i := 0; i < len(args); i++ {
|
|
if i+1 < len(args) && args[i] == "DIALECT" {
|
|
i++
|
|
continue
|
|
}
|
|
compatArgs = append(compatArgs, args[i])
|
|
}
|
|
result = s.client.Do(ctx, compatArgs...)
|
|
}
|
|
if result.Err() != nil {
|
|
errMsg = strings.ToLower(result.Err().Error())
|
|
if isQuerySyntaxError(errMsg) {
|
|
if IsScanFallbackDisabled(ctx) {
|
|
return nil, fmt.Errorf("%w: %w", ErrQuerySyntax, result.Err())
|
|
}
|
|
s.logger.Debug(fmt.Sprintf("FT.SEARCH scan fallback triggered for namespace %s: %s", namespace, result.Err()))
|
|
scanResults, _, scanErr := s.getAllByScan(ctx, namespace, queries, selectFields, nil, int64(searchLimit))
|
|
if scanErr != nil {
|
|
return nil, scanErr
|
|
}
|
|
return scanResults, nil
|
|
}
|
|
return nil, fmt.Errorf("failed to search: %w", result.Err())
|
|
}
|
|
}
|
|
|
|
results, err := s.parseSearchResults(result.Val(), namespace, selectFields)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to parse search results: %w", err)
|
|
}
|
|
|
|
return results, nil
|
|
}
|
|
|
|
func (s *RedisStore) getAllByScan(ctx context.Context, namespace string, queries []Query, selectFields []string, cursor *string, limit int64) ([]SearchResult, *string, error) {
|
|
// Parse offset for deterministic in-memory pagination after full scan.
|
|
offset, err := parseOffsetCursor(cursor)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
all, err := s.scanAllMatchingResults(ctx, namespace, queries, selectFields)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
// Ensure stable pagination boundaries for offset cursors across calls.
|
|
sort.Slice(all, func(i, j int) bool {
|
|
return all[i].ID < all[j].ID
|
|
})
|
|
|
|
if offset > len(all) {
|
|
offset = len(all)
|
|
}
|
|
|
|
if limit == 0 {
|
|
return all[offset:], nil, nil
|
|
}
|
|
if limit < 0 {
|
|
limit = BatchLimit
|
|
}
|
|
|
|
end := offset + int(limit)
|
|
if end > len(all) {
|
|
end = len(all)
|
|
}
|
|
|
|
results := all[offset:end]
|
|
var next *string
|
|
if end < len(all) {
|
|
nextCursorStr := strconv.Itoa(end)
|
|
next = &nextCursorStr
|
|
}
|
|
|
|
return results, next, nil
|
|
}
|
|
|
|
func (s *RedisStore) scanAllMatchingResults(ctx context.Context, namespace string, queries []Query, selectFields []string) ([]SearchResult, error) {
|
|
if clusterClient, ok := s.client.(*redis.ClusterClient); ok {
|
|
return s.scanAllMatchingResultsCluster(ctx, clusterClient, namespace, queries, selectFields)
|
|
}
|
|
return s.scanAllMatchingResultsSingle(ctx, s.client, namespace, queries, selectFields)
|
|
}
|
|
|
|
func (s *RedisStore) scanAllMatchingResultsSingle(ctx context.Context, client redis.Cmdable, namespace string, queries []Query, selectFields []string) ([]SearchResult, error) {
|
|
pattern := buildKey(namespace, "*")
|
|
var (
|
|
scanCursor uint64
|
|
all []SearchResult
|
|
)
|
|
|
|
for {
|
|
keys, nextCursor, err := client.Scan(ctx, scanCursor, pattern, BatchLimit).Result()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to scan keys: %w", err)
|
|
}
|
|
|
|
matches, err := s.fetchMatchingSearchResults(ctx, client, namespace, keys, queries, selectFields)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
all = append(all, matches...)
|
|
|
|
scanCursor = nextCursor
|
|
if scanCursor == 0 {
|
|
break
|
|
}
|
|
}
|
|
|
|
return all, nil
|
|
}
|
|
|
|
func (s *RedisStore) scanAllMatchingResultsCluster(ctx context.Context, client *redis.ClusterClient, namespace string, queries []Query, selectFields []string) ([]SearchResult, error) {
|
|
var (
|
|
all []SearchResult
|
|
allMu sync.Mutex
|
|
seenIDs = make(map[string]struct{})
|
|
seenIDsMu sync.Mutex
|
|
)
|
|
|
|
err := client.ForEachMaster(ctx, func(ctx context.Context, nodeClient *redis.Client) error {
|
|
matches, err := s.scanAllMatchingResultsSingle(ctx, nodeClient, namespace, queries, selectFields)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
unique := make([]SearchResult, 0, len(matches))
|
|
seenIDsMu.Lock()
|
|
for _, match := range matches {
|
|
if _, ok := seenIDs[match.ID]; ok {
|
|
continue
|
|
}
|
|
seenIDs[match.ID] = struct{}{}
|
|
unique = append(unique, match)
|
|
}
|
|
seenIDsMu.Unlock()
|
|
|
|
if len(unique) == 0 {
|
|
return nil
|
|
}
|
|
|
|
allMu.Lock()
|
|
all = append(all, unique...)
|
|
allMu.Unlock()
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to scan cluster nodes: %w", err)
|
|
}
|
|
|
|
return all, nil
|
|
}
|
|
|
|
func (s *RedisStore) fetchMatchingSearchResults(ctx context.Context, client redis.Cmdable, namespace string, keys []string, queries []Query, selectFields []string) ([]SearchResult, error) {
|
|
if len(keys) == 0 {
|
|
return nil, nil
|
|
}
|
|
|
|
pipe := client.Pipeline()
|
|
cmds := make([]*redis.MapStringStringCmd, len(keys))
|
|
for i, key := range keys {
|
|
cmds[i] = pipe.HGetAll(ctx, key)
|
|
}
|
|
|
|
if _, err := pipe.Exec(ctx); err != nil {
|
|
return nil, fmt.Errorf("failed to fetch scanned keys: %w", err)
|
|
}
|
|
|
|
results := make([]SearchResult, 0, len(keys))
|
|
for i, cmd := range cmds {
|
|
if cmd.Err() != nil {
|
|
continue
|
|
}
|
|
fields := cmd.Val()
|
|
if len(fields) == 0 {
|
|
continue
|
|
}
|
|
|
|
key := keys[i]
|
|
id := strings.TrimPrefix(key, namespace+":")
|
|
if id == key {
|
|
continue
|
|
}
|
|
|
|
properties := make(map[string]interface{}, len(fields))
|
|
for k, v := range fields {
|
|
properties[k] = v
|
|
}
|
|
|
|
if !matchesQueriesForScan(properties, queries) {
|
|
continue
|
|
}
|
|
|
|
searchResult := SearchResult{
|
|
ID: id,
|
|
Properties: make(map[string]interface{}),
|
|
}
|
|
|
|
if len(selectFields) == 0 {
|
|
searchResult.Properties = properties
|
|
} else {
|
|
for _, field := range selectFields {
|
|
if val, ok := properties[field]; ok {
|
|
searchResult.Properties[field] = val
|
|
}
|
|
}
|
|
}
|
|
|
|
results = append(results, searchResult)
|
|
}
|
|
|
|
return results, nil
|
|
}
|
|
|
|
func matchesQueriesForScan(properties map[string]interface{}, queries []Query) bool {
|
|
for _, q := range queries {
|
|
raw, exists := properties[q.Field]
|
|
|
|
// NOTE: missing fields are treated as non-matching for most operators
|
|
// (Equal, Like, GreaterThan, etc.) but pass NotEqual — i.e. a document
|
|
// without the field is considered "not equal" to any value. This differs
|
|
// from SQL NULL semantics where NULL != value evaluates to NULL/unknown.
|
|
// Change this if scan results need to match FT.SEARCH behavior exactly.
|
|
|
|
rawStr := fmt.Sprintf("%v", raw)
|
|
queryStr := fmt.Sprintf("%v", q.Value)
|
|
|
|
switch q.Operator {
|
|
case QueryOperatorEqual:
|
|
if !exists || rawStr != queryStr {
|
|
return false
|
|
}
|
|
case QueryOperatorNotEqual:
|
|
if exists && rawStr == queryStr {
|
|
return false
|
|
}
|
|
case QueryOperatorIsNull:
|
|
if exists {
|
|
return false
|
|
}
|
|
case QueryOperatorIsNotNull:
|
|
if !exists {
|
|
return false
|
|
}
|
|
case QueryOperatorLike:
|
|
if !exists || !strings.Contains(strings.ToLower(rawStr), strings.ToLower(queryStr)) {
|
|
return false
|
|
}
|
|
case QueryOperatorGreaterThan:
|
|
if !exists {
|
|
return false
|
|
}
|
|
rawF, errR := strconv.ParseFloat(rawStr, 64)
|
|
queryF, errQ := strconv.ParseFloat(queryStr, 64)
|
|
if errR != nil || errQ != nil || rawF <= queryF {
|
|
return false
|
|
}
|
|
case QueryOperatorGreaterThanOrEqual:
|
|
if !exists {
|
|
return false
|
|
}
|
|
rawF, errR := strconv.ParseFloat(rawStr, 64)
|
|
queryF, errQ := strconv.ParseFloat(queryStr, 64)
|
|
if errR != nil || errQ != nil || rawF < queryF {
|
|
return false
|
|
}
|
|
case QueryOperatorLessThan:
|
|
if !exists {
|
|
return false
|
|
}
|
|
rawF, errR := strconv.ParseFloat(rawStr, 64)
|
|
queryF, errQ := strconv.ParseFloat(queryStr, 64)
|
|
if errR != nil || errQ != nil || rawF >= queryF {
|
|
return false
|
|
}
|
|
case QueryOperatorLessThanOrEqual:
|
|
if !exists {
|
|
return false
|
|
}
|
|
rawF, errR := strconv.ParseFloat(rawStr, 64)
|
|
queryF, errQ := strconv.ParseFloat(queryStr, 64)
|
|
if errR != nil || errQ != nil || rawF > queryF {
|
|
return false
|
|
}
|
|
case QueryOperatorContainsAny:
|
|
if !exists {
|
|
return false
|
|
}
|
|
propertyValues, ok := parseStringValuesForContains(raw)
|
|
if !ok {
|
|
return false
|
|
}
|
|
queryValues, ok := parseQueryContainsValues(q.Value)
|
|
if !ok {
|
|
return false
|
|
}
|
|
if !containsAnyString(propertyValues, queryValues) {
|
|
return false
|
|
}
|
|
case QueryOperatorContainsAll:
|
|
if !exists {
|
|
return false
|
|
}
|
|
propertyValues, ok := parseStringValuesForContains(raw)
|
|
if !ok {
|
|
return false
|
|
}
|
|
queryValues, ok := parseQueryContainsValues(q.Value)
|
|
if !ok {
|
|
return false
|
|
}
|
|
if !containsAllStrings(propertyValues, queryValues) {
|
|
return false
|
|
}
|
|
default:
|
|
// Conservative fallback: require exact match semantics for unsupported operators.
|
|
if !exists || rawStr != queryStr {
|
|
return false
|
|
}
|
|
}
|
|
}
|
|
return true
|
|
}
|
|
|
|
// parseSearchResults parses FT.SEARCH results into SearchResult slice.
|
|
func (s *RedisStore) parseSearchResults(result interface{}, namespace string, selectFields []string) ([]SearchResult, error) {
|
|
results := []SearchResult{}
|
|
|
|
// RESP3 style in Redis/Valkey:
|
|
// map{ "results": [ { "id": "...", "extra_attributes": {...} } ] }
|
|
switch typed := result.(type) {
|
|
case map[interface{}]interface{}:
|
|
rawResults, ok := typed["results"]
|
|
if !ok {
|
|
return results, nil
|
|
}
|
|
resultItems, ok := rawResults.([]interface{})
|
|
if !ok {
|
|
return results, nil
|
|
}
|
|
for _, item := range resultItems {
|
|
if parsed, ok := parseSearchResultDocument(item, namespace, selectFields); ok {
|
|
results = append(results, parsed)
|
|
}
|
|
}
|
|
return results, nil
|
|
case map[string]interface{}:
|
|
rawResults, ok := typed["results"]
|
|
if !ok {
|
|
return results, nil
|
|
}
|
|
resultItems, ok := rawResults.([]interface{})
|
|
if !ok {
|
|
return results, nil
|
|
}
|
|
for _, item := range resultItems {
|
|
if parsed, ok := parseSearchResultDocument(item, namespace, selectFields); ok {
|
|
results = append(results, parsed)
|
|
}
|
|
}
|
|
return results, nil
|
|
case []interface{}:
|
|
// RESP2 style in Redis/Valkey:
|
|
// [total, "namespace:id", ["field", "value", ...], ...]
|
|
if len(typed) < 3 {
|
|
return results, nil
|
|
}
|
|
for i := 1; i+1 < len(typed); i += 2 {
|
|
idValue := typed[i]
|
|
attrsValue := typed[i+1]
|
|
doc := map[string]interface{}{
|
|
"id": idValue,
|
|
"extra_attributes": attrsValue,
|
|
}
|
|
if parsed, ok := parseSearchResultDocument(doc, namespace, selectFields); ok {
|
|
results = append(results, parsed)
|
|
}
|
|
}
|
|
return results, nil
|
|
default:
|
|
return results, nil
|
|
}
|
|
}
|
|
|
|
func parseSearchResultIDs(result interface{}, namespace string) []string {
|
|
ids := make([]string, 0)
|
|
appendID := func(value interface{}) {
|
|
id, ok := toString(value)
|
|
if !ok {
|
|
return
|
|
}
|
|
id = strings.TrimSpace(id)
|
|
if id == "" {
|
|
return
|
|
}
|
|
if namespace != "" {
|
|
prefix := namespace + ":"
|
|
if strings.HasPrefix(id, prefix) {
|
|
id = strings.TrimPrefix(id, prefix)
|
|
}
|
|
}
|
|
if id == "" {
|
|
return
|
|
}
|
|
ids = append(ids, id)
|
|
}
|
|
|
|
extractRESP3IDs := func(rawResults interface{}) {
|
|
resultItems, ok := rawResults.([]interface{})
|
|
if !ok {
|
|
return
|
|
}
|
|
for _, item := range resultItems {
|
|
switch doc := item.(type) {
|
|
case map[string]interface{}:
|
|
appendID(doc["id"])
|
|
case map[interface{}]interface{}:
|
|
appendID(doc["id"])
|
|
default:
|
|
appendID(item)
|
|
}
|
|
}
|
|
}
|
|
|
|
switch typed := result.(type) {
|
|
case map[interface{}]interface{}:
|
|
extractRESP3IDs(typed["results"])
|
|
case map[string]interface{}:
|
|
extractRESP3IDs(typed["results"])
|
|
case []interface{}:
|
|
if len(typed) < 2 {
|
|
return ids
|
|
}
|
|
for i := 1; i < len(typed); i++ {
|
|
appendID(typed[i])
|
|
|
|
// RESP2 payloads can be [total, id, attrs, id, attrs, ...].
|
|
if i+1 < len(typed) {
|
|
switch typed[i+1].(type) {
|
|
case []interface{}, map[string]interface{}, map[interface{}]interface{}:
|
|
i++
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
return ids
|
|
}
|
|
|
|
func parseSearchResultDocument(resultItem interface{}, namespace string, selectFields []string) (SearchResult, bool) {
|
|
var docMap map[string]interface{}
|
|
|
|
switch item := resultItem.(type) {
|
|
case map[string]interface{}:
|
|
docMap = item
|
|
case map[interface{}]interface{}:
|
|
docMap = make(map[string]interface{}, len(item))
|
|
for k, v := range item {
|
|
docMap[fmt.Sprintf("%v", k)] = v
|
|
}
|
|
default:
|
|
return SearchResult{}, false
|
|
}
|
|
|
|
idRaw, ok := docMap["id"]
|
|
if !ok {
|
|
return SearchResult{}, false
|
|
}
|
|
|
|
id, ok := toString(idRaw)
|
|
if !ok {
|
|
return SearchResult{}, false
|
|
}
|
|
|
|
docID := id
|
|
if namespace != "" {
|
|
prefix := namespace + ":"
|
|
if strings.HasPrefix(id, prefix) {
|
|
docID = strings.TrimPrefix(id, prefix)
|
|
}
|
|
}
|
|
|
|
attrsRaw, ok := docMap["extra_attributes"]
|
|
if !ok {
|
|
return SearchResult{}, false
|
|
}
|
|
|
|
attrs := attributesToMap(attrsRaw)
|
|
if attrs == nil {
|
|
return SearchResult{}, false
|
|
}
|
|
|
|
searchResult := SearchResult{
|
|
ID: docID,
|
|
Properties: make(map[string]interface{}, len(attrs)),
|
|
}
|
|
|
|
for fieldName, fieldValue := range attrs {
|
|
if fieldName == "score" {
|
|
searchResult.Properties[fieldName] = fieldValue
|
|
if scoreFloat, ok := toFloat64(fieldValue); ok {
|
|
searchResult.Score = &scoreFloat
|
|
}
|
|
continue
|
|
}
|
|
|
|
if len(selectFields) > 0 && !containsField(selectFields, fieldName) {
|
|
continue
|
|
}
|
|
|
|
searchResult.Properties[fieldName] = fieldValue
|
|
}
|
|
|
|
return searchResult, true
|
|
}
|
|
|
|
func attributesToMap(value interface{}) map[string]interface{} {
|
|
switch attrs := value.(type) {
|
|
case map[string]interface{}:
|
|
return attrs
|
|
case map[interface{}]interface{}:
|
|
out := make(map[string]interface{}, len(attrs))
|
|
for k, v := range attrs {
|
|
out[fmt.Sprintf("%v", k)] = v
|
|
}
|
|
return out
|
|
case []interface{}:
|
|
// RESP2 attribute pairs: ["field", "value", "field2", "value2", ...]
|
|
if len(attrs)%2 != 0 {
|
|
return nil
|
|
}
|
|
out := make(map[string]interface{}, len(attrs)/2)
|
|
for i := 0; i+1 < len(attrs); i += 2 {
|
|
key, ok := toString(attrs[i])
|
|
if !ok {
|
|
continue
|
|
}
|
|
out[key] = attrs[i+1]
|
|
}
|
|
return out
|
|
default:
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func toString(value interface{}) (string, bool) {
|
|
switch v := value.(type) {
|
|
case string:
|
|
return v, true
|
|
case []byte:
|
|
return string(v), true
|
|
default:
|
|
return "", false
|
|
}
|
|
}
|
|
|
|
func toFloat64(value interface{}) (float64, bool) {
|
|
switch v := value.(type) {
|
|
case float64:
|
|
return v, true
|
|
case float32:
|
|
return float64(v), true
|
|
case int:
|
|
return float64(v), true
|
|
case int64:
|
|
return float64(v), true
|
|
case string:
|
|
parsed, err := strconv.ParseFloat(v, 64)
|
|
if err != nil {
|
|
return 0, false
|
|
}
|
|
return parsed, true
|
|
case []byte:
|
|
parsed, err := strconv.ParseFloat(string(v), 64)
|
|
if err != nil {
|
|
return 0, false
|
|
}
|
|
return parsed, true
|
|
default:
|
|
return 0, false
|
|
}
|
|
}
|
|
|
|
func containsField(fields []string, candidate string) bool {
|
|
for _, field := range fields {
|
|
if field == candidate {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func (s *RedisStore) cacheNamespaceFieldTypes(namespace string, properties map[string]VectorStoreProperties) {
|
|
if strings.TrimSpace(namespace) == "" || len(properties) == 0 {
|
|
return
|
|
}
|
|
|
|
fieldTypes := make(map[string]VectorStorePropertyType, len(properties))
|
|
for field, prop := range properties {
|
|
fieldTypes[field] = prop.DataType
|
|
}
|
|
|
|
s.namespaceFieldTypesMu.Lock()
|
|
defer s.namespaceFieldTypesMu.Unlock()
|
|
if s.namespaceFieldTypes == nil {
|
|
s.namespaceFieldTypes = make(map[string]map[string]VectorStorePropertyType)
|
|
}
|
|
s.namespaceFieldTypes[namespace] = fieldTypes
|
|
}
|
|
|
|
func (s *RedisStore) deleteNamespaceFieldTypes(namespace string) {
|
|
if strings.TrimSpace(namespace) == "" {
|
|
return
|
|
}
|
|
s.namespaceFieldTypesMu.Lock()
|
|
defer s.namespaceFieldTypesMu.Unlock()
|
|
delete(s.namespaceFieldTypes, namespace)
|
|
}
|
|
|
|
func (s *RedisStore) getNamespaceFieldTypes(namespace string) map[string]VectorStorePropertyType {
|
|
if strings.TrimSpace(namespace) == "" {
|
|
return nil
|
|
}
|
|
|
|
s.namespaceFieldTypesMu.RLock()
|
|
defer s.namespaceFieldTypesMu.RUnlock()
|
|
|
|
fieldTypes, ok := s.namespaceFieldTypes[namespace]
|
|
if !ok {
|
|
return nil
|
|
}
|
|
|
|
copied := make(map[string]VectorStorePropertyType, len(fieldTypes))
|
|
for field, dataType := range fieldTypes {
|
|
copied[field] = dataType
|
|
}
|
|
return copied
|
|
}
|
|
|
|
// buildRedisQuery converts []Query to Redis query syntax
|
|
func buildRedisQuery(queries []Query, fieldTypes map[string]VectorStorePropertyType) string {
|
|
if len(queries) == 0 {
|
|
return "*"
|
|
}
|
|
|
|
var conditions []string
|
|
for _, query := range queries {
|
|
condition := buildRedisQueryCondition(query, fieldTypes)
|
|
if condition != "" {
|
|
conditions = append(conditions, condition)
|
|
}
|
|
}
|
|
|
|
if len(conditions) == 0 {
|
|
return "*"
|
|
}
|
|
|
|
// Join conditions with space (AND operation in Redis)
|
|
return strings.Join(conditions, " ")
|
|
}
|
|
|
|
func shouldUseNumericEquality(field string, value interface{}, fieldTypes map[string]VectorStorePropertyType) (string, bool) {
|
|
if fieldTypes != nil {
|
|
if dataType, ok := fieldTypes[field]; ok {
|
|
if dataType == VectorStorePropertyTypeInteger {
|
|
return normalizeNumericQueryValue(value)
|
|
}
|
|
return "", false
|
|
}
|
|
}
|
|
return normalizeNumericQueryValue(value)
|
|
}
|
|
|
|
func normalizeNumericQueryValue(value interface{}) (string, bool) {
|
|
switch v := value.(type) {
|
|
case int:
|
|
return strconv.FormatInt(int64(v), 10), true
|
|
case int8:
|
|
return strconv.FormatInt(int64(v), 10), true
|
|
case int16:
|
|
return strconv.FormatInt(int64(v), 10), true
|
|
case int32:
|
|
return strconv.FormatInt(int64(v), 10), true
|
|
case int64:
|
|
return strconv.FormatInt(v, 10), true
|
|
case uint:
|
|
return strconv.FormatUint(uint64(v), 10), true
|
|
case uint8:
|
|
return strconv.FormatUint(uint64(v), 10), true
|
|
case uint16:
|
|
return strconv.FormatUint(uint64(v), 10), true
|
|
case uint32:
|
|
return strconv.FormatUint(uint64(v), 10), true
|
|
case uint64:
|
|
return strconv.FormatUint(v, 10), true
|
|
case float32:
|
|
return strconv.FormatFloat(float64(v), 'f', -1, 32), true
|
|
case float64:
|
|
return strconv.FormatFloat(v, 'f', -1, 64), true
|
|
case string:
|
|
trimmed := strings.TrimSpace(v)
|
|
if trimmed == "" {
|
|
return "", false
|
|
}
|
|
if _, err := strconv.ParseFloat(trimmed, 64); err != nil {
|
|
return "", false
|
|
}
|
|
return trimmed, true
|
|
default:
|
|
return "", false
|
|
}
|
|
}
|
|
|
|
// buildRedisQueryCondition builds a single Redis query condition
|
|
func buildRedisQueryCondition(query Query, fieldTypes map[string]VectorStorePropertyType) string {
|
|
field := query.Field
|
|
operator := query.Operator
|
|
value := query.Value
|
|
|
|
// Convert value to string
|
|
var stringValue string
|
|
switch val := value.(type) {
|
|
case string:
|
|
stringValue = val
|
|
case int, int64, float64, bool:
|
|
stringValue = fmt.Sprintf("%v", val)
|
|
default:
|
|
jsonData, _ := json.Marshal(val)
|
|
stringValue = string(jsonData)
|
|
}
|
|
|
|
// Escape special characters for TAG fields
|
|
escapedValue := escapeSearchValue(stringValue) // new function for TAG escaping
|
|
|
|
switch operator {
|
|
case QueryOperatorEqual:
|
|
if numericValue, useNumeric := shouldUseNumericEquality(field, value, fieldTypes); useNumeric {
|
|
return fmt.Sprintf("@%s:[%s %s]", field, numericValue, numericValue)
|
|
}
|
|
// TAG exact match
|
|
return fmt.Sprintf("@%s:{%s}", field, escapedValue)
|
|
case QueryOperatorNotEqual:
|
|
if numericValue, useNumeric := shouldUseNumericEquality(field, value, fieldTypes); useNumeric {
|
|
return fmt.Sprintf("-@%s:[%s %s]", field, numericValue, numericValue)
|
|
}
|
|
// TAG negation
|
|
return fmt.Sprintf("-@%s:{%s}", field, escapedValue)
|
|
case QueryOperatorLike:
|
|
// Cannot do LIKE with TAGs directly; fallback to exact match
|
|
return fmt.Sprintf("@%s:{%s}", field, escapedValue)
|
|
case QueryOperatorGreaterThan:
|
|
return fmt.Sprintf("@%s:[(%s +inf]", field, escapedValue)
|
|
case QueryOperatorGreaterThanOrEqual:
|
|
return fmt.Sprintf("@%s:[%s +inf]", field, escapedValue)
|
|
case QueryOperatorLessThan:
|
|
return fmt.Sprintf("@%s:[-inf (%s]", field, escapedValue)
|
|
case QueryOperatorLessThanOrEqual:
|
|
return fmt.Sprintf("@%s:[-inf %s]", field, escapedValue)
|
|
case QueryOperatorIsNull:
|
|
// Field not present
|
|
return fmt.Sprintf("-@%s:*", field)
|
|
case QueryOperatorIsNotNull:
|
|
// Field exists
|
|
return fmt.Sprintf("@%s:*", field)
|
|
case QueryOperatorContainsAny:
|
|
if values, ok := value.([]interface{}); ok {
|
|
var orConditions []string
|
|
for _, v := range values {
|
|
vStr := fmt.Sprintf("%v", v)
|
|
orConditions = append(orConditions, fmt.Sprintf("@%s:{%s}", field, escapeSearchValue(vStr)))
|
|
}
|
|
return fmt.Sprintf("(%s)", strings.Join(orConditions, " | "))
|
|
}
|
|
return fmt.Sprintf("@%s:{%s}", field, escapedValue)
|
|
case QueryOperatorContainsAll:
|
|
if values, ok := value.([]interface{}); ok {
|
|
var andConditions []string
|
|
for _, v := range values {
|
|
vStr := fmt.Sprintf("%v", v)
|
|
andConditions = append(andConditions, fmt.Sprintf("@%s:{%s}", field, escapeSearchValue(vStr)))
|
|
}
|
|
return strings.Join(andConditions, " ")
|
|
}
|
|
return fmt.Sprintf("@%s:{%s}", field, escapedValue)
|
|
default:
|
|
return fmt.Sprintf("@%s:{%s}", field, escapedValue)
|
|
}
|
|
}
|
|
|
|
// GetNearest retrieves the nearest chunks from the Redis vector store.
|
|
func (s *RedisStore) GetNearest(ctx context.Context, namespace string, vector []float32, queries []Query, selectFields []string, threshold float64, limit int64) ([]SearchResult, error) {
|
|
ctx, cancel := withTimeout(ctx, s.config.ContextTimeout)
|
|
defer cancel()
|
|
|
|
// Build Redis query from the provided queries
|
|
redisQuery := buildRedisQuery(queries, s.getNamespaceFieldTypes(namespace))
|
|
|
|
// Convert query embedding to binary format
|
|
queryBytes := float32SliceToBytes(vector)
|
|
|
|
// Build hybrid FT.SEARCH query: metadata filters + KNN vector search
|
|
// The correct syntax is: (metadata_filter)=>[KNN k @embedding $vec AS score]
|
|
var hybridQuery string
|
|
if len(queries) > 0 {
|
|
// Wrap metadata query in parentheses for hybrid syntax
|
|
hybridQuery = fmt.Sprintf("(%s)", redisQuery)
|
|
} else {
|
|
// Wildcard for pure vector search
|
|
hybridQuery = "*"
|
|
}
|
|
|
|
// Execute FT.SEARCH with KNN
|
|
// Use large limit for "all" (limit=0) in KNN query
|
|
knnLimit := limit
|
|
if limit == 0 {
|
|
knnLimit = math.MaxInt32
|
|
}
|
|
|
|
args := []interface{}{
|
|
"FT.SEARCH", namespace,
|
|
fmt.Sprintf("%s=>[KNN %d @embedding $vec AS score]", hybridQuery, knnLimit),
|
|
"PARAMS", "2", "vec", queryBytes,
|
|
"SORTBY", "score",
|
|
}
|
|
|
|
// Add RETURN clause - always include score for vector search
|
|
// For vector search, we need to include the score field generated by KNN
|
|
returnFields := []string{"score"}
|
|
if len(selectFields) > 0 {
|
|
returnFields = append(returnFields, selectFields...)
|
|
}
|
|
|
|
args = append(args, "RETURN", len(returnFields))
|
|
for _, field := range returnFields {
|
|
args = append(args, field)
|
|
}
|
|
|
|
// Add LIMIT clause and DIALECT 2 for better query parsing
|
|
searchLimit := limit
|
|
if limit == 0 {
|
|
searchLimit = math.MaxInt32
|
|
}
|
|
args = append(args, "LIMIT", 0, int(searchLimit), "DIALECT", "2")
|
|
|
|
result := s.client.Do(ctx, args...)
|
|
if result.Err() != nil {
|
|
errMsg := strings.ToLower(result.Err().Error())
|
|
// Some Valkey implementations reject SORTBY in KNN search (already distance-ordered).
|
|
if strings.Contains(errMsg, "unexpected argument `sortby`") || strings.Contains(errMsg, "unexpected argument sortby") {
|
|
compatArgs := make([]interface{}, 0, len(args)-2)
|
|
for i := 0; i < len(args); i++ {
|
|
if i+1 < len(args) && args[i] == "SORTBY" {
|
|
i++ // skip sort field value too
|
|
continue
|
|
}
|
|
compatArgs = append(compatArgs, args[i])
|
|
}
|
|
result = s.client.Do(ctx, compatArgs...)
|
|
}
|
|
if result.Err() != nil {
|
|
return nil, fmt.Errorf("native vector search failed: %w", result.Err())
|
|
}
|
|
}
|
|
|
|
// Parse search results
|
|
results, err := s.parseSearchResults(result.Val(), namespace, selectFields)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Apply threshold filter and extract scores
|
|
var filteredResults []SearchResult
|
|
for _, result := range results {
|
|
// Extract score from the result
|
|
if scoreValue, exists := result.Properties["score"]; exists {
|
|
score, ok := toFloat64(scoreValue)
|
|
if !ok {
|
|
continue
|
|
}
|
|
|
|
// Convert cosine distance to similarity: similarity = 1 - distance
|
|
similarity := 1.0 - score
|
|
result.Score = &similarity
|
|
|
|
// Apply threshold filter
|
|
if similarity >= threshold {
|
|
filteredResults = append(filteredResults, result)
|
|
}
|
|
} else {
|
|
// If no score, include the result (shouldn't happen with KNN queries)
|
|
filteredResults = append(filteredResults, result)
|
|
}
|
|
}
|
|
|
|
results = filteredResults
|
|
|
|
return results, nil
|
|
}
|
|
|
|
// Add stores a new chunk in the Redis vector store.
|
|
func (s *RedisStore) Add(ctx context.Context, namespace string, id string, embedding []float32, metadata map[string]interface{}) error {
|
|
ctx, cancel := withTimeout(ctx, s.config.ContextTimeout)
|
|
defer cancel()
|
|
|
|
if strings.TrimSpace(id) == "" {
|
|
return fmt.Errorf("id is required")
|
|
}
|
|
|
|
// Create key with namespace
|
|
key := buildKey(namespace, id)
|
|
|
|
// Prepare hash fields: binary embedding + metadata
|
|
fields := make(map[string]interface{})
|
|
|
|
// Only add embedding if it's not empty
|
|
if len(embedding) > 0 {
|
|
// Convert float32 slice to bytes for Redis storage
|
|
embeddingBytes := float32SliceToBytes(embedding)
|
|
fields["embedding"] = embeddingBytes
|
|
}
|
|
|
|
// Add metadata fields directly (no prefix needed with proper indexing)
|
|
for k, v := range metadata {
|
|
switch val := v.(type) {
|
|
case string:
|
|
fields[k] = val
|
|
case int, int64, float64, bool:
|
|
fields[k] = fmt.Sprintf("%v", val)
|
|
case []interface{}:
|
|
// Preserve arrays as JSON to support round-trips (e.g., stream_chunks)
|
|
b, err := json.Marshal(val)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to marshal array metadata %s: %w", k, err)
|
|
}
|
|
fields[k] = string(b)
|
|
default:
|
|
// JSON encode complex types
|
|
jsonData, err := json.Marshal(val)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to marshal metadata field %s: %w", k, err)
|
|
}
|
|
fields[k] = string(jsonData)
|
|
}
|
|
}
|
|
|
|
// Store as hash for efficient native vector search
|
|
if err := s.client.HSet(ctx, key, fields).Err(); err != nil {
|
|
return fmt.Errorf("failed to store semantic cache entry: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Delete deletes a chunk from the Redis vector store.
|
|
func (s *RedisStore) Delete(ctx context.Context, namespace string, id string) error {
|
|
ctx, cancel := withTimeout(ctx, s.config.ContextTimeout)
|
|
defer cancel()
|
|
|
|
if strings.TrimSpace(id) == "" {
|
|
return fmt.Errorf("id is required")
|
|
}
|
|
|
|
// Create key with namespace
|
|
key := buildKey(namespace, id)
|
|
|
|
// Delete the hash key
|
|
result := s.client.Del(ctx, key)
|
|
if result.Err() != nil {
|
|
return fmt.Errorf("failed to delete chunk %s: %w", id, result.Err())
|
|
}
|
|
|
|
// Check if the key actually existed
|
|
if result.Val() == 0 {
|
|
return fmt.Errorf("chunk not found: %s", id)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// DeleteAll deletes all chunks from the Redis vector store.
|
|
func (s *RedisStore) DeleteAll(ctx context.Context, namespace string, queries []Query) ([]DeleteResult, error) {
|
|
ctx, cancel := withTimeout(ctx, s.config.ContextTimeout)
|
|
defer cancel()
|
|
|
|
return s.deleteAllBySnapshot(ctx, namespace, queries)
|
|
}
|
|
|
|
// deleteAllBySnapshot snapshots matching ids before deleting to avoid
|
|
// offset/cursor drift while mutating the dataset.
|
|
func (s *RedisStore) deleteAllBySnapshot(ctx context.Context, namespace string, queries []Query) ([]DeleteResult, error) {
|
|
ids, err := s.getAllMatchingIDs(ctx, namespace, queries)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to find documents to delete: %w", err)
|
|
}
|
|
|
|
if len(ids) == 0 {
|
|
return []DeleteResult{}, nil
|
|
}
|
|
|
|
// Delete this batch of documents
|
|
var deleteResults []DeleteResult
|
|
batchSize := BatchLimit // Process in batches to avoid overwhelming Redis
|
|
|
|
for i := 0; i < len(ids); i += batchSize {
|
|
end := i + batchSize
|
|
if end > len(ids) {
|
|
end = len(ids)
|
|
}
|
|
batch := ids[i:end]
|
|
|
|
// Create pipeline for batch deletion
|
|
pipe := s.client.Pipeline()
|
|
cmds := make([]*redis.IntCmd, len(batch))
|
|
|
|
for j, id := range batch {
|
|
key := buildKey(namespace, id)
|
|
cmds[j] = pipe.Del(ctx, key)
|
|
}
|
|
|
|
// Execute pipeline
|
|
_, err := pipe.Exec(ctx)
|
|
if err != nil {
|
|
// If pipeline fails, mark all in this batch as failed
|
|
for _, id := range batch {
|
|
deleteResults = append(deleteResults, DeleteResult{
|
|
ID: id,
|
|
Status: DeleteStatusError,
|
|
Error: fmt.Sprintf("pipeline execution failed: %v", err),
|
|
})
|
|
}
|
|
continue
|
|
}
|
|
|
|
// Process results for this batch
|
|
for j, cmd := range cmds {
|
|
id := batch[j]
|
|
if cmd.Err() != nil {
|
|
deleteResults = append(deleteResults, DeleteResult{
|
|
ID: id,
|
|
Status: DeleteStatusError,
|
|
Error: cmd.Err().Error(),
|
|
})
|
|
} else if cmd.Val() > 0 {
|
|
// Key existed and was deleted
|
|
deleteResults = append(deleteResults, DeleteResult{
|
|
ID: id,
|
|
Status: DeleteStatusSuccess,
|
|
})
|
|
} else {
|
|
// Key didn't exist
|
|
deleteResults = append(deleteResults, DeleteResult{
|
|
ID: id,
|
|
Status: DeleteStatusError,
|
|
Error: "document not found",
|
|
})
|
|
}
|
|
}
|
|
}
|
|
|
|
return deleteResults, nil
|
|
}
|
|
|
|
func (s *RedisStore) getAllMatchingIDs(ctx context.Context, namespace string, queries []Query) ([]string, error) {
|
|
redisQuery := buildRedisQuery(queries, s.getNamespaceFieldTypes(namespace))
|
|
offset := 0
|
|
ids := make([]string, 0)
|
|
|
|
for {
|
|
args := []interface{}{
|
|
"FT.SEARCH", namespace,
|
|
redisQuery,
|
|
"RETURN", 0,
|
|
"LIMIT", offset, BatchLimit,
|
|
"DIALECT", "2",
|
|
}
|
|
|
|
result := s.client.Do(ctx, args...)
|
|
if result.Err() != nil {
|
|
errMsg := strings.ToLower(result.Err().Error())
|
|
if isQuerySyntaxError(errMsg) {
|
|
s.logger.Debug(fmt.Sprintf("FT.SEARCH DIALECT fallback triggered for namespace %s while collecting ids: %s", namespace, result.Err()))
|
|
compatArgs := make([]interface{}, 0, len(args)-2)
|
|
for i := 0; i < len(args); i++ {
|
|
if i+1 < len(args) && args[i] == "DIALECT" {
|
|
i++
|
|
continue
|
|
}
|
|
compatArgs = append(compatArgs, args[i])
|
|
}
|
|
result = s.client.Do(ctx, compatArgs...)
|
|
}
|
|
if result.Err() != nil {
|
|
errMsg = strings.ToLower(result.Err().Error())
|
|
if isQuerySyntaxError(errMsg) {
|
|
if IsScanFallbackDisabled(ctx) {
|
|
return nil, fmt.Errorf("failed to collect matching ids without scan fallback: %w", result.Err())
|
|
}
|
|
s.logger.Debug(fmt.Sprintf("FT.SEARCH scan fallback triggered for namespace %s while collecting ids: %s", namespace, result.Err()))
|
|
scanResults, _, scanErr := s.getAllByScan(ctx, namespace, queries, nil, nil, 0)
|
|
if scanErr != nil {
|
|
return nil, fmt.Errorf("failed to collect matching ids via scan fallback: %w", scanErr)
|
|
}
|
|
scanIDs := make([]string, 0, len(scanResults))
|
|
for _, scanResult := range scanResults {
|
|
scanIDs = append(scanIDs, scanResult.ID)
|
|
}
|
|
return scanIDs, nil
|
|
}
|
|
return nil, fmt.Errorf("failed to search for matching ids: %w", result.Err())
|
|
}
|
|
}
|
|
|
|
pageIDs := parseSearchResultIDs(result.Val(), namespace)
|
|
if len(pageIDs) == 0 {
|
|
break
|
|
}
|
|
ids = append(ids, pageIDs...)
|
|
|
|
if len(pageIDs) < BatchLimit {
|
|
break
|
|
}
|
|
offset += len(pageIDs)
|
|
}
|
|
|
|
return ids, nil
|
|
}
|
|
|
|
// DeleteNamespace deletes a namespace from the Redis vector store.
|
|
func (s *RedisStore) DeleteNamespace(ctx context.Context, namespace string) error {
|
|
ctx, cancel := withTimeout(ctx, s.config.ContextTimeout)
|
|
defer cancel()
|
|
|
|
// Drop the index using FT.DROPINDEX
|
|
if err := s.client.Do(ctx, "FT.DROPINDEX", namespace).Err(); err != nil {
|
|
// Check if error is "Unknown Index name" - that's OK, index doesn't exist
|
|
if strings.Contains(strings.ToLower(err.Error()), "unknown index name") {
|
|
s.deleteNamespaceFieldTypes(namespace)
|
|
return nil // Index doesn't exist, nothing to drop
|
|
}
|
|
return fmt.Errorf("failed to drop semantic index %s: %w", namespace, err)
|
|
}
|
|
|
|
s.deleteNamespaceFieldTypes(namespace)
|
|
return nil
|
|
}
|
|
|
|
// Close closes the Redis vector store.
|
|
func (s *RedisStore) Close(ctx context.Context, namespace string) error {
|
|
// Close the Redis client connection
|
|
return s.client.Close()
|
|
}
|
|
|
|
// RequiresVectors returns false because Redis can store hash data with or without vectors.
|
|
func (s *RedisStore) RequiresVectors() bool {
|
|
return false
|
|
}
|
|
|
|
// escapeSearchValue escapes special characters in search values.
|
|
func escapeSearchValue(value string) string {
|
|
// Escape special RediSearch characters
|
|
replacer := strings.NewReplacer(
|
|
"(", "\\(",
|
|
")", "\\)",
|
|
"[", "\\[",
|
|
"]", "\\]",
|
|
"{", "\\{",
|
|
"}", "\\}",
|
|
"*", "\\*",
|
|
"?", "\\?",
|
|
"|", "\\|",
|
|
"&", "\\&",
|
|
"!", "\\!",
|
|
"@", "\\@",
|
|
"#", "\\#",
|
|
"$", "\\$",
|
|
"%", "\\%",
|
|
"^", "\\^",
|
|
"~", "\\~",
|
|
"`", "\\`",
|
|
"\"", "\\\"",
|
|
"'", "\\'",
|
|
" ", "\\ ",
|
|
"-", "\\-",
|
|
".", "\\.",
|
|
",", "\\,",
|
|
)
|
|
return replacer.Replace(value)
|
|
}
|
|
|
|
// Binary embedding conversion helpers
|
|
func float32SliceToBytes(floats []float32) []byte {
|
|
bytes := make([]byte, len(floats)*4)
|
|
for i, f := range floats {
|
|
binary.LittleEndian.PutUint32(bytes[i*4:], math.Float32bits(f))
|
|
}
|
|
return bytes
|
|
}
|
|
|
|
// isQuerySyntaxError checks whether a lowercased error message indicates an
|
|
// incompatible search query syntax. It covers error strings from Redis Stack,
|
|
// Valkey Search, and other compatible engines.
|
|
func isQuerySyntaxError(errMsg string) bool {
|
|
return strings.Contains(errMsg, "missing `=>`") ||
|
|
strings.Contains(errMsg, "invalid filter") ||
|
|
strings.Contains(errMsg, "invalid query") ||
|
|
strings.Contains(errMsg, "vector query clause is missing")
|
|
}
|
|
|
|
func parseOffsetCursor(cursor *string) (int, error) {
|
|
offset := 0
|
|
if cursor == nil || *cursor == "" {
|
|
return offset, nil
|
|
}
|
|
|
|
parsedOffset, err := strconv.ParseInt(*cursor, 10, 64)
|
|
if err != nil {
|
|
// Keep existing behavior: malformed cursor is treated as offset 0.
|
|
return offset, nil
|
|
}
|
|
if parsedOffset > math.MaxInt32 {
|
|
return 0, fmt.Errorf("offset value %d exceeds maximum allowed value", parsedOffset)
|
|
}
|
|
if parsedOffset < 0 {
|
|
return 0, fmt.Errorf("offset value %d cannot be negative", parsedOffset)
|
|
}
|
|
if parsedOffset > 0 {
|
|
offset = int(parsedOffset)
|
|
}
|
|
return offset, nil
|
|
}
|
|
|
|
func parseStringValuesForContains(value interface{}) ([]string, bool) {
|
|
switch v := value.(type) {
|
|
case []string:
|
|
return v, true
|
|
case []interface{}:
|
|
out := make([]string, 0, len(v))
|
|
for _, item := range v {
|
|
out = append(out, fmt.Sprintf("%v", item))
|
|
}
|
|
return out, true
|
|
case string:
|
|
trimmed := strings.TrimSpace(v)
|
|
if trimmed == "" {
|
|
return []string{}, true
|
|
}
|
|
// Redis scan fallback values may be JSON-encoded arrays.
|
|
if strings.HasPrefix(trimmed, "[") {
|
|
var arr []interface{}
|
|
if err := json.Unmarshal([]byte(trimmed), &arr); err == nil {
|
|
out := make([]string, 0, len(arr))
|
|
for _, item := range arr {
|
|
out = append(out, fmt.Sprintf("%v", item))
|
|
}
|
|
return out, true
|
|
}
|
|
}
|
|
return []string{v}, true
|
|
default:
|
|
return []string{fmt.Sprintf("%v", v)}, true
|
|
}
|
|
}
|
|
|
|
func parseQueryContainsValues(value interface{}) ([]string, bool) {
|
|
switch v := value.(type) {
|
|
case []interface{}:
|
|
out := make([]string, 0, len(v))
|
|
for _, item := range v {
|
|
out = append(out, fmt.Sprintf("%v", item))
|
|
}
|
|
return out, true
|
|
case []string:
|
|
return v, true
|
|
default:
|
|
return nil, false
|
|
}
|
|
}
|
|
|
|
func containsAnyString(haystack []string, needles []string) bool {
|
|
if len(needles) == 0 {
|
|
return false
|
|
}
|
|
index := make(map[string]struct{}, len(haystack))
|
|
for _, item := range haystack {
|
|
index[item] = struct{}{}
|
|
}
|
|
for _, needle := range needles {
|
|
if _, ok := index[needle]; ok {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func containsAllStrings(haystack []string, needles []string) bool {
|
|
if len(needles) == 0 {
|
|
return false
|
|
}
|
|
index := make(map[string]struct{}, len(haystack))
|
|
for _, item := range haystack {
|
|
index[item] = struct{}{}
|
|
}
|
|
for _, needle := range needles {
|
|
if _, ok := index[needle]; !ok {
|
|
return false
|
|
}
|
|
}
|
|
return true
|
|
}
|
|
|
|
// buildKey creates a Redis key by combining namespace and id.
|
|
func buildKey(namespace, id string) string {
|
|
return fmt.Sprintf("%s:%s", namespace, id)
|
|
}
|
|
|
|
// newRedisStore creates a new Redis vector store.
|
|
func newRedisStore(_ context.Context, config RedisConfig, logger schemas.Logger) (*RedisStore, error) {
|
|
// Validate required fields
|
|
if config.Addr == nil || config.Addr.GetValue() == "" {
|
|
return nil, fmt.Errorf("redis addr is required")
|
|
}
|
|
if config.Username == nil {
|
|
config.Username = schemas.NewEnvVar("")
|
|
}
|
|
if config.Password == nil {
|
|
config.Password = schemas.NewEnvVar("")
|
|
}
|
|
db := 0
|
|
if config.DB != nil {
|
|
db = config.DB.CoerceInt(0)
|
|
}
|
|
|
|
// TLS configuration
|
|
var tlsConfig *tls.Config
|
|
if config.UseTLS.CoerceBool(false) {
|
|
tlsConfig = &tls.Config{
|
|
MinVersion: tls.VersionTLS12,
|
|
InsecureSkipVerify: config.InsecureSkipVerify.CoerceBool(false),
|
|
}
|
|
if config.CACertPEM != nil && config.CACertPEM.GetValue() != "" {
|
|
rootCAs, err := systemCertPoolWithCA(config.CACertPEM.GetValue())
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to configure Redis TLS CA certificate: %w", err)
|
|
}
|
|
tlsConfig.RootCAs = rootCAs
|
|
}
|
|
}
|
|
|
|
clusterMode := config.ClusterMode.CoerceBool(false)
|
|
|
|
var client redis.UniversalClient
|
|
if clusterMode {
|
|
// Redis Cluster does not support database selection
|
|
if db != 0 {
|
|
return nil, fmt.Errorf("redis cluster mode does not support database selection (DB must be 0)")
|
|
}
|
|
client = redis.NewClusterClient(&redis.ClusterOptions{
|
|
Addrs: []string{config.Addr.GetValue()},
|
|
Username: config.Username.GetValue(),
|
|
Password: config.Password.GetValue(),
|
|
Protocol: 3, // Explicitly use RESP3 protocol
|
|
TLSConfig: tlsConfig,
|
|
PoolSize: config.PoolSize,
|
|
MaxActiveConns: config.MaxActiveConns,
|
|
MinIdleConns: config.MinIdleConns,
|
|
MaxIdleConns: config.MaxIdleConns,
|
|
ConnMaxLifetime: config.ConnMaxLifetime,
|
|
ConnMaxIdleTime: config.ConnMaxIdleTime,
|
|
DialTimeout: config.DialTimeout,
|
|
ReadTimeout: config.ReadTimeout,
|
|
WriteTimeout: config.WriteTimeout,
|
|
})
|
|
} else {
|
|
client = redis.NewClient(&redis.Options{
|
|
Addr: config.Addr.GetValue(),
|
|
Username: config.Username.GetValue(),
|
|
Password: config.Password.GetValue(),
|
|
DB: db,
|
|
Protocol: 3, // Explicitly use RESP3 protocol
|
|
TLSConfig: tlsConfig,
|
|
PoolSize: config.PoolSize,
|
|
MaxActiveConns: config.MaxActiveConns,
|
|
MinIdleConns: config.MinIdleConns,
|
|
MaxIdleConns: config.MaxIdleConns,
|
|
ConnMaxLifetime: config.ConnMaxLifetime,
|
|
ConnMaxIdleTime: config.ConnMaxIdleTime,
|
|
DialTimeout: config.DialTimeout,
|
|
ReadTimeout: config.ReadTimeout,
|
|
WriteTimeout: config.WriteTimeout,
|
|
})
|
|
}
|
|
|
|
// Creating store connection
|
|
store := &RedisStore{
|
|
client: client,
|
|
config: config,
|
|
logger: logger,
|
|
namespaceFieldTypes: make(map[string]map[string]VectorStorePropertyType),
|
|
}
|
|
// Eagerly verify connectivity, consistent with other store constructors (e.g. Qdrant)
|
|
if err := store.Ping(context.Background()); err != nil {
|
|
return nil, fmt.Errorf("failed to connect to redis: %w", err)
|
|
}
|
|
return store, nil
|
|
}
|
|
|
|
func systemCertPoolWithCA(caCertPEM string) (*x509.CertPool, error) {
|
|
rootCAs, err := x509.SystemCertPool()
|
|
if err != nil {
|
|
rootCAs = x509.NewCertPool()
|
|
}
|
|
if !rootCAs.AppendCertsFromPEM([]byte(caCertPEM)) {
|
|
return nil, fmt.Errorf("failed to parse CA certificate PEM")
|
|
}
|
|
return rootCAs, nil
|
|
}
|