mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Faster general unary op (#2472)
* faster general unary op * faster general ops + reorg * fix + comment * binary two * copy general
This commit is contained in:
@@ -146,6 +146,23 @@ inline __device__ void store_vector(
|
||||
}
|
||||
}
|
||||
|
||||
template <int N, typename T, typename SizeT>
|
||||
inline __device__ void store_vector(
|
||||
T* ptr,
|
||||
uint32_t offset,
|
||||
const AlignedVector<T, N>& vec,
|
||||
SizeT size,
|
||||
int64_t stride) {
|
||||
if (is_aligned<N>(ptr) && (offset + 1) * N <= size && stride == 1) {
|
||||
auto* to = reinterpret_cast<AlignedVector<T, N>*>(ptr);
|
||||
to[offset] = vec;
|
||||
} else {
|
||||
for (int i = 0; (offset * N + i) < size && i < N; ++i) {
|
||||
ptr[stride * (offset * N + i)] = vec[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Type limits utils
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
Reference in New Issue
Block a user