Files
mlx/mlx/backend/metal/kernels/indexing/gather_front.h
Awni Hannun 111f1e71af Faster contiguous gather for indices in the first axis (#2552)
* faster contiguous gather for indices in the first axis

* work per thread > 1

* angelos suggestion for scales / biases
2025-08-28 21:26:30 -07:00

25 lines
723 B
C++

// Copyright © 2025 Apple Inc.
#pragma once
#include "mlx/backend/metal/kernels/indexing/indexing.h"
template <typename T, typename IdxT, typename LocT, int N>
[[kernel]] void gather_front(
const device T* src,
const device IdxT* indices,
device T* out,
const constant int64_t& stride,
const constant int& size,
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
auto idx = offset_neg_idx(indices[index.y], size);
LocT src_idx = static_cast<LocT>(stride) * idx;
LocT out_idx = static_cast<LocT>(stride) * index.y;
int s_idx = N * index.x;
for (int i = 0; i < N && s_idx < stride; ++i, ++s_idx) {
out[out_idx + s_idx] = src[src_idx + s_idx];
}
}