diff --git a/benchmarks/python/batch_matmul_bench.py b/benchmarks/python/batch_matmul_bench.py index 84bf95fb0..bf0f74b82 100644 --- a/benchmarks/python/batch_matmul_bench.py +++ b/benchmarks/python/batch_matmul_bench.py @@ -30,7 +30,7 @@ def time_batch_matmul(): time_fn(batch_vjp_second) -def time_unbatch_matmul(key): +def time_unbatch_matmul(): mx.random.seed(3) a = mx.random.uniform(shape=(B * T, D)) b = mx.random.uniform(shape=(D, D))