Faster gather (#626)

* faster gather

* update copyright
This commit is contained in:
Awni Hannun 2024-02-04 17:25:44 -08:00 committed by GitHub
parent 5c3ac52dd7
commit e319383ef9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 142 additions and 41 deletions

View File

@ -0,0 +1,64 @@
# Copyright © 2023-2024 Apple Inc.
import argparse
from time import time
import mlx.core as mx
import torch
def measure_runtime(fn, **kwargs):
# Warmup
for _ in range(5):
fn(**kwargs)
tic = time()
iters = 10
for _ in range(iters):
fn(**kwargs)
return (time() - tic) * 1000 / iters
def benchmark_gather_mlx(x_shape, idx_shape):
def gather(x, idx):
mx.eval(x[idx])
idx = mx.random.randint(0, x_shape[0] - 1, idx_shape)
x = mx.random.normal(x_shape).astype(mx.float32)
runtime = measure_runtime(gather, x=x, idx=idx)
print(f"MLX: {runtime:.3f}ms")
def benchmark_gather_torch(x_shape, idx_shape, device):
def gather(x, idx, device):
_ = x[idx]
if device == torch.device("mps"):
torch.mps.synchronize()
idx = torch.randint(0, x_shape[0] - 1, idx_shape).to(device)
x = torch.randn(x_shape, dtype=torch.float32).to(device)
runtime = measure_runtime(gather, x=x, idx=idx, device=device)
print(f"PyTorch: {runtime:.3f}ms")
if __name__ == "__main__":
parser = argparse.ArgumentParser("Gather benchmarks.")
parser.add_argument("--cpu", action="store_true", help="Use the CPU.")
args = parser.parse_args()
if args.cpu:
mx.set_default_device(mx.cpu)
device = torch.device("cpu")
else:
device = torch.device("mps")
idx_shapes = [(1_000_000,), (100_000,), ()]
x_shapes = [(100, 64), (100, 1024), (4, 1_000_000)]
for x_shape, idx_shape in zip(x_shapes, idx_shapes):
print("=" * 20)
print(f"X {x_shape}, Indices {idx_shape}")
benchmark_gather_mlx(x_shape, idx_shape)
benchmark_gather_torch(x_shape, idx_shape, device=device)

View File

@ -1,4 +1,4 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
#include <algorithm>
#include <cassert>
#include <numeric>
@ -39,9 +39,15 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& s = stream();
auto& d = metal::device(s.device);
int idx_ndim = nidx ? inputs[1].ndim() : 0;
size_t ndim = src.ndim();
std::ostringstream kname;
std::string idx_type_name = nidx ? type_to_name(inputs[1]) : "";
kname << "gather" << type_to_name(src) << idx_type_name << "_" << nidx;
if (idx_ndim <= 1) {
kname << "_" << idx_ndim;
}
auto compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
@ -51,15 +57,11 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
slice_size *= s;
}
size_t ndim = src.ndim();
size_t nthreads = out.size();
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size > nthreads) {
thread_group_size = nthreads;
}
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
// Launch 2D grid of threads: indices x slice
size_t dim0 = out.size() / slice_size;
size_t dim1 = slice_size;
auto group_dims = get_block_dims(dim0, dim1, 1);
MTL::Size grid_dims = MTL::Size(dim0, dim1, 1);
compute_encoder->setComputePipelineState(kernel);
@ -90,7 +92,6 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
auto arg_enc = d.argument_encoder(arg_descs);
// Allocate and fill buffers for shapes and strides
int idx_ndim = nidx ? inputs[1].ndim() : 0;
auto idx_shapes_buf = allocator::malloc_or_wait(sizeof(int) * idx_ndim);
auto idx_strides_buf = allocator::malloc_or_wait(sizeof(size_t) * idx_ndim);
for (int i = 0; i < nidx; ++i) {
@ -130,12 +131,12 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
set_array_buffer(compute_encoder, src, 0);
compute_encoder->setBuffer(static_cast<MTL::Buffer*>(arg_buf.ptr()), 0, 1);
set_array_buffer(compute_encoder, out, 2);
compute_encoder->setBytes(src.shape().data(), ndim * sizeof(int), 3);
compute_encoder->setBytes(src.strides().data(), ndim * sizeof(size_t), 4);
compute_encoder->setBytes(&ndim, sizeof(size_t), 5);
compute_encoder->setBytes(slice_sizes_.data(), ndim * sizeof(int), 6);
compute_encoder->setBytes(&slice_size, sizeof(size_t), 7);
compute_encoder->setBytes(axes_.data(), nidx * sizeof(int), 8);
compute_encoder->setBytes(axes_.data(), nidx * sizeof(int), 7);
compute_encoder->dispatchThreads(grid_dims, group_dims);

View File

@ -1,4 +1,4 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
#include <metal_atomic>
#include <metal_texture>
@ -36,29 +36,38 @@ inline size_t offset_neg_idx(uint32_t idx, size_t) {
return idx;
}
template <typename T, typename IdxT, int NIDX>
// IDX_NDIM is the number of dimensions of the indices arrays. Compile-time
// special case for 0 and 1. Anything >= 2 uses the general case
template <typename T, typename IdxT, int NIDX, int IDX_NDIM>
[[kernel]] void gather(
const device T *src [[buffer(0)]],
const device Indices<IdxT, NIDX>& indices [[buffer(1)]],
const constant Indices<IdxT, NIDX>& indices [[buffer(1)]],
device T *out [[buffer(2)]],
const device int *src_shape [[buffer(3)]],
const device size_t *src_strides [[buffer(4)]],
const device size_t& src_ndim [[buffer(5)]],
const device int *slice_sizes [[buffer(6)]],
const device size_t& slice_size [[buffer(7)]],
const device int *axes [[buffer(8)]],
uint gid [[thread_position_in_grid]]) {
const constant int *src_shape [[buffer(3)]],
const constant size_t *src_strides [[buffer(4)]],
const constant size_t& src_ndim [[buffer(5)]],
const constant int *slice_sizes [[buffer(6)]],
const constant int *axes [[buffer(7)]],
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
auto ind_idx = gid / slice_size;
auto ind_offset = gid % slice_size;
auto ind_idx = index.x;
auto ind_offset = index.y;
size_t src_idx = 0;
for (int i = 0; i < NIDX; ++i) {
auto idx_loc = elem_to_loc(
size_t idx_loc;
if (IDX_NDIM == 0) {
idx_loc = 0;
} else if (IDX_NDIM == 1) {
idx_loc = ind_idx * indices.strides[indices.ndim * i];
} else {
idx_loc = elem_to_loc(
ind_idx,
&indices.shapes[indices.ndim * i],
&indices.strides[indices.ndim * i],
indices.ndim);
}
auto ax = axes[i];
auto idx_val = offset_neg_idx(
indices.buffers[i][idx_loc], src_shape[ax]);
@ -67,22 +76,49 @@ template <typename T, typename IdxT, int NIDX>
auto src_offset = elem_to_loc(
ind_offset, slice_sizes, src_strides, src_ndim);
out[gid] = src[src_idx + src_offset];
size_t out_idx = index.y + static_cast<size_t>(grid_dim.y) * index.x;
out[out_idx] = src[src_offset + src_idx];
}
#define instantiate_gather4(name, src_type, ind_type, nindex) \
template [[host_name("gather" name "_" #nindex)]] \
[[kernel]] void gather( \
template [[host_name("gather" name "_" #nindex "_0")]] \
[[kernel]] void gather<src_type, ind_type, nindex, 0>( \
const device src_type *src [[buffer(0)]], \
const device Indices<ind_type, nindex>& indices [[buffer(1)]], \
const constant Indices<ind_type, nindex>& indices [[buffer(1)]], \
device src_type *out [[buffer(2)]], \
const device int *src_shape [[buffer(3)]], \
const device size_t *src_strides [[buffer(4)]], \
const device size_t& src_ndim [[buffer(5)]], \
const device int *slice_sizes [[buffer(6)]], \
const device size_t& slice_size [[buffer(7)]], \
const device int* axes [[buffer(8)]], \
uint gid [[thread_position_in_grid]]);
const constant int *src_shape [[buffer(3)]], \
const constant size_t *src_strides [[buffer(4)]], \
const constant size_t& src_ndim [[buffer(5)]], \
const constant int *slice_sizes [[buffer(6)]], \
const constant int* axes [[buffer(7)]], \
uint2 index [[thread_position_in_grid]], \
uint2 grid_dim [[threads_per_grid]]); \
template [[host_name("gather" name "_" #nindex "_1")]] \
[[kernel]] void gather<src_type, ind_type, nindex, 1>( \
const device src_type *src [[buffer(0)]], \
const constant Indices<ind_type, nindex>& indices [[buffer(1)]], \
device src_type *out [[buffer(2)]], \
const constant int *src_shape [[buffer(3)]], \
const constant size_t *src_strides [[buffer(4)]], \
const constant size_t& src_ndim [[buffer(5)]], \
const constant int *slice_sizes [[buffer(6)]], \
const constant int* axes [[buffer(7)]], \
uint2 index [[thread_position_in_grid]], \
uint2 grid_dim [[threads_per_grid]]); \
template [[host_name("gather" name "_" #nindex)]] \
[[kernel]] void gather<src_type, ind_type, nindex, 2>( \
const device src_type *src [[buffer(0)]], \
const constant Indices<ind_type, nindex>& indices [[buffer(1)]], \
device src_type *out [[buffer(2)]], \
const constant int *src_shape [[buffer(3)]], \
const constant size_t *src_strides [[buffer(4)]], \
const constant size_t& src_ndim [[buffer(5)]], \
const constant int *slice_sizes [[buffer(6)]], \
const constant int* axes [[buffer(7)]], \
uint2 index [[thread_position_in_grid]], \
uint2 grid_dim [[threads_per_grid]]);
// Special for case NIDX=0
instantiate_gather4("bool_", bool, bool, 0)