From 2fdc2462c3b7e261d69700075a50252cb2e8c83f Mon Sep 17 00:00:00 2001 From: Vijay Krish Date: Tue, 13 Feb 2024 17:47:41 -0800 Subject: [PATCH] Faster gather and scatter. (#682) Reduce unnecessary integer ops, especially since there kernels are integer bound. Increase number of iterations for benchmarks for better smoothing. Github Issue #506 Co-authored-by: Vijay Krishnamoorthy --- benchmarks/python/time_utils.py | 2 +- mlx/backend/metal/kernels/utils.h | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/benchmarks/python/time_utils.py b/benchmarks/python/time_utils.py index 32d22ed99..f10635ec9 100644 --- a/benchmarks/python/time_utils.py +++ b/benchmarks/python/time_utils.py @@ -28,7 +28,7 @@ def measure_runtime(fn, **kwargs): fn(**kwargs) tic = time.time() - iters = 10 + iters = 100 for _ in range(iters): fn(**kwargs) return (time.time() - tic) * 1000 / iters diff --git a/mlx/backend/metal/kernels/utils.h b/mlx/backend/metal/kernels/utils.h index f9d507cf2..8ef1127b6 100644 --- a/mlx/backend/metal/kernels/utils.h +++ b/mlx/backend/metal/kernels/utils.h @@ -71,7 +71,7 @@ inline size_t elem_to_loc( device const size_t* strides, int ndim) { size_t loc = 0; - for (int i = ndim - 1; i >= 0; --i) { + for (int i = ndim - 1; i >= 0 && elem > 0; --i) { loc += (elem % shape[i]) * strides[i]; elem /= shape[i]; } @@ -84,7 +84,7 @@ inline size_t elem_to_loc( constant const size_t* strides, int ndim) { size_t loc = 0; - for (int i = ndim - 1; i >= 0; --i) { + for (int i = ndim - 1; i >= 0 && elem > 0; --i) { loc += (elem % shape[i]) * strides[i]; elem /= shape[i]; }