mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-13 19:51:13 +08:00
Faster gather and scatter. (#682)
Reduce unnecessary integer ops, especially since there kernels are integer bound. Increase number of iterations for benchmarks for better smoothing. Github Issue #506 Co-authored-by: Vijay Krishnamoorthy <vijay_krish@apple.com>
This commit is contained in:
parent
be6e9d6a9f
commit
2fdc2462c3
@ -28,7 +28,7 @@ def measure_runtime(fn, **kwargs):
|
||||
fn(**kwargs)
|
||||
|
||||
tic = time.time()
|
||||
iters = 10
|
||||
iters = 100
|
||||
for _ in range(iters):
|
||||
fn(**kwargs)
|
||||
return (time.time() - tic) * 1000 / iters
|
||||
|
@ -71,7 +71,7 @@ inline size_t elem_to_loc(
|
||||
device const size_t* strides,
|
||||
int ndim) {
|
||||
size_t loc = 0;
|
||||
for (int i = ndim - 1; i >= 0; --i) {
|
||||
for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
|
||||
loc += (elem % shape[i]) * strides[i];
|
||||
elem /= shape[i];
|
||||
}
|
||||
@ -84,7 +84,7 @@ inline size_t elem_to_loc(
|
||||
constant const size_t* strides,
|
||||
int ndim) {
|
||||
size_t loc = 0;
|
||||
for (int i = ndim - 1; i >= 0; --i) {
|
||||
for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
|
||||
loc += (elem % shape[i]) * strides[i];
|
||||
elem /= shape[i];
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user