// Copyright © 2025 Apple Inc. #include #include namespace mlx::core::cu { // Convert an absolute index to positions in a 3d grid, assuming the index is // calculated with: // index = x * dim1 * dim2 + y * dim2 + z template inline __host__ __device__ cuda::std::tuple index_to_dims(T index, T dim1, T dim2) { T x = index / (dim1 * dim2); T y = (index % (dim1 * dim2)) / dim2; T z = index % dim2; return cuda::std::make_tuple(x, y, z); } // Get absolute index from possible negative index. template inline __host__ __device__ auto absolute_index(IdxT idx, int32_t size) { if constexpr (cuda::std::is_unsigned_v) { return idx; } else { return static_cast(idx < 0 ? idx + size : idx); } } } // namespace mlx::core::cu