mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 09:51:17 +08:00
parent
5c3ac52dd7
commit
e319383ef9
64
benchmarks/python/gather_bench.py
Normal file
64
benchmarks/python/gather_bench.py
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
from time import time
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def measure_runtime(fn, **kwargs):
|
||||||
|
# Warmup
|
||||||
|
for _ in range(5):
|
||||||
|
fn(**kwargs)
|
||||||
|
|
||||||
|
tic = time()
|
||||||
|
iters = 10
|
||||||
|
for _ in range(iters):
|
||||||
|
fn(**kwargs)
|
||||||
|
return (time() - tic) * 1000 / iters
|
||||||
|
|
||||||
|
|
||||||
|
def benchmark_gather_mlx(x_shape, idx_shape):
|
||||||
|
def gather(x, idx):
|
||||||
|
mx.eval(x[idx])
|
||||||
|
|
||||||
|
idx = mx.random.randint(0, x_shape[0] - 1, idx_shape)
|
||||||
|
x = mx.random.normal(x_shape).astype(mx.float32)
|
||||||
|
|
||||||
|
runtime = measure_runtime(gather, x=x, idx=idx)
|
||||||
|
print(f"MLX: {runtime:.3f}ms")
|
||||||
|
|
||||||
|
|
||||||
|
def benchmark_gather_torch(x_shape, idx_shape, device):
|
||||||
|
def gather(x, idx, device):
|
||||||
|
_ = x[idx]
|
||||||
|
if device == torch.device("mps"):
|
||||||
|
torch.mps.synchronize()
|
||||||
|
|
||||||
|
idx = torch.randint(0, x_shape[0] - 1, idx_shape).to(device)
|
||||||
|
x = torch.randn(x_shape, dtype=torch.float32).to(device)
|
||||||
|
|
||||||
|
runtime = measure_runtime(gather, x=x, idx=idx, device=device)
|
||||||
|
print(f"PyTorch: {runtime:.3f}ms")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser("Gather benchmarks.")
|
||||||
|
parser.add_argument("--cpu", action="store_true", help="Use the CPU.")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.cpu:
|
||||||
|
mx.set_default_device(mx.cpu)
|
||||||
|
device = torch.device("cpu")
|
||||||
|
else:
|
||||||
|
device = torch.device("mps")
|
||||||
|
|
||||||
|
idx_shapes = [(1_000_000,), (100_000,), ()]
|
||||||
|
x_shapes = [(100, 64), (100, 1024), (4, 1_000_000)]
|
||||||
|
|
||||||
|
for x_shape, idx_shape in zip(x_shapes, idx_shapes):
|
||||||
|
print("=" * 20)
|
||||||
|
print(f"X {x_shape}, Indices {idx_shape}")
|
||||||
|
benchmark_gather_mlx(x_shape, idx_shape)
|
||||||
|
benchmark_gather_torch(x_shape, idx_shape, device=device)
|
@ -1,4 +1,4 @@
|
|||||||
// Copyright © 2023 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
@ -39,9 +39,15 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
auto& s = stream();
|
auto& s = stream();
|
||||||
auto& d = metal::device(s.device);
|
auto& d = metal::device(s.device);
|
||||||
|
|
||||||
|
int idx_ndim = nidx ? inputs[1].ndim() : 0;
|
||||||
|
size_t ndim = src.ndim();
|
||||||
|
|
||||||
std::ostringstream kname;
|
std::ostringstream kname;
|
||||||
std::string idx_type_name = nidx ? type_to_name(inputs[1]) : "";
|
std::string idx_type_name = nidx ? type_to_name(inputs[1]) : "";
|
||||||
kname << "gather" << type_to_name(src) << idx_type_name << "_" << nidx;
|
kname << "gather" << type_to_name(src) << idx_type_name << "_" << nidx;
|
||||||
|
if (idx_ndim <= 1) {
|
||||||
|
kname << "_" << idx_ndim;
|
||||||
|
}
|
||||||
|
|
||||||
auto compute_encoder = d.get_command_encoder(s.index);
|
auto compute_encoder = d.get_command_encoder(s.index);
|
||||||
auto kernel = d.get_kernel(kname.str());
|
auto kernel = d.get_kernel(kname.str());
|
||||||
@ -51,15 +57,11 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
slice_size *= s;
|
slice_size *= s;
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t ndim = src.ndim();
|
// Launch 2D grid of threads: indices x slice
|
||||||
size_t nthreads = out.size();
|
size_t dim0 = out.size() / slice_size;
|
||||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
size_t dim1 = slice_size;
|
||||||
if (thread_group_size > nthreads) {
|
auto group_dims = get_block_dims(dim0, dim1, 1);
|
||||||
thread_group_size = nthreads;
|
MTL::Size grid_dims = MTL::Size(dim0, dim1, 1);
|
||||||
}
|
|
||||||
|
|
||||||
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
|
|
||||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
|
||||||
|
|
||||||
compute_encoder->setComputePipelineState(kernel);
|
compute_encoder->setComputePipelineState(kernel);
|
||||||
|
|
||||||
@ -90,7 +92,6 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
auto arg_enc = d.argument_encoder(arg_descs);
|
auto arg_enc = d.argument_encoder(arg_descs);
|
||||||
|
|
||||||
// Allocate and fill buffers for shapes and strides
|
// Allocate and fill buffers for shapes and strides
|
||||||
int idx_ndim = nidx ? inputs[1].ndim() : 0;
|
|
||||||
auto idx_shapes_buf = allocator::malloc_or_wait(sizeof(int) * idx_ndim);
|
auto idx_shapes_buf = allocator::malloc_or_wait(sizeof(int) * idx_ndim);
|
||||||
auto idx_strides_buf = allocator::malloc_or_wait(sizeof(size_t) * idx_ndim);
|
auto idx_strides_buf = allocator::malloc_or_wait(sizeof(size_t) * idx_ndim);
|
||||||
for (int i = 0; i < nidx; ++i) {
|
for (int i = 0; i < nidx; ++i) {
|
||||||
@ -130,12 +131,12 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
set_array_buffer(compute_encoder, src, 0);
|
set_array_buffer(compute_encoder, src, 0);
|
||||||
compute_encoder->setBuffer(static_cast<MTL::Buffer*>(arg_buf.ptr()), 0, 1);
|
compute_encoder->setBuffer(static_cast<MTL::Buffer*>(arg_buf.ptr()), 0, 1);
|
||||||
set_array_buffer(compute_encoder, out, 2);
|
set_array_buffer(compute_encoder, out, 2);
|
||||||
|
|
||||||
compute_encoder->setBytes(src.shape().data(), ndim * sizeof(int), 3);
|
compute_encoder->setBytes(src.shape().data(), ndim * sizeof(int), 3);
|
||||||
compute_encoder->setBytes(src.strides().data(), ndim * sizeof(size_t), 4);
|
compute_encoder->setBytes(src.strides().data(), ndim * sizeof(size_t), 4);
|
||||||
compute_encoder->setBytes(&ndim, sizeof(size_t), 5);
|
compute_encoder->setBytes(&ndim, sizeof(size_t), 5);
|
||||||
compute_encoder->setBytes(slice_sizes_.data(), ndim * sizeof(int), 6);
|
compute_encoder->setBytes(slice_sizes_.data(), ndim * sizeof(int), 6);
|
||||||
compute_encoder->setBytes(&slice_size, sizeof(size_t), 7);
|
compute_encoder->setBytes(axes_.data(), nidx * sizeof(int), 7);
|
||||||
compute_encoder->setBytes(axes_.data(), nidx * sizeof(int), 8);
|
|
||||||
|
|
||||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||||
|
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
// Copyright © 2023 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
#include <metal_atomic>
|
#include <metal_atomic>
|
||||||
#include <metal_texture>
|
#include <metal_texture>
|
||||||
@ -36,29 +36,38 @@ inline size_t offset_neg_idx(uint32_t idx, size_t) {
|
|||||||
return idx;
|
return idx;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename IdxT, int NIDX>
|
// IDX_NDIM is the number of dimensions of the indices arrays. Compile-time
|
||||||
|
// special case for 0 and 1. Anything >= 2 uses the general case
|
||||||
|
template <typename T, typename IdxT, int NIDX, int IDX_NDIM>
|
||||||
[[kernel]] void gather(
|
[[kernel]] void gather(
|
||||||
const device T *src [[buffer(0)]],
|
const device T *src [[buffer(0)]],
|
||||||
const device Indices<IdxT, NIDX>& indices [[buffer(1)]],
|
const constant Indices<IdxT, NIDX>& indices [[buffer(1)]],
|
||||||
device T *out [[buffer(2)]],
|
device T *out [[buffer(2)]],
|
||||||
const device int *src_shape [[buffer(3)]],
|
const constant int *src_shape [[buffer(3)]],
|
||||||
const device size_t *src_strides [[buffer(4)]],
|
const constant size_t *src_strides [[buffer(4)]],
|
||||||
const device size_t& src_ndim [[buffer(5)]],
|
const constant size_t& src_ndim [[buffer(5)]],
|
||||||
const device int *slice_sizes [[buffer(6)]],
|
const constant int *slice_sizes [[buffer(6)]],
|
||||||
const device size_t& slice_size [[buffer(7)]],
|
const constant int *axes [[buffer(7)]],
|
||||||
const device int *axes [[buffer(8)]],
|
uint2 index [[thread_position_in_grid]],
|
||||||
uint gid [[thread_position_in_grid]]) {
|
uint2 grid_dim [[threads_per_grid]]) {
|
||||||
|
|
||||||
auto ind_idx = gid / slice_size;
|
auto ind_idx = index.x;
|
||||||
auto ind_offset = gid % slice_size;
|
auto ind_offset = index.y;
|
||||||
|
|
||||||
size_t src_idx = 0;
|
size_t src_idx = 0;
|
||||||
for (int i = 0; i < NIDX; ++i) {
|
for (int i = 0; i < NIDX; ++i) {
|
||||||
auto idx_loc = elem_to_loc(
|
size_t idx_loc;
|
||||||
ind_idx,
|
if (IDX_NDIM == 0) {
|
||||||
&indices.shapes[indices.ndim * i],
|
idx_loc = 0;
|
||||||
&indices.strides[indices.ndim * i],
|
} else if (IDX_NDIM == 1) {
|
||||||
indices.ndim);
|
idx_loc = ind_idx * indices.strides[indices.ndim * i];
|
||||||
|
} else {
|
||||||
|
idx_loc = elem_to_loc(
|
||||||
|
ind_idx,
|
||||||
|
&indices.shapes[indices.ndim * i],
|
||||||
|
&indices.strides[indices.ndim * i],
|
||||||
|
indices.ndim);
|
||||||
|
}
|
||||||
auto ax = axes[i];
|
auto ax = axes[i];
|
||||||
auto idx_val = offset_neg_idx(
|
auto idx_val = offset_neg_idx(
|
||||||
indices.buffers[i][idx_loc], src_shape[ax]);
|
indices.buffers[i][idx_loc], src_shape[ax]);
|
||||||
@ -67,22 +76,49 @@ template <typename T, typename IdxT, int NIDX>
|
|||||||
|
|
||||||
auto src_offset = elem_to_loc(
|
auto src_offset = elem_to_loc(
|
||||||
ind_offset, slice_sizes, src_strides, src_ndim);
|
ind_offset, slice_sizes, src_strides, src_ndim);
|
||||||
out[gid] = src[src_idx + src_offset];
|
|
||||||
|
size_t out_idx = index.y + static_cast<size_t>(grid_dim.y) * index.x;
|
||||||
|
out[out_idx] = src[src_offset + src_idx];
|
||||||
}
|
}
|
||||||
|
|
||||||
#define instantiate_gather4(name, src_type, ind_type, nindex) \
|
#define instantiate_gather4(name, src_type, ind_type, nindex) \
|
||||||
template [[host_name("gather" name "_" #nindex)]] \
|
template [[host_name("gather" name "_" #nindex "_0")]] \
|
||||||
[[kernel]] void gather( \
|
[[kernel]] void gather<src_type, ind_type, nindex, 0>( \
|
||||||
const device src_type *src [[buffer(0)]], \
|
const device src_type *src [[buffer(0)]], \
|
||||||
const device Indices<ind_type, nindex>& indices [[buffer(1)]], \
|
const constant Indices<ind_type, nindex>& indices [[buffer(1)]], \
|
||||||
device src_type *out [[buffer(2)]], \
|
device src_type *out [[buffer(2)]], \
|
||||||
const device int *src_shape [[buffer(3)]], \
|
const constant int *src_shape [[buffer(3)]], \
|
||||||
const device size_t *src_strides [[buffer(4)]], \
|
const constant size_t *src_strides [[buffer(4)]], \
|
||||||
const device size_t& src_ndim [[buffer(5)]], \
|
const constant size_t& src_ndim [[buffer(5)]], \
|
||||||
const device int *slice_sizes [[buffer(6)]], \
|
const constant int *slice_sizes [[buffer(6)]], \
|
||||||
const device size_t& slice_size [[buffer(7)]], \
|
const constant int* axes [[buffer(7)]], \
|
||||||
const device int* axes [[buffer(8)]], \
|
uint2 index [[thread_position_in_grid]], \
|
||||||
uint gid [[thread_position_in_grid]]);
|
uint2 grid_dim [[threads_per_grid]]); \
|
||||||
|
template [[host_name("gather" name "_" #nindex "_1")]] \
|
||||||
|
[[kernel]] void gather<src_type, ind_type, nindex, 1>( \
|
||||||
|
const device src_type *src [[buffer(0)]], \
|
||||||
|
const constant Indices<ind_type, nindex>& indices [[buffer(1)]], \
|
||||||
|
device src_type *out [[buffer(2)]], \
|
||||||
|
const constant int *src_shape [[buffer(3)]], \
|
||||||
|
const constant size_t *src_strides [[buffer(4)]], \
|
||||||
|
const constant size_t& src_ndim [[buffer(5)]], \
|
||||||
|
const constant int *slice_sizes [[buffer(6)]], \
|
||||||
|
const constant int* axes [[buffer(7)]], \
|
||||||
|
uint2 index [[thread_position_in_grid]], \
|
||||||
|
uint2 grid_dim [[threads_per_grid]]); \
|
||||||
|
template [[host_name("gather" name "_" #nindex)]] \
|
||||||
|
[[kernel]] void gather<src_type, ind_type, nindex, 2>( \
|
||||||
|
const device src_type *src [[buffer(0)]], \
|
||||||
|
const constant Indices<ind_type, nindex>& indices [[buffer(1)]], \
|
||||||
|
device src_type *out [[buffer(2)]], \
|
||||||
|
const constant int *src_shape [[buffer(3)]], \
|
||||||
|
const constant size_t *src_strides [[buffer(4)]], \
|
||||||
|
const constant size_t& src_ndim [[buffer(5)]], \
|
||||||
|
const constant int *slice_sizes [[buffer(6)]], \
|
||||||
|
const constant int* axes [[buffer(7)]], \
|
||||||
|
uint2 index [[thread_position_in_grid]], \
|
||||||
|
uint2 grid_dim [[threads_per_grid]]);
|
||||||
|
|
||||||
|
|
||||||
// Special for case NIDX=0
|
// Special for case NIDX=0
|
||||||
instantiate_gather4("bool_", bool, bool, 0)
|
instantiate_gather4("bool_", bool, bool, 0)
|
||||||
|
Loading…
Reference in New Issue
Block a user