From 9bb2dd62f350d9f0f1a8003e354431b5172831e5 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 23 Sep 2024 11:39:25 -0700 Subject: [PATCH] Encodec (#991) * initial encodec * works * nits * use fast group norm * fix for rnn layer * fix mlx version * use custom LSTM kernel * audio encodec * fix example, support batched inference * nits --- README.md | 1 + encodec/README.md | 83 ++++ encodec/benchmarks/bench_mx.py | 30 ++ encodec/benchmarks/bench_pt.py | 34 ++ encodec/convert.py | 213 +++++++++++ encodec/encodec.py | 671 +++++++++++++++++++++++++++++++++ encodec/example.py | 37 ++ encodec/requirements.txt | 3 + encodec/test.py | 66 ++++ encodec/utils.py | 129 +++++++ 10 files changed, 1267 insertions(+) create mode 100644 encodec/README.md create mode 100644 encodec/benchmarks/bench_mx.py create mode 100644 encodec/benchmarks/bench_pt.py create mode 100644 encodec/convert.py create mode 100644 encodec/encodec.py create mode 100644 encodec/example.py create mode 100644 encodec/requirements.txt create mode 100644 encodec/test.py create mode 100644 encodec/utils.py diff --git a/README.md b/README.md index 00e57803..bd180975 100644 --- a/README.md +++ b/README.md @@ -27,6 +27,7 @@ Some more useful examples are listed below. ### Audio Models - Speech recognition with [OpenAI's Whisper](whisper). +- Audio compression and generation with [Meta's EnCodec](encodec). ### Multimodal models diff --git a/encodec/README.md b/encodec/README.md new file mode 100644 index 00000000..3ab2793c --- /dev/null +++ b/encodec/README.md @@ -0,0 +1,83 @@ +# EnCodec + +An example of Meta's EnCodec model in MLX.[^1] EnCodec is used to compress and +generate audio. + +### Setup + +Install the requirements: + +``` +pip install -r requirements.txt +``` + +Optionally install FFmpeg and SciPy for loading and saving audio files, +respectively. + +Install [FFmpeg](https://ffmpeg.org/): + +``` +# on macOS using Homebrew (https://brew.sh/) +brew install ffmpeg +``` + +Install SciPy: + +``` +pip install scipy +``` + +### Example + +An example using the model: + +```python +import mlx.core as mx +from utils import load, load_audio, save_audio + +# Load the 48 KHz model and preprocessor. +model, processor = load("mlx-community/encodec-48khz-float32") + +# Load an audio file +audio = load_audio("path/to/aduio", model.sampling_rate, model.channels) + +# Preprocess the audio (this can also be a list of arrays for batched +# processing). +feats, mask = processor(audio) + +# Encode at the given bandwidth. A lower bandwidth results in more +# compression but lower reconstruction quality. +@mx.compile +def encode(feats, mask): + return model.encode(feats, mask, bandwidth=3) + +# Decode to reconstruct the audio +@mx.compile +def decode(codes, scales, mask): + return model.decode(codes, scales, mask) + + +codes, scales = encode(feats, mask) +reconstructed = decode(codes, scales, mask) + +# Trim any padding: +reconstructed = reconstructed[0, : len(audio)] + +# Save the audio as a wave file +save_audio("reconstructed.wav", reconstructed, model.sampling_rate) +``` + +The 24 KHz, 32 KHz, and 48 KHz MLX formatted models are available in the +[Hugging Face MLX Community](https://huggingface.co/collections/mlx-community/encodec-66e62334038300b07a43b164) +in several data types. + +### Optional + +To convert models, use the `convert.py` script. To see the options, run: + +```bash +python convert.py -h +``` + +[^1]: Refer to the [arXiv paper](https://arxiv.org/abs/2210.13438) and + [code](https://github.com/facebookresearch/encodec) for more details. diff --git a/encodec/benchmarks/bench_mx.py b/encodec/benchmarks/bench_mx.py new file mode 100644 index 00000000..2acd4b75 --- /dev/null +++ b/encodec/benchmarks/bench_mx.py @@ -0,0 +1,30 @@ +# Copyright © 2024 Apple Inc. + +import time + +import mlx.core as mx +from utils import load + +model, processor = load("mlx-community/encodec-48khz-float32") + +audio = mx.random.uniform(shape=(288000, 2)) +feats, mask = processor(audio) +mx.eval(model, feats, mask) + + +@mx.compile +def fun(): + codes, scales = model.encode(feats, mask, bandwidth=3) + reconstructed = model.decode(codes, scales, mask) + return reconstructed + + +for _ in range(5): + mx.eval(fun()) + +tic = time.time() +for _ in range(10): + mx.eval(fun()) +toc = time.time() +ms = 1000 * (toc - tic) / 10 +print(f"Time per it: {ms:.3f}") diff --git a/encodec/benchmarks/bench_pt.py b/encodec/benchmarks/bench_pt.py new file mode 100644 index 00000000..5d158a32 --- /dev/null +++ b/encodec/benchmarks/bench_pt.py @@ -0,0 +1,34 @@ +# Copyright © 2024 Apple Inc. + +import time + +import numpy as np +import torch +from transformers import AutoProcessor, EncodecModel + +processor = AutoProcessor.from_pretrained("facebook/encodec_48khz") +audio = np.random.uniform(size=(2, 288000)).astype(np.float32) + +pt_model = EncodecModel.from_pretrained("facebook/encodec_48khz").to("mps") +pt_inputs = processor( + raw_audio=audio, sampling_rate=processor.sampling_rate, return_tensors="pt" +).to("mps") + + +def fun(): + pt_encoded = pt_model.encode(pt_inputs["input_values"], pt_inputs["padding_mask"]) + pt_audio = pt_model.decode( + pt_encoded.audio_codes, pt_encoded.audio_scales, pt_inputs["padding_mask"] + ) + torch.mps.synchronize() + + +for _ in range(5): + fun() + +tic = time.time() +for _ in range(10): + fun() +toc = time.time() +ms = 1000 * (toc - tic) / 10 +print(f"Time per it: {ms:.3f}") diff --git a/encodec/convert.py b/encodec/convert.py new file mode 100644 index 00000000..13bd31a6 --- /dev/null +++ b/encodec/convert.py @@ -0,0 +1,213 @@ +# Copyright © 2024 Apple Inc. + +import argparse +import json +from pathlib import Path +from textwrap import dedent +from types import SimpleNamespace +from typing import Any, Dict, Union + +import mlx.core as mx +import mlx.nn as nn +from huggingface_hub import snapshot_download +from mlx.utils import tree_flatten + +import encodec + + +def fetch_from_hub(hf_repo: str) -> Path: + model_path = Path( + snapshot_download( + repo_id=hf_repo, + allow_patterns=["*.json", "*.safetensors"], + ) + ) + return model_path + + +def upload_to_hub(path: str, upload_repo: str, hf_path: str): + """ + Uploads the model to Hugging Face hub. + + Args: + path (str): Local path to the model. + upload_repo (str): Name of the HF repo to upload to. + hf_path (str): Path to the original Hugging Face model. + """ + import os + + from huggingface_hub import HfApi, ModelCard, logging + + content = dedent( + f""" + --- + language: en + license: other + library: mlx + tags: + - mlx + --- + + The Model [{upload_repo}](https://huggingface.co/{upload_repo}) was + converted to MLX format from + [{hf_path}](https://huggingface.co/{hf_path}). + + This model is intended to be used with the [EnCodec MLX + example](https://github.com/ml-explore/mlx-examples/tree/main/encodec). + """ + ) + + card = ModelCard(content) + card.save(os.path.join(path, "README.md")) + + logging.set_verbosity_info() + + api = HfApi() + api.create_repo(repo_id=upload_repo, exist_ok=True) + api.upload_folder( + folder_path=path, + repo_id=upload_repo, + repo_type="model", + multi_commits=True, + multi_commits_verbose=True, + ) + print(f"Upload successful, go to https://huggingface.co/{upload_repo} for details.") + + +def save_weights(save_path: Union[str, Path], weights: Dict[str, Any]) -> None: + if isinstance(save_path, str): + save_path = Path(save_path) + save_path.mkdir(parents=True, exist_ok=True) + + total_size = sum(v.nbytes for v in weights.values()) + index_data = {"metadata": {"total_size": total_size}, "weight_map": {}} + mx.save_safetensors( + str(save_path / "model.safetensors"), weights, metadata={"format": "mlx"} + ) + + for weight_name in weights.keys(): + index_data["weight_map"][weight_name] = "model.safetensors" + + index_data["weight_map"] = { + k: index_data["weight_map"][k] for k in sorted(index_data["weight_map"]) + } + + with open(save_path / "model.safetensors.index.json", "w") as f: + json.dump(index_data, f, indent=4) + + +def save_config( + config: dict, + config_path: Union[str, Path], +) -> None: + """Save the model configuration to the ``config_path``. + + The final configuration will be sorted before saving for better readability. + + Args: + config (dict): The model configuration. + config_path (Union[str, Path]): Model configuration file path. + """ + # Clean unused keys + config.pop("_name_or_path", None) + + # sort the config for better readability + config = dict(sorted(config.items())) + + # write the updated config to the config_path (if provided) + with open(config_path, "w") as fid: + json.dump(config, fid, indent=4) + + +def convert( + upload: bool, + model: str, + dtype: str = None, +): + hf_repo = f"facebook/encodec_{model}" + mlx_repo = f"mlx-community/encodec-{model}-{dtype}" + path = fetch_from_hub(hf_repo) + save_path = Path("mlx_models") + + weights = mx.load(str(Path(path) / "model.safetensors")) + + with open(path / "config.json", "r") as fid: + config = SimpleNamespace(**json.load(fid)) + + model = encodec.EncodecModel(config) + + new_weights = {} + for k, v in weights.items(): + basename, pname = k.rsplit(".", 1) + if pname == "weight_v": + g = weights[basename + ".weight_g"] + v = g * (v / mx.linalg.norm(v, axis=(1, 2), keepdims=True)) + k = basename + ".weight" + elif pname in ["weight_g", "embed_avg", "cluster_size", "inited"]: + continue + elif "lstm" in basename: + w_or_b, ih_or_hh, ln = pname.split("_") + if w_or_b == "weight": + new_pname = "Wx" if ih_or_hh == "ih" else "Wh" + elif w_or_b == "bias" and ih_or_hh == "ih": + continue + else: + v = v + weights[k.replace("_hh_", "_ih_")] + new_pname = "bias" + k = basename + "." + ln[1:] + "." + new_pname + if "conv.weight" in k: + # Possibly a transposed conv which has a different order + if "decoder" in k: + ln = int(k.split(".")[2]) + if "conv" in model.decoder.layers[ln] and isinstance( + model.decoder.layers[ln].conv, nn.ConvTranspose1d + ): + v = mx.moveaxis(v, 0, 2) + else: + v = mx.moveaxis(v, 1, 2) + else: + v = mx.moveaxis(v, 1, 2) + + new_weights[k] = v + weights = new_weights + + model.load_weights(list(weights.items())) + + if dtype is not None: + t = getattr(mx, dtype) + weights = {k: v.astype(t) for k, v in weights.items()} + + if isinstance(save_path, str): + save_path = Path(save_path) + + save_weights(save_path, weights) + + save_config(vars(config), config_path=save_path / "config.json") + + if upload: + upload_to_hub(save_path, mlx_repo, hf_repo) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Convert EnCodec weights to MLX.") + parser.add_argument( + "--model", + type=str, + default="48khz", + help="", + choices=["24khz", "32khz", "48khz"], + ) + parser.add_argument( + "--upload", + action="store_true", + help="Upload the weights to Hugging Face.", + ) + parser.add_argument( + "--dtype", + type=str, + help="Data type to convert the model to.", + default="float32", + choices=["float32", "bfloat16", "float16"], + ) + args = parser.parse_args() + convert(upload=args.upload, model=args.model, dtype=args.dtype) diff --git a/encodec/encodec.py b/encodec/encodec.py new file mode 100644 index 00000000..3ef47369 --- /dev/null +++ b/encodec/encodec.py @@ -0,0 +1,671 @@ +# Copyright © 2024 Apple Inc. + +import math +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import mlx.core as mx +import mlx.nn as nn +import numpy as np + +_lstm_kernel = mx.fast.metal_kernel( + name="lstm", + input_names=["x", "h_in", "cell", "hidden_size", "time_step", "num_time_steps"], + output_names=["hidden_state", "cell_state"], + header=""" + template + T sigmoid(T x) { + auto y = 1 / (1 + metal::exp(-metal::abs(x))); + return (x < 0) ? 1 - y : y; + } + """, + source=""" + uint b = thread_position_in_grid.x; + uint d = hidden_size * 4; + + uint elem = b * d + thread_position_in_grid.y; + uint index = elem; + uint x_index = b * num_time_steps * d + time_step * d + index; + + auto i = sigmoid(h_in[index] + x[x_index]); + index += hidden_size; + x_index += hidden_size; + auto f = sigmoid(h_in[index] + x[x_index]); + index += hidden_size; + x_index += hidden_size; + auto g = metal::precise::tanh(h_in[index] + x[x_index]); + index += hidden_size; + x_index += hidden_size; + auto o = sigmoid(h_in[index] + x[x_index]); + + cell_state[elem] = f * cell[elem] + i * g; + hidden_state[elem] = o * metal::precise::tanh(cell_state[elem]); + """, +) + + +def lstm_custom(x, h_in, cell, time_step): + assert x.ndim == 3, "Input to LSTM must have 3 dimensions." + out_shape = cell.shape + return _lstm_kernel( + inputs=[x, h_in, cell, out_shape[-1], time_step, x.shape[-2]], + output_shapes=[out_shape, out_shape], + output_dtypes=[h_in.dtype, h_in.dtype], + grid=(x.shape[0], h_in.size // 4, 1), + threadgroup=(256, 1, 1), + ) + + +class LSTM(nn.Module): + def __init__( + self, + input_size: int, + hidden_size: int, + bias: bool = True, + ): + super().__init__() + + self.hidden_size = hidden_size + self.Wx = mx.zeros((4 * hidden_size, input_size)) + self.Wh = mx.zeros((4 * hidden_size, hidden_size)) + self.bias = mx.zeros((4 * hidden_size,)) if bias else None + + def __call__(self, x, hidden=None, cell=None): + if self.bias is not None: + x = mx.addmm(self.bias, x, self.Wx.T) + else: + x = x @ self.Wx.T + + all_hidden = [] + + B = x.shape[0] + cell = cell or mx.zeros((B, self.hidden_size), x.dtype) + for t in range(x.shape[-2]): + if hidden is None: + hidden = mx.zeros((B, self.hidden_size * 4), x.dtype) + else: + hidden = hidden @ self.Wh.T + hidden, cell = lstm_custom(x, hidden, cell, t) + all_hidden.append(hidden) + + return mx.stack(all_hidden, axis=-2) + + +class EncodecConv1d(nn.Module): + """Conv1d with asymmetric or causal padding and normalization.""" + + def __init__( + self, + config, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + dilation: int = 1, + ): + super().__init__() + self.causal = config.use_causal_conv + self.pad_mode = config.pad_mode + self.norm_type = config.norm_type + + self.conv = nn.Conv1d( + in_channels, out_channels, kernel_size, stride, dilation=dilation + ) + if self.norm_type == "time_group_norm": + self.norm = nn.GroupNorm(1, out_channels, pytorch_compatible=True) + + self.stride = stride + + # Effective kernel size with dilations. + self.kernel_size = (kernel_size - 1) * dilation + 1 + + self.padding_total = kernel_size - stride + + def _get_extra_padding_for_conv1d( + self, + hidden_states: mx.array, + ) -> mx.array: + length = hidden_states.shape[1] + n_frames = (length - self.kernel_size + self.padding_total) / self.stride + 1 + n_frames = int(math.ceil(n_frames)) - 1 + ideal_length = n_frames * self.stride + self.kernel_size - self.padding_total + return ideal_length - length + + def _pad1d( + self, + hidden_states: mx.array, + paddings: Tuple[int, int], + mode: str = "zero", + value: float = 0.0, + ): + if mode != "reflect": + return mx.pad( + hidden_states, paddings, mode="constant", constant_values=value + ) + + length = hidden_states.shape[1] + prefix = hidden_states[:, 1 : paddings[0] + 1][:, ::-1] + suffix = hidden_states[:, max(length - (paddings[1] + 1), 0) : -1][:, ::-1] + return mx.concatenate([prefix, hidden_states, suffix], axis=1) + + def __call__(self, hidden_states): + extra_padding = self._get_extra_padding_for_conv1d(hidden_states) + + if self.causal: + # Left padding for causal + hidden_states = self._pad1d( + hidden_states, (self.padding_total, extra_padding), mode=self.pad_mode + ) + else: + # Asymmetric padding required for odd strides + padding_right = self.padding_total // 2 + padding_left = self.padding_total - padding_right + hidden_states = self._pad1d( + hidden_states, + (padding_left, padding_right + extra_padding), + mode=self.pad_mode, + ) + + hidden_states = self.conv(hidden_states) + + if self.norm_type == "time_group_norm": + hidden_states = self.norm(hidden_states) + + return hidden_states + + +class EncodecConvTranspose1d(nn.Module): + """ConvTranspose1d with asymmetric or causal padding and normalization.""" + + def __init__( + self, + config, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + ): + super().__init__() + self.causal = config.use_causal_conv + self.trim_right_ratio = config.trim_right_ratio + self.norm_type = config.norm_type + self.conv = nn.ConvTranspose1d(in_channels, out_channels, kernel_size, stride) + if config.norm_type == "time_group_norm": + self.norm = nn.GroupNorm(1, out_channels, pytorch_compatible=True) + self.padding_total = kernel_size - stride + + def __call__(self, hidden_states): + hidden_states = self.conv(hidden_states) + + if self.norm_type == "time_group_norm": + hidden_states = self.norm(hidden_states) + + if self.causal: + padding_right = math.ceil(self.padding_total * self.trim_right_ratio) + else: + padding_right = self.padding_total // 2 + + padding_left = self.padding_total - padding_right + + end = hidden_states.shape[1] - padding_right + hidden_states = hidden_states[:, padding_left:end, :] + return hidden_states + + +class EncodecLSTM(nn.Module): + def __init__(self, config, dimension): + super().__init__() + self.lstm = [LSTM(dimension, dimension) for _ in range(config.num_lstm_layers)] + + def __call__(self, hidden_states): + h = hidden_states + for lstm in self.lstm: + h = lstm(h) + return h + hidden_states + + +class EncodecResnetBlock(nn.Module): + """ + Residual block from SEANet model as used by EnCodec. + """ + + def __init__(self, config, dim: int, dilations: List[int]): + super().__init__() + kernel_sizes = (config.residual_kernel_size, 1) + if len(kernel_sizes) != len(dilations): + raise ValueError("Number of kernel sizes should match number of dilations") + + hidden = dim // config.compress + block = [] + for i, (kernel_size, dilation) in enumerate(zip(kernel_sizes, dilations)): + in_chs = dim if i == 0 else hidden + out_chs = dim if i == len(kernel_sizes) - 1 else hidden + block += [nn.ELU()] + block += [ + EncodecConv1d(config, in_chs, out_chs, kernel_size, dilation=dilation) + ] + self.block = block + + if getattr(config, "use_conv_shortcut", True): + self.shortcut = EncodecConv1d(config, dim, dim, kernel_size=1) + else: + self.shortcut = nn.Identity() + + def __call__(self, hidden_states): + residual = hidden_states + for layer in self.block: + hidden_states = layer(hidden_states) + + return self.shortcut(residual) + hidden_states + + +class EncodecEncoder(nn.Module): + """SEANet encoder as used by EnCodec.""" + + def __init__(self, config): + super().__init__() + model = [ + EncodecConv1d( + config, config.audio_channels, config.num_filters, config.kernel_size + ) + ] + scaling = 1 + + for ratio in reversed(config.upsampling_ratios): + current_scale = scaling * config.num_filters + for j in range(config.num_residual_layers): + model += [ + EncodecResnetBlock( + config, current_scale, [config.dilation_growth_rate**j, 1] + ) + ] + model += [nn.ELU()] + model += [ + EncodecConv1d( + config, + current_scale, + current_scale * 2, + kernel_size=ratio * 2, + stride=ratio, + ) + ] + scaling *= 2 + + model += [EncodecLSTM(config, scaling * config.num_filters)] + model += [nn.ELU()] + model += [ + EncodecConv1d( + config, + scaling * config.num_filters, + config.hidden_size, + config.last_kernel_size, + ) + ] + + self.layers = model + + def __call__(self, hidden_states): + for layer in self.layers: + hidden_states = layer(hidden_states) + return hidden_states + + +class EncodecDecoder(nn.Module): + """SEANet decoder as used by EnCodec.""" + + def __init__(self, config): + super().__init__() + scaling = int(2 ** len(config.upsampling_ratios)) + model = [ + EncodecConv1d( + config, + config.hidden_size, + scaling * config.num_filters, + config.kernel_size, + ) + ] + + model += [EncodecLSTM(config, scaling * config.num_filters)] + + for ratio in config.upsampling_ratios: + current_scale = scaling * config.num_filters + model += [nn.ELU()] + model += [ + EncodecConvTranspose1d( + config, + current_scale, + current_scale // 2, + kernel_size=ratio * 2, + stride=ratio, + ) + ] + for j in range(config.num_residual_layers): + model += [ + EncodecResnetBlock( + config, current_scale // 2, (config.dilation_growth_rate**j, 1) + ) + ] + scaling //= 2 + + model += [nn.ELU()] + model += [ + EncodecConv1d( + config, + config.num_filters, + config.audio_channels, + config.last_kernel_size, + ) + ] + self.layers = model + + def __call__(self, hidden_states): + for layer in self.layers: + hidden_states = layer(hidden_states) + return hidden_states + + +class EncodecEuclideanCodebook(nn.Module): + """Codebook with Euclidean distance.""" + + def __init__(self, config): + super().__init__() + self.embed = mx.zeros((config.codebook_size, config.codebook_dim)) + + def quantize(self, hidden_states): + embed = self.embed.T + scaled_states = hidden_states.square().sum(axis=1, keepdims=True) + dist = -( + scaled_states + - 2 * hidden_states @ embed + + embed.square().sum(axis=0, keepdims=True) + ) + embed_ind = dist.argmax(axis=-1) + return embed_ind + + def encode(self, hidden_states): + shape = hidden_states.shape + hidden_states = hidden_states.reshape((-1, shape[-1])) + embed_ind = self.quantize(hidden_states) + embed_ind = embed_ind.reshape(*shape[:-1]) + return embed_ind + + def decode(self, embed_ind): + return self.embed[embed_ind] + + +class EncodecVectorQuantization(nn.Module): + """ + Vector quantization implementation. Currently supports only euclidean distance. + """ + + def __init__(self, config): + super().__init__() + self.codebook = EncodecEuclideanCodebook(config) + + def encode(self, hidden_states): + return self.codebook.encode(hidden_states) + + def decode(self, embed_ind): + return self.codebook.decode(embed_ind) + + +class EncodecResidualVectorQuantizer(nn.Module): + """Residual Vector Quantizer.""" + + def __init__(self, config): + super().__init__() + self.codebook_size = config.codebook_size + + hop_length = np.prod(config.upsampling_ratios) + self.frame_rate = math.ceil(config.sampling_rate / hop_length) + self.num_quantizers = int( + 1000 * config.target_bandwidths[-1] // (self.frame_rate * 10) + ) + self.layers = [ + EncodecVectorQuantization(config) for _ in range(self.num_quantizers) + ] + + def get_num_quantizers_for_bandwidth( + self, bandwidth: Optional[float] = None + ) -> int: + """Return num_quantizers based on specified target bandwidth.""" + bw_per_q = math.log2(self.codebook_size) * self.frame_rate + num_quantizers = self.num_quantizers + if bandwidth is not None and bandwidth > 0.0: + num_quantizers = int(max(1, math.floor(bandwidth * 1000 / bw_per_q))) + return num_quantizers + + def encode( + self, embeddings: mx.array, bandwidth: Optional[float] = None + ) -> mx.array: + """ + Encode a given input array with the specified frame rate at the given + bandwidth. The RVQ encode method sets the appropriate number of + quantizers to use and returns indices for each quantizer. + """ + num_quantizers = self.get_num_quantizers_for_bandwidth(bandwidth) + residual = embeddings + all_indices = [] + for layer in self.layers[:num_quantizers]: + indices = layer.encode(residual) + quantized = layer.decode(indices) + residual = residual - quantized + all_indices.append(indices) + out_indices = mx.stack(all_indices, axis=1) + return out_indices + + def decode(self, codes: mx.array) -> mx.array: + """Decode the given codes to the quantized representation.""" + quantized_out = None + for i, indices in enumerate(codes.split(codes.shape[1], axis=1)): + layer = self.layers[i] + quantized = layer.decode(indices.squeeze(1)) + if quantized_out is None: + quantized_out = quantized + else: + quantized_out = quantized + quantized_out + return quantized_out + + +class EncodecModel(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.encoder = EncodecEncoder(config) + self.decoder = EncodecDecoder(config) + self.quantizer = EncodecResidualVectorQuantizer(config) + + def _encode_frame( + self, input_values: mx.array, bandwidth: float, padding_mask: mx.array + ) -> Tuple[mx.array, Optional[mx.array]]: + """ + Encodes the given input using the underlying VQVAE. + """ + length = input_values.shape[1] + duration = length / self.config.sampling_rate + + if ( + self.config.chunk_length_s is not None + and duration > 1e-5 + self.config.chunk_length_s + ): + raise RuntimeError( + f"Duration of frame ({duration}) is longer than chunk {self.config.chunk_length_s}" + ) + + scale = None + if self.config.normalize: + # if the padding is non zero + input_values = input_values * padding_mask[..., None] + mono = mx.sum(input_values, axis=2, keepdims=True) / input_values.shape[2] + scale = mono.square().mean(axis=1, keepdims=True).sqrt() + 1e-8 + input_values = input_values / scale + + embeddings = self.encoder(input_values) + codes = self.quantizer.encode(embeddings, bandwidth) + return codes, scale + + def encode( + self, + input_values: mx.array, + padding_mask: mx.array = None, + bandwidth: Optional[float] = None, + ) -> Tuple[mx.array, Optional[mx.array]]: + """ + Encodes the input audio waveform into discrete codes. + + Args: + input_values (mx.array): The input audio waveform with shape + ``(batch_size, channels, sequence_length)``. + padding_mask (mx.array): Padding mask used to pad the ``input_values``. + bandwidth (float, optional): The target bandwidth. Must be one of + ``config.target_bandwidths``. If ``None``, uses the smallest + possible bandwidth. bandwidth is represented as a thousandth of + what it is, e.g. 6kbps bandwidth is represented as bandwidth == 6.0 + + Returns: + A list of frames containing the discrete encoded codes for the + input audio waveform, along with rescaling factors for each chunk + when ``config.normalize==True``. Each frame is a tuple ``(codebook, + scale)``, with ``codebook`` of shape ``(batch_size, num_codebooks, + frames)``. + """ + + if bandwidth is None: + bandwidth = self.config.target_bandwidths[0] + if bandwidth not in self.config.target_bandwidths: + raise ValueError( + f"This model doesn't support the bandwidth {bandwidth}. " + f"Select one of {self.config.target_bandwidths}." + ) + + _, input_length, channels = input_values.shape + + if channels < 1 or channels > 2: + raise ValueError( + f"Number of audio channels must be 1 or 2, but got {channels}" + ) + + chunk_length = self.chunk_length + if chunk_length is None: + chunk_length = input_length + stride = input_length + else: + stride = self.chunk_stride + + if padding_mask is None: + padding_mask = mx.ones(input_values.shape[:2], dtype=mx.bool_) + encoded_frames = [] + scales = [] + + step = chunk_length - stride + if (input_length % stride) != step: + raise ValueError( + "The input length is not properly padded for batched chunked " + "encoding. Make sure to pad the input correctly." + ) + + for offset in range(0, input_length - step, stride): + mask = padding_mask[:, offset : offset + chunk_length].astype(mx.bool_) + frame = input_values[:, offset : offset + chunk_length] + encoded_frame, scale = self._encode_frame(frame, bandwidth, mask) + encoded_frames.append(encoded_frame) + scales.append(scale) + + encoded_frames = mx.stack(encoded_frames) + + return (encoded_frames, scales) + + @staticmethod + def _linear_overlap_add(frames: List[mx.array], stride: int): + if len(frames) == 0: + raise ValueError("`frames` cannot be an empty list.") + + dtype = frames[0].dtype + N, frame_length, C = frames[0].shape + total_size = stride * (len(frames) - 1) + frames[-1].shape[1] + + time_vec = mx.linspace(0, 1, frame_length + 2, dtype=dtype)[1:-1] + weight = 0.5 - (time_vec - 0.5).abs() + + weight = weight[:, None] + sum_weight = mx.zeros((total_size, 1), dtype=dtype) + out = mx.zeros((N, total_size, C), dtype=dtype) + offset = 0 + + for frame in frames: + frame_length = frame.shape[1] + out[:, offset : offset + frame_length] += weight[:frame_length] * frame + sum_weight[offset : offset + frame_length] += weight[:frame_length] + offset += stride + + return out / sum_weight + + def _decode_frame( + self, codes: mx.array, scale: Optional[mx.array] = None + ) -> mx.array: + embeddings = self.quantizer.decode(codes) + outputs = self.decoder(embeddings) + if scale is not None: + outputs = outputs * scale + return outputs + + @property + def channels(self): + return self.config.audio_channels + + @property + def sampling_rate(self): + return self.config.sampling_rate + + @property + def chunk_length(self): + if self.config.chunk_length_s is None: + return None + else: + return int(self.config.chunk_length_s * self.config.sampling_rate) + + @property + def chunk_stride(self): + if self.config.chunk_length_s is None or self.config.overlap is None: + return None + else: + return max(1, int((1.0 - self.config.overlap) * self.chunk_length)) + + def decode( + self, + audio_codes: mx.array, + audio_scales: Union[mx.array, List[mx.array]], + padding_mask: Optional[mx.array] = None, + ) -> Tuple[mx.array, mx.array]: + """ + Decodes the given frames into an output audio waveform. + + Note that the output might be a bit bigger than the input. In that + case, any extra steps at the end should be trimmed. + + Args: + audio_codes (mx.array): Discret code embeddings of shape + ``(batch_size, nb_chunks, chunk_length)``. + audio_scales (mx.array): Scaling factor for each input. + padding_mask (mx.array): Padding mask. + """ + chunk_length = self.chunk_length + if chunk_length is None: + if audio_codes.shape[1] != 1: + raise ValueError(f"Expected one frame, got {len(audio_codes)}") + audio_values = self._decode_frame(audio_codes[:, 0], audio_scales[0]) + else: + decoded_frames = [] + + for frame, scale in zip(audio_codes, audio_scales): + frames = self._decode_frame(frame, scale) + decoded_frames.append(frames) + + audio_values = self._linear_overlap_add( + decoded_frames, self.chunk_stride or 1 + ) + + # truncate based on padding mask + if padding_mask is not None and padding_mask.shape[1] < audio_values.shape[1]: + audio_values = audio_values[:, : padding_mask.shape[1]] + return audio_values diff --git a/encodec/example.py b/encodec/example.py new file mode 100644 index 00000000..97b311a1 --- /dev/null +++ b/encodec/example.py @@ -0,0 +1,37 @@ +# Copyright © 2024 Apple Inc. + +import mlx.core as mx +from utils import load, load_audio, save_audio + +# Load the 48 KHz model and preprocessor. +model, processor = load("mlx-community/encodec-48khz-float32") + +# Load an audio file +audio = load_audio("/path/to/audio", model.sampling_rate, model.channels) + +# Preprocess the audio (this can also be a list of arrays for batched +# processing). +feats, mask = processor(audio) + + +# Encode at the given bandwidth. A lower bandwidth results in more +# compression but lower reconstruction quality. +@mx.compile +def encode(feats, mask): + return model.encode(feats, mask, bandwidth=3) + + +# Decode to reconstruct the audio +@mx.compile +def decode(codes, scales, mask): + return model.decode(codes, scales, mask) + + +codes, scales = encode(feats, mask) +reconstructed = decode(codes, scales, mask) + +# Trim any padding: +reconstructed = reconstructed[0, : len(audio)] + +# Save the audio as a wave file +save_audio("reconstructed.wav", reconstructed, model.sampling_rate) diff --git a/encodec/requirements.txt b/encodec/requirements.txt new file mode 100644 index 00000000..de5cc646 --- /dev/null +++ b/encodec/requirements.txt @@ -0,0 +1,3 @@ +mlx>=0.18 +numpy +huggingface_hub diff --git a/encodec/test.py b/encodec/test.py new file mode 100644 index 00000000..ffc23505 --- /dev/null +++ b/encodec/test.py @@ -0,0 +1,66 @@ +# Copyright © 2024 Apple Inc. + +import mlx.core as mx +import numpy as np +import torch +from datasets import Audio, load_dataset +from transformers import AutoProcessor, EncodecModel +from utils import load, load_audio, preprocess_audio + + +def compare_processors(): + np.random.seed(0) + audio_length = 95500 + audio = np.random.uniform(size=(2, audio_length)).astype(np.float32) + + processor = AutoProcessor.from_pretrained("facebook/encodec_48khz") + + pt_inputs = processor( + raw_audio=audio, sampling_rate=processor.sampling_rate, return_tensors="pt" + ) + mx_inputs = preprocess_audio( + mx.array(audio).T, + processor.sampling_rate, + processor.chunk_length, + processor.chunk_stride, + ) + + assert np.array_equal(pt_inputs["input_values"], mx_inputs[0].moveaxis(2, 1)) + assert np.array_equal(pt_inputs["padding_mask"], mx_inputs[1]) + + +def compare_models(): + pt_model = EncodecModel.from_pretrained("facebook/encodec_48khz") + mx_model, _ = load("mlx-community/encodec-48khz-float32") + + np.random.seed(0) + audio_length = 190560 + audio = np.random.uniform(size=(1, audio_length, 2)).astype(np.float32) + mask = np.ones((1, audio_length), dtype=np.int32) + pt_encoded = pt_model.encode( + torch.tensor(audio).moveaxis(2, 1), torch.tensor(mask)[None] + ) + mx_encoded = mx_model.encode(mx.array(audio), mx.array(mask)) + pt_codes = pt_encoded.audio_codes.numpy() + mx_codes = mx_encoded[0] + assert np.array_equal(pt_codes, mx_codes), "Encoding codes mismatch" + + for mx_scale, pt_scale in zip(mx_encoded[1], pt_encoded.audio_scales): + if mx_scale is not None: + pt_scale = pt_scale.numpy() + assert np.allclose(pt_scale, mx_scale, atol=1e-3, rtol=1e-4) + + pt_audio = pt_model.decode( + pt_encoded.audio_codes, pt_encoded.audio_scales, torch.tensor(mask)[None] + ) + pt_audio = pt_audio[0].squeeze().T.detach().numpy() + mx_audio = mx_model.decode(*mx_encoded, mx.array(mask)) + mx_audio = mx_audio.squeeze() + assert np.allclose( + pt_audio, mx_audio, atol=1e-4, rtol=1e-4 + ), "Decoding audio mismatch" + + +if __name__ == "__main__": + compare_processors() + compare_models() diff --git a/encodec/utils.py b/encodec/utils.py new file mode 100644 index 00000000..18b3f063 --- /dev/null +++ b/encodec/utils.py @@ -0,0 +1,129 @@ +# Copyright © 2024 Apple Inc. + +import functools +import json +from pathlib import Path +from types import SimpleNamespace +from typing import List, Optional, Union + +import mlx.core as mx +import numpy as np +from huggingface_hub import snapshot_download + +import encodec + + +def save_audio(file: str, audio: mx.array, sampling_rate: int): + """ + Save audio to a wave (.wav) file. + """ + from scipy.io.wavfile import write + + audio = (audio * 32767).astype(mx.int16) + write(file, sampling_rate, np.array(audio)) + + +def load_audio(file: str, sampling_rate: int, channels: int): + """ + Read audio into an mx.array, resampling if necessary. + + Args: + file (str): The audio file to open. + sampling_rate (int): The sample rate to resample the audio at if needed. + channels (int): The number of audio channels. + + Returns: + An mx.array containing the audio waveform in float32. + """ + from subprocess import CalledProcessError, run + + # This launches a subprocess to decode audio while down-mixing + # and resampling as necessary. Requires the ffmpeg CLI in PATH. + # fmt: off + cmd = [ + "ffmpeg", + "-nostdin", + "-threads", "0", + "-i", file, + "-f", "s16le", + "-ac", str(channels), + "-acodec", "pcm_s16le", + "-ar", str(sampling_rate), + "-" + ] + # fmt: on + try: + out = run(cmd, capture_output=True, check=True).stdout + except CalledProcessError as e: + raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e + + out = mx.array(np.frombuffer(out, np.int16)) + return out.reshape(-1, channels).astype(mx.float32) / 32767.0 + + +def preprocess_audio( + raw_audio: Union[mx.array, List[mx.array]], + sampling_rate: int = 24000, + chunk_length: Optional[int] = None, + chunk_stride: Optional[int] = None, +): + r""" + Prepare inputs for the EnCodec model. + + Args: + raw_audio (mx.array or List[mx.array]): The sequence or batch of + sequences to be processed. + sampling_rate (int): The sampling rate at which the audio waveform + should be digitalized. + chunk_length (int, optional): The model's chunk length. + chunk_stride (int, optional): The model's chunk stride. + """ + if not isinstance(raw_audio, list): + raw_audio = [raw_audio] + + raw_audio = [x[..., None] if x.ndim == 1 else x for x in raw_audio] + + max_length = max(array.shape[0] for array in raw_audio) + if chunk_length is not None: + max_length += chunk_length - (max_length % chunk_stride) + + inputs = [] + masks = [] + for x in raw_audio: + length = x.shape[0] + mask = mx.ones((length,), dtype=mx.bool_) + difference = max_length - length + if difference > 0: + mask = mx.pad(mask, (0, difference)) + x = mx.pad(x, ((0, difference), (0, 0))) + inputs.append(x) + masks.append(mask) + return mx.stack(inputs), mx.stack(masks) + + +def load(path_or_repo): + """ + Load the model and audo preprocessor. + """ + path = Path(path_or_repo) + if not path.exists(): + path = Path( + snapshot_download( + repo_id=path_or_repo, + allow_patterns=["*.json", "*.safetensors", "*.model"], + ) + ) + + with open(path / "config.json", "r") as f: + config = SimpleNamespace(**json.load(f)) + + model = encodec.EncodecModel(config) + model.load_weights(str(path / "model.safetensors")) + processor = functools.partial( + preprocess_audio, + sampling_rate=config.sampling_rate, + chunk_length=model.chunk_length, + chunk_stride=model.chunk_stride, + ) + mx.eval(model) + return model, processor