mlx/mlx/backend/metal/kernels/indexing.h
Awni Hannun 226748b3e7
JIT compile option for binary minimization (#1091)
* try cpp 20 for compile

* unary, binary, ternary in jit

* nits

* fix gather/scatter

* fix rebase

* reorg compile

* add ternary to compile

* jit copy

* jit compile flag

* fix build

* use linked function for ternary

* some nits

* docs + circle min size build

* docs + circle min size build

* fix extension

* fix no cpu build

* improve includes
2024-05-22 12:57:13 -07:00

23 lines
450 B
C++

// Copyright © 2023-2024 Apple Inc.
#pragma once
#include <metal_stdlib>
template <typename IdxT, int NIDX>
struct Indices {
const array<const device IdxT*, NIDX> buffers;
const constant int* shapes;
const constant size_t* strides;
const int ndim;
};
template <typename IdxT>
METAL_FUNC size_t offset_neg_idx(IdxT idx, size_t size) {
if (is_unsigned_v<IdxT>) {
return idx;
} else {
return (idx < 0) ? idx + size : idx;
}
}