2024-05-23 03:57:13 +08:00
|
|
|
// Copyright © 2024 Apple Inc.
|
|
|
|
|
|
|
|
#pragma once
|
|
|
|
|
|
|
|
#include "mlx/backend/metal/kernels/indexing.h"
|
|
|
|
|
2024-11-19 11:52:00 +08:00
|
|
|
template <typename T, typename IdxT, int NIDX, int IDX_NDIM, typename LocT>
|
2024-05-23 03:57:13 +08:00
|
|
|
METAL_FUNC void gather_impl(
|
|
|
|
const device T* src [[buffer(0)]],
|
|
|
|
device T* out [[buffer(1)]],
|
|
|
|
const constant int* src_shape [[buffer(2)]],
|
|
|
|
const constant size_t* src_strides [[buffer(3)]],
|
|
|
|
const constant size_t& src_ndim [[buffer(4)]],
|
|
|
|
const constant int* slice_sizes [[buffer(5)]],
|
|
|
|
const constant int* axes [[buffer(6)]],
|
|
|
|
const thread Indices<IdxT, NIDX>& indices,
|
2024-08-23 01:48:24 +08:00
|
|
|
uint3 index [[thread_position_in_grid]],
|
|
|
|
uint3 grid_dim [[threads_per_grid]]) {
|
2024-11-19 11:52:00 +08:00
|
|
|
LocT src_idx = 0;
|
2024-05-23 03:57:13 +08:00
|
|
|
for (int i = 0; i < NIDX; ++i) {
|
2024-11-19 11:52:00 +08:00
|
|
|
LocT idx_loc;
|
2024-05-23 03:57:13 +08:00
|
|
|
if (IDX_NDIM == 0) {
|
|
|
|
idx_loc = 0;
|
|
|
|
} else if (IDX_NDIM == 1) {
|
2024-11-19 11:52:00 +08:00
|
|
|
idx_loc = index.x * static_cast<LocT>(indices.strides[indices.ndim * i]);
|
2024-05-23 03:57:13 +08:00
|
|
|
} else {
|
2024-11-19 11:52:00 +08:00
|
|
|
idx_loc = index.x * static_cast<LocT>(indices.strides[indices.ndim * i]);
|
2024-10-31 10:30:54 +08:00
|
|
|
idx_loc += indices.row_contiguous[i]
|
|
|
|
? index.y
|
2024-11-19 11:52:00 +08:00
|
|
|
: elem_to_loc<size_t, LocT>(
|
2024-10-31 10:30:54 +08:00
|
|
|
index.y,
|
|
|
|
&indices.shapes[indices.ndim * i + 1],
|
|
|
|
&indices.strides[indices.ndim * i + 1],
|
|
|
|
indices.ndim - 1);
|
2024-05-23 03:57:13 +08:00
|
|
|
}
|
|
|
|
auto ax = axes[i];
|
|
|
|
auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], src_shape[ax]);
|
2024-11-19 11:52:00 +08:00
|
|
|
src_idx += static_cast<LocT>(idx_val) * static_cast<LocT>(src_strides[ax]);
|
2024-05-23 03:57:13 +08:00
|
|
|
}
|
|
|
|
|
2024-11-19 11:52:00 +08:00
|
|
|
auto src_offset =
|
|
|
|
elem_to_loc<size_t, LocT>(index.z, slice_sizes, src_strides, src_ndim);
|
2024-05-23 03:57:13 +08:00
|
|
|
|
2024-11-19 11:52:00 +08:00
|
|
|
LocT out_idx = index.z;
|
2024-08-23 01:48:24 +08:00
|
|
|
if (IDX_NDIM == 1) {
|
2024-11-19 11:52:00 +08:00
|
|
|
out_idx += static_cast<LocT>(grid_dim.z) * index.x;
|
2024-08-23 01:48:24 +08:00
|
|
|
} else if (IDX_NDIM >= 2) {
|
2024-11-19 11:52:00 +08:00
|
|
|
out_idx += grid_dim.z * (index.x * static_cast<LocT>(grid_dim.y) + index.y);
|
2024-08-23 01:48:24 +08:00
|
|
|
}
|
2024-05-23 03:57:13 +08:00
|
|
|
out[out_idx] = src[src_offset + src_idx];
|
|
|
|
}
|