mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
* use int64 stride everywhere * fix ext * fix ext * more shape + cleanup * one more * few more
24 lines
487 B
C++
24 lines
487 B
C++
// Copyright © 2023-2024 Apple Inc.
|
|
|
|
#pragma once
|
|
|
|
#include <metal_stdlib>
|
|
|
|
template <typename IdxT, int NIDX>
|
|
struct Indices {
|
|
const array<const device IdxT*, NIDX> buffers;
|
|
const constant int* shapes;
|
|
const constant int64_t* strides;
|
|
const constant bool* row_contiguous;
|
|
const int ndim;
|
|
};
|
|
|
|
template <typename IdxT>
|
|
METAL_FUNC size_t offset_neg_idx(IdxT idx, int size) {
|
|
if (is_unsigned_v<IdxT>) {
|
|
return idx;
|
|
} else {
|
|
return (idx < 0) ? idx + size : idx;
|
|
}
|
|
}
|