mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
fix gemv regression (#2445)
This commit is contained in:
@@ -43,10 +43,18 @@ struct alignas(sizeof(T) * N) AlignedVector {
|
||||
};
|
||||
|
||||
template <int N, typename T>
|
||||
inline __device__ bool is_aligned(T* x) {
|
||||
inline __host__ __device__ bool is_aligned(T* x) {
|
||||
return (reinterpret_cast<uintptr_t>(x) % (N * sizeof(T))) == 0;
|
||||
}
|
||||
|
||||
template <int N, typename T>
|
||||
inline __device__ AlignedVector<T, N> unsafe_load_vector(
|
||||
const T* ptr,
|
||||
uint32_t offset) {
|
||||
auto* from = reinterpret_cast<const AlignedVector<T, N>*>(ptr);
|
||||
return from[offset];
|
||||
}
|
||||
|
||||
template <int N, typename T>
|
||||
inline __device__ AlignedVector<T, N> load_vector(
|
||||
const T* ptr,
|
||||
@@ -101,6 +109,13 @@ inline __device__ AlignedVector<T, N> load_vector(
|
||||
}
|
||||
}
|
||||
|
||||
template <int N, typename T>
|
||||
inline __device__ void
|
||||
unsafe_store_vector(T* ptr, uint32_t offset, const AlignedVector<T, N>& vec) {
|
||||
auto* to = reinterpret_cast<AlignedVector<T, N>*>(ptr);
|
||||
to[offset] = vec;
|
||||
}
|
||||
|
||||
template <int N, typename T>
|
||||
inline __device__ void
|
||||
store_vector(T* ptr, uint32_t offset, const AlignedVector<T, N>& vec) {
|
||||
|
||||
Reference in New Issue
Block a user