mlx-examples/encodec
Awni Hannun 9bb2dd62f3
Encodec (#991)
* initial encodec

* works

* nits

* use fast group norm

* fix for rnn layer

* fix mlx version

* use custom LSTM kernel

* audio encodec

* fix example, support batched inference

* nits
2024-09-23 11:39:25 -07:00
..
benchmarks Encodec (#991) 2024-09-23 11:39:25 -07:00
convert.py Encodec (#991) 2024-09-23 11:39:25 -07:00
encodec.py Encodec (#991) 2024-09-23 11:39:25 -07:00
example.py Encodec (#991) 2024-09-23 11:39:25 -07:00
README.md Encodec (#991) 2024-09-23 11:39:25 -07:00
requirements.txt Encodec (#991) 2024-09-23 11:39:25 -07:00
test.py Encodec (#991) 2024-09-23 11:39:25 -07:00
utils.py Encodec (#991) 2024-09-23 11:39:25 -07:00

EnCodec

An example of Meta's EnCodec model in MLX.1 EnCodec is used to compress and generate audio.

Setup

Install the requirements:

pip install -r requirements.txt

Optionally install FFmpeg and SciPy for loading and saving audio files, respectively.

Install FFmpeg:

# on macOS using Homebrew (https://brew.sh/)
brew install ffmpeg

Install SciPy:

pip install scipy

Example

An example using the model:

import mlx.core as mx
from utils import load, load_audio, save_audio

# Load the 48 KHz model and preprocessor.
model, processor = load("mlx-community/encodec-48khz-float32")

# Load an audio file
audio = load_audio("path/to/aduio", model.sampling_rate, model.channels)

# Preprocess the audio (this can also be a list of arrays for batched
# processing).
feats, mask = processor(audio)

# Encode at the given bandwidth. A lower bandwidth results in more
# compression but lower reconstruction quality.
@mx.compile
def encode(feats, mask):
    return model.encode(feats, mask, bandwidth=3)

# Decode to reconstruct the audio
@mx.compile
def decode(codes, scales, mask):
    return model.decode(codes, scales, mask)


codes, scales = encode(feats, mask)
reconstructed = decode(codes, scales, mask)

# Trim any padding:
reconstructed = reconstructed[0, : len(audio)]

# Save the audio as a wave file
save_audio("reconstructed.wav", reconstructed, model.sampling_rate)

The 24 KHz, 32 KHz, and 48 KHz MLX formatted models are available in the Hugging Face MLX Community in several data types.

Optional

To convert models, use the convert.py script. To see the options, run:

python convert.py -h

  1. Refer to the arXiv paper and code for more details. ↩︎