Scatter optimization : Eliminate 64b integer divide. (#662)

Launch 2D grid to eliminate divide and mod in device code,
since 64b integer division is very expensive.

Github Issue #506

Co-authored-by: Vijay Krishnamoorthy <vijay_krish@apple.com>
This commit is contained in:
Vijay Krish 2024-02-10 08:49:51 -08:00 committed by GitHub
parent 11d2c8f7a1
commit 06072601ce
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 77 additions and 21 deletions

View File

@ -5,18 +5,7 @@ from time import time
import mlx.core as mx
import torch
def measure_runtime(fn, **kwargs):
# Warmup
for _ in range(5):
fn(**kwargs)
tic = time()
iters = 10
for _ in range(iters):
fn(**kwargs)
return (time() - tic) * 1000 / iters
from time_utils import measure_runtime
def benchmark_gather_mlx(x_shape, idx_shape):

View File

@ -0,0 +1,56 @@
# Copyright © 2023-2024 Apple Inc.
import argparse
import mlx.core as mx
import torch
from time_utils import measure_runtime
def benchmark_scatter_mlx(dst_shape, x_shape, idx_shape):
def scatter(dst, x, idx):
dst[idx] = x
mx.eval(dst)
idx = mx.random.randint(0, dst_shape[0] - 1, idx_shape)
x = mx.random.normal(x_shape).astype(mx.float32)
dst = mx.random.normal(dst_shape).astype(mx.float32)
runtime = measure_runtime(scatter, dst=dst, x=x, idx=idx)
print(f"MLX: {runtime:.3f}ms")
def benchmark_scatter_torch(dst_shape, x_shape, idx_shape, device):
def gather(dst, x, idx, device):
dst[idx] = x
if device == torch.device("mps"):
torch.mps.synchronize()
idx = torch.randint(0, dst_shape[0] - 1, idx_shape).to(device)
x = torch.randn(x_shape, dtype=torch.float32).to(device)
dst = torch.randn(dst_shape, dtype=torch.float32).to(device)
runtime = measure_runtime(gather, dst=dst, x=x, idx=idx, device=device)
print(f"PyTorch: {runtime:.3f}ms")
if __name__ == "__main__":
parser = argparse.ArgumentParser("Gather benchmarks.")
parser.add_argument("--cpu", action="store_true", help="Use the CPU.")
args = parser.parse_args()
if args.cpu:
mx.set_default_device(mx.cpu)
device = torch.device("cpu")
else:
device = torch.device("mps")
dst_shapes = [(10, 64), (100_000, 64), (1_000_000, 64)]
idx_shapes = [(1_000_000,), (1_000_000,), (100_000,)]
x_shapes = [(1_000_000, 64), (1_000_000, 64), (100_000, 64)]
for dst_shape, x_shape, idx_shape in zip(dst_shapes, x_shapes, idx_shapes):
print("=" * 20)
print(f"X {x_shape}, Indices {idx_shape}")
benchmark_scatter_mlx(dst_shape, x_shape, idx_shape)
benchmark_scatter_torch(dst_shape, x_shape, idx_shape, device=device)

View File

@ -1,4 +1,4 @@
# Copyright © 2023 Apple Inc.
# Copyright © 2023-2024 Apple Inc.
import time
@ -20,3 +20,15 @@ def time_fn(fn, *args, **kwargs):
msec = 1e3 * (toc - tic) / num_iters
print(f"{msec:.5f} msec")
def measure_runtime(fn, **kwargs):
# Warmup
for _ in range(5):
fn(**kwargs)
tic = time.time()
iters = 10
for _ in range(iters):
fn(**kwargs)
return (time.time() - tic) * 1000 / iters

View File

@ -212,9 +212,6 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
thread_group_size = nthreads;
}
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
compute_encoder->setComputePipelineState(kernel);
// Make the argument buffer to store the indices for the
@ -317,6 +314,8 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder->setBytes(&out_ndim, sizeof(size_t), 9);
compute_encoder->setBytes(axes_.data(), axes_.size() * sizeof(int), 10);
MTL::Size grid_dims = MTL::Size(upd_size, nthreads / upd_size, 1);
MTL::Size group_dims = get_block_dims(upd_size, nthreads / upd_size, 1);
compute_encoder->dispatchThreads(grid_dims, group_dims);
// Cleanup temporaries

View File

@ -187,11 +187,11 @@ template <typename T, typename IdxT, typename Op, int NIDX>
const device size_t *out_strides [[buffer(8)]],
const device size_t& out_ndim [[buffer(9)]],
const device int* axes [[buffer(10)]],
uint gid [[thread_position_in_grid]]) {
uint2 gid [[thread_position_in_grid]]) {
Op op;
auto ind_idx = gid / upd_size;
auto ind_offset = gid % upd_size;
auto ind_idx = gid.y;
auto ind_offset = gid.x;
size_t out_idx = 0;
for (int i = 0; i < NIDX; ++i) {
@ -208,7 +208,7 @@ template <typename T, typename IdxT, typename Op, int NIDX>
auto out_offset = elem_to_loc(
ind_offset, upd_shape + indices.ndim, out_strides, out_ndim);
auto upd_idx = elem_to_loc(gid, upd_shape, upd_strides, upd_ndim);
auto upd_idx = elem_to_loc(gid.y * upd_size + gid.x, upd_shape, upd_strides, upd_ndim);
op.atomic_update(out, updates[upd_idx], out_idx + out_offset);
}
@ -226,7 +226,7 @@ template [[host_name("scatter" name "_" #nindex)]] \
const device size_t *out_strides [[buffer(8)]], \
const device size_t& out_ndim [[buffer(9)]], \
const device int* axes [[buffer(10)]], \
uint gid [[thread_position_in_grid]]);
uint2 gid [[thread_position_in_grid]]);
// Special case NINDEX=0
#define instantiate_scatter_nd0(name, type) \