first commit
This commit is contained in:
516
core/internal/llmtests/audio_validation.go
Normal file
516
core/internal/llmtests/audio_validation.go
Normal file
@@ -0,0 +1,516 @@
|
||||
package llmtests
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/hajimehoshi/go-mp3"
|
||||
)
|
||||
|
||||
// AllowedAudioFormats defines the set of valid audio formats for speech synthesis
|
||||
var AllowedAudioFormats = map[string]bool{
|
||||
"flac": true, "mp3": true, "mp4": true, "mpeg": true,
|
||||
"mpga": true, "m4a": true, "ogg": true, "wav": true, "webm": true,
|
||||
}
|
||||
|
||||
// AudioValidationResult contains the results of audio validation
|
||||
type AudioValidationResult struct {
|
||||
Valid bool
|
||||
Format string
|
||||
MagicBytesValid bool
|
||||
DecodeValid bool
|
||||
FileSize int64
|
||||
Errors []string
|
||||
}
|
||||
|
||||
// ValidateAudioFile validates an audio file by checking magic bytes and attempting decode
|
||||
func ValidateAudioFile(t *testing.T, filePath string, expectedFormat string) error {
|
||||
t.Helper()
|
||||
|
||||
result := validateAudioFileInternal(filePath, expectedFormat)
|
||||
|
||||
if !result.Valid {
|
||||
return fmt.Errorf("audio validation failed for %s (format: %s): %s",
|
||||
filePath, expectedFormat, strings.Join(result.Errors, "; "))
|
||||
}
|
||||
|
||||
t.Logf("✅ Audio validation passed: format=%s, size=%d bytes, magic_bytes=%v, decode=%v",
|
||||
result.Format, result.FileSize, result.MagicBytesValid, result.DecodeValid)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateAudioBytes validates audio bytes by checking magic bytes and attempting decode
|
||||
func ValidateAudioBytes(t *testing.T, audioData []byte, expectedFormat string) error {
|
||||
t.Helper()
|
||||
|
||||
result := validateAudioBytesInternal(audioData, expectedFormat)
|
||||
|
||||
if !result.Valid {
|
||||
return fmt.Errorf("audio validation failed (format: %s): %s",
|
||||
expectedFormat, strings.Join(result.Errors, "; "))
|
||||
}
|
||||
|
||||
t.Logf("✅ Audio validation passed: format=%s, size=%d bytes, magic_bytes=%v, decode=%v",
|
||||
result.Format, len(audioData), result.MagicBytesValid, result.DecodeValid)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SaveAndValidateAudio saves audio bytes to a temp file, validates it, and registers cleanup.
|
||||
// It auto-detects the audio format from magic bytes and validates it's one of the allowed formats.
|
||||
// Returns the temp file path for logging purposes.
|
||||
func SaveAndValidateAudio(t *testing.T, audioData []byte) (string, error) {
|
||||
t.Helper()
|
||||
|
||||
if len(audioData) == 0 {
|
||||
return "", fmt.Errorf("audio data is empty")
|
||||
}
|
||||
|
||||
// Detect audio format from magic bytes
|
||||
detectedFormat := DetectAudioFormat(audioData)
|
||||
if detectedFormat == "" {
|
||||
return "", fmt.Errorf("unable to detect audio format from data (first 16 bytes: %x)", audioData[:min(16, len(audioData))])
|
||||
}
|
||||
|
||||
// Validate the detected format is in the allowed list
|
||||
if !AllowedAudioFormats[detectedFormat] {
|
||||
allowedList := make([]string, 0, len(AllowedAudioFormats))
|
||||
for format := range AllowedAudioFormats {
|
||||
allowedList = append(allowedList, format)
|
||||
}
|
||||
return "", fmt.Errorf("detected format %q is not in allowed formats: %v", detectedFormat, allowedList)
|
||||
}
|
||||
|
||||
// Create temp file with unique name in bifrost subdirectory
|
||||
tempDir := os.TempDir()
|
||||
bifrostDir := filepath.Join(tempDir, "bifrost")
|
||||
fileName := fmt.Sprintf("bifrost_test_speech_%s.%s", uuid.New().String(), detectedFormat)
|
||||
filePath := filepath.Join(bifrostDir, fileName)
|
||||
|
||||
// Ensure bifrost subdirectory exists
|
||||
if err := os.MkdirAll(bifrostDir, 0755); err != nil {
|
||||
return "", fmt.Errorf("failed to create temp directory: %w", err)
|
||||
}
|
||||
|
||||
// Write audio data to file
|
||||
if err := os.WriteFile(filePath, audioData, 0644); err != nil {
|
||||
return "", fmt.Errorf("failed to write audio file: %w", err)
|
||||
}
|
||||
|
||||
// Register cleanup to delete file regardless of test outcome
|
||||
// t.Cleanup(func() {
|
||||
// if err := os.Remove(filePath); err != nil && !os.IsNotExist(err) {
|
||||
// t.Logf("⚠️ Failed to cleanup audio file %s: %v", filePath, err)
|
||||
// } else {
|
||||
// t.Logf("🧹 Cleaned up audio file: %s", filePath)
|
||||
// }
|
||||
// })
|
||||
|
||||
t.Logf("Detected audio format: %s, saved to temp file: %s (%d bytes)", detectedFormat, filePath, len(audioData))
|
||||
|
||||
// Validate the audio file using the detected format
|
||||
if err := ValidateAudioFile(t, filePath, detectedFormat); err != nil {
|
||||
return filePath, err
|
||||
}
|
||||
|
||||
return filePath, nil
|
||||
}
|
||||
|
||||
func validateAudioFileInternal(filePath string, expectedFormat string) AudioValidationResult {
|
||||
result := AudioValidationResult{
|
||||
Format: expectedFormat,
|
||||
}
|
||||
|
||||
// Read file
|
||||
data, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
result.Errors = append(result.Errors, fmt.Sprintf("failed to read file: %v", err))
|
||||
return result
|
||||
}
|
||||
|
||||
fileInfo, err := os.Stat(filePath)
|
||||
if err != nil {
|
||||
result.Errors = append(result.Errors, fmt.Sprintf("failed to stat file: %v", err))
|
||||
return result
|
||||
}
|
||||
result.FileSize = fileInfo.Size()
|
||||
|
||||
return validateAudioBytesInternal(data, expectedFormat)
|
||||
}
|
||||
|
||||
func validateAudioBytesInternal(data []byte, expectedFormat string) AudioValidationResult {
|
||||
result := AudioValidationResult{
|
||||
Format: expectedFormat,
|
||||
FileSize: int64(len(data)),
|
||||
}
|
||||
|
||||
if len(data) == 0 {
|
||||
result.Errors = append(result.Errors, "audio data is empty")
|
||||
return result
|
||||
}
|
||||
|
||||
format := strings.ToLower(expectedFormat)
|
||||
|
||||
switch format {
|
||||
case "mp3":
|
||||
result.MagicBytesValid = validateMP3MagicBytes(data)
|
||||
if !result.MagicBytesValid {
|
||||
result.Errors = append(result.Errors, "invalid MP3 magic bytes")
|
||||
}
|
||||
|
||||
decodeErr := validateMP3Decode(data)
|
||||
result.DecodeValid = decodeErr == nil
|
||||
if decodeErr != nil {
|
||||
result.Errors = append(result.Errors, fmt.Sprintf("MP3 decode failed: %v", decodeErr))
|
||||
}
|
||||
|
||||
case "wav":
|
||||
result.MagicBytesValid = validateWAVMagicBytes(data)
|
||||
if !result.MagicBytesValid {
|
||||
result.Errors = append(result.Errors, "invalid WAV magic bytes")
|
||||
}
|
||||
|
||||
decodeErr := validateWAVDecode(data)
|
||||
result.DecodeValid = decodeErr == nil
|
||||
if decodeErr != nil {
|
||||
result.Errors = append(result.Errors, fmt.Sprintf("WAV decode failed: %v", decodeErr))
|
||||
}
|
||||
|
||||
case "flac":
|
||||
result.MagicBytesValid = validateFLACMagicBytes(data)
|
||||
if !result.MagicBytesValid {
|
||||
result.Errors = append(result.Errors, "invalid FLAC magic bytes")
|
||||
}
|
||||
// Basic magic bytes check is sufficient for FLAC
|
||||
result.DecodeValid = result.MagicBytesValid
|
||||
|
||||
case "ogg", "opus":
|
||||
result.MagicBytesValid = validateOGGMagicBytes(data)
|
||||
if !result.MagicBytesValid {
|
||||
result.Errors = append(result.Errors, "invalid OGG magic bytes")
|
||||
}
|
||||
// Basic magic bytes check is sufficient for OGG
|
||||
result.DecodeValid = result.MagicBytesValid
|
||||
|
||||
case "aac":
|
||||
result.MagicBytesValid = validateAACMagicBytes(data)
|
||||
if !result.MagicBytesValid {
|
||||
result.Errors = append(result.Errors, "invalid AAC magic bytes")
|
||||
}
|
||||
// Basic magic bytes check is sufficient for AAC
|
||||
result.DecodeValid = result.MagicBytesValid
|
||||
|
||||
case "mp4", "m4a":
|
||||
result.MagicBytesValid = validateMP4MagicBytes(data)
|
||||
if !result.MagicBytesValid {
|
||||
result.Errors = append(result.Errors, "invalid MP4/M4A magic bytes")
|
||||
}
|
||||
// Basic magic bytes check is sufficient for MP4/M4A containers
|
||||
result.DecodeValid = result.MagicBytesValid
|
||||
|
||||
case "webm":
|
||||
result.MagicBytesValid = validateWEBMMagicBytes(data)
|
||||
if !result.MagicBytesValid {
|
||||
result.Errors = append(result.Errors, "invalid WEBM magic bytes")
|
||||
}
|
||||
// Basic magic bytes check is sufficient for WEBM
|
||||
result.DecodeValid = result.MagicBytesValid
|
||||
|
||||
case "mpeg", "mpga":
|
||||
// MPEG/MPGA are essentially MP3 audio
|
||||
result.MagicBytesValid = validateMP3MagicBytes(data)
|
||||
if !result.MagicBytesValid {
|
||||
result.Errors = append(result.Errors, "invalid MPEG magic bytes")
|
||||
}
|
||||
|
||||
decodeErr := validateMP3Decode(data)
|
||||
result.DecodeValid = decodeErr == nil
|
||||
if decodeErr != nil {
|
||||
result.Errors = append(result.Errors, fmt.Sprintf("MPEG decode failed: %v", decodeErr))
|
||||
}
|
||||
|
||||
case "pcm16":
|
||||
// PCM has no magic bytes, just validate it has data
|
||||
result.MagicBytesValid = len(data) > 0
|
||||
result.DecodeValid = len(data) > 0 && len(data)%2 == 0 // PCM16 should have even byte count
|
||||
|
||||
default:
|
||||
result.Errors = append(result.Errors, fmt.Sprintf("unsupported audio format: %s", format))
|
||||
return result
|
||||
}
|
||||
|
||||
result.Valid = result.MagicBytesValid && result.DecodeValid
|
||||
return result
|
||||
}
|
||||
|
||||
// validateMP3MagicBytes checks for valid MP3 file signatures
|
||||
// MP3 files can start with:
|
||||
// - ID3 tag: 0x49 0x44 0x33 ("ID3")
|
||||
// - MPEG frame sync: 0xFF 0xFB, 0xFF 0xFA, 0xFF 0xF3, 0xFF 0xF2, 0xFF 0xE3, 0xFF 0xE2
|
||||
func validateMP3MagicBytes(data []byte) bool {
|
||||
if len(data) < 3 {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check for ID3 tag
|
||||
if data[0] == 0x49 && data[1] == 0x44 && data[2] == 0x33 {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check for MPEG audio frame sync
|
||||
if len(data) >= 2 && data[0] == 0xFF {
|
||||
// Valid MPEG audio frame sync bytes
|
||||
// 0xFB = MPEG1 Layer3
|
||||
// 0xFA = MPEG1 Layer3 with CRC
|
||||
// 0xF3 = MPEG2 Layer3
|
||||
// 0xF2 = MPEG2 Layer3 with CRC
|
||||
// 0xE3 = MPEG2.5 Layer3
|
||||
// 0xE2 = MPEG2.5 Layer3 with CRC
|
||||
switch data[1] & 0xF6 {
|
||||
case 0xF2, 0xE2: // Layer 3
|
||||
return true
|
||||
}
|
||||
// Also check the more common patterns
|
||||
switch data[1] {
|
||||
case 0xFB, 0xFA, 0xF3, 0xF2, 0xE3, 0xE2:
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// validateMP3Decode attempts to decode MP3 data to verify it's valid
|
||||
func validateMP3Decode(data []byte) error {
|
||||
reader := bytes.NewReader(data)
|
||||
decoder, err := mp3.NewDecoder(reader)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create MP3 decoder: %w", err)
|
||||
}
|
||||
|
||||
// Try to read a small sample to verify decoding works
|
||||
buf := make([]byte, 4096)
|
||||
n, err := decoder.Read(buf)
|
||||
if err != nil && err != io.EOF {
|
||||
return fmt.Errorf("failed to decode MP3 sample: %w", err)
|
||||
}
|
||||
|
||||
if n == 0 && err != io.EOF {
|
||||
return fmt.Errorf("no audio data decoded from MP3")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateWAVMagicBytes checks for valid WAV file signature
|
||||
// WAV files start with "RIFF" followed by file size, then "WAVE"
|
||||
func validateWAVMagicBytes(data []byte) bool {
|
||||
if len(data) < 12 {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check RIFF header
|
||||
if string(data[0:4]) != "RIFF" {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check WAVE format
|
||||
if string(data[8:12]) != "WAVE" {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// validateWAVDecode parses WAV header to verify the file structure
|
||||
func validateWAVDecode(data []byte) error {
|
||||
if len(data) < 44 {
|
||||
return fmt.Errorf("WAV file too small: %d bytes (minimum 44)", len(data))
|
||||
}
|
||||
|
||||
// Verify RIFF chunk
|
||||
if string(data[0:4]) != "RIFF" {
|
||||
return fmt.Errorf("missing RIFF header")
|
||||
}
|
||||
|
||||
// Get file size from header (we just validate the header exists, not the size
|
||||
// since some encoders don't set this correctly and streaming may not have final size)
|
||||
_ = binary.LittleEndian.Uint32(data[4:8])
|
||||
|
||||
// Verify WAVE format
|
||||
if string(data[8:12]) != "WAVE" {
|
||||
return fmt.Errorf("missing WAVE format marker")
|
||||
}
|
||||
|
||||
// Find and validate fmt chunk
|
||||
offset := 12
|
||||
foundFmt := false
|
||||
foundData := false
|
||||
|
||||
for offset < len(data)-8 {
|
||||
chunkID := string(data[offset : offset+4])
|
||||
chunkSize := int(binary.LittleEndian.Uint32(data[offset+4 : offset+8]))
|
||||
|
||||
if chunkID == "fmt " {
|
||||
foundFmt = true
|
||||
if offset+8+chunkSize > len(data) {
|
||||
return fmt.Errorf("fmt chunk extends beyond file")
|
||||
}
|
||||
|
||||
// Validate audio format (offset+8 is start of fmt chunk data)
|
||||
if offset+10 <= len(data) {
|
||||
audioFormat := binary.LittleEndian.Uint16(data[offset+8 : offset+10])
|
||||
// 1 = PCM, 3 = IEEE float, 6 = A-law, 7 = mu-law, 0xFFFE = extensible
|
||||
validFormats := map[uint16]bool{1: true, 3: true, 6: true, 7: true, 0xFFFE: true}
|
||||
if !validFormats[audioFormat] {
|
||||
return fmt.Errorf("unsupported audio format in WAV: %d", audioFormat)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if chunkID == "data" {
|
||||
foundData = true
|
||||
}
|
||||
|
||||
offset += 8 + chunkSize
|
||||
// Align to even boundary
|
||||
if chunkSize%2 != 0 {
|
||||
offset++
|
||||
}
|
||||
}
|
||||
|
||||
if !foundFmt {
|
||||
return fmt.Errorf("missing fmt chunk in WAV file")
|
||||
}
|
||||
|
||||
if !foundData {
|
||||
return fmt.Errorf("missing data chunk in WAV file")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateFLACMagicBytes checks for valid FLAC file signature
|
||||
// FLAC files start with "fLaC" (0x66 0x4C 0x61 0x43)
|
||||
func validateFLACMagicBytes(data []byte) bool {
|
||||
if len(data) < 4 {
|
||||
return false
|
||||
}
|
||||
return string(data[0:4]) == "fLaC"
|
||||
}
|
||||
|
||||
// validateOGGMagicBytes checks for valid OGG file signature
|
||||
// OGG files start with "OggS" (0x4F 0x67 0x67 0x53)
|
||||
func validateOGGMagicBytes(data []byte) bool {
|
||||
if len(data) < 4 {
|
||||
return false
|
||||
}
|
||||
return string(data[0:4]) == "OggS"
|
||||
}
|
||||
|
||||
// validateAACMagicBytes checks for valid AAC file signature
|
||||
// AAC ADTS frames start with 0xFF 0xF1 or 0xFF 0xF9
|
||||
// AAC in M4A container starts with "ftyp" at offset 4
|
||||
func validateAACMagicBytes(data []byte) bool {
|
||||
if len(data) < 4 {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check for ADTS sync word
|
||||
if data[0] == 0xFF && (data[1]&0xF0) == 0xF0 {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check for M4A/MP4 container (ftyp at offset 4)
|
||||
if len(data) >= 8 && string(data[4:8]) == "ftyp" {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// validateWEBMMagicBytes checks for valid WEBM file signature
|
||||
// WEBM files start with EBML header: 0x1A 0x45 0xDF 0xA3
|
||||
// and contain "webm" doctype somewhere in the first ~40 bytes
|
||||
func validateWEBMMagicBytes(data []byte) bool {
|
||||
if len(data) < 4 {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check for EBML header (Matroska/WebM container)
|
||||
if data[0] != 0x1A || data[1] != 0x45 || data[2] != 0xDF || data[3] != 0xA3 {
|
||||
return false
|
||||
}
|
||||
|
||||
// Look for "webm" doctype in the first 64 bytes
|
||||
searchLen := 64
|
||||
if len(data) < searchLen {
|
||||
searchLen = len(data)
|
||||
}
|
||||
|
||||
return bytes.Contains(data[:searchLen], []byte("webm"))
|
||||
}
|
||||
|
||||
// validateMP4MagicBytes checks for valid MP4/M4A file signature
|
||||
// MP4/M4A files have "ftyp" at offset 4, followed by brand identifiers
|
||||
func validateMP4MagicBytes(data []byte) bool {
|
||||
if len(data) < 12 {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check for ftyp box
|
||||
return string(data[4:8]) == "ftyp"
|
||||
}
|
||||
|
||||
// DetectAudioFormat detects the audio format from the buffer header bytes.
|
||||
// Returns the detected format string (mp3, wav, flac, ogg, mp4, m4a, webm) or empty string if unknown.
|
||||
func DetectAudioFormat(data []byte) string {
|
||||
if len(data) < 4 {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Check WAV first (RIFF + WAVE)
|
||||
if validateWAVMagicBytes(data) {
|
||||
return "wav"
|
||||
}
|
||||
|
||||
// Check FLAC (fLaC)
|
||||
if validateFLACMagicBytes(data) {
|
||||
return "flac"
|
||||
}
|
||||
|
||||
// Check OGG/Opus (OggS)
|
||||
if validateOGGMagicBytes(data) {
|
||||
return "ogg"
|
||||
}
|
||||
|
||||
// Check WEBM (EBML header with webm doctype) - check before MP4 as both are containers
|
||||
if validateWEBMMagicBytes(data) {
|
||||
return "webm"
|
||||
}
|
||||
|
||||
// Check MP4/M4A container (ftyp box) - returns m4a for audio containers
|
||||
if validateMP4MagicBytes(data) {
|
||||
return "m4a"
|
||||
}
|
||||
|
||||
// Check MP3 (ID3 or MPEG frame sync)
|
||||
if validateMP3MagicBytes(data) {
|
||||
return "mp3"
|
||||
}
|
||||
|
||||
// Check AAC ADTS (raw AAC stream without container)
|
||||
if len(data) >= 2 && data[0] == 0xFF && (data[1]&0xF0) == 0xF0 {
|
||||
return "aac"
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
Reference in New Issue
Block a user