mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-22 02:58:16 +08:00
scatter axis + gather axis primitives (#1813)
* scatter axis + gather axis primitives * add transforms * comment
This commit is contained in:
@@ -35,6 +35,8 @@ make_jit_source(ternary_ops)
|
||||
make_jit_source(reduce_utils kernels/atomic.h kernels/reduction/ops.h)
|
||||
make_jit_source(scatter kernels/indexing.h)
|
||||
make_jit_source(gather kernels/indexing.h)
|
||||
make_jit_source(gather_axis)
|
||||
make_jit_source(scatter_axis)
|
||||
make_jit_source(hadamard)
|
||||
|
||||
if(MLX_METAL_JIT)
|
||||
|
@@ -6,6 +6,7 @@
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/jit/includes.h"
|
||||
#include "mlx/backend/metal/jit/indexing.h"
|
||||
#include "mlx/backend/metal/kernels.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
@@ -388,4 +389,217 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& src = inputs[0];
|
||||
auto& idx = inputs[1];
|
||||
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
if (out.size() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
size_t ndim = src.ndim();
|
||||
|
||||
bool large = idx.size() > INT32_MAX || src.size() > INT32_MAX;
|
||||
|
||||
std::string kernel_name = fmt::format(
|
||||
"gather_axis{0}{1}_{2}",
|
||||
type_to_name(out),
|
||||
type_to_name(idx),
|
||||
large ? "int64_t" : "int");
|
||||
std::string lib_name = kernel_name;
|
||||
kernel_name += src.flags().row_contiguous ? "c" : "nc";
|
||||
kernel_name += idx.flags().row_contiguous ? "c" : "nc";
|
||||
|
||||
auto lib = d.get_library(lib_name, [&]() {
|
||||
std::string kernel_source = metal::utils();
|
||||
kernel_source += metal::gather_axis();
|
||||
std::string out_type_str = get_type_string(out.dtype());
|
||||
std::string idx_type_str = get_type_string(idx.dtype());
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
bool sc = i & 1;
|
||||
bool ic = i & 2;
|
||||
kernel_source += get_template_definition(
|
||||
lib_name + (sc ? "c" : "nc") + (ic ? "c" : "nc"),
|
||||
"gather_axis",
|
||||
out_type_str,
|
||||
idx_type_str,
|
||||
large ? "int64_t" : "int",
|
||||
sc ? "true" : "false",
|
||||
ic ? "true" : "false");
|
||||
}
|
||||
return kernel_source;
|
||||
});
|
||||
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kernel_name, lib);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
// Grid [size post, index size, size pre]
|
||||
size_t size_pre = 1;
|
||||
size_t size_post = 1;
|
||||
for (int i = 0; i < axis_; ++i) {
|
||||
size_pre *= idx.shape(i);
|
||||
}
|
||||
for (int i = axis_ + 1; i < idx.ndim(); ++i) {
|
||||
size_post *= idx.shape(i);
|
||||
}
|
||||
|
||||
int idx_ax_size = idx.shape(axis_);
|
||||
auto group_dims = get_block_dims(size_post, idx_ax_size, size_pre);
|
||||
MTL::Size grid_dims = MTL::Size(size_post, idx_ax_size, size_pre);
|
||||
|
||||
// Set all the buffers
|
||||
compute_encoder.set_input_array(src, 0);
|
||||
compute_encoder.set_input_array(idx, 1);
|
||||
compute_encoder.set_output_array(out, 2);
|
||||
|
||||
// Set source info
|
||||
auto shape = idx.shape();
|
||||
shape.erase(shape.begin() + axis_);
|
||||
compute_encoder.set_vector_bytes(shape, 3);
|
||||
|
||||
auto strides = src.strides();
|
||||
strides.erase(strides.begin() + axis_);
|
||||
compute_encoder.set_vector_bytes(strides, 4);
|
||||
|
||||
strides = idx.strides();
|
||||
strides.erase(strides.begin() + axis_);
|
||||
compute_encoder.set_vector_bytes(strides, 5);
|
||||
compute_encoder.set_bytes(ndim - 1, 6);
|
||||
compute_encoder.set_bytes(axis_, 7);
|
||||
compute_encoder.set_bytes(src.shape(axis_), 8);
|
||||
compute_encoder.set_bytes(src.strides(axis_), 9);
|
||||
compute_encoder.set_bytes(idx.strides(axis_), 10);
|
||||
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& src = inputs[0];
|
||||
auto& idx = inputs[1];
|
||||
auto& upd = inputs[2];
|
||||
|
||||
// Copy src into out
|
||||
CopyType copy_type;
|
||||
if (src.data_size() == 1) {
|
||||
copy_type = CopyType::Scalar;
|
||||
} else if (src.flags().row_contiguous) {
|
||||
copy_type = CopyType::Vector;
|
||||
} else {
|
||||
copy_type = CopyType::General;
|
||||
}
|
||||
copy_gpu(src, out, copy_type);
|
||||
|
||||
// Empty update
|
||||
if (upd.size() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
size_t ndim = src.ndim();
|
||||
|
||||
bool large = idx.size() > INT32_MAX || src.size() > INT32_MAX;
|
||||
|
||||
std::string op_name;
|
||||
switch (reduce_type_) {
|
||||
case ScatterAxis::None:
|
||||
op_name = "none";
|
||||
break;
|
||||
case ScatterAxis::Sum:
|
||||
op_name = "sum";
|
||||
break;
|
||||
}
|
||||
|
||||
std::string kernel_name = fmt::format(
|
||||
"scatter_axis{0}{1}_{2}_{3}",
|
||||
type_to_name(out),
|
||||
type_to_name(idx),
|
||||
op_name,
|
||||
large ? "int64_t" : "int");
|
||||
std::string lib_name = kernel_name;
|
||||
kernel_name += upd.flags().row_contiguous ? "c" : "nc";
|
||||
kernel_name += idx.flags().row_contiguous ? "c" : "nc";
|
||||
|
||||
auto lib = d.get_library(lib_name, [&]() {
|
||||
std::string kernel_source = metal::utils();
|
||||
kernel_source += metal::reduce_utils();
|
||||
kernel_source += metal::scatter_axis();
|
||||
std::string out_type_str = get_type_string(out.dtype());
|
||||
std::string idx_type_str = get_type_string(idx.dtype());
|
||||
std::string op_type;
|
||||
switch (reduce_type_) {
|
||||
case ScatterAxis::None:
|
||||
op_type = "None";
|
||||
break;
|
||||
case ScatterAxis::Sum:
|
||||
op_type = "Sum<" + out_type_str + ">";
|
||||
break;
|
||||
}
|
||||
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
bool uc = i & 1;
|
||||
bool ic = i & 2;
|
||||
kernel_source += get_template_definition(
|
||||
lib_name + (uc ? "c" : "nc") + (ic ? "c" : "nc"),
|
||||
"scatter_axis",
|
||||
out_type_str,
|
||||
idx_type_str,
|
||||
large ? "int64_t" : "int",
|
||||
op_type,
|
||||
uc ? "true" : "false",
|
||||
ic ? "true" : "false");
|
||||
}
|
||||
return kernel_source;
|
||||
});
|
||||
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kernel_name, lib);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
// Grid [size post, index size, size pre]
|
||||
size_t size_pre = 1;
|
||||
size_t size_post = 1;
|
||||
for (int i = 0; i < axis_; ++i) {
|
||||
size_pre *= idx.shape(i);
|
||||
}
|
||||
for (int i = axis_ + 1; i < idx.ndim(); ++i) {
|
||||
size_post *= idx.shape(i);
|
||||
}
|
||||
|
||||
int idx_ax_size = idx.shape(axis_);
|
||||
auto group_dims = get_block_dims(size_post, idx_ax_size, size_pre);
|
||||
MTL::Size grid_dims = MTL::Size(size_post, idx_ax_size, size_pre);
|
||||
|
||||
// Set all the buffers
|
||||
compute_encoder.set_input_array(upd, 0);
|
||||
compute_encoder.set_input_array(idx, 1);
|
||||
compute_encoder.set_output_array(out, 2);
|
||||
|
||||
// Set source info
|
||||
auto shape = idx.shape();
|
||||
shape.erase(shape.begin() + axis_);
|
||||
compute_encoder.set_vector_bytes(shape, 3);
|
||||
|
||||
auto strides = upd.strides();
|
||||
strides.erase(strides.begin() + axis_);
|
||||
compute_encoder.set_vector_bytes(strides, 4);
|
||||
|
||||
strides = idx.strides();
|
||||
strides.erase(strides.begin() + axis_);
|
||||
compute_encoder.set_vector_bytes(strides, 5);
|
||||
compute_encoder.set_bytes(ndim - 1, 6);
|
||||
compute_encoder.set_bytes(axis_, 7);
|
||||
compute_encoder.set_bytes(out.shape(axis_), 8);
|
||||
compute_encoder.set_bytes(upd.strides(axis_), 9);
|
||||
compute_encoder.set_bytes(idx.strides(axis_), 10);
|
||||
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -18,10 +18,12 @@ const char* binary();
|
||||
const char* binary_two();
|
||||
const char* copy();
|
||||
const char* fft();
|
||||
const char* gather_axis();
|
||||
const char* hadamard();
|
||||
const char* quantized();
|
||||
const char* ternary();
|
||||
const char* scan();
|
||||
const char* scatter_axis();
|
||||
const char* softmax();
|
||||
const char* sort();
|
||||
const char* reduce();
|
||||
|
44
mlx/backend/metal/kernels/gather_axis.h
Normal file
44
mlx/backend/metal/kernels/gather_axis.h
Normal file
@@ -0,0 +1,44 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
template <typename T, typename IdxT, typename LocT, bool SrcC, bool IdxC>
|
||||
[[kernel]] void gather_axis(
|
||||
const device T* src [[buffer(0)]],
|
||||
const device IdxT* indices [[buffer(1)]],
|
||||
device T* out [[buffer(2)]],
|
||||
const constant int* shape [[buffer(3)]],
|
||||
const constant int64_t* src_strides [[buffer(4)]],
|
||||
const constant int64_t* idx_strides [[buffer(5)]],
|
||||
const constant size_t& ndim [[buffer(6)]],
|
||||
const constant int& axis [[buffer(7)]],
|
||||
const constant int& axis_size [[buffer(8)]],
|
||||
const constant size_t& src_ax_stride [[buffer(9)]],
|
||||
const constant size_t& idx_ax_stride [[buffer(10)]],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
LocT elem_idx = index.z * static_cast<LocT>(grid_dim.x);
|
||||
LocT out_idx = elem_idx * grid_dim.y + index.x;
|
||||
|
||||
LocT idx_loc = index.y * static_cast<LocT>(idx_ax_stride);
|
||||
if (IdxC) {
|
||||
idx_loc += out_idx;
|
||||
} else {
|
||||
idx_loc += elem_to_loc<LocT>(elem_idx + index.x, shape, idx_strides, ndim);
|
||||
}
|
||||
|
||||
auto idx_val = indices[idx_loc];
|
||||
if (is_signed_v<IdxT>) {
|
||||
idx_val = (idx_val < 0) ? idx_val + axis_size : idx_val;
|
||||
}
|
||||
|
||||
LocT src_idx = idx_val * static_cast<LocT>(src_ax_stride);
|
||||
if (SrcC) {
|
||||
src_idx += elem_idx * axis_size + index.x;
|
||||
} else {
|
||||
src_idx += elem_to_loc<LocT>(elem_idx + index.x, shape, src_strides, ndim);
|
||||
}
|
||||
|
||||
out_idx += index.y * static_cast<LocT>(grid_dim.x);
|
||||
out[out_idx] = src[src_idx];
|
||||
}
|
52
mlx/backend/metal/kernels/scatter_axis.h
Normal file
52
mlx/backend/metal/kernels/scatter_axis.h
Normal file
@@ -0,0 +1,52 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename IdxT,
|
||||
typename LocT,
|
||||
typename Op,
|
||||
bool UpdC,
|
||||
bool IdxC>
|
||||
[[kernel]] void scatter_axis(
|
||||
const device T* upd [[buffer(0)]],
|
||||
const device IdxT* indices [[buffer(1)]],
|
||||
device mlx_atomic<T>* out [[buffer(2)]],
|
||||
const constant int* shape [[buffer(3)]],
|
||||
const constant int64_t* upd_strides [[buffer(4)]],
|
||||
const constant int64_t* idx_strides [[buffer(5)]],
|
||||
const constant size_t& ndim [[buffer(6)]],
|
||||
const constant int& axis [[buffer(7)]],
|
||||
const constant int& out_axis_size [[buffer(8)]],
|
||||
const constant size_t& upd_ax_stride [[buffer(9)]],
|
||||
const constant size_t& idx_ax_stride [[buffer(10)]],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
Op op;
|
||||
|
||||
LocT elem_idx = index.z * static_cast<LocT>(grid_dim.x);
|
||||
|
||||
LocT idx_loc = index.y * static_cast<LocT>(idx_ax_stride);
|
||||
if (IdxC) {
|
||||
idx_loc += elem_idx * grid_dim.y + index.x;
|
||||
} else {
|
||||
idx_loc += elem_to_loc<LocT>(elem_idx + index.x, shape, idx_strides, ndim);
|
||||
}
|
||||
|
||||
auto idx_val = indices[idx_loc];
|
||||
if (is_signed_v<IdxT>) {
|
||||
idx_val = (idx_val < 0) ? idx_val + out_axis_size : idx_val;
|
||||
}
|
||||
|
||||
LocT upd_idx = index.y * static_cast<LocT>(upd_ax_stride);
|
||||
if (UpdC) {
|
||||
upd_idx += elem_idx * grid_dim.y + index.x;
|
||||
} else {
|
||||
upd_idx += elem_to_loc<LocT>(elem_idx + index.x, shape, upd_strides, ndim);
|
||||
}
|
||||
|
||||
LocT out_idx = elem_idx * static_cast<LocT>(out_axis_size) +
|
||||
idx_val * grid_dim.x + index.x;
|
||||
op.atomic_update(out, upd[upd_idx], out_idx);
|
||||
}
|
Reference in New Issue
Block a user