Faster gather and scatter. (#682)

Reduce unnecessary integer ops, especially since
there kernels are integer bound.

Increase number of iterations for benchmarks for
better smoothing.

Github Issue #506

Co-authored-by: Vijay Krishnamoorthy <vijay_krish@apple.com>
This commit is contained in:
Vijay Krish 2024-02-13 17:47:41 -08:00 committed by GitHub
parent be6e9d6a9f
commit 2fdc2462c3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 3 additions and 3 deletions

View File

@ -28,7 +28,7 @@ def measure_runtime(fn, **kwargs):
fn(**kwargs) fn(**kwargs)
tic = time.time() tic = time.time()
iters = 10 iters = 100
for _ in range(iters): for _ in range(iters):
fn(**kwargs) fn(**kwargs)
return (time.time() - tic) * 1000 / iters return (time.time() - tic) * 1000 / iters

View File

@ -71,7 +71,7 @@ inline size_t elem_to_loc(
device const size_t* strides, device const size_t* strides,
int ndim) { int ndim) {
size_t loc = 0; size_t loc = 0;
for (int i = ndim - 1; i >= 0; --i) { for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
loc += (elem % shape[i]) * strides[i]; loc += (elem % shape[i]) * strides[i];
elem /= shape[i]; elem /= shape[i];
} }
@ -84,7 +84,7 @@ inline size_t elem_to_loc(
constant const size_t* strides, constant const size_t* strides,
int ndim) { int ndim) {
size_t loc = 0; size_t loc = 0;
for (int i = ndim - 1; i >= 0; --i) { for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
loc += (elem % shape[i]) * strides[i]; loc += (elem % shape[i]) * strides[i];
elem /= shape[i]; elem /= shape[i];
} }