Faster general unary op (#2472)

* faster general unary op

* faster general ops + reorg

* fix + comment

* binary two

* copy general
This commit is contained in:
Awni Hannun
2025-08-15 15:04:12 -07:00
committed by GitHub
parent dfb5022eab
commit 6441c21a94
62 changed files with 1215 additions and 203 deletions

View File

@@ -146,6 +146,23 @@ inline __device__ void store_vector(
}
}
template <int N, typename T, typename SizeT>
inline __device__ void store_vector(
T* ptr,
uint32_t offset,
const AlignedVector<T, N>& vec,
SizeT size,
int64_t stride) {
if (is_aligned<N>(ptr) && (offset + 1) * N <= size && stride == 1) {
auto* to = reinterpret_cast<AlignedVector<T, N>*>(ptr);
to[offset] = vec;
} else {
for (int i = 0; (offset * N + i) < size && i < N; ++i) {
ptr[stride * (offset * N + i)] = vec[i];
}
}
}
///////////////////////////////////////////////////////////////////////////////
// Type limits utils
///////////////////////////////////////////////////////////////////////////////