package vectorstore import ( "context" "crypto/tls" "crypto/x509" "encoding/binary" "encoding/json" "fmt" "math" "sort" "strconv" "strings" "sync" "time" "github.com/maximhq/bifrost/core/schemas" "github.com/redis/go-redis/v9" ) const ( // BatchLimit is the default limit used for pagination and batch operations BatchLimit = 100 // RedisMaxSearchResults is the maximum number of results Redis Search returns in a single query. // This is the default MAXSEARCHRESULTS configuration in Redis Search. RedisMaxSearchResults = 10000 ) type RedisConfig struct { // Connection settings Addr *schemas.EnvVar `json:"addr"` // Redis server address (host:port) - REQUIRED Username *schemas.EnvVar `json:"username,omitempty"` // Username for Redis AUTH (optional) Password *schemas.EnvVar `json:"password,omitempty"` // Password for Redis AUTH (optional) DB *schemas.EnvVar `json:"db,omitempty"` // Redis database number (default: 0) // TLS settings UseTLS *schemas.EnvVar `json:"use_tls,omitempty"` // Enable TLS for connection (default: false) InsecureSkipVerify *schemas.EnvVar `json:"insecure_skip_verify,omitempty"` // Skip TLS cert verification (default: false) CACertPEM *schemas.EnvVar `json:"ca_cert_pem,omitempty"` // PEM-encoded CA certificate to trust for Redis/Valkey TLS // Cluster mode ClusterMode *schemas.EnvVar `json:"cluster_mode,omitempty"` // Use Redis Cluster client (default: false) // Connection pool and timeout settings (passed directly to Redis client) PoolSize int `json:"pool_size,omitempty"` // Maximum number of socket connections (optional) MaxActiveConns int `json:"max_active_conns,omitempty"` // Maximum number of active connections (optional) MinIdleConns int `json:"min_idle_conns,omitempty"` // Minimum number of idle connections (optional) MaxIdleConns int `json:"max_idle_conns,omitempty"` // Maximum number of idle connections (optional) ConnMaxLifetime time.Duration `json:"conn_max_lifetime,omitempty"` // Connection maximum lifetime (optional) ConnMaxIdleTime time.Duration `json:"conn_max_idle_time,omitempty"` // Connection maximum idle time (optional) DialTimeout time.Duration `json:"dial_timeout,omitempty"` // Timeout for socket connection (optional) ReadTimeout time.Duration `json:"read_timeout,omitempty"` // Timeout for socket reads (optional) WriteTimeout time.Duration `json:"write_timeout,omitempty"` // Timeout for socket writes (optional) ContextTimeout time.Duration `json:"context_timeout,omitempty"` // Timeout for Redis operations (optional) } // RedisStore represents the Redis vector store. type RedisStore struct { client redis.UniversalClient config RedisConfig logger schemas.Logger namespaceFieldTypesMu sync.RWMutex namespaceFieldTypes map[string]map[string]VectorStorePropertyType } // Ping checks if the Redis server is reachable. func (s *RedisStore) Ping(ctx context.Context) error { return s.client.Ping(ctx).Err() } // CreateNamespace creates a new namespace in the Redis vector store. func (s *RedisStore) CreateNamespace(ctx context.Context, namespace string, dimension int, properties map[string]VectorStoreProperties) error { ctx, cancel := withTimeout(ctx, s.config.ContextTimeout) defer cancel() // Check if index already exists infoResult := s.client.Do(ctx, "FT.INFO", namespace) if infoResult.Err() == nil { s.cacheNamespaceFieldTypes(namespace, properties) return nil // Index already exists } if err := infoResult.Err(); err != nil && strings.Contains(strings.ToLower(err.Error()), "unknown command") { return fmt.Errorf("search module not available: please use Redis Stack or a Valkey bundle with search support (FT.* commands required). original error: %w", err) } // Extract metadata field names from properties var metadataFields []string for fieldName := range properties { metadataFields = append(metadataFields, fieldName) } // Create index with VECTOR field + metadata fields keyPrefix := fmt.Sprintf("%s:", namespace) if dimension <= 0 { return fmt.Errorf("redis vector index %q: dimension must be > 0 (got %d)", namespace, dimension) } args := []interface{}{ "FT.CREATE", namespace, "ON", "HASH", "PREFIX", "1", keyPrefix, "SCHEMA", // Native vector field with HNSW algorithm "embedding", "VECTOR", "HNSW", "6", "TYPE", "FLOAT32", "DIM", dimension, "DISTANCE_METRIC", "COSINE", } // Add all metadata fields as TEXT with exact matching // All values are converted to strings for consistent searching for _, field := range metadataFields { // Detect field type from VectorStoreProperties prop := properties[field] switch prop.DataType { case VectorStorePropertyTypeInteger: args = append(args, field, "NUMERIC") default: args = append(args, field, "TAG") } } // Create the index if err := s.client.Do(ctx, args...).Err(); err != nil { return fmt.Errorf("failed to create semantic vector index %s: %w", namespace, err) } s.cacheNamespaceFieldTypes(namespace, properties) return nil } // GetChunk retrieves a chunk from the Redis vector store. func (s *RedisStore) GetChunk(ctx context.Context, namespace string, id string) (SearchResult, error) { ctx, cancel := withTimeout(ctx, s.config.ContextTimeout) defer cancel() if strings.TrimSpace(id) == "" { return SearchResult{}, fmt.Errorf("id is required") } // Create key with namespace key := buildKey(namespace, id) // Get all fields from the hash result := s.client.HGetAll(ctx, key) if result.Err() != nil { return SearchResult{}, fmt.Errorf("failed to get chunk: %w", result.Err()) } fields := result.Val() if len(fields) == 0 { return SearchResult{}, fmt.Errorf("chunk not found: %s", id) } // Build SearchResult searchResult := SearchResult{ ID: id, Properties: make(map[string]interface{}), } // Parse fields for k, v := range fields { searchResult.Properties[k] = v } return searchResult, nil } // GetChunks retrieves multiple chunks from the Redis vector store. func (s *RedisStore) GetChunks(ctx context.Context, namespace string, ids []string) ([]SearchResult, error) { ctx, cancel := withTimeout(ctx, s.config.ContextTimeout) defer cancel() if len(ids) == 0 { return []SearchResult{}, nil } // Create keys with namespace keys := make([]string, len(ids)) for i, id := range ids { if strings.TrimSpace(id) == "" { return nil, fmt.Errorf("id cannot be empty at index %d", i) } keys[i] = buildKey(namespace, id) } // Use pipeline for efficient batch retrieval pipe := s.client.Pipeline() cmds := make([]*redis.MapStringStringCmd, len(keys)) for i, key := range keys { cmds[i] = pipe.HGetAll(ctx, key) } // Execute pipeline _, err := pipe.Exec(ctx) if err != nil { return nil, fmt.Errorf("failed to execute pipeline: %w", err) } // Process results var results []SearchResult for i, cmd := range cmds { if cmd.Err() != nil { // Log error but continue with other results s.logger.Debug(fmt.Sprintf("failed to get chunk %s: %v", ids[i], cmd.Err())) continue } fields := cmd.Val() if len(fields) == 0 { // Chunk not found, skip continue } // Build SearchResult searchResult := SearchResult{ ID: ids[i], Properties: make(map[string]interface{}), } // Parse fields for k, v := range fields { searchResult.Properties[k] = v } results = append(results, searchResult) } return results, nil } // GetAll retrieves all chunks from the Redis vector store. func (s *RedisStore) GetAll(ctx context.Context, namespace string, queries []Query, selectFields []string, cursor *string, limit int64) ([]SearchResult, *string, error) { ctx, cancel := withTimeout(ctx, s.config.ContextTimeout) defer cancel() // Set default limit if not provided if limit < 0 { limit = BatchLimit } // Build Redis query from the provided queries redisQuery := buildRedisQuery(queries, s.getNamespaceFieldTypes(namespace)) // When limit=0 (get all), use internal pagination to avoid exceeding Redis MAXSEARCHRESULTS if limit == 0 { return s.getAllWithPagination(ctx, namespace, redisQuery, queries, selectFields) } // For explicit limit, cap to Redis maximum and use single query with cursor support searchLimit := limit if searchLimit > RedisMaxSearchResults { searchLimit = RedisMaxSearchResults } // Add OFFSET for pagination if cursor is provided offset, err := parseOffsetCursor(cursor) if err != nil { return nil, nil, err } results, err := s.executeSearch(ctx, namespace, redisQuery, queries, selectFields, offset, int(searchLimit)) if err != nil { return nil, nil, err } // Implement cursor-based pagination using OFFSET var nextCursor *string = nil if cursor != nil && *cursor != "" { if len(results) == int(limit) && limit > 0 { offset, err := strconv.ParseInt(*cursor, 10, 64) if err == nil { nextOffset := offset + limit nextCursorStr := strconv.FormatInt(nextOffset, 10) nextCursor = &nextCursorStr } } } else if len(results) == int(limit) && limit > 0 { nextCursorStr := strconv.FormatInt(limit, 10) nextCursor = &nextCursorStr } return results, nextCursor, nil } // getAllWithPagination fetches all matching results using internal pagination to avoid // exceeding Redis Search's MAXSEARCHRESULTS limit (default 10,000). func (s *RedisStore) getAllWithPagination(ctx context.Context, namespace string, redisQuery string, queries []Query, selectFields []string) ([]SearchResult, *string, error) { var allResults []SearchResult offset := 0 for { pageResults, err := s.executeSearch(ctx, namespace, redisQuery, queries, selectFields, offset, RedisMaxSearchResults) if err != nil { return nil, nil, err } if len(pageResults) == 0 { break } allResults = append(allResults, pageResults...) if len(pageResults) < RedisMaxSearchResults { break } offset += len(pageResults) } return allResults, nil, nil } // executeSearch performs a single FT.SEARCH query with the given offset and limit. func (s *RedisStore) executeSearch(ctx context.Context, namespace string, redisQuery string, queries []Query, selectFields []string, offset int, searchLimit int) ([]SearchResult, error) { args := []interface{}{ "FT.SEARCH", namespace, redisQuery, } if len(selectFields) > 0 { args = append(args, "RETURN", len(selectFields)) for _, field := range selectFields { args = append(args, field) } } args = append(args, "LIMIT", offset, searchLimit, "DIALECT", "2") result := s.client.Do(ctx, args...) if result.Err() != nil { errMsg := strings.ToLower(result.Err().Error()) if isQuerySyntaxError(errMsg) { s.logger.Debug(fmt.Sprintf("FT.SEARCH DIALECT fallback triggered for namespace %s: %s", namespace, result.Err())) compatArgs := make([]interface{}, 0, len(args)-2) for i := 0; i < len(args); i++ { if i+1 < len(args) && args[i] == "DIALECT" { i++ continue } compatArgs = append(compatArgs, args[i]) } result = s.client.Do(ctx, compatArgs...) } if result.Err() != nil { errMsg = strings.ToLower(result.Err().Error()) if isQuerySyntaxError(errMsg) { if IsScanFallbackDisabled(ctx) { return nil, fmt.Errorf("%w: %w", ErrQuerySyntax, result.Err()) } s.logger.Debug(fmt.Sprintf("FT.SEARCH scan fallback triggered for namespace %s: %s", namespace, result.Err())) scanResults, _, scanErr := s.getAllByScan(ctx, namespace, queries, selectFields, nil, int64(searchLimit)) if scanErr != nil { return nil, scanErr } return scanResults, nil } return nil, fmt.Errorf("failed to search: %w", result.Err()) } } results, err := s.parseSearchResults(result.Val(), namespace, selectFields) if err != nil { return nil, fmt.Errorf("failed to parse search results: %w", err) } return results, nil } func (s *RedisStore) getAllByScan(ctx context.Context, namespace string, queries []Query, selectFields []string, cursor *string, limit int64) ([]SearchResult, *string, error) { // Parse offset for deterministic in-memory pagination after full scan. offset, err := parseOffsetCursor(cursor) if err != nil { return nil, nil, err } all, err := s.scanAllMatchingResults(ctx, namespace, queries, selectFields) if err != nil { return nil, nil, err } // Ensure stable pagination boundaries for offset cursors across calls. sort.Slice(all, func(i, j int) bool { return all[i].ID < all[j].ID }) if offset > len(all) { offset = len(all) } if limit == 0 { return all[offset:], nil, nil } if limit < 0 { limit = BatchLimit } end := offset + int(limit) if end > len(all) { end = len(all) } results := all[offset:end] var next *string if end < len(all) { nextCursorStr := strconv.Itoa(end) next = &nextCursorStr } return results, next, nil } func (s *RedisStore) scanAllMatchingResults(ctx context.Context, namespace string, queries []Query, selectFields []string) ([]SearchResult, error) { if clusterClient, ok := s.client.(*redis.ClusterClient); ok { return s.scanAllMatchingResultsCluster(ctx, clusterClient, namespace, queries, selectFields) } return s.scanAllMatchingResultsSingle(ctx, s.client, namespace, queries, selectFields) } func (s *RedisStore) scanAllMatchingResultsSingle(ctx context.Context, client redis.Cmdable, namespace string, queries []Query, selectFields []string) ([]SearchResult, error) { pattern := buildKey(namespace, "*") var ( scanCursor uint64 all []SearchResult ) for { keys, nextCursor, err := client.Scan(ctx, scanCursor, pattern, BatchLimit).Result() if err != nil { return nil, fmt.Errorf("failed to scan keys: %w", err) } matches, err := s.fetchMatchingSearchResults(ctx, client, namespace, keys, queries, selectFields) if err != nil { return nil, err } all = append(all, matches...) scanCursor = nextCursor if scanCursor == 0 { break } } return all, nil } func (s *RedisStore) scanAllMatchingResultsCluster(ctx context.Context, client *redis.ClusterClient, namespace string, queries []Query, selectFields []string) ([]SearchResult, error) { var ( all []SearchResult allMu sync.Mutex seenIDs = make(map[string]struct{}) seenIDsMu sync.Mutex ) err := client.ForEachMaster(ctx, func(ctx context.Context, nodeClient *redis.Client) error { matches, err := s.scanAllMatchingResultsSingle(ctx, nodeClient, namespace, queries, selectFields) if err != nil { return err } unique := make([]SearchResult, 0, len(matches)) seenIDsMu.Lock() for _, match := range matches { if _, ok := seenIDs[match.ID]; ok { continue } seenIDs[match.ID] = struct{}{} unique = append(unique, match) } seenIDsMu.Unlock() if len(unique) == 0 { return nil } allMu.Lock() all = append(all, unique...) allMu.Unlock() return nil }) if err != nil { return nil, fmt.Errorf("failed to scan cluster nodes: %w", err) } return all, nil } func (s *RedisStore) fetchMatchingSearchResults(ctx context.Context, client redis.Cmdable, namespace string, keys []string, queries []Query, selectFields []string) ([]SearchResult, error) { if len(keys) == 0 { return nil, nil } pipe := client.Pipeline() cmds := make([]*redis.MapStringStringCmd, len(keys)) for i, key := range keys { cmds[i] = pipe.HGetAll(ctx, key) } if _, err := pipe.Exec(ctx); err != nil { return nil, fmt.Errorf("failed to fetch scanned keys: %w", err) } results := make([]SearchResult, 0, len(keys)) for i, cmd := range cmds { if cmd.Err() != nil { continue } fields := cmd.Val() if len(fields) == 0 { continue } key := keys[i] id := strings.TrimPrefix(key, namespace+":") if id == key { continue } properties := make(map[string]interface{}, len(fields)) for k, v := range fields { properties[k] = v } if !matchesQueriesForScan(properties, queries) { continue } searchResult := SearchResult{ ID: id, Properties: make(map[string]interface{}), } if len(selectFields) == 0 { searchResult.Properties = properties } else { for _, field := range selectFields { if val, ok := properties[field]; ok { searchResult.Properties[field] = val } } } results = append(results, searchResult) } return results, nil } func matchesQueriesForScan(properties map[string]interface{}, queries []Query) bool { for _, q := range queries { raw, exists := properties[q.Field] // NOTE: missing fields are treated as non-matching for most operators // (Equal, Like, GreaterThan, etc.) but pass NotEqual — i.e. a document // without the field is considered "not equal" to any value. This differs // from SQL NULL semantics where NULL != value evaluates to NULL/unknown. // Change this if scan results need to match FT.SEARCH behavior exactly. rawStr := fmt.Sprintf("%v", raw) queryStr := fmt.Sprintf("%v", q.Value) switch q.Operator { case QueryOperatorEqual: if !exists || rawStr != queryStr { return false } case QueryOperatorNotEqual: if exists && rawStr == queryStr { return false } case QueryOperatorIsNull: if exists { return false } case QueryOperatorIsNotNull: if !exists { return false } case QueryOperatorLike: if !exists || !strings.Contains(strings.ToLower(rawStr), strings.ToLower(queryStr)) { return false } case QueryOperatorGreaterThan: if !exists { return false } rawF, errR := strconv.ParseFloat(rawStr, 64) queryF, errQ := strconv.ParseFloat(queryStr, 64) if errR != nil || errQ != nil || rawF <= queryF { return false } case QueryOperatorGreaterThanOrEqual: if !exists { return false } rawF, errR := strconv.ParseFloat(rawStr, 64) queryF, errQ := strconv.ParseFloat(queryStr, 64) if errR != nil || errQ != nil || rawF < queryF { return false } case QueryOperatorLessThan: if !exists { return false } rawF, errR := strconv.ParseFloat(rawStr, 64) queryF, errQ := strconv.ParseFloat(queryStr, 64) if errR != nil || errQ != nil || rawF >= queryF { return false } case QueryOperatorLessThanOrEqual: if !exists { return false } rawF, errR := strconv.ParseFloat(rawStr, 64) queryF, errQ := strconv.ParseFloat(queryStr, 64) if errR != nil || errQ != nil || rawF > queryF { return false } case QueryOperatorContainsAny: if !exists { return false } propertyValues, ok := parseStringValuesForContains(raw) if !ok { return false } queryValues, ok := parseQueryContainsValues(q.Value) if !ok { return false } if !containsAnyString(propertyValues, queryValues) { return false } case QueryOperatorContainsAll: if !exists { return false } propertyValues, ok := parseStringValuesForContains(raw) if !ok { return false } queryValues, ok := parseQueryContainsValues(q.Value) if !ok { return false } if !containsAllStrings(propertyValues, queryValues) { return false } default: // Conservative fallback: require exact match semantics for unsupported operators. if !exists || rawStr != queryStr { return false } } } return true } // parseSearchResults parses FT.SEARCH results into SearchResult slice. func (s *RedisStore) parseSearchResults(result interface{}, namespace string, selectFields []string) ([]SearchResult, error) { results := []SearchResult{} // RESP3 style in Redis/Valkey: // map{ "results": [ { "id": "...", "extra_attributes": {...} } ] } switch typed := result.(type) { case map[interface{}]interface{}: rawResults, ok := typed["results"] if !ok { return results, nil } resultItems, ok := rawResults.([]interface{}) if !ok { return results, nil } for _, item := range resultItems { if parsed, ok := parseSearchResultDocument(item, namespace, selectFields); ok { results = append(results, parsed) } } return results, nil case map[string]interface{}: rawResults, ok := typed["results"] if !ok { return results, nil } resultItems, ok := rawResults.([]interface{}) if !ok { return results, nil } for _, item := range resultItems { if parsed, ok := parseSearchResultDocument(item, namespace, selectFields); ok { results = append(results, parsed) } } return results, nil case []interface{}: // RESP2 style in Redis/Valkey: // [total, "namespace:id", ["field", "value", ...], ...] if len(typed) < 3 { return results, nil } for i := 1; i+1 < len(typed); i += 2 { idValue := typed[i] attrsValue := typed[i+1] doc := map[string]interface{}{ "id": idValue, "extra_attributes": attrsValue, } if parsed, ok := parseSearchResultDocument(doc, namespace, selectFields); ok { results = append(results, parsed) } } return results, nil default: return results, nil } } func parseSearchResultIDs(result interface{}, namespace string) []string { ids := make([]string, 0) appendID := func(value interface{}) { id, ok := toString(value) if !ok { return } id = strings.TrimSpace(id) if id == "" { return } if namespace != "" { prefix := namespace + ":" if strings.HasPrefix(id, prefix) { id = strings.TrimPrefix(id, prefix) } } if id == "" { return } ids = append(ids, id) } extractRESP3IDs := func(rawResults interface{}) { resultItems, ok := rawResults.([]interface{}) if !ok { return } for _, item := range resultItems { switch doc := item.(type) { case map[string]interface{}: appendID(doc["id"]) case map[interface{}]interface{}: appendID(doc["id"]) default: appendID(item) } } } switch typed := result.(type) { case map[interface{}]interface{}: extractRESP3IDs(typed["results"]) case map[string]interface{}: extractRESP3IDs(typed["results"]) case []interface{}: if len(typed) < 2 { return ids } for i := 1; i < len(typed); i++ { appendID(typed[i]) // RESP2 payloads can be [total, id, attrs, id, attrs, ...]. if i+1 < len(typed) { switch typed[i+1].(type) { case []interface{}, map[string]interface{}, map[interface{}]interface{}: i++ } } } } return ids } func parseSearchResultDocument(resultItem interface{}, namespace string, selectFields []string) (SearchResult, bool) { var docMap map[string]interface{} switch item := resultItem.(type) { case map[string]interface{}: docMap = item case map[interface{}]interface{}: docMap = make(map[string]interface{}, len(item)) for k, v := range item { docMap[fmt.Sprintf("%v", k)] = v } default: return SearchResult{}, false } idRaw, ok := docMap["id"] if !ok { return SearchResult{}, false } id, ok := toString(idRaw) if !ok { return SearchResult{}, false } docID := id if namespace != "" { prefix := namespace + ":" if strings.HasPrefix(id, prefix) { docID = strings.TrimPrefix(id, prefix) } } attrsRaw, ok := docMap["extra_attributes"] if !ok { return SearchResult{}, false } attrs := attributesToMap(attrsRaw) if attrs == nil { return SearchResult{}, false } searchResult := SearchResult{ ID: docID, Properties: make(map[string]interface{}, len(attrs)), } for fieldName, fieldValue := range attrs { if fieldName == "score" { searchResult.Properties[fieldName] = fieldValue if scoreFloat, ok := toFloat64(fieldValue); ok { searchResult.Score = &scoreFloat } continue } if len(selectFields) > 0 && !containsField(selectFields, fieldName) { continue } searchResult.Properties[fieldName] = fieldValue } return searchResult, true } func attributesToMap(value interface{}) map[string]interface{} { switch attrs := value.(type) { case map[string]interface{}: return attrs case map[interface{}]interface{}: out := make(map[string]interface{}, len(attrs)) for k, v := range attrs { out[fmt.Sprintf("%v", k)] = v } return out case []interface{}: // RESP2 attribute pairs: ["field", "value", "field2", "value2", ...] if len(attrs)%2 != 0 { return nil } out := make(map[string]interface{}, len(attrs)/2) for i := 0; i+1 < len(attrs); i += 2 { key, ok := toString(attrs[i]) if !ok { continue } out[key] = attrs[i+1] } return out default: return nil } } func toString(value interface{}) (string, bool) { switch v := value.(type) { case string: return v, true case []byte: return string(v), true default: return "", false } } func toFloat64(value interface{}) (float64, bool) { switch v := value.(type) { case float64: return v, true case float32: return float64(v), true case int: return float64(v), true case int64: return float64(v), true case string: parsed, err := strconv.ParseFloat(v, 64) if err != nil { return 0, false } return parsed, true case []byte: parsed, err := strconv.ParseFloat(string(v), 64) if err != nil { return 0, false } return parsed, true default: return 0, false } } func containsField(fields []string, candidate string) bool { for _, field := range fields { if field == candidate { return true } } return false } func (s *RedisStore) cacheNamespaceFieldTypes(namespace string, properties map[string]VectorStoreProperties) { if strings.TrimSpace(namespace) == "" || len(properties) == 0 { return } fieldTypes := make(map[string]VectorStorePropertyType, len(properties)) for field, prop := range properties { fieldTypes[field] = prop.DataType } s.namespaceFieldTypesMu.Lock() defer s.namespaceFieldTypesMu.Unlock() if s.namespaceFieldTypes == nil { s.namespaceFieldTypes = make(map[string]map[string]VectorStorePropertyType) } s.namespaceFieldTypes[namespace] = fieldTypes } func (s *RedisStore) deleteNamespaceFieldTypes(namespace string) { if strings.TrimSpace(namespace) == "" { return } s.namespaceFieldTypesMu.Lock() defer s.namespaceFieldTypesMu.Unlock() delete(s.namespaceFieldTypes, namespace) } func (s *RedisStore) getNamespaceFieldTypes(namespace string) map[string]VectorStorePropertyType { if strings.TrimSpace(namespace) == "" { return nil } s.namespaceFieldTypesMu.RLock() defer s.namespaceFieldTypesMu.RUnlock() fieldTypes, ok := s.namespaceFieldTypes[namespace] if !ok { return nil } copied := make(map[string]VectorStorePropertyType, len(fieldTypes)) for field, dataType := range fieldTypes { copied[field] = dataType } return copied } // buildRedisQuery converts []Query to Redis query syntax func buildRedisQuery(queries []Query, fieldTypes map[string]VectorStorePropertyType) string { if len(queries) == 0 { return "*" } var conditions []string for _, query := range queries { condition := buildRedisQueryCondition(query, fieldTypes) if condition != "" { conditions = append(conditions, condition) } } if len(conditions) == 0 { return "*" } // Join conditions with space (AND operation in Redis) return strings.Join(conditions, " ") } func shouldUseNumericEquality(field string, value interface{}, fieldTypes map[string]VectorStorePropertyType) (string, bool) { if fieldTypes != nil { if dataType, ok := fieldTypes[field]; ok { if dataType == VectorStorePropertyTypeInteger { return normalizeNumericQueryValue(value) } return "", false } } return normalizeNumericQueryValue(value) } func normalizeNumericQueryValue(value interface{}) (string, bool) { switch v := value.(type) { case int: return strconv.FormatInt(int64(v), 10), true case int8: return strconv.FormatInt(int64(v), 10), true case int16: return strconv.FormatInt(int64(v), 10), true case int32: return strconv.FormatInt(int64(v), 10), true case int64: return strconv.FormatInt(v, 10), true case uint: return strconv.FormatUint(uint64(v), 10), true case uint8: return strconv.FormatUint(uint64(v), 10), true case uint16: return strconv.FormatUint(uint64(v), 10), true case uint32: return strconv.FormatUint(uint64(v), 10), true case uint64: return strconv.FormatUint(v, 10), true case float32: return strconv.FormatFloat(float64(v), 'f', -1, 32), true case float64: return strconv.FormatFloat(v, 'f', -1, 64), true case string: trimmed := strings.TrimSpace(v) if trimmed == "" { return "", false } if _, err := strconv.ParseFloat(trimmed, 64); err != nil { return "", false } return trimmed, true default: return "", false } } // buildRedisQueryCondition builds a single Redis query condition func buildRedisQueryCondition(query Query, fieldTypes map[string]VectorStorePropertyType) string { field := query.Field operator := query.Operator value := query.Value // Convert value to string var stringValue string switch val := value.(type) { case string: stringValue = val case int, int64, float64, bool: stringValue = fmt.Sprintf("%v", val) default: jsonData, _ := json.Marshal(val) stringValue = string(jsonData) } // Escape special characters for TAG fields escapedValue := escapeSearchValue(stringValue) // new function for TAG escaping switch operator { case QueryOperatorEqual: if numericValue, useNumeric := shouldUseNumericEquality(field, value, fieldTypes); useNumeric { return fmt.Sprintf("@%s:[%s %s]", field, numericValue, numericValue) } // TAG exact match return fmt.Sprintf("@%s:{%s}", field, escapedValue) case QueryOperatorNotEqual: if numericValue, useNumeric := shouldUseNumericEquality(field, value, fieldTypes); useNumeric { return fmt.Sprintf("-@%s:[%s %s]", field, numericValue, numericValue) } // TAG negation return fmt.Sprintf("-@%s:{%s}", field, escapedValue) case QueryOperatorLike: // Cannot do LIKE with TAGs directly; fallback to exact match return fmt.Sprintf("@%s:{%s}", field, escapedValue) case QueryOperatorGreaterThan: return fmt.Sprintf("@%s:[(%s +inf]", field, escapedValue) case QueryOperatorGreaterThanOrEqual: return fmt.Sprintf("@%s:[%s +inf]", field, escapedValue) case QueryOperatorLessThan: return fmt.Sprintf("@%s:[-inf (%s]", field, escapedValue) case QueryOperatorLessThanOrEqual: return fmt.Sprintf("@%s:[-inf %s]", field, escapedValue) case QueryOperatorIsNull: // Field not present return fmt.Sprintf("-@%s:*", field) case QueryOperatorIsNotNull: // Field exists return fmt.Sprintf("@%s:*", field) case QueryOperatorContainsAny: if values, ok := value.([]interface{}); ok { var orConditions []string for _, v := range values { vStr := fmt.Sprintf("%v", v) orConditions = append(orConditions, fmt.Sprintf("@%s:{%s}", field, escapeSearchValue(vStr))) } return fmt.Sprintf("(%s)", strings.Join(orConditions, " | ")) } return fmt.Sprintf("@%s:{%s}", field, escapedValue) case QueryOperatorContainsAll: if values, ok := value.([]interface{}); ok { var andConditions []string for _, v := range values { vStr := fmt.Sprintf("%v", v) andConditions = append(andConditions, fmt.Sprintf("@%s:{%s}", field, escapeSearchValue(vStr))) } return strings.Join(andConditions, " ") } return fmt.Sprintf("@%s:{%s}", field, escapedValue) default: return fmt.Sprintf("@%s:{%s}", field, escapedValue) } } // GetNearest retrieves the nearest chunks from the Redis vector store. func (s *RedisStore) GetNearest(ctx context.Context, namespace string, vector []float32, queries []Query, selectFields []string, threshold float64, limit int64) ([]SearchResult, error) { ctx, cancel := withTimeout(ctx, s.config.ContextTimeout) defer cancel() // Build Redis query from the provided queries redisQuery := buildRedisQuery(queries, s.getNamespaceFieldTypes(namespace)) // Convert query embedding to binary format queryBytes := float32SliceToBytes(vector) // Build hybrid FT.SEARCH query: metadata filters + KNN vector search // The correct syntax is: (metadata_filter)=>[KNN k @embedding $vec AS score] var hybridQuery string if len(queries) > 0 { // Wrap metadata query in parentheses for hybrid syntax hybridQuery = fmt.Sprintf("(%s)", redisQuery) } else { // Wildcard for pure vector search hybridQuery = "*" } // Execute FT.SEARCH with KNN // Use large limit for "all" (limit=0) in KNN query knnLimit := limit if limit == 0 { knnLimit = math.MaxInt32 } args := []interface{}{ "FT.SEARCH", namespace, fmt.Sprintf("%s=>[KNN %d @embedding $vec AS score]", hybridQuery, knnLimit), "PARAMS", "2", "vec", queryBytes, "SORTBY", "score", } // Add RETURN clause - always include score for vector search // For vector search, we need to include the score field generated by KNN returnFields := []string{"score"} if len(selectFields) > 0 { returnFields = append(returnFields, selectFields...) } args = append(args, "RETURN", len(returnFields)) for _, field := range returnFields { args = append(args, field) } // Add LIMIT clause and DIALECT 2 for better query parsing searchLimit := limit if limit == 0 { searchLimit = math.MaxInt32 } args = append(args, "LIMIT", 0, int(searchLimit), "DIALECT", "2") result := s.client.Do(ctx, args...) if result.Err() != nil { errMsg := strings.ToLower(result.Err().Error()) // Some Valkey implementations reject SORTBY in KNN search (already distance-ordered). if strings.Contains(errMsg, "unexpected argument `sortby`") || strings.Contains(errMsg, "unexpected argument sortby") { compatArgs := make([]interface{}, 0, len(args)-2) for i := 0; i < len(args); i++ { if i+1 < len(args) && args[i] == "SORTBY" { i++ // skip sort field value too continue } compatArgs = append(compatArgs, args[i]) } result = s.client.Do(ctx, compatArgs...) } if result.Err() != nil { return nil, fmt.Errorf("native vector search failed: %w", result.Err()) } } // Parse search results results, err := s.parseSearchResults(result.Val(), namespace, selectFields) if err != nil { return nil, err } // Apply threshold filter and extract scores var filteredResults []SearchResult for _, result := range results { // Extract score from the result if scoreValue, exists := result.Properties["score"]; exists { score, ok := toFloat64(scoreValue) if !ok { continue } // Convert cosine distance to similarity: similarity = 1 - distance similarity := 1.0 - score result.Score = &similarity // Apply threshold filter if similarity >= threshold { filteredResults = append(filteredResults, result) } } else { // If no score, include the result (shouldn't happen with KNN queries) filteredResults = append(filteredResults, result) } } results = filteredResults return results, nil } // Add stores a new chunk in the Redis vector store. func (s *RedisStore) Add(ctx context.Context, namespace string, id string, embedding []float32, metadata map[string]interface{}) error { ctx, cancel := withTimeout(ctx, s.config.ContextTimeout) defer cancel() if strings.TrimSpace(id) == "" { return fmt.Errorf("id is required") } // Create key with namespace key := buildKey(namespace, id) // Prepare hash fields: binary embedding + metadata fields := make(map[string]interface{}) // Only add embedding if it's not empty if len(embedding) > 0 { // Convert float32 slice to bytes for Redis storage embeddingBytes := float32SliceToBytes(embedding) fields["embedding"] = embeddingBytes } // Add metadata fields directly (no prefix needed with proper indexing) for k, v := range metadata { switch val := v.(type) { case string: fields[k] = val case int, int64, float64, bool: fields[k] = fmt.Sprintf("%v", val) case []interface{}: // Preserve arrays as JSON to support round-trips (e.g., stream_chunks) b, err := json.Marshal(val) if err != nil { return fmt.Errorf("failed to marshal array metadata %s: %w", k, err) } fields[k] = string(b) default: // JSON encode complex types jsonData, err := json.Marshal(val) if err != nil { return fmt.Errorf("failed to marshal metadata field %s: %w", k, err) } fields[k] = string(jsonData) } } // Store as hash for efficient native vector search if err := s.client.HSet(ctx, key, fields).Err(); err != nil { return fmt.Errorf("failed to store semantic cache entry: %w", err) } return nil } // Delete deletes a chunk from the Redis vector store. func (s *RedisStore) Delete(ctx context.Context, namespace string, id string) error { ctx, cancel := withTimeout(ctx, s.config.ContextTimeout) defer cancel() if strings.TrimSpace(id) == "" { return fmt.Errorf("id is required") } // Create key with namespace key := buildKey(namespace, id) // Delete the hash key result := s.client.Del(ctx, key) if result.Err() != nil { return fmt.Errorf("failed to delete chunk %s: %w", id, result.Err()) } // Check if the key actually existed if result.Val() == 0 { return fmt.Errorf("chunk not found: %s", id) } return nil } // DeleteAll deletes all chunks from the Redis vector store. func (s *RedisStore) DeleteAll(ctx context.Context, namespace string, queries []Query) ([]DeleteResult, error) { ctx, cancel := withTimeout(ctx, s.config.ContextTimeout) defer cancel() return s.deleteAllBySnapshot(ctx, namespace, queries) } // deleteAllBySnapshot snapshots matching ids before deleting to avoid // offset/cursor drift while mutating the dataset. func (s *RedisStore) deleteAllBySnapshot(ctx context.Context, namespace string, queries []Query) ([]DeleteResult, error) { ids, err := s.getAllMatchingIDs(ctx, namespace, queries) if err != nil { return nil, fmt.Errorf("failed to find documents to delete: %w", err) } if len(ids) == 0 { return []DeleteResult{}, nil } // Delete this batch of documents var deleteResults []DeleteResult batchSize := BatchLimit // Process in batches to avoid overwhelming Redis for i := 0; i < len(ids); i += batchSize { end := i + batchSize if end > len(ids) { end = len(ids) } batch := ids[i:end] // Create pipeline for batch deletion pipe := s.client.Pipeline() cmds := make([]*redis.IntCmd, len(batch)) for j, id := range batch { key := buildKey(namespace, id) cmds[j] = pipe.Del(ctx, key) } // Execute pipeline _, err := pipe.Exec(ctx) if err != nil { // If pipeline fails, mark all in this batch as failed for _, id := range batch { deleteResults = append(deleteResults, DeleteResult{ ID: id, Status: DeleteStatusError, Error: fmt.Sprintf("pipeline execution failed: %v", err), }) } continue } // Process results for this batch for j, cmd := range cmds { id := batch[j] if cmd.Err() != nil { deleteResults = append(deleteResults, DeleteResult{ ID: id, Status: DeleteStatusError, Error: cmd.Err().Error(), }) } else if cmd.Val() > 0 { // Key existed and was deleted deleteResults = append(deleteResults, DeleteResult{ ID: id, Status: DeleteStatusSuccess, }) } else { // Key didn't exist deleteResults = append(deleteResults, DeleteResult{ ID: id, Status: DeleteStatusError, Error: "document not found", }) } } } return deleteResults, nil } func (s *RedisStore) getAllMatchingIDs(ctx context.Context, namespace string, queries []Query) ([]string, error) { redisQuery := buildRedisQuery(queries, s.getNamespaceFieldTypes(namespace)) offset := 0 ids := make([]string, 0) for { args := []interface{}{ "FT.SEARCH", namespace, redisQuery, "RETURN", 0, "LIMIT", offset, BatchLimit, "DIALECT", "2", } result := s.client.Do(ctx, args...) if result.Err() != nil { errMsg := strings.ToLower(result.Err().Error()) if isQuerySyntaxError(errMsg) { s.logger.Debug(fmt.Sprintf("FT.SEARCH DIALECT fallback triggered for namespace %s while collecting ids: %s", namespace, result.Err())) compatArgs := make([]interface{}, 0, len(args)-2) for i := 0; i < len(args); i++ { if i+1 < len(args) && args[i] == "DIALECT" { i++ continue } compatArgs = append(compatArgs, args[i]) } result = s.client.Do(ctx, compatArgs...) } if result.Err() != nil { errMsg = strings.ToLower(result.Err().Error()) if isQuerySyntaxError(errMsg) { if IsScanFallbackDisabled(ctx) { return nil, fmt.Errorf("failed to collect matching ids without scan fallback: %w", result.Err()) } s.logger.Debug(fmt.Sprintf("FT.SEARCH scan fallback triggered for namespace %s while collecting ids: %s", namespace, result.Err())) scanResults, _, scanErr := s.getAllByScan(ctx, namespace, queries, nil, nil, 0) if scanErr != nil { return nil, fmt.Errorf("failed to collect matching ids via scan fallback: %w", scanErr) } scanIDs := make([]string, 0, len(scanResults)) for _, scanResult := range scanResults { scanIDs = append(scanIDs, scanResult.ID) } return scanIDs, nil } return nil, fmt.Errorf("failed to search for matching ids: %w", result.Err()) } } pageIDs := parseSearchResultIDs(result.Val(), namespace) if len(pageIDs) == 0 { break } ids = append(ids, pageIDs...) if len(pageIDs) < BatchLimit { break } offset += len(pageIDs) } return ids, nil } // DeleteNamespace deletes a namespace from the Redis vector store. func (s *RedisStore) DeleteNamespace(ctx context.Context, namespace string) error { ctx, cancel := withTimeout(ctx, s.config.ContextTimeout) defer cancel() // Drop the index using FT.DROPINDEX if err := s.client.Do(ctx, "FT.DROPINDEX", namespace).Err(); err != nil { // Check if error is "Unknown Index name" - that's OK, index doesn't exist if strings.Contains(strings.ToLower(err.Error()), "unknown index name") { s.deleteNamespaceFieldTypes(namespace) return nil // Index doesn't exist, nothing to drop } return fmt.Errorf("failed to drop semantic index %s: %w", namespace, err) } s.deleteNamespaceFieldTypes(namespace) return nil } // Close closes the Redis vector store. func (s *RedisStore) Close(ctx context.Context, namespace string) error { // Close the Redis client connection return s.client.Close() } // RequiresVectors returns false because Redis can store hash data with or without vectors. func (s *RedisStore) RequiresVectors() bool { return false } // escapeSearchValue escapes special characters in search values. func escapeSearchValue(value string) string { // Escape special RediSearch characters replacer := strings.NewReplacer( "(", "\\(", ")", "\\)", "[", "\\[", "]", "\\]", "{", "\\{", "}", "\\}", "*", "\\*", "?", "\\?", "|", "\\|", "&", "\\&", "!", "\\!", "@", "\\@", "#", "\\#", "$", "\\$", "%", "\\%", "^", "\\^", "~", "\\~", "`", "\\`", "\"", "\\\"", "'", "\\'", " ", "\\ ", "-", "\\-", ".", "\\.", ",", "\\,", ) return replacer.Replace(value) } // Binary embedding conversion helpers func float32SliceToBytes(floats []float32) []byte { bytes := make([]byte, len(floats)*4) for i, f := range floats { binary.LittleEndian.PutUint32(bytes[i*4:], math.Float32bits(f)) } return bytes } // isQuerySyntaxError checks whether a lowercased error message indicates an // incompatible search query syntax. It covers error strings from Redis Stack, // Valkey Search, and other compatible engines. func isQuerySyntaxError(errMsg string) bool { return strings.Contains(errMsg, "missing `=>`") || strings.Contains(errMsg, "invalid filter") || strings.Contains(errMsg, "invalid query") || strings.Contains(errMsg, "vector query clause is missing") } func parseOffsetCursor(cursor *string) (int, error) { offset := 0 if cursor == nil || *cursor == "" { return offset, nil } parsedOffset, err := strconv.ParseInt(*cursor, 10, 64) if err != nil { // Keep existing behavior: malformed cursor is treated as offset 0. return offset, nil } if parsedOffset > math.MaxInt32 { return 0, fmt.Errorf("offset value %d exceeds maximum allowed value", parsedOffset) } if parsedOffset < 0 { return 0, fmt.Errorf("offset value %d cannot be negative", parsedOffset) } if parsedOffset > 0 { offset = int(parsedOffset) } return offset, nil } func parseStringValuesForContains(value interface{}) ([]string, bool) { switch v := value.(type) { case []string: return v, true case []interface{}: out := make([]string, 0, len(v)) for _, item := range v { out = append(out, fmt.Sprintf("%v", item)) } return out, true case string: trimmed := strings.TrimSpace(v) if trimmed == "" { return []string{}, true } // Redis scan fallback values may be JSON-encoded arrays. if strings.HasPrefix(trimmed, "[") { var arr []interface{} if err := json.Unmarshal([]byte(trimmed), &arr); err == nil { out := make([]string, 0, len(arr)) for _, item := range arr { out = append(out, fmt.Sprintf("%v", item)) } return out, true } } return []string{v}, true default: return []string{fmt.Sprintf("%v", v)}, true } } func parseQueryContainsValues(value interface{}) ([]string, bool) { switch v := value.(type) { case []interface{}: out := make([]string, 0, len(v)) for _, item := range v { out = append(out, fmt.Sprintf("%v", item)) } return out, true case []string: return v, true default: return nil, false } } func containsAnyString(haystack []string, needles []string) bool { if len(needles) == 0 { return false } index := make(map[string]struct{}, len(haystack)) for _, item := range haystack { index[item] = struct{}{} } for _, needle := range needles { if _, ok := index[needle]; ok { return true } } return false } func containsAllStrings(haystack []string, needles []string) bool { if len(needles) == 0 { return false } index := make(map[string]struct{}, len(haystack)) for _, item := range haystack { index[item] = struct{}{} } for _, needle := range needles { if _, ok := index[needle]; !ok { return false } } return true } // buildKey creates a Redis key by combining namespace and id. func buildKey(namespace, id string) string { return fmt.Sprintf("%s:%s", namespace, id) } // newRedisStore creates a new Redis vector store. func newRedisStore(_ context.Context, config RedisConfig, logger schemas.Logger) (*RedisStore, error) { // Validate required fields if config.Addr == nil || config.Addr.GetValue() == "" { return nil, fmt.Errorf("redis addr is required") } if config.Username == nil { config.Username = schemas.NewEnvVar("") } if config.Password == nil { config.Password = schemas.NewEnvVar("") } db := 0 if config.DB != nil { db = config.DB.CoerceInt(0) } // TLS configuration var tlsConfig *tls.Config if config.UseTLS.CoerceBool(false) { tlsConfig = &tls.Config{ MinVersion: tls.VersionTLS12, InsecureSkipVerify: config.InsecureSkipVerify.CoerceBool(false), } if config.CACertPEM != nil && config.CACertPEM.GetValue() != "" { rootCAs, err := systemCertPoolWithCA(config.CACertPEM.GetValue()) if err != nil { return nil, fmt.Errorf("failed to configure Redis TLS CA certificate: %w", err) } tlsConfig.RootCAs = rootCAs } } clusterMode := config.ClusterMode.CoerceBool(false) var client redis.UniversalClient if clusterMode { // Redis Cluster does not support database selection if db != 0 { return nil, fmt.Errorf("redis cluster mode does not support database selection (DB must be 0)") } client = redis.NewClusterClient(&redis.ClusterOptions{ Addrs: []string{config.Addr.GetValue()}, Username: config.Username.GetValue(), Password: config.Password.GetValue(), Protocol: 3, // Explicitly use RESP3 protocol TLSConfig: tlsConfig, PoolSize: config.PoolSize, MaxActiveConns: config.MaxActiveConns, MinIdleConns: config.MinIdleConns, MaxIdleConns: config.MaxIdleConns, ConnMaxLifetime: config.ConnMaxLifetime, ConnMaxIdleTime: config.ConnMaxIdleTime, DialTimeout: config.DialTimeout, ReadTimeout: config.ReadTimeout, WriteTimeout: config.WriteTimeout, }) } else { client = redis.NewClient(&redis.Options{ Addr: config.Addr.GetValue(), Username: config.Username.GetValue(), Password: config.Password.GetValue(), DB: db, Protocol: 3, // Explicitly use RESP3 protocol TLSConfig: tlsConfig, PoolSize: config.PoolSize, MaxActiveConns: config.MaxActiveConns, MinIdleConns: config.MinIdleConns, MaxIdleConns: config.MaxIdleConns, ConnMaxLifetime: config.ConnMaxLifetime, ConnMaxIdleTime: config.ConnMaxIdleTime, DialTimeout: config.DialTimeout, ReadTimeout: config.ReadTimeout, WriteTimeout: config.WriteTimeout, }) } // Creating store connection store := &RedisStore{ client: client, config: config, logger: logger, namespaceFieldTypes: make(map[string]map[string]VectorStorePropertyType), } // Eagerly verify connectivity, consistent with other store constructors (e.g. Qdrant) if err := store.Ping(context.Background()); err != nil { return nil, fmt.Errorf("failed to connect to redis: %w", err) } return store, nil } func systemCertPoolWithCA(caCertPEM string) (*x509.CertPool, error) { rootCAs, err := x509.SystemCertPool() if err != nil { rootCAs = x509.NewCertPool() } if !rootCAs.AppendCertsFromPEM([]byte(caCertPEM)) { return nil, fmt.Errorf("failed to parse CA certificate PEM") } return rootCAs, nil }