mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Update gather and scatter to not use Argument Encoder (#683)
* Replace argument encoder usage for gather and scatter * Use constant address space for shapes and strides * Split gather and scatter to improve compile times * Enable the GPU tests * Update the CI config * Fix scatter dispatch for scalar indices * Remove arg encoder utils --------- Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
parent
1eb04aa23f
commit
1a48713d32
@ -94,8 +94,7 @@ jobs:
|
||||
command: |
|
||||
source env/bin/activate
|
||||
LOW_MEMORY=1 DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
|
||||
# TODO: Reenable when Circle CI can run gpu jobs
|
||||
# DEVICE=gpu python3.9 -m xmlrunner discover -v python/tests -o test-results/gpu
|
||||
LOW_MEMORY=1 DEVICE=gpu python3.9 -m xmlrunner discover -v python/tests -o test-results/gpu
|
||||
# TODO: Reenable when extension api becomes stable
|
||||
# - run:
|
||||
# name: Build example extension
|
||||
@ -110,8 +109,9 @@ jobs:
|
||||
mkdir -p build && cd build && cmake .. && make -j
|
||||
- run:
|
||||
name: Run CPP tests
|
||||
#command: METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 ./build/tests/tests
|
||||
command: DEVICE=cpu ./build/tests/tests
|
||||
command: |
|
||||
DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 ./build/tests/tests
|
||||
DEVICE=cpu ./build/tests/tests
|
||||
|
||||
build_release:
|
||||
parameters:
|
||||
@ -225,7 +225,9 @@ workflows:
|
||||
build_and_test:
|
||||
when:
|
||||
and:
|
||||
- equal: [ main, << pipeline.git.branch >> ]
|
||||
- matches:
|
||||
pattern: "^(?!pull/)[-\\w]+$"
|
||||
value: << pipeline.git.branch >>
|
||||
- not: << pipeline.parameters.nightly_build >>
|
||||
- not: << pipeline.parameters.weekly_build >>
|
||||
- not: << pipeline.parameters.test_release >>
|
||||
|
@ -215,15 +215,6 @@ MTL::ComputeCommandEncoder* Device::get_command_encoder(int index) {
|
||||
return eit->second;
|
||||
}
|
||||
|
||||
MTL::ArgumentEncoder* Device::argument_encoder(
|
||||
const std::vector<MTL::ArgumentDescriptor*>& arg_descs) const {
|
||||
// NB array here is already autoreleased but the returned argument
|
||||
// encoder is owned by the caller and must be released/autoreleased
|
||||
NS::Array* arg_desc_arr = NS::Array::array(
|
||||
reinterpret_cast<NS::Object* const*>(arg_descs.data()), arg_descs.size());
|
||||
return device_->newArgumentEncoder(arg_desc_arr);
|
||||
}
|
||||
|
||||
void Device::register_library(
|
||||
const std::string& lib_name,
|
||||
const std::string& lib_path) {
|
||||
|
@ -51,6 +51,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
size_t slice_size = 1;
|
||||
for (auto s : slice_sizes_) {
|
||||
@ -63,91 +64,50 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto group_dims = get_block_dims(dim0, dim1, 1);
|
||||
MTL::Size grid_dims = MTL::Size(dim0, dim1, 1);
|
||||
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
// Collect all idx shapes and strides into one place
|
||||
std::vector<int> idx_shapes;
|
||||
std::vector<size_t> idx_strides;
|
||||
|
||||
// Make the argument buffer to store the indices for the
|
||||
// `Indices` struct in kernels/indexing.metal
|
||||
std::vector<MTL::ArgumentDescriptor*> arg_descs(4);
|
||||
arg_descs[0] = MTL::ArgumentDescriptor::argumentDescriptor();
|
||||
arg_descs[0]->setIndex(0);
|
||||
arg_descs[0]->setDataType(MTL::DataType::DataTypePointer);
|
||||
arg_descs[0]->setArrayLength(nidx);
|
||||
|
||||
// Shapes
|
||||
arg_descs[1] = MTL::ArgumentDescriptor::argumentDescriptor();
|
||||
arg_descs[1]->setDataType(MTL::DataType::DataTypePointer);
|
||||
arg_descs[1]->setIndex(nidx + 1);
|
||||
|
||||
// Strides
|
||||
arg_descs[2] = MTL::ArgumentDescriptor::argumentDescriptor();
|
||||
arg_descs[2]->setDataType(MTL::DataType::DataTypePointer);
|
||||
arg_descs[2]->setIndex(nidx + 2);
|
||||
|
||||
// Indices ndim
|
||||
arg_descs[3] = MTL::ArgumentDescriptor::argumentDescriptor();
|
||||
arg_descs[3]->setDataType(MTL::DataType::DataTypeInt);
|
||||
arg_descs[3]->setIndex(nidx + 3);
|
||||
|
||||
// Get the argument encoder
|
||||
auto arg_enc = d.argument_encoder(arg_descs);
|
||||
|
||||
// Allocate and fill buffers for shapes and strides
|
||||
auto idx_shapes_buf = allocator::malloc_or_wait(sizeof(int) * idx_ndim);
|
||||
auto idx_strides_buf = allocator::malloc_or_wait(sizeof(size_t) * idx_ndim);
|
||||
for (int i = 0; i < nidx; ++i) {
|
||||
std::copy(
|
||||
idx_shapes.insert(
|
||||
idx_shapes.end(),
|
||||
inputs[i + 1].shape().begin(),
|
||||
inputs[i + 1].shape().end(),
|
||||
static_cast<int*>(idx_shapes_buf.raw_ptr()) + i * idx_ndim);
|
||||
std::copy(
|
||||
inputs[i + 1].shape().end());
|
||||
|
||||
idx_strides.insert(
|
||||
idx_strides.end(),
|
||||
inputs[i + 1].strides().begin(),
|
||||
inputs[i + 1].strides().end(),
|
||||
static_cast<size_t*>(idx_strides_buf.raw_ptr()) + i * idx_ndim);
|
||||
inputs[i + 1].strides().end());
|
||||
}
|
||||
|
||||
// Allocate the argument buffer
|
||||
auto arg_buf = allocator::malloc_or_wait(arg_enc->encodedLength());
|
||||
|
||||
// Register data with the encoder
|
||||
arg_enc->setArgumentBuffer(static_cast<MTL::Buffer*>(arg_buf.ptr()), 0);
|
||||
for (int i = 0; i < nidx; ++i) {
|
||||
set_array_buffer(compute_encoder, arg_enc, inputs[i + 1], i);
|
||||
}
|
||||
if (idx_ndim > 0) {
|
||||
arg_enc->setBuffer(
|
||||
static_cast<MTL::Buffer*>(idx_shapes_buf.ptr()), 0, nidx + 1);
|
||||
compute_encoder->useResource(
|
||||
static_cast<MTL::Buffer*>(idx_shapes_buf.ptr()),
|
||||
MTL::ResourceUsageRead);
|
||||
arg_enc->setBuffer(
|
||||
static_cast<MTL::Buffer*>(idx_strides_buf.ptr()), 0, nidx + 2);
|
||||
compute_encoder->useResource(
|
||||
static_cast<MTL::Buffer*>(idx_strides_buf.ptr()),
|
||||
MTL::ResourceUsageRead);
|
||||
}
|
||||
*static_cast<int*>(arg_enc->constantData(nidx + 3)) = idx_ndim;
|
||||
|
||||
// Set all the buffers
|
||||
set_array_buffer(compute_encoder, src, 0);
|
||||
compute_encoder->setBuffer(static_cast<MTL::Buffer*>(arg_buf.ptr()), 0, 1);
|
||||
set_array_buffer(compute_encoder, out, 2);
|
||||
set_array_buffer(compute_encoder, out, 1);
|
||||
|
||||
compute_encoder->setBytes(src.shape().data(), ndim * sizeof(int), 3);
|
||||
compute_encoder->setBytes(src.strides().data(), ndim * sizeof(size_t), 4);
|
||||
compute_encoder->setBytes(&ndim, sizeof(size_t), 5);
|
||||
compute_encoder->setBytes(slice_sizes_.data(), ndim * sizeof(int), 6);
|
||||
compute_encoder->setBytes(axes_.data(), nidx * sizeof(int), 7);
|
||||
// Set source info
|
||||
compute_encoder->setBytes(src.shape().data(), ndim * sizeof(int), 2);
|
||||
compute_encoder->setBytes(src.strides().data(), ndim * sizeof(size_t), 3);
|
||||
compute_encoder->setBytes(&ndim, sizeof(size_t), 4);
|
||||
compute_encoder->setBytes(slice_sizes_.data(), ndim * sizeof(int), 5);
|
||||
compute_encoder->setBytes(axes_.data(), nidx * sizeof(int), 6);
|
||||
|
||||
// Set index info
|
||||
//
|
||||
// We don't need to check for empty idx_shapes because gather has a
|
||||
// idx_ndim == 0 specialization
|
||||
compute_encoder->setBytes(
|
||||
idx_shapes.data(), idx_shapes.size() * sizeof(int), 7);
|
||||
compute_encoder->setBytes(
|
||||
idx_strides.data(), idx_strides.size() * sizeof(size_t), 8);
|
||||
compute_encoder->setBytes(&idx_ndim, sizeof(int), 9);
|
||||
|
||||
// Set index buffers
|
||||
for (int i = 1; i < nidx + 1; ++i) {
|
||||
set_array_buffer(compute_encoder, inputs[i], 20 + i);
|
||||
}
|
||||
|
||||
// Launch grid
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
|
||||
// Cleanup temporaries
|
||||
arg_enc->release();
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[arg_buf, idx_shapes_buf, idx_strides_buf](MTL::CommandBuffer*) {
|
||||
allocator::free(arg_buf);
|
||||
allocator::free(idx_shapes_buf);
|
||||
allocator::free(idx_strides_buf);
|
||||
});
|
||||
}
|
||||
|
||||
void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
@ -214,77 +174,33 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
// Make the argument buffer to store the indices for the
|
||||
// `Indices` struct in kernels/indexing.metal
|
||||
std::vector<MTL::ArgumentDescriptor*> arg_descs(4);
|
||||
arg_descs[0] = MTL::ArgumentDescriptor::argumentDescriptor();
|
||||
arg_descs[0]->setIndex(0);
|
||||
arg_descs[0]->setDataType(MTL::DataType::DataTypePointer);
|
||||
arg_descs[0]->setArrayLength(nidx);
|
||||
|
||||
// Shapes
|
||||
arg_descs[1] = MTL::ArgumentDescriptor::argumentDescriptor();
|
||||
arg_descs[1]->setDataType(MTL::DataType::DataTypePointer);
|
||||
arg_descs[1]->setIndex(nidx + 1);
|
||||
|
||||
// Strides
|
||||
arg_descs[2] = MTL::ArgumentDescriptor::argumentDescriptor();
|
||||
arg_descs[2]->setDataType(MTL::DataType::DataTypePointer);
|
||||
arg_descs[2]->setIndex(nidx + 2);
|
||||
|
||||
// Indices ndim
|
||||
arg_descs[3] = MTL::ArgumentDescriptor::argumentDescriptor();
|
||||
arg_descs[3]->setDataType(MTL::DataType::DataTypeInt);
|
||||
arg_descs[3]->setIndex(nidx + 3);
|
||||
|
||||
// Get the argument encoder
|
||||
auto arg_enc = d.argument_encoder(arg_descs);
|
||||
|
||||
// Allocate and fill buffers for shapes and strides
|
||||
// Collect all idx shapes and strides into one place
|
||||
int idx_ndim = nidx ? inputs[1].ndim() : 0;
|
||||
auto idx_shapes_buf = allocator::malloc_or_wait(sizeof(int) * idx_ndim);
|
||||
auto idx_strides_buf = allocator::malloc_or_wait(sizeof(size_t) * idx_ndim);
|
||||
std::vector<int> idx_shapes;
|
||||
std::vector<size_t> idx_strides;
|
||||
|
||||
for (int i = 0; i < nidx; ++i) {
|
||||
std::copy(
|
||||
idx_shapes.insert(
|
||||
idx_shapes.end(),
|
||||
inputs[i + 1].shape().begin(),
|
||||
inputs[i + 1].shape().end(),
|
||||
static_cast<int*>(idx_shapes_buf.raw_ptr()) + i * idx_ndim);
|
||||
std::copy(
|
||||
inputs[i + 1].shape().end());
|
||||
|
||||
idx_strides.insert(
|
||||
idx_strides.end(),
|
||||
inputs[i + 1].strides().begin(),
|
||||
inputs[i + 1].strides().end(),
|
||||
static_cast<size_t*>(idx_strides_buf.raw_ptr()) + i * idx_ndim);
|
||||
inputs[i + 1].strides().end());
|
||||
}
|
||||
|
||||
// Allocate the argument buffer
|
||||
auto arg_buf = allocator::malloc_or_wait(arg_enc->encodedLength());
|
||||
// Set all the buffers
|
||||
set_array_buffer(compute_encoder, upd, 1);
|
||||
set_array_buffer(compute_encoder, out, 2);
|
||||
|
||||
// Register data with the encoder
|
||||
arg_enc->setArgumentBuffer(static_cast<MTL::Buffer*>(arg_buf.ptr()), 0);
|
||||
for (int i = 0; i < nidx; ++i) {
|
||||
set_array_buffer(compute_encoder, arg_enc, inputs[i + 1], i);
|
||||
}
|
||||
if (idx_ndim > 0) {
|
||||
arg_enc->setBuffer(
|
||||
static_cast<MTL::Buffer*>(idx_shapes_buf.ptr()), 0, nidx + 1);
|
||||
compute_encoder->useResource(
|
||||
static_cast<MTL::Buffer*>(idx_shapes_buf.ptr()),
|
||||
MTL::ResourceUsageRead);
|
||||
arg_enc->setBuffer(
|
||||
static_cast<MTL::Buffer*>(idx_strides_buf.ptr()), 0, nidx + 2);
|
||||
compute_encoder->useResource(
|
||||
static_cast<MTL::Buffer*>(idx_strides_buf.ptr()),
|
||||
MTL::ResourceUsageRead);
|
||||
}
|
||||
*static_cast<int*>(arg_enc->constantData(nidx + 3)) = idx_ndim;
|
||||
|
||||
compute_encoder->setBuffer(static_cast<MTL::Buffer*>(arg_buf.ptr()), 0, 0);
|
||||
// Set update info
|
||||
size_t upd_ndim = upd.ndim();
|
||||
size_t upd_size = 1;
|
||||
for (int i = idx_ndim; i < upd.ndim(); ++i) {
|
||||
upd_size *= upd.shape(i);
|
||||
}
|
||||
set_array_buffer(compute_encoder, upd, 1);
|
||||
set_array_buffer(compute_encoder, out, 2);
|
||||
if (upd_ndim == 0) {
|
||||
// Need placeholders so Metal doesn't compalain
|
||||
int shape_ = 0;
|
||||
@ -299,6 +215,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder->setBytes(&upd_ndim, sizeof(size_t), 5);
|
||||
compute_encoder->setBytes(&upd_size, sizeof(size_t), 6);
|
||||
|
||||
// Set output info
|
||||
size_t out_ndim = out.ndim();
|
||||
if (out_ndim == 0) {
|
||||
// Need placeholders so Metal doesn't compalain
|
||||
@ -314,18 +231,28 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder->setBytes(&out_ndim, sizeof(size_t), 9);
|
||||
compute_encoder->setBytes(axes_.data(), axes_.size() * sizeof(int), 10);
|
||||
|
||||
// Set index info
|
||||
if (idx_ndim == 0) {
|
||||
// Add a 0 in idx_shapes and strides to avoid the missing buffer binding
|
||||
// error in the metal API.
|
||||
idx_shapes.push_back(0);
|
||||
idx_strides.push_back(0);
|
||||
}
|
||||
compute_encoder->setBytes(
|
||||
idx_shapes.data(), idx_shapes.size() * sizeof(int), 11);
|
||||
compute_encoder->setBytes(
|
||||
idx_strides.data(), idx_strides.size() * sizeof(size_t), 12);
|
||||
compute_encoder->setBytes(&idx_ndim, sizeof(int), 13);
|
||||
|
||||
// Set index buffers
|
||||
for (int i = 1; i < nidx + 1; ++i) {
|
||||
set_array_buffer(compute_encoder, inputs[i], 20 + i);
|
||||
}
|
||||
|
||||
// Launch grid
|
||||
MTL::Size grid_dims = MTL::Size(upd_size, nthreads / upd_size, 1);
|
||||
MTL::Size group_dims = get_block_dims(upd_size, nthreads / upd_size, 1);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
|
||||
// Cleanup temporaries
|
||||
arg_enc->release();
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[arg_buf, idx_shapes_buf, idx_strides_buf](MTL::CommandBuffer*) {
|
||||
allocator::free(arg_buf);
|
||||
allocator::free(idx_shapes_buf);
|
||||
allocator::free(idx_strides_buf);
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -6,6 +6,7 @@ set(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/complex.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/defines.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/erf.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/indexing.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/utils.h
|
||||
)
|
||||
@ -26,7 +27,8 @@ set(
|
||||
"softmax"
|
||||
"sort"
|
||||
"unary"
|
||||
"indexing"
|
||||
"gather"
|
||||
"scatter"
|
||||
)
|
||||
|
||||
function(build_kernel_base TARGET SRCFILE DEPS)
|
||||
|
187
mlx/backend/metal/kernels/gather.metal
Normal file
187
mlx/backend/metal/kernels/gather.metal
Normal file
@ -0,0 +1,187 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <metal_atomic>
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/indexing.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
using namespace metal;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////
|
||||
// Gather kernel
|
||||
/////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T, typename IdxT, int NIDX, int IDX_NDIM>
|
||||
METAL_FUNC void gather_impl(
|
||||
const device T *src [[buffer(0)]],
|
||||
device T *out [[buffer(1)]],
|
||||
const constant int *src_shape [[buffer(2)]],
|
||||
const constant size_t *src_strides [[buffer(3)]],
|
||||
const constant size_t& src_ndim [[buffer(4)]],
|
||||
const constant int *slice_sizes [[buffer(5)]],
|
||||
const constant int *axes [[buffer(6)]],
|
||||
const thread Indices<IdxT, NIDX>& indices,
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
|
||||
auto ind_idx = index.x;
|
||||
auto ind_offset = index.y;
|
||||
|
||||
size_t src_idx = 0;
|
||||
for (int i = 0; i < NIDX; ++i) {
|
||||
size_t idx_loc;
|
||||
if (IDX_NDIM == 0) {
|
||||
idx_loc = 0;
|
||||
} else if (IDX_NDIM == 1) {
|
||||
idx_loc = ind_idx * indices.strides[indices.ndim * i];
|
||||
} else {
|
||||
idx_loc = elem_to_loc(
|
||||
ind_idx,
|
||||
&indices.shapes[indices.ndim * i],
|
||||
&indices.strides[indices.ndim * i],
|
||||
indices.ndim);
|
||||
}
|
||||
auto ax = axes[i];
|
||||
auto idx_val = offset_neg_idx(
|
||||
indices.buffers[i][idx_loc], src_shape[ax]);
|
||||
src_idx += idx_val * src_strides[ax];
|
||||
}
|
||||
|
||||
auto src_offset = elem_to_loc(
|
||||
ind_offset, slice_sizes, src_strides, src_ndim);
|
||||
|
||||
size_t out_idx = index.y + static_cast<size_t>(grid_dim.y) * index.x;
|
||||
out[out_idx] = src[src_offset + src_idx];
|
||||
|
||||
}
|
||||
|
||||
#define make_gather_impl(IDX_ARG, IDX_ARR) \
|
||||
template <typename T, typename IdxT, int NIDX, int IDX_NDIM> \
|
||||
[[kernel]] void gather( \
|
||||
const device T *src [[buffer(0)]], \
|
||||
device T *out [[buffer(1)]], \
|
||||
const constant int *src_shape [[buffer(2)]], \
|
||||
const constant size_t *src_strides [[buffer(3)]], \
|
||||
const constant size_t& src_ndim [[buffer(4)]], \
|
||||
const constant int *slice_sizes [[buffer(5)]], \
|
||||
const constant int *axes [[buffer(6)]], \
|
||||
const constant int *idx_shapes [[buffer(7)]], \
|
||||
const constant size_t *idx_strides [[buffer(8)]], \
|
||||
const constant int& idx_ndim [[buffer(9)]], \
|
||||
IDX_ARG(IdxT) \
|
||||
uint2 index [[thread_position_in_grid]], \
|
||||
uint2 grid_dim [[threads_per_grid]]) { \
|
||||
\
|
||||
Indices<IdxT, NIDX> idxs{ \
|
||||
{{IDX_ARR()}}, \
|
||||
idx_shapes, \
|
||||
idx_strides, \
|
||||
idx_ndim}; \
|
||||
\
|
||||
return gather_impl<T, IdxT, NIDX, IDX_NDIM>( \
|
||||
src, \
|
||||
out, \
|
||||
src_shape, \
|
||||
src_strides, \
|
||||
src_ndim, \
|
||||
slice_sizes, \
|
||||
axes, \
|
||||
idxs, \
|
||||
index, \
|
||||
grid_dim); \
|
||||
}
|
||||
|
||||
#define make_gather(n) make_gather_impl(IDX_ARG_ ##n, IDX_ARR_ ##n)
|
||||
|
||||
make_gather(0)
|
||||
make_gather(1)
|
||||
make_gather(2)
|
||||
make_gather(3)
|
||||
make_gather(4)
|
||||
make_gather(5)
|
||||
make_gather(6)
|
||||
make_gather(7)
|
||||
make_gather(8)
|
||||
make_gather(9)
|
||||
make_gather(10)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////
|
||||
// Gather instantiations
|
||||
/////////////////////////////////////////////////////////////////////
|
||||
|
||||
#define instantiate_gather6(name, src_t, idx_t, nidx, IDX_ARG, nd, nd_name) \
|
||||
template [[host_name("gather" name "_" #nidx "" #nd_name)]] \
|
||||
[[kernel]] void gather<src_t, idx_t, nidx, nd>( \
|
||||
const device src_t *src [[buffer(0)]], \
|
||||
device src_t *out [[buffer(1)]], \
|
||||
const constant int *src_shape [[buffer(2)]], \
|
||||
const constant size_t *src_strides [[buffer(3)]], \
|
||||
const constant size_t& src_ndim [[buffer(4)]], \
|
||||
const constant int *slice_sizes [[buffer(5)]], \
|
||||
const constant int *axes [[buffer(6)]], \
|
||||
const constant int *idx_shapes [[buffer(7)]], \
|
||||
const constant size_t *idx_strides [[buffer(8)]], \
|
||||
const constant int& idx_ndim [[buffer(9)]], \
|
||||
IDX_ARG(idx_t) \
|
||||
uint2 index [[thread_position_in_grid]], \
|
||||
uint2 grid_dim [[threads_per_grid]]);
|
||||
|
||||
#define instantiate_gather5(name, src_t, idx_t, nidx, nd, nd_name) \
|
||||
instantiate_gather6(name, src_t, idx_t, nidx, IDX_ARG_ ##nidx, nd, nd_name)
|
||||
|
||||
#define instantiate_gather4(name, src_t, idx_t, nidx) \
|
||||
instantiate_gather5(name, src_t, idx_t, nidx, 0, _0) \
|
||||
instantiate_gather5(name, src_t, idx_t, nidx, 1, _1) \
|
||||
instantiate_gather5(name, src_t, idx_t, nidx, 2, )
|
||||
|
||||
|
||||
// Special for case NIDX=0
|
||||
instantiate_gather4("bool_", bool, bool, 0)
|
||||
instantiate_gather4("uint8", uint8_t, bool, 0)
|
||||
instantiate_gather4("uint16", uint16_t, bool, 0)
|
||||
instantiate_gather4("uint32", uint32_t, bool, 0)
|
||||
instantiate_gather4("uint64", uint64_t, bool, 0)
|
||||
instantiate_gather4("int8", int8_t, bool, 0)
|
||||
instantiate_gather4("int16", int16_t, bool, 0)
|
||||
instantiate_gather4("int32", int32_t, bool, 0)
|
||||
instantiate_gather4("int64", int64_t, bool, 0)
|
||||
instantiate_gather4("float16", half, bool, 0)
|
||||
instantiate_gather4("float32", float, bool, 0)
|
||||
instantiate_gather4("bfloat16", bfloat16_t, bool, 0)
|
||||
|
||||
#define instantiate_gather3(name, src_type, ind_type) \
|
||||
instantiate_gather4(name, src_type, ind_type, 1) \
|
||||
instantiate_gather4(name, src_type, ind_type, 2) \
|
||||
instantiate_gather4(name, src_type, ind_type, 3) \
|
||||
instantiate_gather4(name, src_type, ind_type, 4) \
|
||||
instantiate_gather4(name, src_type, ind_type, 5) \
|
||||
instantiate_gather4(name, src_type, ind_type, 6) \
|
||||
instantiate_gather4(name, src_type, ind_type, 7) \
|
||||
instantiate_gather4(name, src_type, ind_type, 8) \
|
||||
instantiate_gather4(name, src_type, ind_type, 9) \
|
||||
instantiate_gather4(name, src_type, ind_type, 10)
|
||||
|
||||
#define instantiate_gather(name, src_type) \
|
||||
instantiate_gather3(#name "bool_", src_type, bool) \
|
||||
instantiate_gather3(#name "uint8", src_type, uint8_t) \
|
||||
instantiate_gather3(#name "uint16", src_type, uint16_t) \
|
||||
instantiate_gather3(#name "uint32", src_type, uint32_t) \
|
||||
instantiate_gather3(#name "uint64", src_type, uint64_t) \
|
||||
instantiate_gather3(#name "int8", src_type, int8_t) \
|
||||
instantiate_gather3(#name "int16", src_type, int16_t) \
|
||||
instantiate_gather3(#name "int32", src_type, int32_t) \
|
||||
instantiate_gather3(#name "int64", src_type, int64_t)
|
||||
|
||||
instantiate_gather(bool_, bool)
|
||||
instantiate_gather(uint8, uint8_t)
|
||||
instantiate_gather(uint16, uint16_t)
|
||||
instantiate_gather(uint32, uint32_t)
|
||||
instantiate_gather(uint64, uint64_t)
|
||||
instantiate_gather(int8, int8_t)
|
||||
instantiate_gather(int16, int16_t)
|
||||
instantiate_gather(int32, int32_t)
|
||||
instantiate_gather(int64, int64_t)
|
||||
instantiate_gather(float16, half)
|
||||
instantiate_gather(float32, float)
|
||||
instantiate_gather(bfloat16, bfloat16_t)
|
54
mlx/backend/metal/kernels/indexing.h
Normal file
54
mlx/backend/metal/kernels/indexing.h
Normal file
@ -0,0 +1,54 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <metal_stdlib>
|
||||
|
||||
using namespace metal;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////
|
||||
// Indexing utils
|
||||
/////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename IdxT, int NIDX>
|
||||
struct Indices {
|
||||
const array<const device IdxT*, NIDX> buffers;
|
||||
const constant int* shapes;
|
||||
const constant size_t* strides;
|
||||
const int ndim;
|
||||
};
|
||||
|
||||
template <typename IdxT>
|
||||
METAL_FUNC size_t offset_neg_idx(IdxT idx, size_t size) {
|
||||
if (is_unsigned_v<IdxT>) {
|
||||
return idx;
|
||||
} else {
|
||||
return (idx < 0) ? idx + size : idx;
|
||||
}
|
||||
}
|
||||
|
||||
#define IDX_ARG_N(idx_t, n) const device idx_t *idx##n [[buffer(n)]],
|
||||
|
||||
#define IDX_ARG_0(idx_t)
|
||||
#define IDX_ARG_1(idx_t) IDX_ARG_0(idx_t) IDX_ARG_N(idx_t, 21)
|
||||
#define IDX_ARG_2(idx_t) IDX_ARG_1(idx_t) IDX_ARG_N(idx_t, 22)
|
||||
#define IDX_ARG_3(idx_t) IDX_ARG_2(idx_t) IDX_ARG_N(idx_t, 23)
|
||||
#define IDX_ARG_4(idx_t) IDX_ARG_3(idx_t) IDX_ARG_N(idx_t, 24)
|
||||
#define IDX_ARG_5(idx_t) IDX_ARG_4(idx_t) IDX_ARG_N(idx_t, 25)
|
||||
#define IDX_ARG_6(idx_t) IDX_ARG_5(idx_t) IDX_ARG_N(idx_t, 26)
|
||||
#define IDX_ARG_7(idx_t) IDX_ARG_6(idx_t) IDX_ARG_N(idx_t, 27)
|
||||
#define IDX_ARG_8(idx_t) IDX_ARG_7(idx_t) IDX_ARG_N(idx_t, 28)
|
||||
#define IDX_ARG_9(idx_t) IDX_ARG_8(idx_t) IDX_ARG_N(idx_t, 29)
|
||||
#define IDX_ARG_10(idx_t) IDX_ARG_9(idx_t) IDX_ARG_N(idx_t, 30)
|
||||
|
||||
#define IDX_ARR_N(n) idx##n,
|
||||
|
||||
#define IDX_ARR_0()
|
||||
#define IDX_ARR_1() IDX_ARR_0() IDX_ARR_N(21)
|
||||
#define IDX_ARR_2() IDX_ARR_1() IDX_ARR_N(22)
|
||||
#define IDX_ARR_3() IDX_ARR_2() IDX_ARR_N(23)
|
||||
#define IDX_ARR_4() IDX_ARR_3() IDX_ARR_N(24)
|
||||
#define IDX_ARR_5() IDX_ARR_4() IDX_ARR_N(25)
|
||||
#define IDX_ARR_6() IDX_ARR_5() IDX_ARR_N(26)
|
||||
#define IDX_ARR_7() IDX_ARR_6() IDX_ARR_N(27)
|
||||
#define IDX_ARR_8() IDX_ARR_7() IDX_ARR_N(28)
|
||||
#define IDX_ARR_9() IDX_ARR_8() IDX_ARR_N(29)
|
||||
#define IDX_ARR_10() IDX_ARR_9() IDX_ARR_N(30)
|
@ -1,290 +0,0 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <metal_atomic>
|
||||
#include <metal_texture>
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/reduce.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
using namespace metal;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////
|
||||
// Gather kernel
|
||||
/////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename IdxT, int NIDX>
|
||||
struct Indices {
|
||||
const array<device IdxT*, NIDX> buffers [[id(0)]];
|
||||
device int* shapes [[id(NIDX + 1)]];
|
||||
device size_t* strides [[id(NIDX + 2)]];
|
||||
const int ndim [[id(NIDX + 3)]];
|
||||
};
|
||||
|
||||
template <typename IdxT>
|
||||
inline size_t offset_neg_idx(IdxT idx, size_t size) {
|
||||
return (idx < 0) ? idx + size : idx;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline size_t offset_neg_idx(bool idx, size_t) {
|
||||
return idx;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline size_t offset_neg_idx(uint32_t idx, size_t) {
|
||||
return idx;
|
||||
}
|
||||
|
||||
// IDX_NDIM is the number of dimensions of the indices arrays. Compile-time
|
||||
// special case for 0 and 1. Anything >= 2 uses the general case
|
||||
template <typename T, typename IdxT, int NIDX, int IDX_NDIM>
|
||||
[[kernel]] void gather(
|
||||
const device T *src [[buffer(0)]],
|
||||
const constant Indices<IdxT, NIDX>& indices [[buffer(1)]],
|
||||
device T *out [[buffer(2)]],
|
||||
const constant int *src_shape [[buffer(3)]],
|
||||
const constant size_t *src_strides [[buffer(4)]],
|
||||
const constant size_t& src_ndim [[buffer(5)]],
|
||||
const constant int *slice_sizes [[buffer(6)]],
|
||||
const constant int *axes [[buffer(7)]],
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
|
||||
auto ind_idx = index.x;
|
||||
auto ind_offset = index.y;
|
||||
|
||||
size_t src_idx = 0;
|
||||
for (int i = 0; i < NIDX; ++i) {
|
||||
size_t idx_loc;
|
||||
if (IDX_NDIM == 0) {
|
||||
idx_loc = 0;
|
||||
} else if (IDX_NDIM == 1) {
|
||||
idx_loc = ind_idx * indices.strides[indices.ndim * i];
|
||||
} else {
|
||||
idx_loc = elem_to_loc(
|
||||
ind_idx,
|
||||
&indices.shapes[indices.ndim * i],
|
||||
&indices.strides[indices.ndim * i],
|
||||
indices.ndim);
|
||||
}
|
||||
auto ax = axes[i];
|
||||
auto idx_val = offset_neg_idx(
|
||||
indices.buffers[i][idx_loc], src_shape[ax]);
|
||||
src_idx += idx_val * src_strides[ax];
|
||||
}
|
||||
|
||||
auto src_offset = elem_to_loc(
|
||||
ind_offset, slice_sizes, src_strides, src_ndim);
|
||||
|
||||
size_t out_idx = index.y + static_cast<size_t>(grid_dim.y) * index.x;
|
||||
out[out_idx] = src[src_offset + src_idx];
|
||||
}
|
||||
|
||||
#define instantiate_gather4(name, src_type, ind_type, nindex) \
|
||||
template [[host_name("gather" name "_" #nindex "_0")]] \
|
||||
[[kernel]] void gather<src_type, ind_type, nindex, 0>( \
|
||||
const device src_type *src [[buffer(0)]], \
|
||||
const constant Indices<ind_type, nindex>& indices [[buffer(1)]], \
|
||||
device src_type *out [[buffer(2)]], \
|
||||
const constant int *src_shape [[buffer(3)]], \
|
||||
const constant size_t *src_strides [[buffer(4)]], \
|
||||
const constant size_t& src_ndim [[buffer(5)]], \
|
||||
const constant int *slice_sizes [[buffer(6)]], \
|
||||
const constant int* axes [[buffer(7)]], \
|
||||
uint2 index [[thread_position_in_grid]], \
|
||||
uint2 grid_dim [[threads_per_grid]]); \
|
||||
template [[host_name("gather" name "_" #nindex "_1")]] \
|
||||
[[kernel]] void gather<src_type, ind_type, nindex, 1>( \
|
||||
const device src_type *src [[buffer(0)]], \
|
||||
const constant Indices<ind_type, nindex>& indices [[buffer(1)]], \
|
||||
device src_type *out [[buffer(2)]], \
|
||||
const constant int *src_shape [[buffer(3)]], \
|
||||
const constant size_t *src_strides [[buffer(4)]], \
|
||||
const constant size_t& src_ndim [[buffer(5)]], \
|
||||
const constant int *slice_sizes [[buffer(6)]], \
|
||||
const constant int* axes [[buffer(7)]], \
|
||||
uint2 index [[thread_position_in_grid]], \
|
||||
uint2 grid_dim [[threads_per_grid]]); \
|
||||
template [[host_name("gather" name "_" #nindex)]] \
|
||||
[[kernel]] void gather<src_type, ind_type, nindex, 2>( \
|
||||
const device src_type *src [[buffer(0)]], \
|
||||
const constant Indices<ind_type, nindex>& indices [[buffer(1)]], \
|
||||
device src_type *out [[buffer(2)]], \
|
||||
const constant int *src_shape [[buffer(3)]], \
|
||||
const constant size_t *src_strides [[buffer(4)]], \
|
||||
const constant size_t& src_ndim [[buffer(5)]], \
|
||||
const constant int *slice_sizes [[buffer(6)]], \
|
||||
const constant int* axes [[buffer(7)]], \
|
||||
uint2 index [[thread_position_in_grid]], \
|
||||
uint2 grid_dim [[threads_per_grid]]);
|
||||
|
||||
|
||||
// Special for case NIDX=0
|
||||
instantiate_gather4("bool_", bool, bool, 0)
|
||||
instantiate_gather4("uint8", uint8_t, bool, 0)
|
||||
instantiate_gather4("uint16", uint16_t, bool, 0)
|
||||
instantiate_gather4("uint32", uint32_t, bool, 0)
|
||||
instantiate_gather4("uint64", uint64_t, bool, 0)
|
||||
instantiate_gather4("int8", int8_t, bool, 0)
|
||||
instantiate_gather4("int16", int16_t, bool, 0)
|
||||
instantiate_gather4("int32", int32_t, bool, 0)
|
||||
instantiate_gather4("int64", int64_t, bool, 0)
|
||||
instantiate_gather4("float16", half, bool, 0)
|
||||
instantiate_gather4("float32", float, bool, 0)
|
||||
instantiate_gather4("bfloat16", bfloat16_t, bool, 0)
|
||||
|
||||
#define instantiate_gather3(name, src_type, ind_type) \
|
||||
instantiate_gather4(name, src_type, ind_type, 1) \
|
||||
instantiate_gather4(name, src_type, ind_type, 2) \
|
||||
instantiate_gather4(name, src_type, ind_type, 3) \
|
||||
instantiate_gather4(name, src_type, ind_type, 4) \
|
||||
instantiate_gather4(name, src_type, ind_type, 5) \
|
||||
instantiate_gather4(name, src_type, ind_type, 6) \
|
||||
instantiate_gather4(name, src_type, ind_type, 7) \
|
||||
instantiate_gather4(name, src_type, ind_type, 8) \
|
||||
instantiate_gather4(name, src_type, ind_type, 9) \
|
||||
instantiate_gather4(name, src_type, ind_type, 10)
|
||||
|
||||
#define instantiate_gather(name, src_type) \
|
||||
instantiate_gather3(#name "bool_", src_type, bool) \
|
||||
instantiate_gather3(#name "uint8", src_type, uint8_t) \
|
||||
instantiate_gather3(#name "uint16", src_type, uint16_t) \
|
||||
instantiate_gather3(#name "uint32", src_type, uint32_t) \
|
||||
instantiate_gather3(#name "uint64", src_type, uint64_t) \
|
||||
instantiate_gather3(#name "int8", src_type, int8_t) \
|
||||
instantiate_gather3(#name "int16", src_type, int16_t) \
|
||||
instantiate_gather3(#name "int32", src_type, int32_t) \
|
||||
instantiate_gather3(#name "int64", src_type, int64_t)
|
||||
|
||||
instantiate_gather(bool_, bool)
|
||||
instantiate_gather(uint8, uint8_t)
|
||||
instantiate_gather(uint16, uint16_t)
|
||||
instantiate_gather(uint32, uint32_t)
|
||||
instantiate_gather(uint64, uint64_t)
|
||||
instantiate_gather(int8, int8_t)
|
||||
instantiate_gather(int16, int16_t)
|
||||
instantiate_gather(int32, int32_t)
|
||||
instantiate_gather(int64, int64_t)
|
||||
instantiate_gather(float16, half)
|
||||
instantiate_gather(float32, float)
|
||||
instantiate_gather(bfloat16, bfloat16_t)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////
|
||||
// Scatter kernel
|
||||
/////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T, typename IdxT, typename Op, int NIDX>
|
||||
[[kernel]] void scatter(
|
||||
const device Indices<IdxT, NIDX>& indices [[buffer(0)]],
|
||||
const device T *updates [[buffer(1)]],
|
||||
device mlx_atomic<T> *out [[buffer(2)]],
|
||||
const device int *upd_shape [[buffer(3)]],
|
||||
const device size_t *upd_strides [[buffer(4)]],
|
||||
const device size_t& upd_ndim [[buffer(5)]],
|
||||
const device size_t& upd_size [[buffer(6)]],
|
||||
const device int *out_shape [[buffer(7)]],
|
||||
const device size_t *out_strides [[buffer(8)]],
|
||||
const device size_t& out_ndim [[buffer(9)]],
|
||||
const device int* axes [[buffer(10)]],
|
||||
uint2 gid [[thread_position_in_grid]]) {
|
||||
|
||||
Op op;
|
||||
auto ind_idx = gid.y;
|
||||
auto ind_offset = gid.x;
|
||||
|
||||
size_t out_idx = 0;
|
||||
for (int i = 0; i < NIDX; ++i) {
|
||||
auto idx_loc = elem_to_loc(
|
||||
ind_idx,
|
||||
&indices.shapes[indices.ndim * i],
|
||||
&indices.strides[indices.ndim * i],
|
||||
indices.ndim);
|
||||
auto ax = axes[i];
|
||||
auto idx_val = offset_neg_idx(
|
||||
indices.buffers[i][idx_loc], out_shape[ax]);
|
||||
out_idx += idx_val * out_strides[ax];
|
||||
}
|
||||
|
||||
auto out_offset = elem_to_loc(
|
||||
ind_offset, upd_shape + indices.ndim, out_strides, out_ndim);
|
||||
auto upd_idx = elem_to_loc(gid.y * upd_size + gid.x, upd_shape, upd_strides, upd_ndim);
|
||||
op.atomic_update(out, updates[upd_idx], out_idx + out_offset);
|
||||
}
|
||||
|
||||
#define instantiate_scatter4(name, type, ind_type, op_type, nindex) \
|
||||
template [[host_name("scatter" name "_" #nindex)]] \
|
||||
[[kernel]] void scatter<type, ind_type, op_type, nindex>( \
|
||||
const device Indices<ind_type, nindex>& indices [[buffer(0)]], \
|
||||
const device type *updates [[buffer(1)]], \
|
||||
device mlx_atomic<type> *out [[buffer(2)]], \
|
||||
const device int *upd_shape [[buffer(3)]], \
|
||||
const device size_t *upd_strides [[buffer(4)]], \
|
||||
const device size_t& upd_ndim [[buffer(5)]], \
|
||||
const device size_t& upd_size [[buffer(6)]], \
|
||||
const device int *out_shape [[buffer(7)]], \
|
||||
const device size_t *out_strides [[buffer(8)]], \
|
||||
const device size_t& out_ndim [[buffer(9)]], \
|
||||
const device int* axes [[buffer(10)]], \
|
||||
uint2 gid [[thread_position_in_grid]]);
|
||||
|
||||
// Special case NINDEX=0
|
||||
#define instantiate_scatter_nd0(name, type) \
|
||||
instantiate_scatter4(#name "none", type, bool, None, 0) \
|
||||
instantiate_scatter4(#name "_sum", type, bool, Sum<type>, 0) \
|
||||
instantiate_scatter4(#name "_prod", type, bool, Prod<type>, 0) \
|
||||
instantiate_scatter4(#name "_max", type, bool, Max<type>, 0) \
|
||||
instantiate_scatter4(#name "_min", type, bool, Min<type>, 0)
|
||||
|
||||
#define instantiate_scatter3(name, type, ind_type, op_type) \
|
||||
instantiate_scatter4(name, type, ind_type, op_type, 1) \
|
||||
instantiate_scatter4(name, type, ind_type, op_type, 2) \
|
||||
instantiate_scatter4(name, type, ind_type, op_type, 3) \
|
||||
instantiate_scatter4(name, type, ind_type, op_type, 4) \
|
||||
instantiate_scatter4(name, type, ind_type, op_type, 5) \
|
||||
instantiate_scatter4(name, type, ind_type, op_type, 6) \
|
||||
instantiate_scatter4(name, type, ind_type, op_type, 7) \
|
||||
instantiate_scatter4(name, type, ind_type, op_type, 8) \
|
||||
instantiate_scatter4(name, type, ind_type, op_type, 9) \
|
||||
instantiate_scatter4(name, type, ind_type, op_type, 10)
|
||||
|
||||
#define instantiate_scatter2(name, type, ind_type) \
|
||||
instantiate_scatter3(name "_none", type, ind_type, None) \
|
||||
instantiate_scatter3(name "_sum", type, ind_type, Sum<type>) \
|
||||
instantiate_scatter3(name "_prod", type, ind_type, Prod<type>) \
|
||||
instantiate_scatter3(name "_max", type, ind_type, Max<type>) \
|
||||
instantiate_scatter3(name "_min", type, ind_type, Min<type>)
|
||||
|
||||
#define instantiate_scatter(name, type) \
|
||||
instantiate_scatter2(#name "bool_", type, bool) \
|
||||
instantiate_scatter2(#name "uint8", type, uint8_t) \
|
||||
instantiate_scatter2(#name "uint16", type, uint16_t) \
|
||||
instantiate_scatter2(#name "uint32", type, uint32_t) \
|
||||
instantiate_scatter2(#name "uint64", type, uint64_t) \
|
||||
instantiate_scatter2(#name "int8", type, int8_t) \
|
||||
instantiate_scatter2(#name "int16", type, int16_t) \
|
||||
instantiate_scatter2(#name "int32", type, int32_t) \
|
||||
instantiate_scatter2(#name "int64", type, int64_t)
|
||||
|
||||
// TODO uint64 and int64 unsupported
|
||||
instantiate_scatter_nd0(bool_, bool)
|
||||
instantiate_scatter_nd0(uint8, uint8_t)
|
||||
instantiate_scatter_nd0(uint16, uint16_t)
|
||||
instantiate_scatter_nd0(uint32, uint32_t)
|
||||
instantiate_scatter_nd0(int8, int8_t)
|
||||
instantiate_scatter_nd0(int16, int16_t)
|
||||
instantiate_scatter_nd0(int32, int32_t)
|
||||
instantiate_scatter_nd0(float16, half)
|
||||
instantiate_scatter_nd0(float32, float)
|
||||
instantiate_scatter_nd0(bfloat16, bfloat16_t)
|
||||
|
||||
instantiate_scatter(bool_, bool)
|
||||
instantiate_scatter(uint8, uint8_t)
|
||||
instantiate_scatter(uint16, uint16_t)
|
||||
instantiate_scatter(uint32, uint32_t)
|
||||
instantiate_scatter(int8, int8_t)
|
||||
instantiate_scatter(int16, int16_t)
|
||||
instantiate_scatter(int32, int32_t)
|
||||
instantiate_scatter(float16, half)
|
||||
instantiate_scatter(float32, float)
|
||||
instantiate_scatter(bfloat16, bfloat16_t)
|
194
mlx/backend/metal/kernels/scatter.metal
Normal file
194
mlx/backend/metal/kernels/scatter.metal
Normal file
@ -0,0 +1,194 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <metal_atomic>
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/indexing.h"
|
||||
#include "mlx/backend/metal/kernels/reduce.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
using namespace metal;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////
|
||||
// Scatter kernel
|
||||
/////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
template <typename T, typename IdxT, typename Op, int NIDX>
|
||||
METAL_FUNC void scatter_impl(
|
||||
const device T *updates [[buffer(1)]],
|
||||
device mlx_atomic<T> *out [[buffer(2)]],
|
||||
const constant int *upd_shape [[buffer(3)]],
|
||||
const constant size_t *upd_strides [[buffer(4)]],
|
||||
const constant size_t& upd_ndim [[buffer(5)]],
|
||||
const constant size_t& upd_size [[buffer(6)]],
|
||||
const constant int *out_shape [[buffer(7)]],
|
||||
const constant size_t *out_strides [[buffer(8)]],
|
||||
const constant size_t& out_ndim [[buffer(9)]],
|
||||
const constant int* axes [[buffer(10)]],
|
||||
const thread Indices<IdxT, NIDX>& indices,
|
||||
uint2 gid [[thread_position_in_grid]]) {
|
||||
|
||||
Op op;
|
||||
auto ind_idx = gid.y;
|
||||
auto ind_offset = gid.x;
|
||||
|
||||
size_t out_idx = 0;
|
||||
for (int i = 0; i < NIDX; ++i) {
|
||||
auto idx_loc = elem_to_loc(
|
||||
ind_idx,
|
||||
&indices.shapes[indices.ndim * i],
|
||||
&indices.strides[indices.ndim * i],
|
||||
indices.ndim);
|
||||
auto ax = axes[i];
|
||||
auto idx_val = offset_neg_idx(
|
||||
indices.buffers[i][idx_loc], out_shape[ax]);
|
||||
out_idx += idx_val * out_strides[ax];
|
||||
}
|
||||
|
||||
auto out_offset = elem_to_loc(
|
||||
ind_offset, upd_shape + indices.ndim, out_strides, out_ndim);
|
||||
auto upd_idx = elem_to_loc(gid.y * upd_size + gid.x, upd_shape, upd_strides, upd_ndim);
|
||||
op.atomic_update(out, updates[upd_idx], out_idx + out_offset);
|
||||
}
|
||||
|
||||
#define make_scatter_impl(IDX_ARG, IDX_ARR) \
|
||||
template <typename T, typename IdxT, typename Op, int NIDX> \
|
||||
[[kernel]] void scatter( \
|
||||
const device T *updates [[buffer(1)]], \
|
||||
device mlx_atomic<T> *out [[buffer(2)]], \
|
||||
const constant int *upd_shape [[buffer(3)]], \
|
||||
const constant size_t *upd_strides [[buffer(4)]], \
|
||||
const constant size_t& upd_ndim [[buffer(5)]], \
|
||||
const constant size_t& upd_size [[buffer(6)]], \
|
||||
const constant int *out_shape [[buffer(7)]], \
|
||||
const constant size_t *out_strides [[buffer(8)]], \
|
||||
const constant size_t& out_ndim [[buffer(9)]], \
|
||||
const constant int* axes [[buffer(10)]], \
|
||||
const constant int *idx_shapes [[buffer(11)]], \
|
||||
const constant size_t *idx_strides [[buffer(12)]], \
|
||||
const constant int& idx_ndim [[buffer(13)]], \
|
||||
IDX_ARG(IdxT) \
|
||||
uint2 gid [[thread_position_in_grid]]) { \
|
||||
\
|
||||
Indices<IdxT, NIDX> idxs{ \
|
||||
{{IDX_ARR()}}, \
|
||||
idx_shapes, \
|
||||
idx_strides, \
|
||||
idx_ndim}; \
|
||||
\
|
||||
return scatter_impl<T, IdxT, Op, NIDX>( \
|
||||
updates, \
|
||||
out, \
|
||||
upd_shape, \
|
||||
upd_strides, \
|
||||
upd_ndim, \
|
||||
upd_size, \
|
||||
out_shape, \
|
||||
out_strides, \
|
||||
out_ndim, \
|
||||
axes, \
|
||||
idxs, \
|
||||
gid); \
|
||||
}
|
||||
|
||||
#define make_scatter(n) make_scatter_impl(IDX_ARG_ ##n, IDX_ARR_ ##n)
|
||||
|
||||
make_scatter(0)
|
||||
make_scatter(1)
|
||||
make_scatter(2)
|
||||
make_scatter(3)
|
||||
make_scatter(4)
|
||||
make_scatter(5)
|
||||
make_scatter(6)
|
||||
make_scatter(7)
|
||||
make_scatter(8)
|
||||
make_scatter(9)
|
||||
make_scatter(10)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////
|
||||
// Scatter instantiations
|
||||
/////////////////////////////////////////////////////////////////////
|
||||
|
||||
#define instantiate_scatter5(name, src_t, idx_t, op_t, nidx, IDX_ARG) \
|
||||
template [[host_name("scatter" name "_" #nidx)]] \
|
||||
[[kernel]] void scatter<src_t, idx_t, op_t, nidx>( \
|
||||
const device src_t *updates [[buffer(1)]], \
|
||||
device mlx_atomic<src_t> *out [[buffer(2)]], \
|
||||
const constant int *upd_shape [[buffer(3)]], \
|
||||
const constant size_t *upd_strides [[buffer(4)]], \
|
||||
const constant size_t& upd_ndim [[buffer(5)]], \
|
||||
const constant size_t& upd_size [[buffer(6)]], \
|
||||
const constant int *out_shape [[buffer(7)]], \
|
||||
const constant size_t *out_strides [[buffer(8)]], \
|
||||
const constant size_t& out_ndim [[buffer(9)]], \
|
||||
const constant int* axes [[buffer(10)]], \
|
||||
const constant int *idx_shapes [[buffer(11)]], \
|
||||
const constant size_t *idx_strides [[buffer(12)]], \
|
||||
const constant int& idx_ndim [[buffer(13)]], \
|
||||
IDX_ARG(idx_t) \
|
||||
uint2 gid [[thread_position_in_grid]]);
|
||||
|
||||
#define instantiate_scatter4(name, src_t, idx_t, op_t, nidx) \
|
||||
instantiate_scatter5(name, src_t, idx_t, op_t, nidx, IDX_ARG_ ##nidx)
|
||||
|
||||
// Special case NINDEX=0
|
||||
#define instantiate_scatter_nd0(name, type) \
|
||||
instantiate_scatter4(#name "none", type, bool, None, 0) \
|
||||
instantiate_scatter4(#name "_sum", type, bool, Sum<type>, 0) \
|
||||
instantiate_scatter4(#name "_prod", type, bool, Prod<type>, 0) \
|
||||
instantiate_scatter4(#name "_max", type, bool, Max<type>, 0) \
|
||||
instantiate_scatter4(#name "_min", type, bool, Min<type>, 0)
|
||||
|
||||
#define instantiate_scatter3(name, type, ind_type, op_type) \
|
||||
instantiate_scatter4(name, type, ind_type, op_type, 1) \
|
||||
instantiate_scatter4(name, type, ind_type, op_type, 2) \
|
||||
instantiate_scatter4(name, type, ind_type, op_type, 3) \
|
||||
instantiate_scatter4(name, type, ind_type, op_type, 4) \
|
||||
instantiate_scatter4(name, type, ind_type, op_type, 5) \
|
||||
instantiate_scatter4(name, type, ind_type, op_type, 6) \
|
||||
instantiate_scatter4(name, type, ind_type, op_type, 7) \
|
||||
instantiate_scatter4(name, type, ind_type, op_type, 8) \
|
||||
instantiate_scatter4(name, type, ind_type, op_type, 9) \
|
||||
instantiate_scatter4(name, type, ind_type, op_type, 10)
|
||||
|
||||
#define instantiate_scatter2(name, type, ind_type) \
|
||||
instantiate_scatter3(name "_none", type, ind_type, None) \
|
||||
instantiate_scatter3(name "_sum", type, ind_type, Sum<type>) \
|
||||
instantiate_scatter3(name "_prod", type, ind_type, Prod<type>) \
|
||||
instantiate_scatter3(name "_max", type, ind_type, Max<type>) \
|
||||
instantiate_scatter3(name "_min", type, ind_type, Min<type>)
|
||||
|
||||
#define instantiate_scatter(name, type) \
|
||||
instantiate_scatter2(#name "bool_", type, bool) \
|
||||
instantiate_scatter2(#name "uint8", type, uint8_t) \
|
||||
instantiate_scatter2(#name "uint16", type, uint16_t) \
|
||||
instantiate_scatter2(#name "uint32", type, uint32_t) \
|
||||
instantiate_scatter2(#name "uint64", type, uint64_t) \
|
||||
instantiate_scatter2(#name "int8", type, int8_t) \
|
||||
instantiate_scatter2(#name "int16", type, int16_t) \
|
||||
instantiate_scatter2(#name "int32", type, int32_t) \
|
||||
instantiate_scatter2(#name "int64", type, int64_t)
|
||||
|
||||
// TODO uint64 and int64 unsupported
|
||||
instantiate_scatter_nd0(bool_, bool)
|
||||
instantiate_scatter_nd0(uint8, uint8_t)
|
||||
instantiate_scatter_nd0(uint16, uint16_t)
|
||||
instantiate_scatter_nd0(uint32, uint32_t)
|
||||
instantiate_scatter_nd0(int8, int8_t)
|
||||
instantiate_scatter_nd0(int16, int16_t)
|
||||
instantiate_scatter_nd0(int32, int32_t)
|
||||
instantiate_scatter_nd0(float16, half)
|
||||
instantiate_scatter_nd0(float32, float)
|
||||
instantiate_scatter_nd0(bfloat16, bfloat16_t)
|
||||
|
||||
instantiate_scatter(bool_, bool)
|
||||
instantiate_scatter(uint8, uint8_t)
|
||||
instantiate_scatter(uint16, uint16_t)
|
||||
instantiate_scatter(uint32, uint32_t)
|
||||
instantiate_scatter(int8, int8_t)
|
||||
instantiate_scatter(int16, int16_t)
|
||||
instantiate_scatter(int32, int32_t)
|
||||
instantiate_scatter(float16, half)
|
||||
instantiate_scatter(float32, float)
|
||||
instantiate_scatter(bfloat16, bfloat16_t)
|
@ -9,20 +9,6 @@ namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
void set_array_buffer(
|
||||
MTL::ComputeCommandEncoder* compute_encoder,
|
||||
MTL::ArgumentEncoder* enc,
|
||||
const array& a,
|
||||
int idx) {
|
||||
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
|
||||
auto offset = a.data<char>() -
|
||||
static_cast<char*>(const_cast<MTL::Buffer*>(a_buf)->contents());
|
||||
enc->setBuffer(a_buf, offset, idx);
|
||||
// MTL::Resource usage through argument buffer needs to be explicitly
|
||||
// flagged to enable hazard tracking
|
||||
compute_encoder->useResource(a_buf, MTL::ResourceUsageRead);
|
||||
}
|
||||
|
||||
void set_array_buffer(
|
||||
MTL::ComputeCommandEncoder* enc,
|
||||
const array& a,
|
||||
|
Loading…
Reference in New Issue
Block a user