mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
106 lines
2.3 KiB
C++
106 lines
2.3 KiB
C++
// Copyright © 2025 Apple Inc.
|
|
|
|
#pragma once
|
|
|
|
#include <hip/hip_runtime.h>
|
|
#include <cstdint>
|
|
|
|
namespace mlx::core::rocm {
|
|
|
|
template <typename T>
|
|
struct StridedIterator {
|
|
using difference_type = ptrdiff_t;
|
|
using value_type = T;
|
|
using pointer = T*;
|
|
using reference = T&;
|
|
using iterator_category = std::random_access_iterator_tag;
|
|
|
|
T* ptr;
|
|
size_t stride;
|
|
|
|
__device__ StridedIterator(T* ptr, size_t stride)
|
|
: ptr(ptr), stride(stride) {}
|
|
|
|
__device__ StridedIterator operator+(difference_type n) const {
|
|
return StridedIterator(ptr + n * stride, stride);
|
|
}
|
|
|
|
__device__ StridedIterator operator-(difference_type n) const {
|
|
return StridedIterator(ptr - n * stride, stride);
|
|
}
|
|
|
|
__device__ difference_type operator-(const StridedIterator& other) const {
|
|
return (ptr - other.ptr) / stride;
|
|
}
|
|
|
|
__device__ StridedIterator& operator+=(difference_type n) {
|
|
ptr += n * stride;
|
|
return *this;
|
|
}
|
|
|
|
__device__ StridedIterator& operator-=(difference_type n) {
|
|
ptr -= n * stride;
|
|
return *this;
|
|
}
|
|
|
|
__device__ StridedIterator& operator++() {
|
|
ptr += stride;
|
|
return *this;
|
|
}
|
|
|
|
__device__ StridedIterator operator++(int) {
|
|
StridedIterator temp = *this;
|
|
ptr += stride;
|
|
return temp;
|
|
}
|
|
|
|
__device__ StridedIterator& operator--() {
|
|
ptr -= stride;
|
|
return *this;
|
|
}
|
|
|
|
__device__ StridedIterator operator--(int) {
|
|
StridedIterator temp = *this;
|
|
ptr -= stride;
|
|
return temp;
|
|
}
|
|
|
|
__device__ bool operator==(const StridedIterator& other) const {
|
|
return ptr == other.ptr;
|
|
}
|
|
|
|
__device__ bool operator!=(const StridedIterator& other) const {
|
|
return ptr != other.ptr;
|
|
}
|
|
|
|
__device__ bool operator<(const StridedIterator& other) const {
|
|
return ptr < other.ptr;
|
|
}
|
|
|
|
__device__ bool operator>(const StridedIterator& other) const {
|
|
return ptr > other.ptr;
|
|
}
|
|
|
|
__device__ bool operator<=(const StridedIterator& other) const {
|
|
return ptr <= other.ptr;
|
|
}
|
|
|
|
__device__ bool operator>=(const StridedIterator& other) const {
|
|
return ptr >= other.ptr;
|
|
}
|
|
|
|
__device__ T& operator*() const {
|
|
return *ptr;
|
|
}
|
|
|
|
__device__ T& operator[](difference_type n) const {
|
|
return *(ptr + n * stride);
|
|
}
|
|
};
|
|
|
|
template <typename T>
|
|
__device__ StridedIterator<T> make_strided_iterator(T* ptr, size_t stride) {
|
|
return StridedIterator<T>(ptr, stride);
|
|
}
|
|
|
|
} // namespace mlx::core::rocm
|