mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
scatter axis + gather axis primitives (#1813)
* scatter axis + gather axis primitives * add transforms * comment
This commit is contained in:
44
mlx/backend/metal/kernels/gather_axis.h
Normal file
44
mlx/backend/metal/kernels/gather_axis.h
Normal file
@@ -0,0 +1,44 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
template <typename T, typename IdxT, typename LocT, bool SrcC, bool IdxC>
|
||||
[[kernel]] void gather_axis(
|
||||
const device T* src [[buffer(0)]],
|
||||
const device IdxT* indices [[buffer(1)]],
|
||||
device T* out [[buffer(2)]],
|
||||
const constant int* shape [[buffer(3)]],
|
||||
const constant int64_t* src_strides [[buffer(4)]],
|
||||
const constant int64_t* idx_strides [[buffer(5)]],
|
||||
const constant size_t& ndim [[buffer(6)]],
|
||||
const constant int& axis [[buffer(7)]],
|
||||
const constant int& axis_size [[buffer(8)]],
|
||||
const constant size_t& src_ax_stride [[buffer(9)]],
|
||||
const constant size_t& idx_ax_stride [[buffer(10)]],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
LocT elem_idx = index.z * static_cast<LocT>(grid_dim.x);
|
||||
LocT out_idx = elem_idx * grid_dim.y + index.x;
|
||||
|
||||
LocT idx_loc = index.y * static_cast<LocT>(idx_ax_stride);
|
||||
if (IdxC) {
|
||||
idx_loc += out_idx;
|
||||
} else {
|
||||
idx_loc += elem_to_loc<LocT>(elem_idx + index.x, shape, idx_strides, ndim);
|
||||
}
|
||||
|
||||
auto idx_val = indices[idx_loc];
|
||||
if (is_signed_v<IdxT>) {
|
||||
idx_val = (idx_val < 0) ? idx_val + axis_size : idx_val;
|
||||
}
|
||||
|
||||
LocT src_idx = idx_val * static_cast<LocT>(src_ax_stride);
|
||||
if (SrcC) {
|
||||
src_idx += elem_idx * axis_size + index.x;
|
||||
} else {
|
||||
src_idx += elem_to_loc<LocT>(elem_idx + index.x, shape, src_strides, ndim);
|
||||
}
|
||||
|
||||
out_idx += index.y * static_cast<LocT>(grid_dim.x);
|
||||
out[out_idx] = src[src_idx];
|
||||
}
|
||||
52
mlx/backend/metal/kernels/scatter_axis.h
Normal file
52
mlx/backend/metal/kernels/scatter_axis.h
Normal file
@@ -0,0 +1,52 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename IdxT,
|
||||
typename LocT,
|
||||
typename Op,
|
||||
bool UpdC,
|
||||
bool IdxC>
|
||||
[[kernel]] void scatter_axis(
|
||||
const device T* upd [[buffer(0)]],
|
||||
const device IdxT* indices [[buffer(1)]],
|
||||
device mlx_atomic<T>* out [[buffer(2)]],
|
||||
const constant int* shape [[buffer(3)]],
|
||||
const constant int64_t* upd_strides [[buffer(4)]],
|
||||
const constant int64_t* idx_strides [[buffer(5)]],
|
||||
const constant size_t& ndim [[buffer(6)]],
|
||||
const constant int& axis [[buffer(7)]],
|
||||
const constant int& out_axis_size [[buffer(8)]],
|
||||
const constant size_t& upd_ax_stride [[buffer(9)]],
|
||||
const constant size_t& idx_ax_stride [[buffer(10)]],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
Op op;
|
||||
|
||||
LocT elem_idx = index.z * static_cast<LocT>(grid_dim.x);
|
||||
|
||||
LocT idx_loc = index.y * static_cast<LocT>(idx_ax_stride);
|
||||
if (IdxC) {
|
||||
idx_loc += elem_idx * grid_dim.y + index.x;
|
||||
} else {
|
||||
idx_loc += elem_to_loc<LocT>(elem_idx + index.x, shape, idx_strides, ndim);
|
||||
}
|
||||
|
||||
auto idx_val = indices[idx_loc];
|
||||
if (is_signed_v<IdxT>) {
|
||||
idx_val = (idx_val < 0) ? idx_val + out_axis_size : idx_val;
|
||||
}
|
||||
|
||||
LocT upd_idx = index.y * static_cast<LocT>(upd_ax_stride);
|
||||
if (UpdC) {
|
||||
upd_idx += elem_idx * grid_dim.y + index.x;
|
||||
} else {
|
||||
upd_idx += elem_to_loc<LocT>(elem_idx + index.x, shape, upd_strides, ndim);
|
||||
}
|
||||
|
||||
LocT out_idx = elem_idx * static_cast<LocT>(out_axis_size) +
|
||||
idx_val * grid_dim.x + index.x;
|
||||
op.atomic_update(out, upd[upd_idx], out_idx);
|
||||
}
|
||||
Reference in New Issue
Block a user