mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-13 19:51:13 +08:00
Scatter optimization : Eliminate 64b integer divide. (#662)
Launch 2D grid to eliminate divide and mod in device code, since 64b integer division is very expensive. Github Issue #506 Co-authored-by: Vijay Krishnamoorthy <vijay_krish@apple.com>
This commit is contained in:
parent
11d2c8f7a1
commit
06072601ce
@ -5,18 +5,7 @@ from time import time
|
||||
|
||||
import mlx.core as mx
|
||||
import torch
|
||||
|
||||
|
||||
def measure_runtime(fn, **kwargs):
|
||||
# Warmup
|
||||
for _ in range(5):
|
||||
fn(**kwargs)
|
||||
|
||||
tic = time()
|
||||
iters = 10
|
||||
for _ in range(iters):
|
||||
fn(**kwargs)
|
||||
return (time() - tic) * 1000 / iters
|
||||
from time_utils import measure_runtime
|
||||
|
||||
|
||||
def benchmark_gather_mlx(x_shape, idx_shape):
|
||||
|
56
benchmarks/python/scatter_bench.py
Normal file
56
benchmarks/python/scatter_bench.py
Normal file
@ -0,0 +1,56 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import argparse
|
||||
|
||||
import mlx.core as mx
|
||||
import torch
|
||||
from time_utils import measure_runtime
|
||||
|
||||
|
||||
def benchmark_scatter_mlx(dst_shape, x_shape, idx_shape):
|
||||
def scatter(dst, x, idx):
|
||||
dst[idx] = x
|
||||
mx.eval(dst)
|
||||
|
||||
idx = mx.random.randint(0, dst_shape[0] - 1, idx_shape)
|
||||
x = mx.random.normal(x_shape).astype(mx.float32)
|
||||
dst = mx.random.normal(dst_shape).astype(mx.float32)
|
||||
|
||||
runtime = measure_runtime(scatter, dst=dst, x=x, idx=idx)
|
||||
print(f"MLX: {runtime:.3f}ms")
|
||||
|
||||
|
||||
def benchmark_scatter_torch(dst_shape, x_shape, idx_shape, device):
|
||||
def gather(dst, x, idx, device):
|
||||
dst[idx] = x
|
||||
if device == torch.device("mps"):
|
||||
torch.mps.synchronize()
|
||||
|
||||
idx = torch.randint(0, dst_shape[0] - 1, idx_shape).to(device)
|
||||
x = torch.randn(x_shape, dtype=torch.float32).to(device)
|
||||
dst = torch.randn(dst_shape, dtype=torch.float32).to(device)
|
||||
|
||||
runtime = measure_runtime(gather, dst=dst, x=x, idx=idx, device=device)
|
||||
print(f"PyTorch: {runtime:.3f}ms")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser("Gather benchmarks.")
|
||||
parser.add_argument("--cpu", action="store_true", help="Use the CPU.")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.cpu:
|
||||
mx.set_default_device(mx.cpu)
|
||||
device = torch.device("cpu")
|
||||
else:
|
||||
device = torch.device("mps")
|
||||
|
||||
dst_shapes = [(10, 64), (100_000, 64), (1_000_000, 64)]
|
||||
idx_shapes = [(1_000_000,), (1_000_000,), (100_000,)]
|
||||
x_shapes = [(1_000_000, 64), (1_000_000, 64), (100_000, 64)]
|
||||
|
||||
for dst_shape, x_shape, idx_shape in zip(dst_shapes, x_shapes, idx_shapes):
|
||||
print("=" * 20)
|
||||
print(f"X {x_shape}, Indices {idx_shape}")
|
||||
benchmark_scatter_mlx(dst_shape, x_shape, idx_shape)
|
||||
benchmark_scatter_torch(dst_shape, x_shape, idx_shape, device=device)
|
@ -1,4 +1,4 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import time
|
||||
|
||||
@ -20,3 +20,15 @@ def time_fn(fn, *args, **kwargs):
|
||||
|
||||
msec = 1e3 * (toc - tic) / num_iters
|
||||
print(f"{msec:.5f} msec")
|
||||
|
||||
|
||||
def measure_runtime(fn, **kwargs):
|
||||
# Warmup
|
||||
for _ in range(5):
|
||||
fn(**kwargs)
|
||||
|
||||
tic = time.time()
|
||||
iters = 10
|
||||
for _ in range(iters):
|
||||
fn(**kwargs)
|
||||
return (time.time() - tic) * 1000 / iters
|
||||
|
@ -212,9 +212,6 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
thread_group_size = nthreads;
|
||||
}
|
||||
|
||||
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
|
||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
// Make the argument buffer to store the indices for the
|
||||
@ -317,6 +314,8 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder->setBytes(&out_ndim, sizeof(size_t), 9);
|
||||
compute_encoder->setBytes(axes_.data(), axes_.size() * sizeof(int), 10);
|
||||
|
||||
MTL::Size grid_dims = MTL::Size(upd_size, nthreads / upd_size, 1);
|
||||
MTL::Size group_dims = get_block_dims(upd_size, nthreads / upd_size, 1);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
|
||||
// Cleanup temporaries
|
||||
|
@ -187,11 +187,11 @@ template <typename T, typename IdxT, typename Op, int NIDX>
|
||||
const device size_t *out_strides [[buffer(8)]],
|
||||
const device size_t& out_ndim [[buffer(9)]],
|
||||
const device int* axes [[buffer(10)]],
|
||||
uint gid [[thread_position_in_grid]]) {
|
||||
uint2 gid [[thread_position_in_grid]]) {
|
||||
|
||||
Op op;
|
||||
auto ind_idx = gid / upd_size;
|
||||
auto ind_offset = gid % upd_size;
|
||||
auto ind_idx = gid.y;
|
||||
auto ind_offset = gid.x;
|
||||
|
||||
size_t out_idx = 0;
|
||||
for (int i = 0; i < NIDX; ++i) {
|
||||
@ -208,7 +208,7 @@ template <typename T, typename IdxT, typename Op, int NIDX>
|
||||
|
||||
auto out_offset = elem_to_loc(
|
||||
ind_offset, upd_shape + indices.ndim, out_strides, out_ndim);
|
||||
auto upd_idx = elem_to_loc(gid, upd_shape, upd_strides, upd_ndim);
|
||||
auto upd_idx = elem_to_loc(gid.y * upd_size + gid.x, upd_shape, upd_strides, upd_ndim);
|
||||
op.atomic_update(out, updates[upd_idx], out_idx + out_offset);
|
||||
}
|
||||
|
||||
@ -226,7 +226,7 @@ template [[host_name("scatter" name "_" #nindex)]] \
|
||||
const device size_t *out_strides [[buffer(8)]], \
|
||||
const device size_t& out_ndim [[buffer(9)]], \
|
||||
const device int* axes [[buffer(10)]], \
|
||||
uint gid [[thread_position_in_grid]]);
|
||||
uint2 gid [[thread_position_in_grid]]);
|
||||
|
||||
// Special case NINDEX=0
|
||||
#define instantiate_scatter_nd0(name, type) \
|
||||
|
Loading…
Reference in New Issue
Block a user