mirror of
https://github.com/ml-explore/mlx.git
synced 2025-11-03 18:18:15 +08:00
Scatter optimization : Eliminate 64b integer divide. (#662)
Launch 2D grid to eliminate divide and mod in device code, since 64b integer division is very expensive. Github Issue #506 Co-authored-by: Vijay Krishnamoorthy <vijay_krish@apple.com>
This commit is contained in:
@@ -5,18 +5,7 @@ from time import time
|
||||
|
||||
import mlx.core as mx
|
||||
import torch
|
||||
|
||||
|
||||
def measure_runtime(fn, **kwargs):
|
||||
# Warmup
|
||||
for _ in range(5):
|
||||
fn(**kwargs)
|
||||
|
||||
tic = time()
|
||||
iters = 10
|
||||
for _ in range(iters):
|
||||
fn(**kwargs)
|
||||
return (time() - tic) * 1000 / iters
|
||||
from time_utils import measure_runtime
|
||||
|
||||
|
||||
def benchmark_gather_mlx(x_shape, idx_shape):
|
||||
|
||||
Reference in New Issue
Block a user