mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-15 21:21:16 +08:00
Scatter optimization : Eliminate 64b integer divide. (#662)
Launch 2D grid to eliminate divide and mod in device code, since 64b integer division is very expensive. Github Issue #506 Co-authored-by: Vijay Krishnamoorthy <vijay_krish@apple.com>
This commit is contained in:
parent
11d2c8f7a1
commit
06072601ce
@ -5,18 +5,7 @@ from time import time
|
|||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import torch
|
import torch
|
||||||
|
from time_utils import measure_runtime
|
||||||
|
|
||||||
def measure_runtime(fn, **kwargs):
|
|
||||||
# Warmup
|
|
||||||
for _ in range(5):
|
|
||||||
fn(**kwargs)
|
|
||||||
|
|
||||||
tic = time()
|
|
||||||
iters = 10
|
|
||||||
for _ in range(iters):
|
|
||||||
fn(**kwargs)
|
|
||||||
return (time() - tic) * 1000 / iters
|
|
||||||
|
|
||||||
|
|
||||||
def benchmark_gather_mlx(x_shape, idx_shape):
|
def benchmark_gather_mlx(x_shape, idx_shape):
|
||||||
|
56
benchmarks/python/scatter_bench.py
Normal file
56
benchmarks/python/scatter_bench.py
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import torch
|
||||||
|
from time_utils import measure_runtime
|
||||||
|
|
||||||
|
|
||||||
|
def benchmark_scatter_mlx(dst_shape, x_shape, idx_shape):
|
||||||
|
def scatter(dst, x, idx):
|
||||||
|
dst[idx] = x
|
||||||
|
mx.eval(dst)
|
||||||
|
|
||||||
|
idx = mx.random.randint(0, dst_shape[0] - 1, idx_shape)
|
||||||
|
x = mx.random.normal(x_shape).astype(mx.float32)
|
||||||
|
dst = mx.random.normal(dst_shape).astype(mx.float32)
|
||||||
|
|
||||||
|
runtime = measure_runtime(scatter, dst=dst, x=x, idx=idx)
|
||||||
|
print(f"MLX: {runtime:.3f}ms")
|
||||||
|
|
||||||
|
|
||||||
|
def benchmark_scatter_torch(dst_shape, x_shape, idx_shape, device):
|
||||||
|
def gather(dst, x, idx, device):
|
||||||
|
dst[idx] = x
|
||||||
|
if device == torch.device("mps"):
|
||||||
|
torch.mps.synchronize()
|
||||||
|
|
||||||
|
idx = torch.randint(0, dst_shape[0] - 1, idx_shape).to(device)
|
||||||
|
x = torch.randn(x_shape, dtype=torch.float32).to(device)
|
||||||
|
dst = torch.randn(dst_shape, dtype=torch.float32).to(device)
|
||||||
|
|
||||||
|
runtime = measure_runtime(gather, dst=dst, x=x, idx=idx, device=device)
|
||||||
|
print(f"PyTorch: {runtime:.3f}ms")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser("Gather benchmarks.")
|
||||||
|
parser.add_argument("--cpu", action="store_true", help="Use the CPU.")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.cpu:
|
||||||
|
mx.set_default_device(mx.cpu)
|
||||||
|
device = torch.device("cpu")
|
||||||
|
else:
|
||||||
|
device = torch.device("mps")
|
||||||
|
|
||||||
|
dst_shapes = [(10, 64), (100_000, 64), (1_000_000, 64)]
|
||||||
|
idx_shapes = [(1_000_000,), (1_000_000,), (100_000,)]
|
||||||
|
x_shapes = [(1_000_000, 64), (1_000_000, 64), (100_000, 64)]
|
||||||
|
|
||||||
|
for dst_shape, x_shape, idx_shape in zip(dst_shapes, x_shapes, idx_shapes):
|
||||||
|
print("=" * 20)
|
||||||
|
print(f"X {x_shape}, Indices {idx_shape}")
|
||||||
|
benchmark_scatter_mlx(dst_shape, x_shape, idx_shape)
|
||||||
|
benchmark_scatter_torch(dst_shape, x_shape, idx_shape, device=device)
|
@ -1,4 +1,4 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
import time
|
import time
|
||||||
|
|
||||||
@ -20,3 +20,15 @@ def time_fn(fn, *args, **kwargs):
|
|||||||
|
|
||||||
msec = 1e3 * (toc - tic) / num_iters
|
msec = 1e3 * (toc - tic) / num_iters
|
||||||
print(f"{msec:.5f} msec")
|
print(f"{msec:.5f} msec")
|
||||||
|
|
||||||
|
|
||||||
|
def measure_runtime(fn, **kwargs):
|
||||||
|
# Warmup
|
||||||
|
for _ in range(5):
|
||||||
|
fn(**kwargs)
|
||||||
|
|
||||||
|
tic = time.time()
|
||||||
|
iters = 10
|
||||||
|
for _ in range(iters):
|
||||||
|
fn(**kwargs)
|
||||||
|
return (time.time() - tic) * 1000 / iters
|
||||||
|
@ -212,9 +212,6 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
thread_group_size = nthreads;
|
thread_group_size = nthreads;
|
||||||
}
|
}
|
||||||
|
|
||||||
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
|
|
||||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
|
||||||
|
|
||||||
compute_encoder->setComputePipelineState(kernel);
|
compute_encoder->setComputePipelineState(kernel);
|
||||||
|
|
||||||
// Make the argument buffer to store the indices for the
|
// Make the argument buffer to store the indices for the
|
||||||
@ -317,6 +314,8 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
compute_encoder->setBytes(&out_ndim, sizeof(size_t), 9);
|
compute_encoder->setBytes(&out_ndim, sizeof(size_t), 9);
|
||||||
compute_encoder->setBytes(axes_.data(), axes_.size() * sizeof(int), 10);
|
compute_encoder->setBytes(axes_.data(), axes_.size() * sizeof(int), 10);
|
||||||
|
|
||||||
|
MTL::Size grid_dims = MTL::Size(upd_size, nthreads / upd_size, 1);
|
||||||
|
MTL::Size group_dims = get_block_dims(upd_size, nthreads / upd_size, 1);
|
||||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||||
|
|
||||||
// Cleanup temporaries
|
// Cleanup temporaries
|
||||||
|
@ -187,11 +187,11 @@ template <typename T, typename IdxT, typename Op, int NIDX>
|
|||||||
const device size_t *out_strides [[buffer(8)]],
|
const device size_t *out_strides [[buffer(8)]],
|
||||||
const device size_t& out_ndim [[buffer(9)]],
|
const device size_t& out_ndim [[buffer(9)]],
|
||||||
const device int* axes [[buffer(10)]],
|
const device int* axes [[buffer(10)]],
|
||||||
uint gid [[thread_position_in_grid]]) {
|
uint2 gid [[thread_position_in_grid]]) {
|
||||||
|
|
||||||
Op op;
|
Op op;
|
||||||
auto ind_idx = gid / upd_size;
|
auto ind_idx = gid.y;
|
||||||
auto ind_offset = gid % upd_size;
|
auto ind_offset = gid.x;
|
||||||
|
|
||||||
size_t out_idx = 0;
|
size_t out_idx = 0;
|
||||||
for (int i = 0; i < NIDX; ++i) {
|
for (int i = 0; i < NIDX; ++i) {
|
||||||
@ -208,7 +208,7 @@ template <typename T, typename IdxT, typename Op, int NIDX>
|
|||||||
|
|
||||||
auto out_offset = elem_to_loc(
|
auto out_offset = elem_to_loc(
|
||||||
ind_offset, upd_shape + indices.ndim, out_strides, out_ndim);
|
ind_offset, upd_shape + indices.ndim, out_strides, out_ndim);
|
||||||
auto upd_idx = elem_to_loc(gid, upd_shape, upd_strides, upd_ndim);
|
auto upd_idx = elem_to_loc(gid.y * upd_size + gid.x, upd_shape, upd_strides, upd_ndim);
|
||||||
op.atomic_update(out, updates[upd_idx], out_idx + out_offset);
|
op.atomic_update(out, updates[upd_idx], out_idx + out_offset);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -226,7 +226,7 @@ template [[host_name("scatter" name "_" #nindex)]] \
|
|||||||
const device size_t *out_strides [[buffer(8)]], \
|
const device size_t *out_strides [[buffer(8)]], \
|
||||||
const device size_t& out_ndim [[buffer(9)]], \
|
const device size_t& out_ndim [[buffer(9)]], \
|
||||||
const device int* axes [[buffer(10)]], \
|
const device int* axes [[buffer(10)]], \
|
||||||
uint gid [[thread_position_in_grid]]);
|
uint2 gid [[thread_position_in_grid]]);
|
||||||
|
|
||||||
// Special case NINDEX=0
|
// Special case NINDEX=0
|
||||||
#define instantiate_scatter_nd0(name, type) \
|
#define instantiate_scatter_nd0(name, type) \
|
||||||
|
Loading…
Reference in New Issue
Block a user