MLX
 
Loading...
Searching...
No Matches
gather_axis.h
Go to the documentation of this file.
1// Copyright © 2025 Apple Inc.
2
3#pragma once
4
5template <typename T, typename IdxT, typename LocT, bool SrcC, bool IdxC>
6[[kernel]] void gather_axis(
7 const device T* src [[buffer(0)]],
8 const device IdxT* indices [[buffer(1)]],
9 device T* out [[buffer(2)]],
10 const constant int* shape [[buffer(3)]],
11 const constant int64_t* src_strides [[buffer(4)]],
12 const constant int64_t* idx_strides [[buffer(5)]],
13 const constant size_t& ndim [[buffer(6)]],
14 const constant int& axis [[buffer(7)]],
15 const constant int& axis_size [[buffer(8)]],
16 const constant size_t& src_ax_stride [[buffer(9)]],
17 const constant size_t& idx_ax_stride [[buffer(10)]],
18 uint3 index [[thread_position_in_grid]],
19 uint3 grid_dim [[threads_per_grid]]) {
20 LocT elem_idx = index.z * static_cast<LocT>(grid_dim.x);
21 LocT out_idx = elem_idx * grid_dim.y + index.x;
22
23 LocT idx_loc = index.y * static_cast<LocT>(idx_ax_stride);
24 if (IdxC) {
25 idx_loc += out_idx;
26 } else {
27 idx_loc += elem_to_loc<LocT>(elem_idx + index.x, shape, idx_strides, ndim);
28 }
29
30 auto idx_val = indices[idx_loc];
31 if (is_signed_v<IdxT>) {
32 idx_val = (idx_val < 0) ? idx_val + axis_size : idx_val;
33 }
34
35 LocT src_idx = idx_val * static_cast<LocT>(src_ax_stride);
36 if (SrcC) {
37 src_idx += elem_idx * axis_size + index.x;
38 } else {
39 src_idx += elem_to_loc<LocT>(elem_idx + index.x, shape, src_strides, ndim);
40 }
41
42 out_idx += index.y * static_cast<LocT>(grid_dim.x);
43 out[out_idx] = src[src_idx];
44}
METAL_FUNC IdxT elem_to_loc(IdxT elem, constant const int *shape, constant const int64_t *strides, int ndim)
Definition utils.h:93
void gather_axis(const device T *src, const device IdxT *indices, device T *out, const constant int *shape, const constant int64_t *src_strides, const constant int64_t *idx_strides, const constant size_t &ndim, const constant int &axis, const constant int &axis_size, const constant size_t &src_ax_stride, const constant size_t &idx_ax_stride, uint3 index, uint3 grid_dim)
Definition gather_axis.h:6