mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-10-24 22:58:08 +08:00
Encodec (#991)
* initial encodec * works * nits * use fast group norm * fix for rnn layer * fix mlx version * use custom LSTM kernel * audio encodec * fix example, support batched inference * nits
This commit is contained in:
30
encodec/benchmarks/bench_mx.py
Normal file
30
encodec/benchmarks/bench_mx.py
Normal file
@@ -0,0 +1,30 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import time
|
||||
|
||||
import mlx.core as mx
|
||||
from utils import load
|
||||
|
||||
model, processor = load("mlx-community/encodec-48khz-float32")
|
||||
|
||||
audio = mx.random.uniform(shape=(288000, 2))
|
||||
feats, mask = processor(audio)
|
||||
mx.eval(model, feats, mask)
|
||||
|
||||
|
||||
@mx.compile
|
||||
def fun():
|
||||
codes, scales = model.encode(feats, mask, bandwidth=3)
|
||||
reconstructed = model.decode(codes, scales, mask)
|
||||
return reconstructed
|
||||
|
||||
|
||||
for _ in range(5):
|
||||
mx.eval(fun())
|
||||
|
||||
tic = time.time()
|
||||
for _ in range(10):
|
||||
mx.eval(fun())
|
||||
toc = time.time()
|
||||
ms = 1000 * (toc - tic) / 10
|
||||
print(f"Time per it: {ms:.3f}")
|
||||
Reference in New Issue
Block a user