mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
parent
5c3ac52dd7
commit
e319383ef9
64
benchmarks/python/gather_bench.py
Normal file
64
benchmarks/python/gather_bench.py
Normal file
@ -0,0 +1,64 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import argparse
|
||||
from time import time
|
||||
|
||||
import mlx.core as mx
|
||||
import torch
|
||||
|
||||
|
||||
def measure_runtime(fn, **kwargs):
|
||||
# Warmup
|
||||
for _ in range(5):
|
||||
fn(**kwargs)
|
||||
|
||||
tic = time()
|
||||
iters = 10
|
||||
for _ in range(iters):
|
||||
fn(**kwargs)
|
||||
return (time() - tic) * 1000 / iters
|
||||
|
||||
|
||||
def benchmark_gather_mlx(x_shape, idx_shape):
|
||||
def gather(x, idx):
|
||||
mx.eval(x[idx])
|
||||
|
||||
idx = mx.random.randint(0, x_shape[0] - 1, idx_shape)
|
||||
x = mx.random.normal(x_shape).astype(mx.float32)
|
||||
|
||||
runtime = measure_runtime(gather, x=x, idx=idx)
|
||||
print(f"MLX: {runtime:.3f}ms")
|
||||
|
||||
|
||||
def benchmark_gather_torch(x_shape, idx_shape, device):
|
||||
def gather(x, idx, device):
|
||||
_ = x[idx]
|
||||
if device == torch.device("mps"):
|
||||
torch.mps.synchronize()
|
||||
|
||||
idx = torch.randint(0, x_shape[0] - 1, idx_shape).to(device)
|
||||
x = torch.randn(x_shape, dtype=torch.float32).to(device)
|
||||
|
||||
runtime = measure_runtime(gather, x=x, idx=idx, device=device)
|
||||
print(f"PyTorch: {runtime:.3f}ms")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser("Gather benchmarks.")
|
||||
parser.add_argument("--cpu", action="store_true", help="Use the CPU.")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.cpu:
|
||||
mx.set_default_device(mx.cpu)
|
||||
device = torch.device("cpu")
|
||||
else:
|
||||
device = torch.device("mps")
|
||||
|
||||
idx_shapes = [(1_000_000,), (100_000,), ()]
|
||||
x_shapes = [(100, 64), (100, 1024), (4, 1_000_000)]
|
||||
|
||||
for x_shape, idx_shape in zip(x_shapes, idx_shapes):
|
||||
print("=" * 20)
|
||||
print(f"X {x_shape}, Indices {idx_shape}")
|
||||
benchmark_gather_mlx(x_shape, idx_shape)
|
||||
benchmark_gather_torch(x_shape, idx_shape, device=device)
|
@ -1,4 +1,4 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <numeric>
|
||||
@ -39,9 +39,15 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
int idx_ndim = nidx ? inputs[1].ndim() : 0;
|
||||
size_t ndim = src.ndim();
|
||||
|
||||
std::ostringstream kname;
|
||||
std::string idx_type_name = nidx ? type_to_name(inputs[1]) : "";
|
||||
kname << "gather" << type_to_name(src) << idx_type_name << "_" << nidx;
|
||||
if (idx_ndim <= 1) {
|
||||
kname << "_" << idx_ndim;
|
||||
}
|
||||
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
@ -51,15 +57,11 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
slice_size *= s;
|
||||
}
|
||||
|
||||
size_t ndim = src.ndim();
|
||||
size_t nthreads = out.size();
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (thread_group_size > nthreads) {
|
||||
thread_group_size = nthreads;
|
||||
}
|
||||
|
||||
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
|
||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
// Launch 2D grid of threads: indices x slice
|
||||
size_t dim0 = out.size() / slice_size;
|
||||
size_t dim1 = slice_size;
|
||||
auto group_dims = get_block_dims(dim0, dim1, 1);
|
||||
MTL::Size grid_dims = MTL::Size(dim0, dim1, 1);
|
||||
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
@ -90,7 +92,6 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto arg_enc = d.argument_encoder(arg_descs);
|
||||
|
||||
// Allocate and fill buffers for shapes and strides
|
||||
int idx_ndim = nidx ? inputs[1].ndim() : 0;
|
||||
auto idx_shapes_buf = allocator::malloc_or_wait(sizeof(int) * idx_ndim);
|
||||
auto idx_strides_buf = allocator::malloc_or_wait(sizeof(size_t) * idx_ndim);
|
||||
for (int i = 0; i < nidx; ++i) {
|
||||
@ -130,12 +131,12 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
set_array_buffer(compute_encoder, src, 0);
|
||||
compute_encoder->setBuffer(static_cast<MTL::Buffer*>(arg_buf.ptr()), 0, 1);
|
||||
set_array_buffer(compute_encoder, out, 2);
|
||||
|
||||
compute_encoder->setBytes(src.shape().data(), ndim * sizeof(int), 3);
|
||||
compute_encoder->setBytes(src.strides().data(), ndim * sizeof(size_t), 4);
|
||||
compute_encoder->setBytes(&ndim, sizeof(size_t), 5);
|
||||
compute_encoder->setBytes(slice_sizes_.data(), ndim * sizeof(int), 6);
|
||||
compute_encoder->setBytes(&slice_size, sizeof(size_t), 7);
|
||||
compute_encoder->setBytes(axes_.data(), nidx * sizeof(int), 8);
|
||||
compute_encoder->setBytes(axes_.data(), nidx * sizeof(int), 7);
|
||||
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <metal_atomic>
|
||||
#include <metal_texture>
|
||||
@ -36,29 +36,38 @@ inline size_t offset_neg_idx(uint32_t idx, size_t) {
|
||||
return idx;
|
||||
}
|
||||
|
||||
template <typename T, typename IdxT, int NIDX>
|
||||
// IDX_NDIM is the number of dimensions of the indices arrays. Compile-time
|
||||
// special case for 0 and 1. Anything >= 2 uses the general case
|
||||
template <typename T, typename IdxT, int NIDX, int IDX_NDIM>
|
||||
[[kernel]] void gather(
|
||||
const device T *src [[buffer(0)]],
|
||||
const device Indices<IdxT, NIDX>& indices [[buffer(1)]],
|
||||
const constant Indices<IdxT, NIDX>& indices [[buffer(1)]],
|
||||
device T *out [[buffer(2)]],
|
||||
const device int *src_shape [[buffer(3)]],
|
||||
const device size_t *src_strides [[buffer(4)]],
|
||||
const device size_t& src_ndim [[buffer(5)]],
|
||||
const device int *slice_sizes [[buffer(6)]],
|
||||
const device size_t& slice_size [[buffer(7)]],
|
||||
const device int *axes [[buffer(8)]],
|
||||
uint gid [[thread_position_in_grid]]) {
|
||||
const constant int *src_shape [[buffer(3)]],
|
||||
const constant size_t *src_strides [[buffer(4)]],
|
||||
const constant size_t& src_ndim [[buffer(5)]],
|
||||
const constant int *slice_sizes [[buffer(6)]],
|
||||
const constant int *axes [[buffer(7)]],
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
|
||||
auto ind_idx = gid / slice_size;
|
||||
auto ind_offset = gid % slice_size;
|
||||
auto ind_idx = index.x;
|
||||
auto ind_offset = index.y;
|
||||
|
||||
size_t src_idx = 0;
|
||||
for (int i = 0; i < NIDX; ++i) {
|
||||
auto idx_loc = elem_to_loc(
|
||||
ind_idx,
|
||||
&indices.shapes[indices.ndim * i],
|
||||
&indices.strides[indices.ndim * i],
|
||||
indices.ndim);
|
||||
size_t idx_loc;
|
||||
if (IDX_NDIM == 0) {
|
||||
idx_loc = 0;
|
||||
} else if (IDX_NDIM == 1) {
|
||||
idx_loc = ind_idx * indices.strides[indices.ndim * i];
|
||||
} else {
|
||||
idx_loc = elem_to_loc(
|
||||
ind_idx,
|
||||
&indices.shapes[indices.ndim * i],
|
||||
&indices.strides[indices.ndim * i],
|
||||
indices.ndim);
|
||||
}
|
||||
auto ax = axes[i];
|
||||
auto idx_val = offset_neg_idx(
|
||||
indices.buffers[i][idx_loc], src_shape[ax]);
|
||||
@ -67,22 +76,49 @@ template <typename T, typename IdxT, int NIDX>
|
||||
|
||||
auto src_offset = elem_to_loc(
|
||||
ind_offset, slice_sizes, src_strides, src_ndim);
|
||||
out[gid] = src[src_idx + src_offset];
|
||||
|
||||
size_t out_idx = index.y + static_cast<size_t>(grid_dim.y) * index.x;
|
||||
out[out_idx] = src[src_offset + src_idx];
|
||||
}
|
||||
|
||||
#define instantiate_gather4(name, src_type, ind_type, nindex) \
|
||||
template [[host_name("gather" name "_" #nindex)]] \
|
||||
[[kernel]] void gather( \
|
||||
template [[host_name("gather" name "_" #nindex "_0")]] \
|
||||
[[kernel]] void gather<src_type, ind_type, nindex, 0>( \
|
||||
const device src_type *src [[buffer(0)]], \
|
||||
const device Indices<ind_type, nindex>& indices [[buffer(1)]], \
|
||||
const constant Indices<ind_type, nindex>& indices [[buffer(1)]], \
|
||||
device src_type *out [[buffer(2)]], \
|
||||
const device int *src_shape [[buffer(3)]], \
|
||||
const device size_t *src_strides [[buffer(4)]], \
|
||||
const device size_t& src_ndim [[buffer(5)]], \
|
||||
const device int *slice_sizes [[buffer(6)]], \
|
||||
const device size_t& slice_size [[buffer(7)]], \
|
||||
const device int* axes [[buffer(8)]], \
|
||||
uint gid [[thread_position_in_grid]]);
|
||||
const constant int *src_shape [[buffer(3)]], \
|
||||
const constant size_t *src_strides [[buffer(4)]], \
|
||||
const constant size_t& src_ndim [[buffer(5)]], \
|
||||
const constant int *slice_sizes [[buffer(6)]], \
|
||||
const constant int* axes [[buffer(7)]], \
|
||||
uint2 index [[thread_position_in_grid]], \
|
||||
uint2 grid_dim [[threads_per_grid]]); \
|
||||
template [[host_name("gather" name "_" #nindex "_1")]] \
|
||||
[[kernel]] void gather<src_type, ind_type, nindex, 1>( \
|
||||
const device src_type *src [[buffer(0)]], \
|
||||
const constant Indices<ind_type, nindex>& indices [[buffer(1)]], \
|
||||
device src_type *out [[buffer(2)]], \
|
||||
const constant int *src_shape [[buffer(3)]], \
|
||||
const constant size_t *src_strides [[buffer(4)]], \
|
||||
const constant size_t& src_ndim [[buffer(5)]], \
|
||||
const constant int *slice_sizes [[buffer(6)]], \
|
||||
const constant int* axes [[buffer(7)]], \
|
||||
uint2 index [[thread_position_in_grid]], \
|
||||
uint2 grid_dim [[threads_per_grid]]); \
|
||||
template [[host_name("gather" name "_" #nindex)]] \
|
||||
[[kernel]] void gather<src_type, ind_type, nindex, 2>( \
|
||||
const device src_type *src [[buffer(0)]], \
|
||||
const constant Indices<ind_type, nindex>& indices [[buffer(1)]], \
|
||||
device src_type *out [[buffer(2)]], \
|
||||
const constant int *src_shape [[buffer(3)]], \
|
||||
const constant size_t *src_strides [[buffer(4)]], \
|
||||
const constant size_t& src_ndim [[buffer(5)]], \
|
||||
const constant int *slice_sizes [[buffer(6)]], \
|
||||
const constant int* axes [[buffer(7)]], \
|
||||
uint2 index [[thread_position_in_grid]], \
|
||||
uint2 grid_dim [[threads_per_grid]]);
|
||||
|
||||
|
||||
// Special for case NIDX=0
|
||||
instantiate_gather4("bool_", bool, bool, 0)
|
||||
|
Loading…
Reference in New Issue
Block a user