Scatter optimization : Eliminate 64b integer divide. (#662)

Launch 2D grid to eliminate divide and mod in device code,
since 64b integer division is very expensive.

Github Issue #506

Co-authored-by: Vijay Krishnamoorthy <vijay_krish@apple.com>
This commit is contained in:
Vijay Krish
2024-02-10 08:49:51 -08:00
committed by GitHub
parent 11d2c8f7a1
commit 06072601ce
5 changed files with 77 additions and 21 deletions

View File

@@ -5,18 +5,7 @@ from time import time
import mlx.core as mx
import torch
def measure_runtime(fn, **kwargs):
# Warmup
for _ in range(5):
fn(**kwargs)
tic = time()
iters = 10
for _ in range(iters):
fn(**kwargs)
return (time() - tic) * 1000 / iters
from time_utils import measure_runtime
def benchmark_gather_mlx(x_shape, idx_shape):