diff --git a/benchmarks/python/time_utils.py b/benchmarks/python/time_utils.py index 32d22ed99..f10635ec9 100644 --- a/benchmarks/python/time_utils.py +++ b/benchmarks/python/time_utils.py @@ -28,7 +28,7 @@ def measure_runtime(fn, **kwargs): fn(**kwargs) tic = time.time() - iters = 10 + iters = 100 for _ in range(iters): fn(**kwargs) return (time.time() - tic) * 1000 / iters diff --git a/mlx/backend/metal/kernels/utils.h b/mlx/backend/metal/kernels/utils.h index f9d507cf2..8ef1127b6 100644 --- a/mlx/backend/metal/kernels/utils.h +++ b/mlx/backend/metal/kernels/utils.h @@ -71,7 +71,7 @@ inline size_t elem_to_loc( device const size_t* strides, int ndim) { size_t loc = 0; - for (int i = ndim - 1; i >= 0; --i) { + for (int i = ndim - 1; i >= 0 && elem > 0; --i) { loc += (elem % shape[i]) * strides[i]; elem /= shape[i]; } @@ -84,7 +84,7 @@ inline size_t elem_to_loc( constant const size_t* strides, int ndim) { size_t loc = 0; - for (int i = ndim - 1; i >= 0; --i) { + for (int i = ndim - 1; i >= 0 && elem > 0; --i) { loc += (elem % shape[i]) * strides[i]; elem /= shape[i]; }