first commit
This commit is contained in:
64
framework/objectstore/config.go
Normal file
64
framework/objectstore/config.go
Normal file
@@ -0,0 +1,64 @@
|
||||
package objectstore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// StoreType identifies the object storage backend.
|
||||
type StoreType string
|
||||
|
||||
const (
|
||||
StoreTypeS3 StoreType = "s3"
|
||||
StoreTypeGCS StoreType = "gcs"
|
||||
)
|
||||
|
||||
// Config holds the configuration for an object store.
|
||||
type Config struct {
|
||||
Type StoreType `json:"type"` // "s3" or "gcs"
|
||||
Bucket schemas.EnvVar `json:"bucket"`
|
||||
|
||||
// Common fields (apply to all store types)
|
||||
Prefix string `json:"prefix,omitempty"` // Key prefix for all stored objects. Default: "bifrost".
|
||||
Compress bool `json:"compress,omitempty"` // Enables gzip compression for stored objects. Default: false.
|
||||
|
||||
// S3 fields (used when Type == "s3")
|
||||
Region *schemas.EnvVar `json:"region,omitempty"`
|
||||
Endpoint *schemas.EnvVar `json:"endpoint,omitempty"`
|
||||
AccessKeyID *schemas.EnvVar `json:"access_key_id,omitempty"`
|
||||
SecretAccessKey *schemas.EnvVar `json:"secret_access_key,omitempty"`
|
||||
SessionToken *schemas.EnvVar `json:"session_token,omitempty"`
|
||||
RoleARN *schemas.EnvVar `json:"role_arn,omitempty"`
|
||||
ForcePathStyle bool `json:"force_path_style,omitempty"`
|
||||
|
||||
// GCS fields (used when Type == "gcs")
|
||||
Credentials *schemas.EnvVar `json:"credentials,omitempty"` // Deprecated: use credentials_json
|
||||
CredentialsJSON *schemas.EnvVar `json:"credentials_json,omitempty"` // Service account JSON or path
|
||||
ProjectID *schemas.EnvVar `json:"project_id,omitempty"` // GCP project ID override
|
||||
}
|
||||
|
||||
// GetPrefix returns the configured prefix or "bifrost" as default.
|
||||
func (c *Config) GetPrefix() string {
|
||||
if c.Prefix != "" {
|
||||
return c.Prefix
|
||||
}
|
||||
return "bifrost"
|
||||
}
|
||||
|
||||
// NewObjectStore creates the appropriate ObjectStore implementation based on config type.
|
||||
func NewObjectStore(ctx context.Context, cfg *Config, logger schemas.Logger) (ObjectStore, error) {
|
||||
if cfg == nil {
|
||||
return nil, fmt.Errorf("objectstore: config is required")
|
||||
}
|
||||
|
||||
switch cfg.Type {
|
||||
case StoreTypeS3:
|
||||
return NewS3ObjectStore(ctx, cfg, logger)
|
||||
case StoreTypeGCS:
|
||||
return NewGCSObjectStore(ctx, cfg, logger)
|
||||
default:
|
||||
return nil, fmt.Errorf("objectstore: unsupported type %q", cfg.Type)
|
||||
}
|
||||
}
|
||||
165
framework/objectstore/gcs.go
Normal file
165
framework/objectstore/gcs.go
Normal file
@@ -0,0 +1,165 @@
|
||||
package objectstore
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"cloud.google.com/go/storage"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"google.golang.org/api/option"
|
||||
)
|
||||
|
||||
// GCSObjectStore implements ObjectStore using Google Cloud Storage.
|
||||
type GCSObjectStore struct {
|
||||
client *storage.Client
|
||||
bucket string
|
||||
compress bool
|
||||
logger schemas.Logger
|
||||
}
|
||||
|
||||
// NewGCSObjectStore creates a new GCS object store from the given config.
|
||||
func NewGCSObjectStore(ctx context.Context, cfg *Config, logger schemas.Logger) (*GCSObjectStore, error) {
|
||||
if cfg == nil {
|
||||
return nil, fmt.Errorf("objectstore: config is nil")
|
||||
}
|
||||
bucket := cfg.Bucket.GetValue()
|
||||
if bucket == "" {
|
||||
return nil, fmt.Errorf("objectstore: gcs bucket is required")
|
||||
}
|
||||
|
||||
var opts []option.ClientOption
|
||||
|
||||
// Prefer credentials_json (used by Helm/schema) over deprecated credentials field.
|
||||
// Check both non-nil and non-empty to avoid an empty credentials_json shadowing
|
||||
// a valid deprecated credentials value.
|
||||
var creds string
|
||||
switch {
|
||||
case cfg.CredentialsJSON != nil && strings.TrimSpace(cfg.CredentialsJSON.GetValue()) != "":
|
||||
creds = strings.TrimSpace(cfg.CredentialsJSON.GetValue())
|
||||
case cfg.Credentials != nil && strings.TrimSpace(cfg.Credentials.GetValue()) != "":
|
||||
creds = strings.TrimSpace(cfg.Credentials.GetValue())
|
||||
}
|
||||
if creds != "" {
|
||||
if strings.HasPrefix(creds, "{") {
|
||||
if !json.Valid([]byte(creds)) {
|
||||
return nil, fmt.Errorf("objectstore: gcs credentials look like JSON but are not valid; check for syntax errors")
|
||||
}
|
||||
opts = append(opts, option.WithCredentialsJSON([]byte(creds)))
|
||||
} else {
|
||||
opts = append(opts, option.WithCredentialsFile(creds))
|
||||
}
|
||||
}
|
||||
|
||||
client, err := storage.NewClient(ctx, opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("objectstore: failed to create gcs client: %w", err)
|
||||
}
|
||||
|
||||
return &GCSObjectStore{
|
||||
client: client,
|
||||
bucket: bucket,
|
||||
compress: cfg.Compress,
|
||||
logger: logger,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Put uploads data with optional custom metadata. When compression is enabled,
|
||||
// data is gzip-compressed before upload.
|
||||
func (g *GCSObjectStore) Put(ctx context.Context, key string, data []byte, tags map[string]string) error {
|
||||
body := data
|
||||
if g.compress {
|
||||
compressed, err := gzipCompress(data)
|
||||
if err != nil {
|
||||
return fmt.Errorf("objectstore: gzip compress: %w", err)
|
||||
}
|
||||
body = compressed
|
||||
}
|
||||
|
||||
obj := g.client.Bucket(g.bucket).Object(key)
|
||||
w := obj.NewWriter(ctx)
|
||||
w.ContentType = "application/json"
|
||||
if g.compress {
|
||||
w.ContentEncoding = "gzip"
|
||||
}
|
||||
if len(tags) > 0 {
|
||||
w.Metadata = tags
|
||||
}
|
||||
|
||||
if _, err := io.Copy(w, bytes.NewReader(body)); err != nil {
|
||||
_ = w.Close()
|
||||
return fmt.Errorf("objectstore: gcs write %s: %w", key, err)
|
||||
}
|
||||
if err := w.Close(); err != nil {
|
||||
return fmt.Errorf("objectstore: gcs close writer %s: %w", key, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get retrieves and decompresses an object by key.
|
||||
func (g *GCSObjectStore) Get(ctx context.Context, key string) ([]byte, error) {
|
||||
r, err := g.client.Bucket(g.bucket).Object(key).NewReader(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("objectstore: gcs read %s: %w", key, err)
|
||||
}
|
||||
defer r.Close()
|
||||
|
||||
// GCS transparently decompresses objects stored with ContentEncoding: "gzip",
|
||||
// so the bytes returned by ReadAll are already decompressed.
|
||||
body, err := io.ReadAll(r)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("objectstore: gcs read body %s: %w", key, err)
|
||||
}
|
||||
|
||||
return body, nil
|
||||
}
|
||||
|
||||
// Delete removes a single object by key.
|
||||
func (g *GCSObjectStore) Delete(ctx context.Context, key string) error {
|
||||
if err := g.client.Bucket(g.bucket).Object(key).Delete(ctx); err != nil && !errors.Is(err, storage.ErrObjectNotExist) {
|
||||
return fmt.Errorf("objectstore: gcs delete %s: %w", key, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteBatch removes multiple objects.
|
||||
func (g *GCSObjectStore) DeleteBatch(ctx context.Context, keys []string) error {
|
||||
var errs []error
|
||||
for _, key := range keys {
|
||||
if err := g.client.Bucket(g.bucket).Object(key).Delete(ctx); err != nil {
|
||||
if errors.Is(err, storage.ErrObjectNotExist) {
|
||||
continue
|
||||
}
|
||||
g.logger.Warn("objectstore: gcs delete %s: %v", key, err)
|
||||
errs = append(errs, fmt.Errorf("objectstore: gcs delete %s: %w", key, err))
|
||||
}
|
||||
}
|
||||
return errors.Join(errs...)
|
||||
}
|
||||
|
||||
// Ping checks connectivity by writing and deleting a small object, proving
|
||||
// that the credentials have upload access (not just read). This is important
|
||||
// because HybridLogStore strips DB payloads before async upload — a read-only
|
||||
// principal would pass a read-based ping but silently fail all Put calls.
|
||||
func (g *GCSObjectStore) Ping(ctx context.Context) error {
|
||||
key := fmt.Sprintf("__bifrost_ping__/%d", time.Now().UnixNano())
|
||||
obj := g.client.Bucket(g.bucket).Object(key)
|
||||
|
||||
if err := obj.NewWriter(ctx).Close(); err != nil {
|
||||
return fmt.Errorf("objectstore: gcs ping write %s: %w", key, err)
|
||||
}
|
||||
if err := obj.Delete(ctx); err != nil && !errors.Is(err, storage.ErrObjectNotExist) {
|
||||
return fmt.Errorf("objectstore: gcs ping cleanup %s: %w", key, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close releases the GCS client resources.
|
||||
func (g *GCSObjectStore) Close() error {
|
||||
return g.client.Close()
|
||||
}
|
||||
80
framework/objectstore/gzip.go
Normal file
80
framework/objectstore/gzip.go
Normal file
@@ -0,0 +1,80 @@
|
||||
package objectstore
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"io"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Pooled gzip writer/reader to avoid allocation per compress/decompress call.
|
||||
// Follows the same pattern as core/providers/utils/decompression.go.
|
||||
|
||||
var gzipWriterPool = sync.Pool{
|
||||
New: func() any {
|
||||
return gzip.NewWriter(nil)
|
||||
},
|
||||
}
|
||||
|
||||
var gzipReaderPool = sync.Pool{
|
||||
New: func() any {
|
||||
return &gzip.Reader{}
|
||||
},
|
||||
}
|
||||
|
||||
// gzipCompress compresses data using a pooled gzip writer.
|
||||
func gzipCompress(data []byte) ([]byte, error) {
|
||||
var buf bytes.Buffer
|
||||
buf.Grow(len(data) / 2) // Pre-allocate rough estimate.
|
||||
|
||||
w, _ := gzipWriterPool.Get().(*gzip.Writer)
|
||||
if w == nil {
|
||||
w = gzip.NewWriter(&buf)
|
||||
} else {
|
||||
w.Reset(&buf)
|
||||
}
|
||||
|
||||
if _, err := w.Write(data); err != nil {
|
||||
// Don't return the writer to the pool on error — it may be in a bad state.
|
||||
return nil, err
|
||||
}
|
||||
if err := w.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
w.Reset(io.Discard)
|
||||
gzipWriterPool.Put(w)
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
// gzipDecompress decompresses gzip data using a pooled gzip reader.
|
||||
func gzipDecompress(data []byte) ([]byte, error) {
|
||||
v := gzipReaderPool.Get()
|
||||
r, ok := v.(*gzip.Reader)
|
||||
if !ok || r == nil {
|
||||
// Pool had a wrong type or nil — allocate fresh.
|
||||
var err error
|
||||
r, err = gzip.NewReader(bytes.NewReader(data))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
if err := r.Reset(bytes.NewReader(data)); err != nil {
|
||||
// Reset failed — discard and allocate fresh.
|
||||
var err2 error
|
||||
r, err2 = gzip.NewReader(bytes.NewReader(data))
|
||||
if err2 != nil {
|
||||
return nil, err2
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
result, err := io.ReadAll(r)
|
||||
_ = r.Close()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
gzipReaderPool.Put(r)
|
||||
return result, nil
|
||||
}
|
||||
120
framework/objectstore/mock.go
Normal file
120
framework/objectstore/mock.go
Normal file
@@ -0,0 +1,120 @@
|
||||
package objectstore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"maps"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// InMemoryObjectStore is an in-memory ObjectStore implementation for testing.
|
||||
type InMemoryObjectStore struct {
|
||||
mu sync.RWMutex
|
||||
objects map[string][]byte
|
||||
tags map[string]map[string]string
|
||||
|
||||
// PutErr, if set, is returned by Put for simulating failures.
|
||||
PutErr error
|
||||
// GetErr, if set, is returned by Get for simulating failures.
|
||||
GetErr error
|
||||
}
|
||||
|
||||
// NewInMemoryObjectStore creates a new in-memory object store.
|
||||
func NewInMemoryObjectStore() *InMemoryObjectStore {
|
||||
return &InMemoryObjectStore{
|
||||
objects: make(map[string][]byte),
|
||||
tags: make(map[string]map[string]string),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *InMemoryObjectStore) Put(_ context.Context, key string, data []byte, tags map[string]string) error {
|
||||
if m.PutErr != nil {
|
||||
return m.PutErr
|
||||
}
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
// Store a copy to avoid mutation.
|
||||
cp := make([]byte, len(data))
|
||||
copy(cp, data)
|
||||
m.objects[key] = cp
|
||||
if len(tags) > 0 {
|
||||
tagsCp := make(map[string]string, len(tags))
|
||||
maps.Copy(tagsCp, tags)
|
||||
m.tags[key] = tagsCp
|
||||
} else {
|
||||
delete(m.tags, key)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *InMemoryObjectStore) Get(_ context.Context, key string) ([]byte, error) {
|
||||
if m.GetErr != nil {
|
||||
return nil, m.GetErr
|
||||
}
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
data, ok := m.objects[key]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("objectstore: object not found: %s", key)
|
||||
}
|
||||
cp := make([]byte, len(data))
|
||||
copy(cp, data)
|
||||
return cp, nil
|
||||
}
|
||||
|
||||
func (m *InMemoryObjectStore) Delete(_ context.Context, key string) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
delete(m.objects, key)
|
||||
delete(m.tags, key)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *InMemoryObjectStore) DeleteBatch(_ context.Context, keys []string) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
for _, key := range keys {
|
||||
delete(m.objects, key)
|
||||
delete(m.tags, key)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *InMemoryObjectStore) Ping(_ context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *InMemoryObjectStore) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetTags returns the tags stored for a given key. For testing assertions.
|
||||
func (m *InMemoryObjectStore) GetTags(key string) map[string]string {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
tags, ok := m.tags[key]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
cp := make(map[string]string, len(tags))
|
||||
maps.Copy(cp, tags)
|
||||
return cp
|
||||
}
|
||||
|
||||
// Len returns the number of stored objects. For testing assertions.
|
||||
func (m *InMemoryObjectStore) Len() int {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
return len(m.objects)
|
||||
}
|
||||
|
||||
// Keys returns all stored keys. For testing assertions.
|
||||
func (m *InMemoryObjectStore) Keys() []string {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
keys := make([]string, 0, len(m.objects))
|
||||
for k := range m.objects {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
return keys
|
||||
}
|
||||
28
framework/objectstore/objectstore.go
Normal file
28
framework/objectstore/objectstore.go
Normal file
@@ -0,0 +1,28 @@
|
||||
// Package objectstore provides an S3-compatible object storage abstraction.
|
||||
// It can be used by any part of the system that needs to store or retrieve
|
||||
// objects from S3, GCS (via S3 interop), MinIO, R2, or other S3-compatible stores.
|
||||
package objectstore
|
||||
|
||||
import "context"
|
||||
|
||||
// ObjectStore abstracts S3-compatible blob storage operations.
|
||||
type ObjectStore interface {
|
||||
// Put uploads data to the given key with optional tags.
|
||||
// The implementation handles compression (e.g., gzip) internally.
|
||||
Put(ctx context.Context, key string, data []byte, tags map[string]string) error
|
||||
|
||||
// Get retrieves and decompresses data for the given key.
|
||||
Get(ctx context.Context, key string) ([]byte, error)
|
||||
|
||||
// Delete removes an object by key.
|
||||
Delete(ctx context.Context, key string) error
|
||||
|
||||
// DeleteBatch removes multiple objects by key.
|
||||
DeleteBatch(ctx context.Context, keys []string) error
|
||||
|
||||
// Ping checks connectivity to the storage backend.
|
||||
Ping(ctx context.Context) error
|
||||
|
||||
// Close releases resources held by the store.
|
||||
Close() error
|
||||
}
|
||||
145
framework/objectstore/objectstore_test.go
Normal file
145
framework/objectstore/objectstore_test.go
Normal file
@@ -0,0 +1,145 @@
|
||||
package objectstore
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGzipRoundTrip(t *testing.T) {
|
||||
original := []byte(`{"input_history":"[{\"role\":\"user\",\"content\":\"hello world\"}]","output_message":"{\"role\":\"assistant\",\"content\":\"hi there\"}"}`)
|
||||
compressed, err := gzipCompress(original)
|
||||
if err != nil {
|
||||
t.Fatalf("gzipCompress: %v", err)
|
||||
}
|
||||
if len(compressed) >= len(original) {
|
||||
// For very small inputs gzip may be larger; just verify round-trip.
|
||||
t.Logf("compressed (%d) >= original (%d), but checking round-trip", len(compressed), len(original))
|
||||
}
|
||||
decompressed, err := gzipDecompress(compressed)
|
||||
if err != nil {
|
||||
t.Fatalf("gzipDecompress: %v", err)
|
||||
}
|
||||
if !bytes.Equal(original, decompressed) {
|
||||
t.Fatalf("round-trip mismatch: got %q, want %q", decompressed, original)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGzipDecompress_NonGzipData(t *testing.T) {
|
||||
// gzipDecompress should return error for non-gzip data.
|
||||
_, err := gzipDecompress([]byte("not gzip"))
|
||||
if err == nil {
|
||||
t.Fatal("expected error for non-gzip data")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncodeTags(t *testing.T) {
|
||||
tags := map[string]string{
|
||||
"provider": "anthropic",
|
||||
"model": "claude-3",
|
||||
}
|
||||
encoded := encodeTags(tags)
|
||||
// URL-encoded tags, order may vary.
|
||||
if encoded == "" {
|
||||
t.Fatal("expected non-empty encoded tags")
|
||||
}
|
||||
// Verify both tags are present.
|
||||
if !bytes.Contains([]byte(encoded), []byte("provider=anthropic")) {
|
||||
t.Errorf("missing provider tag in %q", encoded)
|
||||
}
|
||||
if !bytes.Contains([]byte(encoded), []byte("model=claude-3")) {
|
||||
t.Errorf("missing model tag in %q", encoded)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInMemoryObjectStore(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
store := NewInMemoryObjectStore()
|
||||
|
||||
// Put
|
||||
if err := store.Put(ctx, "key1", []byte("data1"), map[string]string{"tag": "val"}); err != nil {
|
||||
t.Fatalf("Put: %v", err)
|
||||
}
|
||||
if store.Len() != 1 {
|
||||
t.Fatalf("Len: got %d, want 1", store.Len())
|
||||
}
|
||||
|
||||
// Get
|
||||
data, err := store.Get(ctx, "key1")
|
||||
if err != nil {
|
||||
t.Fatalf("Get: %v", err)
|
||||
}
|
||||
if !bytes.Equal(data, []byte("data1")) {
|
||||
t.Fatalf("Get: got %q, want %q", data, "data1")
|
||||
}
|
||||
|
||||
// GetTags
|
||||
tags := store.GetTags("key1")
|
||||
if tags["tag"] != "val" {
|
||||
t.Fatalf("GetTags: got %v", tags)
|
||||
}
|
||||
|
||||
// Get missing key
|
||||
_, err = store.Get(ctx, "missing")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for missing key")
|
||||
}
|
||||
|
||||
// Delete
|
||||
if err := store.Delete(ctx, "key1"); err != nil {
|
||||
t.Fatalf("Delete: %v", err)
|
||||
}
|
||||
if store.Len() != 0 {
|
||||
t.Fatalf("Len after delete: got %d, want 0", store.Len())
|
||||
}
|
||||
|
||||
// DeleteBatch
|
||||
_ = store.Put(ctx, "a", []byte("1"), nil)
|
||||
_ = store.Put(ctx, "b", []byte("2"), nil)
|
||||
_ = store.Put(ctx, "c", []byte("3"), nil)
|
||||
if err := store.DeleteBatch(ctx, []string{"a", "c"}); err != nil {
|
||||
t.Fatalf("DeleteBatch: %v", err)
|
||||
}
|
||||
if store.Len() != 1 {
|
||||
t.Fatalf("Len after batch delete: got %d, want 1", store.Len())
|
||||
}
|
||||
|
||||
// Ping and Close
|
||||
if err := store.Ping(ctx); err != nil {
|
||||
t.Fatalf("Ping: %v", err)
|
||||
}
|
||||
if err := store.Close(); err != nil {
|
||||
t.Fatalf("Close: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInMemoryObjectStore_SimulateErrors(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
store := NewInMemoryObjectStore()
|
||||
|
||||
store.PutErr = fmt.Errorf("simulated put error")
|
||||
if err := store.Put(ctx, "key", []byte("data"), nil); err == nil {
|
||||
t.Fatal("expected error from Put")
|
||||
}
|
||||
store.PutErr = nil
|
||||
|
||||
if err := store.Put(ctx, "key", []byte("data"), nil); err != nil {
|
||||
t.Fatalf("Put: %v", err)
|
||||
}
|
||||
store.GetErr = fmt.Errorf("simulated get error")
|
||||
if _, err := store.Get(ctx, "key"); err == nil {
|
||||
t.Fatal("expected error from Get")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigGetPrefix(t *testing.T) {
|
||||
c := &Config{Prefix: "custom"}
|
||||
if got := c.GetPrefix(); got != "custom" {
|
||||
t.Fatalf("GetPrefix: got %q, want %q", got, "custom")
|
||||
}
|
||||
c2 := &Config{}
|
||||
if got := c2.GetPrefix(); got != "bifrost" {
|
||||
t.Fatalf("GetPrefix default: got %q, want %q", got, "bifrost")
|
||||
}
|
||||
}
|
||||
246
framework/objectstore/s3.go
Normal file
246
framework/objectstore/s3.go
Normal file
@@ -0,0 +1,246 @@
|
||||
package objectstore
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
awsconfig "github.com/aws/aws-sdk-go-v2/config"
|
||||
"github.com/aws/aws-sdk-go-v2/credentials"
|
||||
"github.com/aws/aws-sdk-go-v2/credentials/stscreds"
|
||||
"github.com/aws/aws-sdk-go-v2/service/s3"
|
||||
"github.com/aws/aws-sdk-go-v2/service/s3/types"
|
||||
"github.com/aws/aws-sdk-go-v2/service/sts"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// S3ObjectStore implements ObjectStore using an S3-compatible backend.
|
||||
type S3ObjectStore struct {
|
||||
client *s3.Client
|
||||
bucket string
|
||||
compress bool
|
||||
logger schemas.Logger
|
||||
}
|
||||
|
||||
// NewS3ObjectStore creates a new S3-compatible object store from the given config.
|
||||
func NewS3ObjectStore(ctx context.Context, cfg *Config, logger schemas.Logger) (*S3ObjectStore, error) {
|
||||
if cfg == nil {
|
||||
return nil, fmt.Errorf("objectstore: config is nil")
|
||||
}
|
||||
|
||||
bucket := cfg.Bucket.GetValue()
|
||||
if bucket == "" {
|
||||
return nil, fmt.Errorf("objectstore: s3 bucket is required")
|
||||
}
|
||||
|
||||
// Validate static credential fields: reject half-configured credentials.
|
||||
if (cfg.AccessKeyID != nil) != (cfg.SecretAccessKey != nil) {
|
||||
return nil, fmt.Errorf("objectstore: access_key_id and secret_access_key must be set together")
|
||||
}
|
||||
if cfg.AccessKeyID != nil && (cfg.AccessKeyID.GetValue() == "" || cfg.SecretAccessKey.GetValue() == "") {
|
||||
return nil, fmt.Errorf("objectstore: access_key_id and secret_access_key must resolve to non-empty values")
|
||||
}
|
||||
if cfg.SessionToken != nil && cfg.SessionToken.GetValue() != "" &&
|
||||
(cfg.AccessKeyID == nil || cfg.SecretAccessKey == nil || cfg.AccessKeyID.GetValue() == "" || cfg.SecretAccessKey.GetValue() == "") {
|
||||
return nil, fmt.Errorf("objectstore: session_token requires access_key_id and secret_access_key")
|
||||
}
|
||||
|
||||
var opts []func(*awsconfig.LoadOptions) error
|
||||
if cfg.Region != nil && cfg.Region.GetValue() != "" {
|
||||
opts = append(opts, awsconfig.WithRegion(cfg.Region.GetValue()))
|
||||
}
|
||||
|
||||
// Static credentials if provided; otherwise default chain (IAM role, env vars, etc.)
|
||||
hasStaticConfig := cfg.AccessKeyID != nil || cfg.SecretAccessKey != nil || cfg.SessionToken != nil
|
||||
if hasStaticConfig {
|
||||
if cfg.AccessKeyID == nil || cfg.AccessKeyID.GetValue() == "" ||
|
||||
cfg.SecretAccessKey == nil || cfg.SecretAccessKey.GetValue() == "" {
|
||||
return nil, fmt.Errorf("objectstore: access_key_id and secret_access_key must both be set when using static credentials")
|
||||
}
|
||||
sessionToken := ""
|
||||
if cfg.SessionToken != nil {
|
||||
sessionToken = cfg.SessionToken.GetValue()
|
||||
}
|
||||
opts = append(opts, awsconfig.WithCredentialsProvider(
|
||||
credentials.NewStaticCredentialsProvider(
|
||||
cfg.AccessKeyID.GetValue(),
|
||||
cfg.SecretAccessKey.GetValue(),
|
||||
sessionToken,
|
||||
),
|
||||
))
|
||||
}
|
||||
|
||||
awsCfg, err := awsconfig.LoadDefaultConfig(ctx, opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("objectstore: failed to load AWS config: %w", err)
|
||||
}
|
||||
|
||||
// If a role ARN is configured, assume that role using STS.
|
||||
// Works on top of either static credentials or the default chain (instance role, env vars, etc.).
|
||||
if cfg.RoleARN != nil && cfg.RoleARN.GetValue() != "" {
|
||||
stsClient := sts.NewFromConfig(awsCfg)
|
||||
awsCfg.Credentials = aws.NewCredentialsCache(
|
||||
stscreds.NewAssumeRoleProvider(stsClient, cfg.RoleARN.GetValue()),
|
||||
)
|
||||
}
|
||||
|
||||
s3Opts := func(o *s3.Options) {
|
||||
if cfg.Endpoint != nil && cfg.Endpoint.GetValue() != "" {
|
||||
o.BaseEndpoint = aws.String(cfg.Endpoint.GetValue())
|
||||
}
|
||||
if cfg.ForcePathStyle {
|
||||
o.UsePathStyle = true
|
||||
}
|
||||
}
|
||||
|
||||
client := s3.NewFromConfig(awsCfg, s3Opts)
|
||||
|
||||
return &S3ObjectStore{
|
||||
client: client,
|
||||
bucket: bucket,
|
||||
compress: cfg.Compress,
|
||||
logger: logger,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Put uploads data with optional S3 object tags. When compression is enabled,
|
||||
// data is gzip-compressed before upload.
|
||||
func (s *S3ObjectStore) Put(ctx context.Context, key string, data []byte, tags map[string]string) error {
|
||||
body := data
|
||||
if s.compress {
|
||||
compressed, err := gzipCompress(data)
|
||||
if err != nil {
|
||||
return fmt.Errorf("objectstore: gzip compress: %w", err)
|
||||
}
|
||||
body = compressed
|
||||
}
|
||||
|
||||
input := &s3.PutObjectInput{
|
||||
Bucket: aws.String(s.bucket),
|
||||
Key: aws.String(key),
|
||||
Body: bytes.NewReader(body),
|
||||
ContentType: aws.String("application/json"),
|
||||
}
|
||||
if s.compress {
|
||||
input.ContentEncoding = aws.String("gzip")
|
||||
}
|
||||
|
||||
if len(tags) > 0 {
|
||||
input.Tagging = aws.String(encodeTags(tags))
|
||||
}
|
||||
|
||||
_, err := s.client.PutObject(ctx, input)
|
||||
if err != nil {
|
||||
return fmt.Errorf("objectstore: put object %s: %w", key, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get retrieves and decompresses an object by key.
|
||||
func (s *S3ObjectStore) Get(ctx context.Context, key string) ([]byte, error) {
|
||||
output, err := s.client.GetObject(ctx, &s3.GetObjectInput{
|
||||
Bucket: aws.String(s.bucket),
|
||||
Key: aws.String(key),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("objectstore: get object %s: %w", key, err)
|
||||
}
|
||||
defer output.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(output.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("objectstore: read body %s: %w", key, err)
|
||||
}
|
||||
|
||||
// Only attempt decompression when the object was stored with gzip encoding.
|
||||
if aws.ToString(output.ContentEncoding) == "gzip" {
|
||||
decompressed, err := gzipDecompress(body)
|
||||
if err != nil {
|
||||
s.logger.Warn("objectstore: gzip decompress failed for %s: %v, returning raw bytes", key, err)
|
||||
return body, nil
|
||||
}
|
||||
return decompressed, nil
|
||||
}
|
||||
|
||||
return body, nil
|
||||
}
|
||||
|
||||
// Delete removes a single object by key.
|
||||
func (s *S3ObjectStore) Delete(ctx context.Context, key string) error {
|
||||
_, err := s.client.DeleteObject(ctx, &s3.DeleteObjectInput{
|
||||
Bucket: aws.String(s.bucket),
|
||||
Key: aws.String(key),
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("objectstore: delete object %s: %w", key, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteBatch removes multiple objects. It uses the S3 DeleteObjects API
|
||||
// which supports up to 1000 keys per call.
|
||||
func (s *S3ObjectStore) DeleteBatch(ctx context.Context, keys []string) error {
|
||||
if len(keys) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
const maxBatchSize = 1000
|
||||
for i := 0; i < len(keys); i += maxBatchSize {
|
||||
end := i + maxBatchSize
|
||||
if end > len(keys) {
|
||||
end = len(keys)
|
||||
}
|
||||
batch := keys[i:end]
|
||||
|
||||
objects := make([]types.ObjectIdentifier, len(batch))
|
||||
for j, key := range batch {
|
||||
objects[j] = types.ObjectIdentifier{Key: aws.String(key)}
|
||||
}
|
||||
|
||||
output, err := s.client.DeleteObjects(ctx, &s3.DeleteObjectsInput{
|
||||
Bucket: aws.String(s.bucket),
|
||||
Delete: &types.Delete{
|
||||
Objects: objects,
|
||||
Quiet: aws.Bool(true),
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("objectstore: delete objects batch starting at index %d: %w", i, err)
|
||||
}
|
||||
if len(output.Errors) > 0 {
|
||||
return fmt.Errorf("objectstore: %d objects failed to delete in batch starting at index %d", len(output.Errors), i)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Ping checks connectivity by performing a HeadBucket call.
|
||||
// Note: HeadBucket requires the s3:ListBucket IAM permission on the bucket resource.
|
||||
func (s *S3ObjectStore) Ping(ctx context.Context) error {
|
||||
_, err := s.client.HeadBucket(ctx, &s3.HeadBucketInput{
|
||||
Bucket: aws.String(s.bucket),
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("objectstore: head bucket %s: %w", s.bucket, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close is a no-op for S3 (no persistent connections to release).
|
||||
func (s *S3ObjectStore) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// encodeTags encodes a tag map into the S3 URL-encoded tagging format.
|
||||
// Format: "key1=value1&key2=value2"
|
||||
func encodeTags(tags map[string]string) string {
|
||||
parts := make([]string, 0, len(tags))
|
||||
for k, v := range tags {
|
||||
parts = append(parts, url.QueryEscape(k)+"="+url.QueryEscape(v))
|
||||
}
|
||||
return strings.Join(parts, "&")
|
||||
}
|
||||
Reference in New Issue
Block a user