mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-23 18:11:17 +08:00

* try cpp 20 for compile * unary, binary, ternary in jit * nits * fix gather/scatter * fix rebase * reorg compile * add ternary to compile * jit copy * jit compile flag * fix build * use linked function for ternary * some nits * docs + circle min size build * docs + circle min size build * fix extension * fix no cpu build * improve includes
23 lines
450 B
C++
23 lines
450 B
C++
// Copyright © 2023-2024 Apple Inc.
|
|
|
|
#pragma once
|
|
|
|
#include <metal_stdlib>
|
|
|
|
template <typename IdxT, int NIDX>
|
|
struct Indices {
|
|
const array<const device IdxT*, NIDX> buffers;
|
|
const constant int* shapes;
|
|
const constant size_t* strides;
|
|
const int ndim;
|
|
};
|
|
|
|
template <typename IdxT>
|
|
METAL_FUNC size_t offset_neg_idx(IdxT idx, size_t size) {
|
|
if (is_unsigned_v<IdxT>) {
|
|
return idx;
|
|
} else {
|
|
return (idx < 0) ? idx + size : idx;
|
|
}
|
|
}
|