initial encodec

This commit is contained in:
Awni Hannun
2024-09-13 14:22:39 -07:00
parent f530f56df2
commit 2d82579bc8
4 changed files with 935 additions and 0 deletions

201
encodec/convert.py Normal file
View File

@@ -0,0 +1,201 @@
# Copyright © 2024 Apple Inc.
import json
from pathlib import Path
from textwrap import dedent
from types import SimpleNamespace
from typing import Any, Dict, Union
import fire
import mlx.core as mx
import mlx.nn as nn
from huggingface_hub import snapshot_download
from mlx.utils import tree_flatten
import encodec
def fetch_from_hub(hf_repo: str) -> Path:
model_path = Path(
snapshot_download(
repo_id=hf_repo,
allow_patterns=["*.json", "*.safetensors"],
)
)
return model_path
def upload_to_hub(path: str, upload_repo: str, hf_path: str):
"""
Uploads the model to Hugging Face hub.
Args:
path (str): Local path to the model.
upload_repo (str): Name of the HF repo to upload to.
hf_path (str): Path to the original Hugging Face model.
"""
import os
from huggingface_hub import HfApi, ModelCard, logging
content = dedent(
f"""
---
language: en
license: other
library: mlx
tags:
- mlx
---
The Model [{upload_repo}](https://huggingface.co/{upload_repo}) was
converted to MLX format from
[{hf_path}](https://huggingface.co/{hf_path}).
This model is intended to be used with the [Encodec MLX
Example](TODO).
"""
)
card = ModelCard(content)
card.save(os.path.join(path, "README.md"))
logging.set_verbosity_info()
api = HfApi()
api.create_repo(repo_id=upload_repo, exist_ok=True)
api.upload_folder(
folder_path=path,
repo_id=upload_repo,
repo_type="model",
multi_commits=True,
multi_commits_verbose=True,
)
print(f"Upload successful, go to https://huggingface.co/{upload_repo} for details.")
def save_weights(save_path: Union[str, Path], weights: Dict[str, Any]) -> None:
if isinstance(save_path, str):
save_path = Path(save_path)
save_path.mkdir(parents=True, exist_ok=True)
total_size = sum(v.nbytes for v in weights.values())
index_data = {"metadata": {"total_size": total_size}, "weight_map": {}}
mx.save_safetensors(
str(save_path / "model.safetensors"), weights, metadata={"format": "mlx"}
)
for weight_name in weights.keys():
index_data["weight_map"][weight_name] = "model.safetensors"
index_data["weight_map"] = {
k: index_data["weight_map"][k] for k in sorted(index_data["weight_map"])
}
with open(save_path / "model.safetensors.index.json", "w") as f:
json.dump(index_data, f, indent=4)
def save_config(
config: dict,
config_path: Union[str, Path],
) -> None:
"""Save the model configuration to the ``config_path``.
The final configuration will be sorted before saving for better readability.
Args:
config (dict): The model configuration.
config_path (Union[str, Path]): Model configuration file path.
"""
# Clean unused keys
config.pop("_name_or_path", None)
# sort the config for better readability
config = dict(sorted(config.items()))
# write the updated config to the config_path (if provided)
with open(config_path, "w") as fid:
json.dump(config, fid, indent=4)
def convert(
upload: bool = False,
model: str = "24khz",
dtype: str = None,
):
"""
Convert CogVideoX models to MLX
Args:
model (str): One of "24khz", "32khz", or "48khz".
upload (bool): Upload to Hugging Face.
dtype (str): One of "float32", "float16", or "bfloat16".
"""
hf_repo = f"facebook/encodec_{model}"
mlx_repo = f"mlx-community/encodec-{model}-mlx"
path = fetch_from_hub(hf_repo)
save_path = Path("mlx_models")
weights = mx.load(str(Path(path) / "model.safetensors"))
with open(path / "config.json", "r") as fid:
config = SimpleNamespace(**json.load(fid))
model = encodec.EncodecModel(config)
new_weights = {}
for k, v in weights.items():
basename, pname = k.rsplit(".", 1)
if pname == "weight_v":
g = weights[basename + ".weight_g"]
v = g * (v / mx.linalg.norm(v, axis=(1, 2), keepdims=True))
k = basename + ".weight"
elif pname in ["weight_g", "embed_avg", "cluster_size", "inited"]:
continue
elif "lstm" in basename:
w_or_b, ih_or_hh, ln = pname.split("_")
if w_or_b == "weight":
new_pname = "Wx" if ih_or_hh == "ih" else "Wh"
v = v.T
elif w_or_b == "bias" and ih_or_hh == "ih":
continue
else:
v = v + weights[k.replace("_hh_", "_ih_")]
new_pname = "bias"
k = basename + "." + ln[1:] + "." + new_pname
if "conv.weight" in k:
# Possibly a transposed conv which has a different order
if "decoder" in k:
ln = int(k.split(".")[2])
if "conv" in model.decoder.layers[ln] and isinstance(
model.decoder.layers[ln].conv, nn.ConvTranspose1d
):
v = mx.moveaxis(v, 0, 2)
else:
v = mx.moveaxis(v, 1, 2)
else:
v = mx.moveaxis(v, 1, 2)
new_weights[k] = v
weights = new_weights
model.load_weights(list(weights.items()))
if dtype is not None:
t = getattr(mx, dtype)
weights = {k: v.astype(t) for k, v in weights.items()}
if isinstance(save_path, str):
save_path = Path(save_path)
save_weights(save_path, weights)
save_config(vars(config), config_path=save_path / "config.json")
if upload:
upload_to_hub(save_path, mlx_repo, hf_path)
if __name__ == "__main__":
fire.Fire(convert)

624
encodec/encodec.py Normal file
View File

@@ -0,0 +1,624 @@
import math
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
import numpy as np
class EncodecConv1d(nn.Module):
"""Conv1d with asymmetric or causal padding and normalization."""
def __init__(
self,
config,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int = 1,
dilation: int = 1,
):
super().__init__()
self.causal = config.use_causal_conv
self.pad_mode = config.pad_mode
self.norm_type = config.norm_type
self.conv = nn.Conv1d(
in_channels, out_channels, kernel_size, stride, dilation=dilation
)
if self.norm_type == "time_group_norm":
self.norm = nn.GroupNorm(1, out_channels)
self.stride = stride
# Effective kernel size with dilations.
self.kernel_size = (kernel_size - 1) * dilation + 1
self.padding_total = kernel_size - stride
def _get_extra_padding_for_conv1d(
self,
hidden_states: mx.array,
) -> mx.array:
length = hidden_states.shape[1]
n_frames = (length - self.kernel_size + self.padding_total) / self.stride + 1
n_frames = int(math.ceil(n_frames)) - 1
ideal_length = n_frames * self.stride + self.kernel_size - self.padding_total
return ideal_length - length
def _pad1d(
self,
hidden_states: mx.array,
paddings: Tuple[int, int],
mode: str = "zero",
value: float = 0.0,
):
if mode != "reflect":
return mx.pad(
hidden_states, paddings, mode="constant", constant_values=value
)
length = hidden_states.shape[1]
prefix = hidden_states[:, 1 : paddings[0] + 1][:, ::-1]
suffix = hidden_states[:, max(length - (paddings[1] + 1), 0) : -1][:, ::-1]
return mx.concatenate([prefix, hidden_states, suffix], axis=1)
def __call__(self, hidden_states):
extra_padding = self._get_extra_padding_for_conv1d(hidden_states)
if self.causal:
# Left padding for causal
hidden_states = self._pad1d(
hidden_states, (self.padding_total, extra_padding), mode=self.pad_mode
)
else:
# Asymmetric padding required for odd strides
padding_right = self.padding_total // 2
padding_left = self.padding_total - padding_right
hidden_states = self._pad1d(
hidden_states,
(padding_left, padding_right + extra_padding),
mode=self.pad_mode,
)
hidden_states = self.conv(hidden_states)
if self.norm_type == "time_group_norm":
hidden_states = self.norm(hidden_states)
return hidden_states
class EncodecConvTranspose1d(nn.Module):
"""ConvTranspose1d with asymmetric or causal padding and normalization."""
def __init__(
self,
config,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int = 1,
):
super().__init__()
self.causal = config.use_causal_conv
self.trim_right_ratio = config.trim_right_ratio
self.norm_type = config.norm_type
self.conv = nn.ConvTranspose1d(in_channels, out_channels, kernel_size, stride)
if config.norm_type == "time_group_norm":
self.norm = nn.GroupNorm(1, out_channels, pytorch_compatible=True)
self.padding_total = kernel_size - stride
def __call__(self, hidden_states):
hidden_states = self.conv(hidden_states)
if self.norm_type == "time_group_norm":
hidden_states = self.norm(hidden_states)
# We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be
# removed at the very end, when keeping only the right length for the output,
# as removing it here would require also passing the length at the matching layer
# in the encoder.
if self.causal:
# Trim the padding on the right according to the specified ratio
# if trim_right_ratio = 1.0, trim everything from right
padding_right = math.ceil(self.padding_total * self.trim_right_ratio)
else:
# Asymmetric padding required for odd strides
padding_right = self.padding_total // 2
padding_left = self.padding_total - padding_right
# unpad
end = hidden_states.shape[1] - padding_right
hidden_states = hidden_states[:, padding_left:end, :]
return hidden_states
class EncodecLSTM(nn.Module):
def __init__(self, config, dimension):
super().__init__()
self.lstm = [
nn.LSTM(dimension, dimension) for _ in range(config.num_lstm_layers)
]
def __call__(self, hidden_states):
h = hidden_states
for lstm in self.lstm:
h = lstm(h)[0]
return h + hidden_states
class EncodecResnetBlock(nn.Module):
"""
Residual block from SEANet model as used by EnCodec.
"""
def __init__(self, config, dim: int, dilations: List[int]):
super().__init__()
kernel_sizes = (config.residual_kernel_size, 1)
if len(kernel_sizes) != len(dilations):
raise ValueError("Number of kernel sizes should match number of dilations")
hidden = dim // config.compress
block = []
for i, (kernel_size, dilation) in enumerate(zip(kernel_sizes, dilations)):
in_chs = dim if i == 0 else hidden
out_chs = dim if i == len(kernel_sizes) - 1 else hidden
block += [nn.ELU()]
block += [
EncodecConv1d(config, in_chs, out_chs, kernel_size, dilation=dilation)
]
self.block = block
if getattr(config, "use_conv_shortcut", True):
self.shortcut = EncodecConv1d(config, dim, dim, kernel_size=1)
else:
self.shortcut = nn.Identity()
def __call__(self, hidden_states):
residual = hidden_states
for layer in self.block:
hidden_states = layer(hidden_states)
return self.shortcut(residual) + hidden_states
class EncodecEncoder(nn.Module):
"""SEANet encoder as used by EnCodec."""
def __init__(self, config):
super().__init__()
model = [
EncodecConv1d(
config, config.audio_channels, config.num_filters, config.kernel_size
)
]
scaling = 1
for ratio in reversed(config.upsampling_ratios):
current_scale = scaling * config.num_filters
for j in range(config.num_residual_layers):
model += [
EncodecResnetBlock(
config, current_scale, [config.dilation_growth_rate**j, 1]
)
]
model += [nn.ELU()]
model += [
EncodecConv1d(
config,
current_scale,
current_scale * 2,
kernel_size=ratio * 2,
stride=ratio,
)
]
scaling *= 2
model += [EncodecLSTM(config, scaling * config.num_filters)]
model += [nn.ELU()]
model += [
EncodecConv1d(
config,
scaling * config.num_filters,
config.hidden_size,
config.last_kernel_size,
)
]
self.layers = model
def __call__(self, hidden_states):
for layer in self.layers:
hidden_states = layer(hidden_states)
return hidden_states
class EncodecDecoder(nn.Module):
"""SEANet decoder as used by EnCodec."""
def __init__(self, config):
super().__init__()
scaling = int(2 ** len(config.upsampling_ratios))
model = [
EncodecConv1d(
config,
config.hidden_size,
scaling * config.num_filters,
config.kernel_size,
)
]
model += [EncodecLSTM(config, scaling * config.num_filters)]
for ratio in config.upsampling_ratios:
current_scale = scaling * config.num_filters
model += [nn.ELU()]
model += [
EncodecConvTranspose1d(
config,
current_scale,
current_scale // 2,
kernel_size=ratio * 2,
stride=ratio,
)
]
for j in range(config.num_residual_layers):
model += [
EncodecResnetBlock(
config, current_scale // 2, (config.dilation_growth_rate**j, 1)
)
]
scaling //= 2
model += [nn.ELU()]
model += [
EncodecConv1d(
config,
config.num_filters,
config.audio_channels,
config.last_kernel_size,
)
]
self.layers = model
def __call__(self, hidden_states):
for layer in self.layers:
hidden_states = layer(hidden_states)
return hidden_states
class EncodecEuclideanCodebook(nn.Module):
"""Codebook with Euclidean distance."""
def __init__(self, config):
super().__init__()
self.embed = mx.zeros((config.codebook_size, config.codebook_dim))
def quantize(self, hidden_states):
embed = self.embed.T
scaled_states = hidden_states.square().sum(axis=1, keepdims=True)
dist = -(
scaled_states
- 2 * hidden_states @ embed
+ embed.square().sum(axis=0, keepdims=True)
)
embed_ind = dist.argmax(axis=-1)
return embed_ind
def encode(self, hidden_states):
shape = hidden_states.shape
# pre-process
hidden_states = hidden_states.reshape((-1, shape[-1]))
# quantize
embed_ind = self.quantize(hidden_states)
# post-process
embed_ind = embed_ind.reshape(*shape[:-1])
return embed_ind
def decode(self, embed_ind):
return self.embed[embed_ind]
class EncodecVectorQuantization(nn.Module):
"""
Vector quantization implementation. Currently supports only euclidean distance.
"""
def __init__(self, config):
super().__init__()
self.codebook = EncodecEuclideanCodebook(config)
def encode(self, hidden_states):
return self.codebook.encode(hidden_states)
def decode(self, embed_ind):
return self.codebook.decode(embed_ind)
class EncodecResidualVectorQuantizer(nn.Module):
"""Residual Vector Quantizer."""
def __init__(self, config):
super().__init__()
self.codebook_size = config.codebook_size
hop_length = np.prod(config.upsampling_ratios)
self.frame_rate = math.ceil(config.sampling_rate / hop_length)
self.num_quantizers = int(
1000 * config.target_bandwidths[-1] // (self.frame_rate * 10)
)
self.layers = [
EncodecVectorQuantization(config) for _ in range(self.num_quantizers)
]
def get_num_quantizers_for_bandwidth(
self, bandwidth: Optional[float] = None
) -> int:
"""Return num_quantizers based on specified target bandwidth."""
bw_per_q = math.log2(self.codebook_size) * self.frame_rate
num_quantizers = self.num_quantizers
if bandwidth is not None and bandwidth > 0.0:
num_quantizers = int(max(1, math.floor(bandwidth * 1000 / bw_per_q)))
return num_quantizers
def encode(
self, embeddings: mx.array, bandwidth: Optional[float] = None
) -> mx.array:
"""
Encode a given input tensor with the specified frame rate at the given
bandwidth. The RVQ encode method sets the appropriate number of
quantizers to use and returns indices for each quantizer.
"""
num_quantizers = self.get_num_quantizers_for_bandwidth(bandwidth)
residual = embeddings
all_indices = []
for layer in self.layers[:num_quantizers]:
indices = layer.encode(residual)
quantized = layer.decode(indices)
residual = residual - quantized
all_indices.append(indices)
out_indices = mx.stack(all_indices, axis=1)
return out_indices
def decode(self, codes: mx.array) -> mx.array:
"""Decode the given codes to the quantized representation."""
quantized_out = None
for i, indices in enumerate(codes.split(codes.shape[1], axis=1)):
layer = self.layers[i]
quantized = layer.decode(indices.squeeze(1))
if quantized_out is None:
quantized_out = quantized
else:
quantized_out = quantized + quantized_out
return quantized_out
class EncodecModel(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.encoder = EncodecEncoder(config)
self.decoder = EncodecDecoder(config)
self.quantizer = EncodecResidualVectorQuantizer(config)
self.bits_per_codebook = int(math.log2(self.config.codebook_size))
def _encode_frame(
self, input_values: mx.array, bandwidth: float, padding_mask: int
) -> Tuple[mx.array, Optional[mx.array]]:
"""
Encodes the given input using the underlying VQVAE.
If `config.normalize` is set to `True` the input is first normalized. The
padding mask is required to compute the correct scale.
"""
length = input_values.shape[1]
duration = length / self.config.sampling_rate
if (
self.config.chunk_length_s is not None
and duration > 1e-5 + self.config.chunk_length_s
):
raise RuntimeError(
f"Duration of frame ({duration}) is longer than chunk {self.config.chunk_length_s}"
)
scale = None
if self.config.normalize:
# if the padding is non zero
input_values = input_values * padding_mask
# TODO looks like an RMSNorm
mono = mx.sum(input_values, axis=1, keepdims=True) / input_values.shape[1]
scale = mono.square().mean(axis=-1, keepdims=True).sqrt() + 1e-8
input_values = input_values / scale
embeddings = self.encoder(input_values)
codes = self.quantizer.encode(embeddings, bandwidth)
return codes, scale
def encode(
self,
input_values: mx.array,
padding_mask: mx.array = None,
bandwidth: Optional[float] = None,
) -> Tuple[mx.array, Optional[mx.array]]:
"""
Encodes the input audio waveform into discrete codes.
Args:
input_values (mx.array): The input audio waveform with shape
``(batch_size, channels, sequence_length)``.
padding_mask (mx.array): Padding mask used to pad the `input_values`.
bandwidth (float, optional): The target bandwidth. Must be one of
``config.target_bandwidths``. If ``None``, uses the smallest
possible bandwidth. bandwidth is represented as a thousandth of
what it is, e.g. 6kbps bandwidth is represented as bandwidth == 6.0
Returns:
A list of frames containing the discrete encoded codes for the
input audio waveform, along with rescaling factors for each chunk
when ``normalize==True``. Each frame is a tuple ``(codebook,
scale)``, with ``codebook`` of shape ``[batch_size, num_codebooks,
frames]``.
"""
if bandwidth is None:
bandwidth = self.config.target_bandwidths[0]
if bandwidth not in self.config.target_bandwidths:
raise ValueError(
f"This model doesn't support the bandwidth {bandwidth}. "
f"Select one of {self.config.target_bandwidths}."
)
_, input_length, channels = input_values.shape
if channels < 1 or channels > 2:
raise ValueError(
f"Number of audio channels must be 1 or 2, but got {channels}"
)
chunk_length = self.chunk_length
if chunk_length is None:
chunk_length = input_length
stride = input_length
else:
stride = self.chunk_stride
if padding_mask is None:
padding_mask = mx.ones(input_values.shape(), dtype=mx.bool_)
encoded_frames = []
scales = []
step = chunk_length - stride
if (input_length % stride) - step != 0:
raise ValueError(
"The input length is not properly padded for batched chunked "
"decoding. Make sure to pad the input correctly."
)
for offset in range(0, input_length - step, stride):
mask = padding_mask[:, offset : offset + chunk_length].astype(mx.bool_)
frame = input_values[:, offset : offset + chunk_length]
encoded_frame, scale = self._encode_frame(frame, bandwidth, mask)
encoded_frames.append(encoded_frame)
scales.append(scale)
encoded_frames = mx.stack(encoded_frames, axis=1)
return (encoded_frames, scales)
@staticmethod
def _linear_overlap_add(frames: List[mx.array], stride: int):
# Generic overlap add, with linear fade-in/fade-out, supporting complex scenario
# e.g., more than 2 frames per position.
# The core idea is to use a weight function that is a triangle,
# with a maximum value at the middle of the chunk.
# We use this weighting when summing the frames, and divide by the sum of weights
# for each positions at the end. Thus:
# - if a frame is the only one to cover a position, the weighting is a no-op.
# - if 2 frames cover a position:
# ... ...
# / \/ \
# / /\ \
# S T , i.e. S offset of second frame starts, T end of first frame.
# Then the weight function for each one is: (t - S), (T - t), with `t` a given offset.
# After the final normalization, the weight of the second frame at position `t` is
# (t - S) / (t - S + (T - t)) = (t - S) / (T - S), which is exactly what we want.
#
# - if more than 2 frames overlap at a given point, we hope that by induction
# something sensible happens.
if len(frames) == 0:
raise ValueError("`frames` cannot be an empty list.")
device = frames[0].device
dtype = frames[0].dtype
shape = frames[0].shape[:-1]
total_size = stride * (len(frames) - 1) + frames[-1].shape[-1]
frame_length = frames[0].shape[-1]
time_vec = torch.linspace(0, 1, frame_length + 2, device=device, dtype=dtype)[
1:-1
]
weight = 0.5 - (time_vec - 0.5).abs()
sum_weight = torch.zeros(total_size, device=device, dtype=dtype)
out = torch.zeros(*shape, total_size, device=device, dtype=dtype)
offset: int = 0
for frame in frames:
frame_length = frame.shape[-1]
out[..., offset : offset + frame_length] += weight[:frame_length] * frame
sum_weight[offset : offset + frame_length] += weight[:frame_length]
offset += stride
if sum_weight.min() == 0:
raise ValueError(
f"`sum_weight` minimum element must be bigger than zero: {sum_weight}`"
)
return out / sum_weight
def _decode_frame(
self, codes: mx.array, scale: Optional[mx.array] = None
) -> mx.array:
embeddings = self.quantizer.decode(codes)
outputs = self.decoder(embeddings)
if scale is not None:
outputs = outputs * scale.view(-1, 1, 1)
return outputs
@property
def chunk_length(self):
if self.config.chunk_length_s is None:
return None
else:
return int(self.config.chunk_length_s * self.sampling_rate)
@property
def chunk_stride(self):
if self.config.chunk_length_s is None or self.config.overlap is None:
return None
else:
return max(1, int((1.0 - self.config.overlap) * self.chunk_length))
def decode(
self,
audio_codes: mx.array,
audio_scales: mx.array,
padding_mask: Optional[mx.array] = None,
) -> Tuple[mx.array, mx.array]:
"""
Decodes the given frames into an output audio waveform.
Note that the output might be a bit bigger than the input. In that case, any extra steps at the end can be
trimmed.
Args:
audio_codes (mx.array): Discret code embeddings of shape ``(batch_size, nb_chunks, chunk_length)``.
audio_scales (mx.array): Scaling factor for each ``audio_codes`` input.
padding_mask (mx.array): Padding mask used to pad the ``audio_codes``.
"""
chunk_length = self.chunk_length
if chunk_length is None:
if audio_codes.shape[1] != 1:
raise ValueError(f"Expected one frame, got {len(audio_codes)}")
audio_values = self._decode_frame(audio_codes[:, 0], audio_scales[0])
else:
decoded_frames = []
for frame, scale in zip(audio_codes, audio_scales):
frames = self._decode_frame(frame, scale)
decoded_frames.append(frames)
audio_values = self._linear_overlap_add(
decoded_frames, self.chunk_stride or 1
)
# truncate based on padding mask
if padding_mask is not None and padding_mask.shape[1] < audio_values.shape[1]:
audio_values = audio_values[:, : padding_mask.shape[1]]
return audio_values

45
encodec/hf_test.py Normal file
View File

@@ -0,0 +1,45 @@
import mlx.core as mx
import numpy as np
import torch
from transformers import AutoProcessor, EncodecModel
from utils import load, load_audio
# processor = AutoProcessor.from_pretrained("facebook/encodec_24khz")
# audio_sample = load_audio("ls_test.flac", processor.sampling_rate)
def compare_models():
pt_model = EncodecModel.from_pretrained("facebook/encodec_24khz")
mx_model = load("mlx_models")
np.random.seed(0)
audio = np.random.uniform(size=(1, 159960)).astype(np.float32)
mask = np.ones(audio.shape, dtype=np.int32)
pt_encoded = pt_model.encode(torch.tensor(audio)[None], torch.tensor(mask)[None])
mx_encoded = mx_model.encode(mx.array(audio)[..., None], mx.array(mask)[..., None])
pt_codes = pt_encoded.audio_codes.numpy()
mx_codes = mx_encoded[0]
assert np.array_equal(pt_codes, mx_codes), "Encoding codes mismatch"
for mx_scale, pt_scale in zip(mx_encoded[1], pt_encoded.audio_scales):
if mx_scale is not None:
pt_scales = pt_scale.numpy()
assert np.allclose(pt_scales, mx_scales, atol=1e-3, rtol=1e-4)
pt_audio = pt_model.decode(
pt_encoded.audio_codes, pt_encoded.audio_scales, torch.tensor(mask)[None]
)
pt_audio = pt_audio[0].squeeze().detach().numpy()
mx_audio = mx_model.decode(*mx_encoded, mx.array(mask)[..., None])
mx_audio = mx_audio.squeeze()
assert np.allclose(
pt_audio, mx_audio, atol=1e-5, rtol=1e-5
), "Decoding audio mismatch"
# pre-process the inputs
# inputs = processor(raw_audio=np.array(audio_sample), sampling_rate=processor.sampling_rate, return_tensors="pt")
# print(inputs["input_values"].shape)
# print(inputs["padding_mask"])
if __name__ == "__main__":
compare_models()

65
encodec/utils.py Normal file
View File

@@ -0,0 +1,65 @@
# Copyright © 2024 Apple Inc.
import json
from pathlib import Path
from subprocess import CalledProcessError, run
from types import SimpleNamespace
import mlx.core as mx
import numpy as np
from huggingface_hub import snapshot_download
import encodec
def load_audio(file: str, sr: int):
"""
Open an audio file and read as mono waveform, resampling as necessary
Args:
file (str): The audio file to open.
sr (int): The sample rate to resample the audio at, if needed.
Returns:
An mx.array containing the audio waveform in float32.
"""
# This launches a subprocess to decode audio while down-mixing
# and resampling as necessary. Requires the ffmpeg CLI in PATH.
# fmt: off
cmd = [
"ffmpeg",
"-nostdin",
"-threads", "0",
"-i", file,
"-f", "s16le",
"-ac", "1",
"-acodec", "pcm_s16le",
"-ar", str(sr),
"-"
]
# fmt: on
try:
out = run(cmd, capture_output=True, check=True).stdout
except CalledProcessError as e:
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
return mx.array(np.frombuffer(out, np.int16)).flatten().astype(mx.float32) / 32768.0
def load(path_or_repo):
path = Path(path_or_repo)
if not path.exists():
path = Path(
snapshot_download(
repo_id=path_or_repo,
allow_patterns=["*.json", "*.safetensors", "*.model"],
)
)
with open(path / "config.json", "r") as f:
config = SimpleNamespace(**json.load(f))
model = encodec.EncodecModel(config)
model.load_weights(str(path / "model.safetensors"))
return model