Files
bifrost/framework/vectorstore/qdrant.go
Beyhan Oğur 880f412e2c first commit
2026-04-26 21:52:23 +03:00

610 lines
16 KiB
Go

package vectorstore
import (
"context"
"fmt"
"strings"
"github.com/google/uuid"
"github.com/maximhq/bifrost/core/schemas"
"github.com/qdrant/go-client/qdrant"
)
// QdrantConfig represents the configuration for the Qdrant vector store.
type QdrantConfig struct {
Host schemas.EnvVar `json:"host"` // Qdrant server host - REQUIRED
Port schemas.EnvVar `json:"port"` // Qdrant server port (fallback to 6334 for gRPC)
APIKey schemas.EnvVar `json:"api_key,omitempty"` // API key for authentication - Optional
UseTLS schemas.EnvVar `json:"use_tls,omitempty"` // Use TLS for connection - Optional
}
// QdrantStore represents the Qdrant vector store.
type QdrantStore struct {
client *qdrant.Client
logger schemas.Logger
}
// Ping checks if the Qdrant server is reachable.
func (s *QdrantStore) Ping(ctx context.Context) error {
_, err := s.client.HealthCheck(ctx)
return err
}
// CreateNamespace creates a new collection in the Qdrant vector store.
func (s *QdrantStore) CreateNamespace(ctx context.Context, namespace string, dimension int, properties map[string]VectorStoreProperties) error {
exists, err := s.client.CollectionExists(ctx, namespace)
if err != nil {
return fmt.Errorf("failed to check collection existence: %w", err)
}
if !exists {
err = s.client.CreateCollection(ctx, &qdrant.CreateCollection{
CollectionName: namespace,
VectorsConfig: qdrant.NewVectorsConfig(&qdrant.VectorParams{
Size: uint64(dimension),
Distance: qdrant.Distance_Cosine,
}),
})
if err != nil {
return fmt.Errorf("failed to create collection: %w", err)
}
}
for fieldName, prop := range properties {
var fieldType qdrant.FieldType
switch prop.DataType {
case VectorStorePropertyTypeInteger:
fieldType = qdrant.FieldType_FieldTypeInteger
case VectorStorePropertyTypeBoolean:
fieldType = qdrant.FieldType_FieldTypeBool
default:
fieldType = qdrant.FieldType_FieldTypeKeyword
}
_, err = s.client.CreateFieldIndex(ctx, &qdrant.CreateFieldIndexCollection{
CollectionName: namespace,
FieldName: fieldName,
FieldType: &fieldType,
})
if err != nil {
s.logger.Debug(fmt.Sprintf("failed to create index for field %s: %v", fieldName, err))
}
}
return nil
}
// DeleteNamespace deletes a collection from the Qdrant vector store.
func (s *QdrantStore) DeleteNamespace(ctx context.Context, namespace string) error {
exists, err := s.client.CollectionExists(ctx, namespace)
if err != nil {
return fmt.Errorf("failed to check collection existence: %w", err)
}
if !exists {
return nil
}
return s.client.DeleteCollection(ctx, namespace)
}
// GetChunk retrieves a single point from the Qdrant vector store.
func (s *QdrantStore) GetChunk(ctx context.Context, namespace string, id string) (SearchResult, error) {
if strings.TrimSpace(id) == "" {
return SearchResult{}, fmt.Errorf("id is required")
}
pointID, err := parsePointID(id)
if err != nil {
return SearchResult{}, fmt.Errorf("invalid id format: %w", err)
}
points, err := s.client.Get(ctx, &qdrant.GetPoints{
CollectionName: namespace,
Ids: []*qdrant.PointId{pointID},
WithPayload: qdrant.NewWithPayload(true),
})
if err != nil {
return SearchResult{}, fmt.Errorf("failed to get point: %w", err)
}
if len(points) == 0 {
return SearchResult{}, fmt.Errorf("not found: %s", id)
}
return SearchResult{
ID: id,
Properties: payloadToMap(points[0].Payload),
}, nil
}
// GetChunks retrieves multiple points from the Qdrant vector store.
func (s *QdrantStore) GetChunks(ctx context.Context, namespace string, ids []string) ([]SearchResult, error) {
if len(ids) == 0 {
return []SearchResult{}, nil
}
pointIDs := make([]*qdrant.PointId, 0, len(ids))
for _, id := range ids {
if strings.TrimSpace(id) == "" {
continue
}
pointID, err := parsePointID(id)
if err != nil {
s.logger.Debug(fmt.Sprintf("skipping invalid id %s: %v", id, err))
continue
}
pointIDs = append(pointIDs, pointID)
}
if len(pointIDs) == 0 {
return []SearchResult{}, nil
}
points, err := s.client.Get(ctx, &qdrant.GetPoints{
CollectionName: namespace,
Ids: pointIDs,
WithPayload: qdrant.NewWithPayload(true),
})
if err != nil {
return nil, fmt.Errorf("failed to get points: %w", err)
}
results := make([]SearchResult, 0, len(points))
for _, point := range points {
results = append(results, SearchResult{
ID: pointIDToString(point.Id),
Properties: payloadToMap(point.Payload),
})
}
return results, nil
}
// GetAll retrieves all points with optional filtering and pagination.
func (s *QdrantStore) GetAll(ctx context.Context, namespace string, queries []Query, selectFields []string, cursor *string, limit int64) ([]SearchResult, *string, error) {
filter := buildQdrantFilter(queries)
var offset *qdrant.PointId
if cursor != nil && *cursor != "" {
var err error
offset, err = parsePointID(*cursor)
if err != nil {
s.logger.Debug(fmt.Sprintf("invalid cursor format: %v", err))
}
}
scrollLimit := uint32(limit)
if limit <= 0 {
scrollLimit = 100
}
scrollResult, err := s.client.Scroll(ctx, &qdrant.ScrollPoints{
CollectionName: namespace,
Filter: filter,
Limit: &scrollLimit,
Offset: offset,
WithPayload: qdrant.NewWithPayload(true),
})
if err != nil {
return nil, nil, fmt.Errorf("failed to scroll points: %w", err)
}
results := make([]SearchResult, 0, len(scrollResult))
var lastID string
for _, point := range scrollResult {
lastID = pointIDToString(point.Id)
results = append(results, SearchResult{
ID: lastID,
Properties: filterProperties(payloadToMap(point.Payload), selectFields),
})
}
if len(scrollResult) >= int(scrollLimit) {
return results, &lastID, nil
}
return results, nil, nil
}
// GetNearest retrieves the nearest points to a vector.
func (s *QdrantStore) GetNearest(ctx context.Context, namespace string, vector []float32, queries []Query, selectFields []string, threshold float64, limit int64) ([]SearchResult, error) {
filter := buildQdrantFilter(queries)
searchLimit := uint64(limit)
if limit <= 0 {
searchLimit = 10
}
searchResult, err := s.client.Query(ctx, &qdrant.QueryPoints{
CollectionName: namespace,
Query: qdrant.NewQuery(vector...),
Filter: filter,
Limit: &searchLimit,
WithPayload: qdrant.NewWithPayload(true),
ScoreThreshold: qdrant.PtrOf(float32(threshold)),
})
if err != nil {
return nil, fmt.Errorf("failed to search points: %w", err)
}
results := make([]SearchResult, 0, len(searchResult))
for _, point := range searchResult {
score := float64(point.Score)
results = append(results, SearchResult{
ID: pointIDToString(point.Id),
Score: &score,
Properties: filterProperties(payloadToMap(point.Payload), selectFields),
})
}
return results, nil
}
// Add stores a new point in the Qdrant vector store.
func (s *QdrantStore) Add(ctx context.Context, namespace string, id string, embedding []float32, metadata map[string]interface{}) error {
if strings.TrimSpace(id) == "" {
return fmt.Errorf("id is required")
}
pointID, err := parsePointID(id)
if err != nil {
return fmt.Errorf("invalid id format (must be UUID): %w", err)
}
point := &qdrant.PointStruct{
Id: pointID,
Payload: mapToPayload(metadata),
}
if len(embedding) > 0 {
point.Vectors = qdrant.NewVectors(embedding...)
}
_, err = s.client.Upsert(ctx, &qdrant.UpsertPoints{
CollectionName: namespace,
Points: []*qdrant.PointStruct{point},
Wait: qdrant.PtrOf(true),
})
if err != nil {
return fmt.Errorf("failed to upsert point: %w", err)
}
return nil
}
// Delete removes a point from the Qdrant vector store.
func (s *QdrantStore) Delete(ctx context.Context, namespace string, id string) error {
if strings.TrimSpace(id) == "" {
return fmt.Errorf("id is required")
}
pointID, err := parsePointID(id)
if err != nil {
return fmt.Errorf("invalid id format: %w", err)
}
_, err = s.client.Delete(ctx, &qdrant.DeletePoints{
CollectionName: namespace,
Points: qdrant.NewPointsSelector(pointID),
})
return err
}
// DeleteAll removes multiple points matching the filter.
func (s *QdrantStore) DeleteAll(ctx context.Context, namespace string, queries []Query) ([]DeleteResult, error) {
filter := buildQdrantFilter(queries)
scrollResult, err := s.client.Scroll(ctx, &qdrant.ScrollPoints{
CollectionName: namespace,
Filter: filter,
WithPayload: qdrant.NewWithPayload(false),
})
if err != nil {
return nil, fmt.Errorf("failed to scroll points: %w", err)
}
if len(scrollResult) == 0 {
return []DeleteResult{}, nil
}
results := make([]DeleteResult, len(scrollResult))
for i, point := range scrollResult {
results[i] = DeleteResult{
ID: pointIDToString(point.Id),
Status: DeleteStatusSuccess,
}
}
var deleteErr error
if filter != nil {
_, deleteErr = s.client.Delete(ctx, &qdrant.DeletePoints{
CollectionName: namespace,
Points: qdrant.NewPointsSelectorFilter(filter),
})
} else {
pointIDs := make([]*qdrant.PointId, len(scrollResult))
for i, point := range scrollResult {
pointIDs[i] = point.Id
}
_, deleteErr = s.client.Delete(ctx, &qdrant.DeletePoints{
CollectionName: namespace,
Points: qdrant.NewPointsSelectorIDs(pointIDs),
})
}
if deleteErr != nil {
for i := range results {
results[i].Status = DeleteStatusError
results[i].Error = deleteErr.Error()
}
}
return results, nil
}
// Close closes the Qdrant client connection.
func (s *QdrantStore) Close(ctx context.Context, namespace string) error {
return s.client.Close()
}
// RequiresVectors returns true because Qdrant is a dedicated vector database
// that requires vectors for all points/entries.
func (s *QdrantStore) RequiresVectors() bool {
return true
}
// newQdrantStore creates a new Qdrant vector store.
func newQdrantStore(ctx context.Context, config *QdrantConfig, logger schemas.Logger) (*QdrantStore, error) {
if strings.TrimSpace(config.Host.GetValue()) == "" {
return nil, fmt.Errorf("qdrant host is required")
}
client, err := qdrant.NewClient(&qdrant.Config{
Host: config.Host.GetValue(),
Port: config.Port.CoerceInt(6334),
APIKey: config.APIKey.GetValue(),
UseTLS: config.UseTLS.CoerceBool(false),
SkipCompatibilityCheck: true,
})
if err != nil {
return nil, fmt.Errorf("failed to create qdrant client: %w", err)
}
_, err = client.HealthCheck(ctx)
if err != nil {
return nil, fmt.Errorf("failed to connect to qdrant: %w", err)
}
return &QdrantStore{
client: client,
logger: logger,
}, nil
}
func parsePointID(id string) (*qdrant.PointId, error) {
if _, err := uuid.Parse(id); err != nil {
return nil, err
}
return qdrant.NewID(id), nil
}
func pointIDToString(id *qdrant.PointId) string {
if id == nil {
return ""
}
switch v := id.PointIdOptions.(type) {
case *qdrant.PointId_Uuid:
return v.Uuid
case *qdrant.PointId_Num:
return fmt.Sprintf("%d", v.Num)
default:
return ""
}
}
func payloadToMap(payload map[string]*qdrant.Value) map[string]interface{} {
if payload == nil {
return make(map[string]interface{})
}
result := make(map[string]interface{}, len(payload))
for k, v := range payload {
result[k] = valueToInterface(v)
}
return result
}
func valueToInterface(v *qdrant.Value) interface{} {
if v == nil {
return nil
}
switch val := v.Kind.(type) {
case *qdrant.Value_StringValue:
return val.StringValue
case *qdrant.Value_IntegerValue:
return val.IntegerValue
case *qdrant.Value_DoubleValue:
return val.DoubleValue
case *qdrant.Value_BoolValue:
return val.BoolValue
case *qdrant.Value_ListValue:
list := make([]interface{}, len(val.ListValue.Values))
for i, item := range val.ListValue.Values {
list[i] = valueToInterface(item)
}
return list
case *qdrant.Value_StructValue:
return payloadToMap(val.StructValue.Fields)
default:
return nil
}
}
func mapToPayload(m map[string]interface{}) map[string]*qdrant.Value {
if m == nil {
return make(map[string]*qdrant.Value)
}
// Convert []string to []interface{} since Qdrant's NewValueMap doesn't handle []string directly
converted := make(map[string]interface{}, len(m))
for k, v := range m {
switch val := v.(type) {
case []string:
// Convert []string to []interface{}
interfaceSlice := make([]interface{}, len(val))
for i, s := range val {
interfaceSlice[i] = s
}
converted[k] = interfaceSlice
default:
converted[k] = v
}
}
return qdrant.NewValueMap(converted)
}
func filterProperties(props map[string]interface{}, selectFields []string) map[string]interface{} {
if len(selectFields) == 0 {
return props
}
filtered := make(map[string]interface{}, len(selectFields))
for _, field := range selectFields {
if val, ok := props[field]; ok {
filtered[field] = val
}
}
return filtered
}
func buildQdrantFilter(queries []Query) *qdrant.Filter {
if len(queries) == 0 {
return nil
}
var conditions []*qdrant.Condition
for _, q := range queries {
condition := buildQdrantCondition(q)
if condition != nil {
conditions = append(conditions, condition)
}
}
if len(conditions) == 0 {
return nil
}
return &qdrant.Filter{
Must: conditions,
}
}
func buildQdrantCondition(q Query) *qdrant.Condition {
field := q.Field
switch q.Operator {
case QueryOperatorEqual:
return buildMatchCondition(field, q.Value)
case QueryOperatorNotEqual:
matchCond := buildMatchCondition(field, q.Value)
if matchCond == nil {
return nil
}
return qdrant.NewFilterAsCondition(&qdrant.Filter{
MustNot: []*qdrant.Condition{matchCond},
})
case QueryOperatorGreaterThan:
return buildRangeCondition(field, q.Value, "gt")
case QueryOperatorGreaterThanOrEqual:
return buildRangeCondition(field, q.Value, "gte")
case QueryOperatorLessThan:
return buildRangeCondition(field, q.Value, "lt")
case QueryOperatorLessThanOrEqual:
return buildRangeCondition(field, q.Value, "lte")
case QueryOperatorIsNull:
return qdrant.NewIsNull(field)
case QueryOperatorIsNotNull:
return qdrant.NewFilterAsCondition(&qdrant.Filter{
MustNot: []*qdrant.Condition{qdrant.NewIsNull(field)},
})
case QueryOperatorContainsAny:
switch v := q.Value.(type) {
case []string:
return qdrant.NewMatchKeywords(field, v...)
case []int:
int64s := make([]int64, len(v))
for i, val := range v {
int64s[i] = int64(val)
}
return qdrant.NewMatchInts(field, int64s...)
case []int64:
return qdrant.NewMatchInts(field, v...)
}
return buildMatchCondition(field, q.Value)
case QueryOperatorContainsAll:
if values, ok := q.Value.([]interface{}); ok {
var mustConditions []*qdrant.Condition
for _, v := range values {
cond := buildMatchCondition(field, v)
if cond != nil {
mustConditions = append(mustConditions, cond)
}
}
if len(mustConditions) > 0 {
return qdrant.NewFilterAsCondition(&qdrant.Filter{
Must: mustConditions,
})
}
}
return buildMatchCondition(field, q.Value)
case QueryOperatorLike:
if str, ok := q.Value.(string); ok {
return qdrant.NewMatchText(field, str)
}
return nil
default:
return buildMatchCondition(field, q.Value)
}
}
func buildMatchCondition(field string, value interface{}) *qdrant.Condition {
switch v := value.(type) {
case string:
return qdrant.NewMatchKeyword(field, v)
case int:
return qdrant.NewMatchInt(field, int64(v))
case int32:
return qdrant.NewMatchInt(field, int64(v))
case int64:
return qdrant.NewMatchInt(field, v)
case bool:
return qdrant.NewMatchBool(field, v)
default:
return qdrant.NewMatchKeyword(field, fmt.Sprintf("%v", v))
}
}
func buildRangeCondition(field string, value interface{}, op string) *qdrant.Condition {
var floatVal float64
switch v := value.(type) {
case int:
floatVal = float64(v)
case int32:
floatVal = float64(v)
case int64:
floatVal = float64(v)
case float32:
floatVal = float64(v)
case float64:
floatVal = v
default:
return nil
}
r := &qdrant.Range{}
switch op {
case "gt":
r.Gt = qdrant.PtrOf(floatVal)
case "gte":
r.Gte = qdrant.PtrOf(floatVal)
case "lt":
r.Lt = qdrant.PtrOf(floatVal)
case "lte":
r.Lte = qdrant.PtrOf(floatVal)
}
return qdrant.NewRange(field, r)
}