mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-13 11:41:14 +08:00
Faster gather and scatter. (#682)
Reduce unnecessary integer ops, especially since there kernels are integer bound. Increase number of iterations for benchmarks for better smoothing. Github Issue #506 Co-authored-by: Vijay Krishnamoorthy <vijay_krish@apple.com>
This commit is contained in:
parent
be6e9d6a9f
commit
2fdc2462c3
@ -28,7 +28,7 @@ def measure_runtime(fn, **kwargs):
|
|||||||
fn(**kwargs)
|
fn(**kwargs)
|
||||||
|
|
||||||
tic = time.time()
|
tic = time.time()
|
||||||
iters = 10
|
iters = 100
|
||||||
for _ in range(iters):
|
for _ in range(iters):
|
||||||
fn(**kwargs)
|
fn(**kwargs)
|
||||||
return (time.time() - tic) * 1000 / iters
|
return (time.time() - tic) * 1000 / iters
|
||||||
|
@ -71,7 +71,7 @@ inline size_t elem_to_loc(
|
|||||||
device const size_t* strides,
|
device const size_t* strides,
|
||||||
int ndim) {
|
int ndim) {
|
||||||
size_t loc = 0;
|
size_t loc = 0;
|
||||||
for (int i = ndim - 1; i >= 0; --i) {
|
for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
|
||||||
loc += (elem % shape[i]) * strides[i];
|
loc += (elem % shape[i]) * strides[i];
|
||||||
elem /= shape[i];
|
elem /= shape[i];
|
||||||
}
|
}
|
||||||
@ -84,7 +84,7 @@ inline size_t elem_to_loc(
|
|||||||
constant const size_t* strides,
|
constant const size_t* strides,
|
||||||
int ndim) {
|
int ndim) {
|
||||||
size_t loc = 0;
|
size_t loc = 0;
|
||||||
for (int i = ndim - 1; i >= 0; --i) {
|
for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
|
||||||
loc += (elem % shape[i]) * strides[i];
|
loc += (elem % shape[i]) * strides[i];
|
||||||
elem /= shape[i];
|
elem /= shape[i];
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user