mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-04 01:36:42 +08:00
faster rms norm (#2433)
This commit is contained in:
parent
970dbe8e25
commit
ef631d63af
@ -28,7 +28,7 @@ __global__ void binary_ss(const In* a, const In* b, Out* out, IdxT size) {
|
|||||||
AlignedVector<Out, N_READS> out_vec;
|
AlignedVector<Out, N_READS> out_vec;
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < N_READS; ++i) {
|
for (int i = 0; i < N_READS; ++i) {
|
||||||
out_vec.val[i] = Op{}(a[0], b[0]);
|
out_vec[i] = Op{}(a[0], b[0]);
|
||||||
}
|
}
|
||||||
|
|
||||||
store_vector<N_READS>(out, index, out_vec);
|
store_vector<N_READS>(out, index, out_vec);
|
||||||
@ -49,7 +49,7 @@ __global__ void binary_sv(const In* a, const In* b, Out* out, IdxT size) {
|
|||||||
AlignedVector<Out, N_READS> out_vec;
|
AlignedVector<Out, N_READS> out_vec;
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < N_READS; ++i) {
|
for (int i = 0; i < N_READS; ++i) {
|
||||||
out_vec.val[i] = Op{}(a[0], b_vec.val[i]);
|
out_vec[i] = Op{}(a[0], b_vec[i]);
|
||||||
}
|
}
|
||||||
|
|
||||||
store_vector<N_READS>(out, index, out_vec);
|
store_vector<N_READS>(out, index, out_vec);
|
||||||
@ -70,7 +70,7 @@ __global__ void binary_vs(const In* a, const In* b, Out* out, IdxT size) {
|
|||||||
AlignedVector<Out, N_READS> out_vec;
|
AlignedVector<Out, N_READS> out_vec;
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < N_READS; ++i) {
|
for (int i = 0; i < N_READS; ++i) {
|
||||||
out_vec.val[i] = Op{}(a_vec.val[i], b[0]);
|
out_vec[i] = Op{}(a_vec[i], b[0]);
|
||||||
}
|
}
|
||||||
|
|
||||||
store_vector<N_READS>(out, index, out_vec);
|
store_vector<N_READS>(out, index, out_vec);
|
||||||
@ -92,7 +92,7 @@ __global__ void binary_vv(const In* a, const In* b, Out* out, IdxT size) {
|
|||||||
AlignedVector<Out, N_READS> out_vec;
|
AlignedVector<Out, N_READS> out_vec;
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < N_READS; ++i) {
|
for (int i = 0; i < N_READS; ++i) {
|
||||||
out_vec.val[i] = Op{}(a_vec.val[i], b_vec.val[i]);
|
out_vec[i] = Op{}(a_vec[i], b_vec[i]);
|
||||||
}
|
}
|
||||||
|
|
||||||
store_vector<N_READS>(out, index, out_vec);
|
store_vector<N_READS>(out, index, out_vec);
|
||||||
@ -248,8 +248,7 @@ void binary_op_gpu_inplace(
|
|||||||
} else {
|
} else {
|
||||||
dispatch_bool(out.data_size() > UINT32_MAX, [&](auto large) {
|
dispatch_bool(out.data_size() > UINT32_MAX, [&](auto large) {
|
||||||
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
|
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
|
||||||
// TODO: Choose optimized value based on type size.
|
constexpr int N_READS = 16 / sizeof(InType);
|
||||||
constexpr int N_READS = 4;
|
|
||||||
auto kernel = cu::binary_ss<Op, InType, OutType, IdxT, N_READS>;
|
auto kernel = cu::binary_ss<Op, InType, OutType, IdxT, N_READS>;
|
||||||
if (bopt == BinaryOpType::ScalarVector) {
|
if (bopt == BinaryOpType::ScalarVector) {
|
||||||
kernel = cu::binary_sv<Op, InType, OutType, IdxT, N_READS>;
|
kernel = cu::binary_sv<Op, InType, OutType, IdxT, N_READS>;
|
||||||
|
@ -33,8 +33,8 @@ binary_two_ss(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
|
|||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < N_READS; ++i) {
|
for (int i = 0; i < N_READS; ++i) {
|
||||||
auto out = Op{}(a[0], b[0]);
|
auto out = Op{}(a[0], b[0]);
|
||||||
out_a_vec.val[i] = out[0];
|
out_a_vec[i] = out[0];
|
||||||
out_b_vec.val[i] = out[1];
|
out_b_vec[i] = out[1];
|
||||||
}
|
}
|
||||||
|
|
||||||
store_vector<N_READS>(out_a, index, out_a_vec);
|
store_vector<N_READS>(out_a, index, out_a_vec);
|
||||||
@ -60,9 +60,9 @@ binary_two_sv(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
|
|||||||
AlignedVector<Out, N_READS> out_b_vec;
|
AlignedVector<Out, N_READS> out_b_vec;
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < N_READS; ++i) {
|
for (int i = 0; i < N_READS; ++i) {
|
||||||
auto out = Op{}(a[0], b_vec.val[i]);
|
auto out = Op{}(a[0], b_vec[i]);
|
||||||
out_a_vec.val[i] = out[0];
|
out_a_vec[i] = out[0];
|
||||||
out_b_vec.val[i] = out[1];
|
out_b_vec[i] = out[1];
|
||||||
}
|
}
|
||||||
|
|
||||||
store_vector<N_READS>(out_a, index, out_a_vec);
|
store_vector<N_READS>(out_a, index, out_a_vec);
|
||||||
@ -88,9 +88,9 @@ binary_two_vs(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
|
|||||||
AlignedVector<Out, N_READS> out_b_vec;
|
AlignedVector<Out, N_READS> out_b_vec;
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < N_READS; ++i) {
|
for (int i = 0; i < N_READS; ++i) {
|
||||||
auto out = Op{}(a_vec.val[i], b[0]);
|
auto out = Op{}(a_vec[i], b[0]);
|
||||||
out_a_vec.val[i] = out[0];
|
out_a_vec[i] = out[0];
|
||||||
out_b_vec.val[i] = out[1];
|
out_b_vec[i] = out[1];
|
||||||
}
|
}
|
||||||
|
|
||||||
store_vector<N_READS>(out_a, index, out_a_vec);
|
store_vector<N_READS>(out_a, index, out_a_vec);
|
||||||
@ -117,9 +117,9 @@ binary_two_vv(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
|
|||||||
AlignedVector<Out, N_READS> out_b_vec;
|
AlignedVector<Out, N_READS> out_b_vec;
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < N_READS; ++i) {
|
for (int i = 0; i < N_READS; ++i) {
|
||||||
auto out = Op{}(a_vec.val[i], b_vec.val[i]);
|
auto out = Op{}(a_vec[i], b_vec[i]);
|
||||||
out_a_vec.val[i] = out[0];
|
out_a_vec[i] = out[0];
|
||||||
out_b_vec.val[i] = out[1];
|
out_b_vec[i] = out[1];
|
||||||
}
|
}
|
||||||
|
|
||||||
store_vector<N_READS>(out_a, index, out_a_vec);
|
store_vector<N_READS>(out_a, index, out_a_vec);
|
||||||
@ -270,8 +270,7 @@ void binary_two_op_gpu_inplace(
|
|||||||
} else {
|
} else {
|
||||||
dispatch_bool(out_a.data_size() > UINT32_MAX, [&](auto large) {
|
dispatch_bool(out_a.data_size() > UINT32_MAX, [&](auto large) {
|
||||||
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
|
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
|
||||||
// TODO: Choose optimized value based on type size.
|
constexpr int N_READS = 16 / sizeof(InType);
|
||||||
constexpr int N_READS = 4;
|
|
||||||
auto kernel = cu::binary_two_ss<Op, InType, OutType, IdxT, N_READS>;
|
auto kernel = cu::binary_two_ss<Op, InType, OutType, IdxT, N_READS>;
|
||||||
if (bopt == BinaryOpType::ScalarVector) {
|
if (bopt == BinaryOpType::ScalarVector) {
|
||||||
kernel = cu::binary_two_sv<Op, InType, OutType, IdxT, N_READS>;
|
kernel = cu::binary_two_sv<Op, InType, OutType, IdxT, N_READS>;
|
||||||
|
@ -22,7 +22,7 @@ __global__ void copy_s(const In* in, Out* out, IdxT size) {
|
|||||||
AlignedVector<Out, N_READS> out_vec;
|
AlignedVector<Out, N_READS> out_vec;
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < N_READS; ++i) {
|
for (int i = 0; i < N_READS; ++i) {
|
||||||
out_vec.val[i] = cast_to<Out>(in[0]);
|
out_vec[i] = cast_to<Out>(in[0]);
|
||||||
}
|
}
|
||||||
|
|
||||||
store_vector<N_READS>(out, index, out_vec);
|
store_vector<N_READS>(out, index, out_vec);
|
||||||
@ -43,7 +43,7 @@ __global__ void copy_v(const In* in, Out* out, IdxT size) {
|
|||||||
AlignedVector<Out, N_READS> out_vec;
|
AlignedVector<Out, N_READS> out_vec;
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < N_READS; ++i) {
|
for (int i = 0; i < N_READS; ++i) {
|
||||||
out_vec.val[i] = cast_to<Out>(in_vec.val[i]);
|
out_vec[i] = cast_to<Out>(in_vec[i]);
|
||||||
}
|
}
|
||||||
|
|
||||||
store_vector<N_READS>(out, index, out_vec);
|
store_vector<N_READS>(out, index, out_vec);
|
||||||
@ -65,8 +65,7 @@ void copy_contiguous(
|
|||||||
using InType = cuda_type_t<MLX_GET_TYPE(in_type_tag)>;
|
using InType = cuda_type_t<MLX_GET_TYPE(in_type_tag)>;
|
||||||
using OutType = cuda_type_t<MLX_GET_TYPE(out_type_tag)>;
|
using OutType = cuda_type_t<MLX_GET_TYPE(out_type_tag)>;
|
||||||
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
|
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
|
||||||
// TODO: Choose optimized value based on type size.
|
constexpr int N_READS = 16 / sizeof(InType);
|
||||||
constexpr int N_READS = 4;
|
|
||||||
auto kernel = cu::copy_s<InType, OutType, IdxT, N_READS>;
|
auto kernel = cu::copy_s<InType, OutType, IdxT, N_READS>;
|
||||||
if (ctype == CopyType::Vector) {
|
if (ctype == CopyType::Vector) {
|
||||||
kernel = cu::copy_v<InType, OutType, IdxT, N_READS>;
|
kernel = cu::copy_v<InType, OutType, IdxT, N_READS>;
|
||||||
|
@ -32,21 +32,103 @@ using Strides = cuda::std::array<int64_t, MAX_NDIM>;
|
|||||||
template <typename T, int N>
|
template <typename T, int N>
|
||||||
struct alignas(sizeof(T) * N) AlignedVector {
|
struct alignas(sizeof(T) * N) AlignedVector {
|
||||||
T val[N];
|
T val[N];
|
||||||
|
|
||||||
|
__device__ T& operator[](int i) {
|
||||||
|
return val[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ T operator[](int i) const {
|
||||||
|
return val[i];
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <int N, typename T>
|
||||||
|
inline __device__ bool is_aligned(T* x) {
|
||||||
|
return (reinterpret_cast<uintptr_t>(x) % (N * sizeof(T))) == 0;
|
||||||
|
}
|
||||||
|
|
||||||
template <int N, typename T>
|
template <int N, typename T>
|
||||||
inline __device__ AlignedVector<T, N> load_vector(
|
inline __device__ AlignedVector<T, N> load_vector(
|
||||||
const T* ptr,
|
const T* ptr,
|
||||||
uint32_t offset) {
|
uint32_t offset) {
|
||||||
auto* from = reinterpret_cast<const AlignedVector<T, N>*>(ptr);
|
if (is_aligned<N>(ptr)) {
|
||||||
return from[offset];
|
auto* from = reinterpret_cast<const AlignedVector<T, N>*>(ptr);
|
||||||
|
return from[offset];
|
||||||
|
} else {
|
||||||
|
AlignedVector<T, N> v;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < N; ++i) {
|
||||||
|
v[i] = ptr[offset * N + i];
|
||||||
|
}
|
||||||
|
return v;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int N, typename T, typename SizeT>
|
||||||
|
inline __device__ AlignedVector<T, N>
|
||||||
|
load_vector(const T* ptr, uint32_t offset, SizeT size, T fallback) {
|
||||||
|
if (is_aligned<N>(ptr) && (offset + 1) * N <= size) {
|
||||||
|
auto* from = reinterpret_cast<const AlignedVector<T, N>*>(ptr);
|
||||||
|
return from[offset];
|
||||||
|
} else {
|
||||||
|
AlignedVector<T, N> v;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < N; ++i) {
|
||||||
|
v[i] = (N * offset + i) < size ? ptr[offset * N + i] : fallback;
|
||||||
|
}
|
||||||
|
return v;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int N, typename T, typename SizeT>
|
||||||
|
inline __device__ AlignedVector<T, N> load_vector(
|
||||||
|
const T* ptr,
|
||||||
|
uint32_t offset,
|
||||||
|
SizeT size,
|
||||||
|
int64_t stride,
|
||||||
|
T fallback) {
|
||||||
|
if (is_aligned<N>(ptr) && stride == 1 && (offset + 1) * N <= size) {
|
||||||
|
auto* from = reinterpret_cast<const AlignedVector<T, N>*>(ptr);
|
||||||
|
return from[offset];
|
||||||
|
} else {
|
||||||
|
AlignedVector<T, N> v;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < N; ++i) {
|
||||||
|
v[i] =
|
||||||
|
(N * offset + i) < size ? ptr[stride * (offset * N + i)] : fallback;
|
||||||
|
}
|
||||||
|
return v;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <int N, typename T>
|
template <int N, typename T>
|
||||||
inline __device__ void
|
inline __device__ void
|
||||||
store_vector(T* ptr, uint32_t offset, const AlignedVector<T, N>& vec) {
|
store_vector(T* ptr, uint32_t offset, const AlignedVector<T, N>& vec) {
|
||||||
auto* to = reinterpret_cast<AlignedVector<T, N>*>(ptr);
|
if (is_aligned<N>(ptr)) {
|
||||||
to[offset] = vec;
|
auto* to = reinterpret_cast<AlignedVector<T, N>*>(ptr);
|
||||||
|
to[offset] = vec;
|
||||||
|
} else {
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < N; ++i) {
|
||||||
|
ptr[offset * N + i] = vec[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int N, typename T, typename SizeT>
|
||||||
|
inline __device__ void store_vector(
|
||||||
|
T* ptr,
|
||||||
|
uint32_t offset,
|
||||||
|
const AlignedVector<T, N>& vec,
|
||||||
|
SizeT size) {
|
||||||
|
if (is_aligned<N>(ptr) && (offset + 1) * N <= size) {
|
||||||
|
auto* to = reinterpret_cast<AlignedVector<T, N>*>(ptr);
|
||||||
|
to[offset] = vec;
|
||||||
|
} else {
|
||||||
|
for (int i = 0; (offset * N + i) < size && i < N; ++i) {
|
||||||
|
ptr[offset * N + i] = vec[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Helper for accessing strided data.
|
// Helper for accessing strided data.
|
||||||
|
@ -31,8 +31,8 @@ gemv_impl(const T* mat, const T* vec, T* out, int rows, int cols) {
|
|||||||
auto local_vec = load_vector<n_per_thread>(vec + col, 0);
|
auto local_vec = load_vector<n_per_thread>(vec + col, 0);
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j = 0; j < n_per_thread; ++j) {
|
for (int j = 0; j < n_per_thread; ++j) {
|
||||||
sum += static_cast<float>(local_mat.val[j]) *
|
sum +=
|
||||||
static_cast<float>(local_vec.val[j]);
|
static_cast<float>(local_mat[j]) * static_cast<float>(local_vec[j]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -73,8 +73,7 @@ __global__ void gemv_batched(
|
|||||||
}
|
}
|
||||||
|
|
||||||
bool can_use_gemv(int M, int N, int K, bool a_transposed, bool b_transposed) {
|
bool can_use_gemv(int M, int N, int K, bool a_transposed, bool b_transposed) {
|
||||||
bool is_multiple = K % 32 == 0 || K % 64 == 0 || K % 128 == 0;
|
return K % 32 == 0 && ((M == 1 && b_transposed) || (N == 1 && !a_transposed));
|
||||||
return is_multiple && ((M == 1 && b_transposed) || (N == 1 && !a_transposed));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename F>
|
template <typename F>
|
||||||
|
@ -10,8 +10,6 @@
|
|||||||
#include <cooperative_groups.h>
|
#include <cooperative_groups.h>
|
||||||
#include <cooperative_groups/reduce.h>
|
#include <cooperative_groups/reduce.h>
|
||||||
#include <nvtx3/nvtx3.hpp>
|
#include <nvtx3/nvtx3.hpp>
|
||||||
#include <cub/block/block_load.cuh>
|
|
||||||
#include <cub/block/block_reduce.cuh>
|
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
@ -74,9 +72,11 @@ __global__ void layer_norm(
|
|||||||
float sum = 0;
|
float sum = 0;
|
||||||
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||||
auto index = r * BLOCK_DIM + block.thread_rank();
|
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||||
T xn[N_READS] = {};
|
auto xn = load_vector<N_READS>(x, index, axis_size, T(0));
|
||||||
cub::LoadDirectBlocked(index, x, xn, axis_size);
|
#pragma unroll
|
||||||
sum += static_cast<float>(cub::ThreadReduce(xn, cuda::std::plus<>{}));
|
for (int i = 0; i < N_READS; ++i) {
|
||||||
|
sum += static_cast<float>(xn[i]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
sum = BlockReduceT{block, temp}.Sum(sum);
|
sum = BlockReduceT{block, temp}.Sum(sum);
|
||||||
|
|
||||||
@ -87,11 +87,18 @@ __global__ void layer_norm(
|
|||||||
float normalizer = 0;
|
float normalizer = 0;
|
||||||
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||||
auto index = r * BLOCK_DIM + block.thread_rank();
|
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||||
T xn[N_READS];
|
if ((index + 1) * N_READS <= axis_size) {
|
||||||
cub::LoadDirectBlocked(index, x, xn, axis_size, mean);
|
auto xn = load_vector<N_READS>(x, index);
|
||||||
for (int i = 0; i < N_READS; ++i) {
|
#pragma unroll
|
||||||
float t = static_cast<float>(xn[i]) - mean;
|
for (int i = 0; i < N_READS; ++i) {
|
||||||
normalizer += t * t;
|
float t = static_cast<float>(xn[i]) - mean;
|
||||||
|
normalizer += t * t;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (int i = index * N_READS; i < axis_size; ++i) {
|
||||||
|
float t = static_cast<float>(x[i]) - mean;
|
||||||
|
normalizer += t * t;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
normalizer = BlockReduceT{block, temp}.Sum(normalizer);
|
normalizer = BlockReduceT{block, temp}.Sum(normalizer);
|
||||||
@ -100,17 +107,15 @@ __global__ void layer_norm(
|
|||||||
// Outputs.
|
// Outputs.
|
||||||
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||||
auto index = r * BLOCK_DIM + block.thread_rank();
|
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||||
T xn[N_READS];
|
auto xn = load_vector<N_READS>(x, index, axis_size, T(0));
|
||||||
T wn[N_READS];
|
auto wn = load_vector<N_READS>(w, index, axis_size, w_stride, T(0));
|
||||||
T bn[N_READS];
|
auto bn = load_vector<N_READS>(b, index, axis_size, b_stride, T(0));
|
||||||
cub::LoadDirectBlocked(index, x, xn, axis_size);
|
#pragma unroll
|
||||||
cub::LoadDirectBlocked(index, StridedIterator(w, w_stride), wn, axis_size);
|
|
||||||
cub::LoadDirectBlocked(index, StridedIterator(b, b_stride), bn, axis_size);
|
|
||||||
for (int i = 0; i < N_READS; ++i) {
|
for (int i = 0; i < N_READS; ++i) {
|
||||||
float norm = (static_cast<float>(xn[i]) - mean) * normalizer;
|
float norm = (static_cast<float>(xn[i]) - mean) * normalizer;
|
||||||
xn[i] = wn[i] * static_cast<T>(norm) + bn[i];
|
xn[i] = wn[i] * static_cast<T>(norm) + bn[i];
|
||||||
}
|
}
|
||||||
cub::StoreDirectBlocked(index, out, xn, axis_size);
|
store_vector<N_READS>(out, index, xn, axis_size);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -143,9 +148,11 @@ __global__ void layer_norm_vjp(
|
|||||||
float sum = 0;
|
float sum = 0;
|
||||||
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||||
auto index = r * BLOCK_DIM + block.thread_rank();
|
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||||
T xn[N_READS] = {};
|
auto xn = load_vector<N_READS>(x, index, axis_size, T(0));
|
||||||
cub::LoadDirectBlocked(index, x, xn, axis_size);
|
#pragma unroll
|
||||||
sum += static_cast<float>(cub::ThreadReduce(xn, cuda::std::plus<>{}));
|
for (int i = 0; i < N_READS; ++i) {
|
||||||
|
sum += static_cast<float>(xn[i]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
sum = BlockReduceF{block, temp.f}.Sum(sum);
|
sum = BlockReduceF{block, temp.f}.Sum(sum);
|
||||||
|
|
||||||
@ -155,19 +162,28 @@ __global__ void layer_norm_vjp(
|
|||||||
// Normalizer.
|
// Normalizer.
|
||||||
float3 factors = {};
|
float3 factors = {};
|
||||||
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||||
T xn[N_READS];
|
|
||||||
T wn[N_READS] = {};
|
|
||||||
T gn[N_READS] = {};
|
|
||||||
auto index = r * BLOCK_DIM + block.thread_rank();
|
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||||
cub::LoadDirectBlocked(index, x, xn, axis_size, mean);
|
auto gn = load_vector<N_READS>(g, index, axis_size, T(0));
|
||||||
cub::LoadDirectBlocked(index, g, gn, axis_size);
|
auto wn = load_vector<N_READS>(w, index, axis_size, w_stride, T(0));
|
||||||
cub::LoadDirectBlocked(index, StridedIterator(w, w_stride), wn, axis_size);
|
|
||||||
for (int i = 0; i < N_READS; i++) {
|
if ((index + 1) * N_READS <= axis_size) {
|
||||||
float t = static_cast<float>(xn[i]) - mean;
|
auto xn = load_vector<N_READS>(x, index);
|
||||||
float wi = wn[i];
|
#pragma unroll
|
||||||
float gi = gn[i];
|
for (int i = 0; i < N_READS; ++i) {
|
||||||
float wg = wi * gi;
|
float t = static_cast<float>(xn[i]) - mean;
|
||||||
factors = plus_f3(factors, {wg, wg * t, t * t});
|
float wi = wn[i];
|
||||||
|
float gi = gn[i];
|
||||||
|
float wg = wi * gi;
|
||||||
|
factors = plus_f3(factors, {wg, wg * t, t * t});
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (int i = index * N_READS; i < axis_size; ++i) {
|
||||||
|
float t = static_cast<float>(x[i]) - mean;
|
||||||
|
float wi = wn[i];
|
||||||
|
float gi = gn[i];
|
||||||
|
float wg = wi * gi;
|
||||||
|
factors = plus_f3(factors, {wg, wg * t, t * t});
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
factors = BlockReduceF3{block, temp.f3}.Reduce(factors, plus_f3, {});
|
factors = BlockReduceF3{block, temp.f3}.Reduce(factors, plus_f3, {});
|
||||||
@ -179,12 +195,10 @@ __global__ void layer_norm_vjp(
|
|||||||
// Outputs.
|
// Outputs.
|
||||||
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||||
auto index = r * BLOCK_DIM + block.thread_rank();
|
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||||
T xn[N_READS];
|
auto xn = load_vector<N_READS>(x, index, axis_size, T(0));
|
||||||
T wn[N_READS];
|
auto gn = load_vector<N_READS>(g, index, axis_size, T(0));
|
||||||
T gn[N_READS];
|
auto wn = load_vector<N_READS>(w, index, axis_size, w_stride, T(0));
|
||||||
cub::LoadDirectBlocked(index, x, xn, axis_size);
|
|
||||||
cub::LoadDirectBlocked(index, g, gn, axis_size);
|
|
||||||
cub::LoadDirectBlocked(index, StridedIterator(w, w_stride), wn, axis_size);
|
|
||||||
for (int i = 0; i < N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
float xi = (static_cast<float>(xn[i]) - mean) * normalizer;
|
float xi = (static_cast<float>(xn[i]) - mean) * normalizer;
|
||||||
float wi = wn[i];
|
float wi = wn[i];
|
||||||
@ -194,9 +208,9 @@ __global__ void layer_norm_vjp(
|
|||||||
wn[i] = gi * xi;
|
wn[i] = gi * xi;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
cub::StoreDirectBlocked(index, gx, xn, axis_size);
|
store_vector<N_READS>(gx, index, xn, axis_size);
|
||||||
if constexpr (HAS_W) {
|
if constexpr (HAS_W) {
|
||||||
cub::StoreDirectBlocked(index, gw, wn, axis_size);
|
store_vector<N_READS>(gw, index, wn, axis_size);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -257,9 +271,9 @@ void LayerNorm::eval_gpu(
|
|||||||
encoder.set_input_array(b);
|
encoder.set_input_array(b);
|
||||||
encoder.set_output_array(out);
|
encoder.set_output_array(out);
|
||||||
dispatch_float_types(out.dtype(), "layernorm", [&](auto type_tag) {
|
dispatch_float_types(out.dtype(), "layernorm", [&](auto type_tag) {
|
||||||
constexpr uint32_t N_READS = 4;
|
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||||
|
constexpr int N_READS = 16 / sizeof(DataType);
|
||||||
dispatch_block_dim(cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
|
dispatch_block_dim(cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
|
||||||
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
|
||||||
auto kernel = cu::layer_norm<DataType, block_dim(), N_READS>;
|
auto kernel = cu::layer_norm<DataType, block_dim(), N_READS>;
|
||||||
encoder.add_kernel_node(
|
encoder.add_kernel_node(
|
||||||
kernel,
|
kernel,
|
||||||
@ -364,10 +378,10 @@ void LayerNormVJP::eval_gpu(
|
|||||||
encoder.set_output_array(gw_temp);
|
encoder.set_output_array(gw_temp);
|
||||||
dispatch_float_types(gx.dtype(), "layernorm_vjp", [&](auto type_tag) {
|
dispatch_float_types(gx.dtype(), "layernorm_vjp", [&](auto type_tag) {
|
||||||
dispatch_bool(has_w, [&](auto has_w_constant) {
|
dispatch_bool(has_w, [&](auto has_w_constant) {
|
||||||
constexpr int N_READS = 4;
|
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||||
|
constexpr int N_READS = 16 / sizeof(DataType);
|
||||||
dispatch_block_dim(
|
dispatch_block_dim(
|
||||||
cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
|
cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
|
||||||
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
|
||||||
auto kernel = cu::layer_norm_vjp<
|
auto kernel = cu::layer_norm_vjp<
|
||||||
DataType,
|
DataType,
|
||||||
has_w_constant.value,
|
has_w_constant.value,
|
||||||
|
@ -5,8 +5,6 @@
|
|||||||
#include "mlx/backend/gpu/copy.h"
|
#include "mlx/backend/gpu/copy.h"
|
||||||
|
|
||||||
#include <nvtx3/nvtx3.hpp>
|
#include <nvtx3/nvtx3.hpp>
|
||||||
#include <thrust/device_ptr.h>
|
|
||||||
#include <thrust/fill.h>
|
|
||||||
|
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
|
|
||||||
|
@ -10,8 +10,6 @@
|
|||||||
#include <cooperative_groups.h>
|
#include <cooperative_groups.h>
|
||||||
#include <cooperative_groups/reduce.h>
|
#include <cooperative_groups/reduce.h>
|
||||||
#include <nvtx3/nvtx3.hpp>
|
#include <nvtx3/nvtx3.hpp>
|
||||||
#include <cub/block/block_load.cuh>
|
|
||||||
#include <cub/block/block_reduce.cuh>
|
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
@ -57,7 +55,7 @@ __global__ void rms_norm(
|
|||||||
const T* w,
|
const T* w,
|
||||||
T* out,
|
T* out,
|
||||||
float eps,
|
float eps,
|
||||||
int32_t axis_size,
|
uint32_t axis_size,
|
||||||
int64_t w_stride) {
|
int64_t w_stride) {
|
||||||
auto grid = cg::this_grid();
|
auto grid = cg::this_grid();
|
||||||
auto block = cg::this_thread_block();
|
auto block = cg::this_thread_block();
|
||||||
@ -72,8 +70,8 @@ __global__ void rms_norm(
|
|||||||
float normalizer = 0;
|
float normalizer = 0;
|
||||||
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||||
auto index = r * BLOCK_DIM + block.thread_rank();
|
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||||
T xn[N_READS];
|
auto xn = load_vector<N_READS>(x, index, axis_size, T(0));
|
||||||
cub::LoadDirectBlocked(index, x, xn, axis_size, cast_to<T>(0));
|
#pragma unroll
|
||||||
for (int i = 0; i < N_READS; ++i) {
|
for (int i = 0; i < N_READS; ++i) {
|
||||||
float t = static_cast<float>(xn[i]);
|
float t = static_cast<float>(xn[i]);
|
||||||
normalizer += t * t;
|
normalizer += t * t;
|
||||||
@ -85,15 +83,14 @@ __global__ void rms_norm(
|
|||||||
// Outputs.
|
// Outputs.
|
||||||
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||||
auto index = r * BLOCK_DIM + block.thread_rank();
|
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||||
T xn[N_READS];
|
auto xn = load_vector<N_READS>(x, index, axis_size, T(0));
|
||||||
T wn[N_READS];
|
auto wn = load_vector<N_READS>(w, index, axis_size, w_stride, T(0));
|
||||||
cub::LoadDirectBlocked(index, x, xn, axis_size);
|
#pragma unroll
|
||||||
cub::LoadDirectBlocked(index, StridedIterator(w, w_stride), wn, axis_size);
|
|
||||||
for (int i = 0; i < N_READS; ++i) {
|
for (int i = 0; i < N_READS; ++i) {
|
||||||
float norm = static_cast<float>(xn[i]) * normalizer;
|
float y = static_cast<float>(xn[i]) * normalizer;
|
||||||
xn[i] = wn[i] * static_cast<T>(norm);
|
xn[i] = wn[i] * static_cast<T>(y);
|
||||||
}
|
}
|
||||||
cub::StoreDirectBlocked(index, out, xn, axis_size);
|
store_vector<N_READS>(out, index, xn, axis_size);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -125,13 +122,10 @@ __global__ void rms_norm_vjp(
|
|||||||
// Normalizer.
|
// Normalizer.
|
||||||
float2 factors = {};
|
float2 factors = {};
|
||||||
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||||
T xn[N_READS];
|
|
||||||
T wn[N_READS] = {};
|
|
||||||
T gn[N_READS] = {};
|
|
||||||
auto index = r * BLOCK_DIM + block.thread_rank();
|
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||||
cub::LoadDirectBlocked(index, x, xn, axis_size, cast_to<T>(0));
|
auto xn = load_vector<N_READS>(x, index, axis_size, T(0));
|
||||||
cub::LoadDirectBlocked(index, g, gn, axis_size);
|
auto gn = load_vector<N_READS>(g, index, axis_size, T(0));
|
||||||
cub::LoadDirectBlocked(index, StridedIterator(w, w_stride), wn, axis_size);
|
auto wn = load_vector<N_READS>(w, index, axis_size, w_stride, T(0));
|
||||||
for (int i = 0; i < N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
float t = static_cast<float>(xn[i]);
|
float t = static_cast<float>(xn[i]);
|
||||||
float wi = wn[i];
|
float wi = wn[i];
|
||||||
@ -148,12 +142,9 @@ __global__ void rms_norm_vjp(
|
|||||||
// Outputs.
|
// Outputs.
|
||||||
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||||
auto index = r * BLOCK_DIM + block.thread_rank();
|
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||||
T xn[N_READS];
|
auto xn = load_vector<N_READS>(x, index, axis_size, T(0));
|
||||||
T wn[N_READS];
|
auto gn = load_vector<N_READS>(g, index, axis_size, T(0));
|
||||||
T gn[N_READS];
|
auto wn = load_vector<N_READS>(w, index, axis_size, w_stride, T(0));
|
||||||
cub::LoadDirectBlocked(index, x, xn, axis_size);
|
|
||||||
cub::LoadDirectBlocked(index, g, gn, axis_size);
|
|
||||||
cub::LoadDirectBlocked(index, StridedIterator(w, w_stride), wn, axis_size);
|
|
||||||
for (int i = 0; i < N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
float xi = xn[i];
|
float xi = xn[i];
|
||||||
float wi = wn[i];
|
float wi = wn[i];
|
||||||
@ -163,9 +154,9 @@ __global__ void rms_norm_vjp(
|
|||||||
wn[i] = static_cast<T>(gi * xi * normalizer);
|
wn[i] = static_cast<T>(gi * xi * normalizer);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
cub::StoreDirectBlocked(index, gx, xn, axis_size);
|
store_vector<N_READS>(gx, index, xn, axis_size);
|
||||||
if constexpr (HAS_W) {
|
if constexpr (HAS_W) {
|
||||||
cub::StoreDirectBlocked(index, gw, wn, axis_size);
|
store_vector<N_READS>(gw, index, wn, axis_size);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -223,9 +214,9 @@ void RMSNorm::eval_gpu(
|
|||||||
encoder.set_input_array(w);
|
encoder.set_input_array(w);
|
||||||
encoder.set_output_array(out);
|
encoder.set_output_array(out);
|
||||||
dispatch_float_types(out.dtype(), "rms_norm", [&](auto type_tag) {
|
dispatch_float_types(out.dtype(), "rms_norm", [&](auto type_tag) {
|
||||||
constexpr uint32_t N_READS = 4;
|
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||||
|
constexpr int N_READS = 16 / sizeof(DataType);
|
||||||
dispatch_block_dim(cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
|
dispatch_block_dim(cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
|
||||||
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
|
||||||
auto kernel = cu::rms_norm<DataType, block_dim(), N_READS>;
|
auto kernel = cu::rms_norm<DataType, block_dim(), N_READS>;
|
||||||
encoder.add_kernel_node(
|
encoder.add_kernel_node(
|
||||||
kernel,
|
kernel,
|
||||||
@ -312,11 +303,10 @@ void RMSNormVJP::eval_gpu(
|
|||||||
encoder.set_output_array(gw_temp);
|
encoder.set_output_array(gw_temp);
|
||||||
dispatch_float_types(gx.dtype(), "rms_norm_vjp", [&](auto type_tag) {
|
dispatch_float_types(gx.dtype(), "rms_norm_vjp", [&](auto type_tag) {
|
||||||
dispatch_bool(has_w, [&](auto has_w_constant) {
|
dispatch_bool(has_w, [&](auto has_w_constant) {
|
||||||
constexpr int N_READS = 4;
|
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||||
|
constexpr int N_READS = 16 / sizeof(DataType);
|
||||||
dispatch_block_dim(
|
dispatch_block_dim(
|
||||||
cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
|
cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
|
||||||
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
|
||||||
constexpr int N_READS = 4;
|
|
||||||
auto kernel = cu::rms_norm_vjp<
|
auto kernel = cu::rms_norm_vjp<
|
||||||
DataType,
|
DataType,
|
||||||
has_w_constant.value,
|
has_w_constant.value,
|
||||||
|
@ -32,7 +32,7 @@ ternary_v(const bool* a, const T* b, const T* c, T* out, IdxT size) {
|
|||||||
AlignedVector<T, N_READS> out_vec;
|
AlignedVector<T, N_READS> out_vec;
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < N_READS; ++i) {
|
for (int i = 0; i < N_READS; ++i) {
|
||||||
out_vec.val[i] = Op{}(a_vec.val[i], b_vec.val[i], c_vec.val[i]);
|
out_vec[i] = Op{}(a_vec[i], b_vec[i], c_vec[i]);
|
||||||
}
|
}
|
||||||
|
|
||||||
store_vector<N_READS>(out, index, out_vec);
|
store_vector<N_READS>(out, index, out_vec);
|
||||||
@ -166,8 +166,7 @@ void ternary_op_gpu_inplace(
|
|||||||
} else {
|
} else {
|
||||||
dispatch_bool(out.data_size() > UINT32_MAX, [&](auto large) {
|
dispatch_bool(out.data_size() > UINT32_MAX, [&](auto large) {
|
||||||
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
|
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
|
||||||
// TODO: Choose optimized value based on type size.
|
constexpr int N_READS = 16 / sizeof(DType);
|
||||||
constexpr int N_READS = 4;
|
|
||||||
auto kernel = cu::ternary_v<Op, DType, IdxT, N_READS>;
|
auto kernel = cu::ternary_v<Op, DType, IdxT, N_READS>;
|
||||||
auto [num_blocks, block_dims] = get_launch_args(
|
auto [num_blocks, block_dims] = get_launch_args(
|
||||||
kernel,
|
kernel,
|
||||||
|
@ -30,7 +30,7 @@ __global__ void unary_v(const In* in, Out* out, IdxT size) {
|
|||||||
AlignedVector<Out, N_READS> out_vec;
|
AlignedVector<Out, N_READS> out_vec;
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < N_READS; ++i) {
|
for (int i = 0; i < N_READS; ++i) {
|
||||||
out_vec.val[i] = Op{}(in_vec.val[i]);
|
out_vec[i] = Op{}(in_vec[i]);
|
||||||
}
|
}
|
||||||
|
|
||||||
store_vector<N_READS>(out, index, out_vec);
|
store_vector<N_READS>(out, index, out_vec);
|
||||||
|
@ -3049,6 +3049,25 @@ class TestOps(mlx_tests.MLXTestCase):
|
|||||||
out = mx.power(mx.array(0j), float("nan"))
|
out = mx.power(mx.array(0j), float("nan"))
|
||||||
self.assertTrue(mx.isnan(out))
|
self.assertTrue(mx.isnan(out))
|
||||||
|
|
||||||
|
def test_irregular_alignments(self):
|
||||||
|
# Unaligned unary op
|
||||||
|
a = mx.ones((64, 1))
|
||||||
|
b = -a[1:]
|
||||||
|
self.assertTrue(mx.all(b == -1.0))
|
||||||
|
|
||||||
|
# Unaligned binary op
|
||||||
|
a = mx.ones((64, 1))
|
||||||
|
b = a[1:]
|
||||||
|
c = b + b
|
||||||
|
self.assertTrue(mx.all(c == 2.0))
|
||||||
|
|
||||||
|
# Unaligned ternary op
|
||||||
|
a = mx.ones((64, 1))
|
||||||
|
b = mx.zeros((63, 1))
|
||||||
|
c = mx.ones((63, 1)).astype(mx.bool_)
|
||||||
|
d = mx.where(c, a[1:], b)
|
||||||
|
self.assertTrue(mx.all(d == 1.0))
|
||||||
|
|
||||||
|
|
||||||
class TestBroadcast(mlx_tests.MLXTestCase):
|
class TestBroadcast(mlx_tests.MLXTestCase):
|
||||||
def test_broadcast_shapes(self):
|
def test_broadcast_shapes(self):
|
||||||
|
Loading…
Reference in New Issue
Block a user