MSEB - Massive Sound Embedding Benchmark

A benchmarking framework by Google Research for evaluating sound embedding methods across diverse sound categories and tasks. Apache 2.0 licensed.

Build & Test

pip install -e .            # Install in dev mode
pip install -e ".[dev]"     # With dev dependencies (pytest, pylint, pyink)
pytest mseb/                # Run all tests
pytest mseb/encoder_test.py # Run a specific test file
pytest -m "not optional"    # Skip tests requiring optional deps (whisper, scann, spacy, tf_hub)

Formatting: Uses pyink (Google style), 80-char line length, 2-space indentation, majority quotes.

Build system: flit (flit_core). Version from mseb/__init__.py.

Project Structure

mseb/
├── encoder.py          # Base MultiModalEncoder, CascadeEncoder, CollectionEncoder
├── types.py            # Core data types: Sound, Text, SoundEmbedding, TextEmbedding, etc.
├── task.py             # Base MSEBTask class
├── evaluator.py        # Base evaluator class
├── runner.py           # DirectRunner (local) and BeamRunner (distributed)
├── dataset.py          # Dataset base class
├── leaderboard.py      # Result aggregation and reporting
├── metrics.py          # Metric computation utilities
├── decoder.py          # Decoder utilities
├── svq.py              # Supervised voice quality utilities
├── utils.py            # General utilities
├── encoders/           # ~38 concrete encoder implementations + registry
├── evaluators/         # Task-specific evaluators (retrieval, classification, clustering, etc.)
├── tasks/              # Task definitions organized by type (retrieval, classification, etc.)
├── datasets/           # Dataset implementations
├── scripts/            # CLI entry points (run_task.py, run_rag_task.py, run_task_setup.py, etc.)
├── results/            # Pre-computed benchmark results (JSONL)
└── testdata/           # Test fixtures

Encoder Architecture

All encoders inherit from MultiModalEncoder (mseb/encoder.py). Key design:

  • Lazy init: __init__ stores config only; heavy model loading in _setup(), called once via idempotent setup().
  • Template method: Public setup() and encode() are @final. Subclasses implement _setup(), _encode(), and _check_input_types().
  • Encoding stats: encode() automatically records EncodingStats (input size, output size, FLOPs) on each embedding.

Creating an Encoder

Implement three abstract methods:

class MyEncoder(MultiModalEncoder):
    def _setup(self):
        # Load model weights, initialize resources

    def _check_input_types(self, batch: Sequence[types.MultiModalObject]) -> None:
        # Validate all items are the expected type (e.g., types.Sound)

    def _encode(self, batch: Sequence[types.MultiModalObject]) -> Sequence[types.MultiModalObject]:
        # Transform inputs to embeddings and return

Composition Patterns

  • CascadeEncoder: Chains encoders sequentially (output of one feeds input of next). Example: ASR encoder -> Converter -> Text embedding encoder.
  • CollectionEncoder: Dispatches to different encoders by input type. Example: Sound encoder for audio queries + Text encoder for document indexing.
  • Converters (encoders/converter.py): Bridge modality gaps between cascade stages (e.g., SoundEmbeddingToTextConverter).

Encoder Registry

encoders/encoder_registry.py provides lookup-by-name for all registered encoders, used by scripts for CLI-based encoder selection.

RAG (Retrieval-Augmented Generation)

RAG is implemented as a retrieval task type with two phases:

Setup Phase (scripts/run_task_setup.py)

  1. RetrievalTask.documents() yields the document corpus
  2. Runner encodes all documents into embeddings
  3. RetrievalEvaluator.build_index() builds a ScaNN (Scalable Approximate Nearest Neighbors) index using dot-product similarity with tree+AH quantization
  4. Index and ID mapping saved to disk

Inference Phase (scripts/run_rag_task.py)

  1. Queries (audio) are encoded via the query encoder
  2. RetrievalEncoder loads the pre-built ScaNN index
  3. For each query embedding, ScaNN returns top-k nearest document IDs with scores
  4. Results formatted as ListPrediction (ranked list of {id, score})
  5. Evaluated with MRR (Mean Reciprocal Rank), Recall@k, Exact Match

Key RAG Files

  • encoders/retrieval_encoder.py — Encoder that wraps ScaNN search as an encode step
  • evaluators/retrieval_evaluator.py — Builds ScaNN indexes, computes predictions and metrics
  • tasks/retrieval.py — Base RetrievalTask managing index lifecycle
  • tasks/retrievals/ — Concrete retrieval tasks (passage/document, in-lang/cross-lang) over the SVQ dataset

RAG Pipeline Composition

A typical RAG encoder is a CascadeEncoder:

[QueryEncoder (Sound → SoundEmbedding)] → [RetrievalEncoder (SoundEmbedding → TextPrediction)]

With a CollectionEncoder dispatching Sound queries to the cascade and Text documents to a text encoder for index building.

Type System (mseb/types.py)

Core union: MultiModalObject = Sound | Text | SoundEmbedding | TextEmbedding | TextPrediction | ...

  • Sound: waveform array + SoundContextParams (id, sample_rate, length, language, text)
  • Text: text string + TextContextParams (id, title, context)
  • SoundEmbedding: embedding array + timestamps + context
  • TextEmbedding: embedding array + character spans + context
  • TextPrediction: prediction string + context
  • ListPrediction: ranked retrieval results with normalization/merge support

Uses jaxtyping for array shape annotations (e.g., Float[Array, "N D"]).

Testing

  • Framework: pytest + absltest
  • conftest.py initializes absl flags before test discovery
  • Markers: @pytest.mark.optional, @pytest.mark.whisper, @pytest.mark.scann
  • Tests use mock encoders; test data in mseb/testdata/

Scripts

Script Purpose
run_task.py Run a benchmark task end-to-end
run_rag_task.py Run RAG retrieval tasks
run_task_setup.py Build indexes/weights for tasks
run_clustering.py Run clustering evaluation
flatten_results.py Flatten results for analysis
generate_table.py Generate result tables

Scripts use absl.flags for configuration (--task, --encoder, --batch_size, etc.).