Files
mlx/mlx/backend/metal/kernels/indexing.h
Awni Hannun 40c62c1321 Use int64 stride everywhere (#1671)
* use int64 stride everywhere

* fix ext

* fix ext

* more shape + cleanup

* one more

* few more
2024-12-09 11:09:02 -08:00

24 lines
487 B
C++

// Copyright © 2023-2024 Apple Inc.
#pragma once
#include <metal_stdlib>
template <typename IdxT, int NIDX>
struct Indices {
const array<const device IdxT*, NIDX> buffers;
const constant int* shapes;
const constant int64_t* strides;
const constant bool* row_contiguous;
const int ndim;
};
template <typename IdxT>
METAL_FUNC size_t offset_neg_idx(IdxT idx, int size) {
if (is_unsigned_v<IdxT>) {
return idx;
} else {
return (idx < 0) ? idx + size : idx;
}
}