Compare commits

..

3 Commits

Author SHA1 Message Date
Cheng
2b95d0c270 [CUDA] Use cuDNN attention when T_q != T_kv (#2843)
Some checks failed
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
2025-11-27 09:58:43 +09:00
Chaoran Yu
b054838780 Added clarification to apply_fn parameter of apply_to_modules (#2831)
Co-authored-by: Awni Hannun <awni@apple.com>
2025-11-26 15:40:56 -08:00
Awni Hannun
dd79d3c465 [CUDA] Faster rms norm for small dimension (#2838) 2025-11-26 15:10:41 -08:00
4 changed files with 285 additions and 54 deletions

View File

@@ -22,26 +22,28 @@ inline __device__ float2 plus_f2(const float2& a, const float2& b) {
} }
// Similar to cub::BlockReduce, but result is broadcasted to every thread. // Similar to cub::BlockReduce, but result is broadcasted to every thread.
template <typename T, int BLOCK_DIM> template <typename T, int BLOCK_DIM, int GROUP_DIM = WARP_SIZE>
struct BlockBroadcastReduce { struct BlockBroadcastReduce {
static_assert(WARP_SIZE <= BLOCK_DIM && BLOCK_DIM <= WARP_SIZE * WARP_SIZE); using TempStorage = T[std::max(BLOCK_DIM / WARP_SIZE, 1)];
static_assert(BLOCK_DIM % WARP_SIZE == 0);
using TempStorage = T[BLOCK_DIM / WARP_SIZE];
cg::thread_block& block; cg::thread_block& block;
TempStorage& temp; TempStorage& temp;
template <typename Op> template <typename Op>
__device__ T Reduce(const T& input, const Op& op, const T& init_value) { __device__ T Reduce(const T& input, const Op& op, const T& init_value) {
auto warp = cg::tiled_partition<WARP_SIZE>(block); auto warp = cg::tiled_partition<GROUP_DIM>(block);
T x = cg::reduce(warp, input, op); T x = cg::reduce(warp, input, op);
if (warp.thread_rank() == 0) { if constexpr (BLOCK_DIM > GROUP_DIM) {
temp[warp.meta_group_rank()] = x; if (warp.thread_rank() == 0) {
temp[warp.meta_group_rank()] = x;
}
block.sync();
x = warp.thread_rank() < warp.meta_group_size() ? temp[warp.thread_rank()]
: init_value;
return cg::reduce(warp, x, op);
} else {
return x;
} }
block.sync();
x = warp.thread_rank() < warp.meta_group_size() ? temp[warp.thread_rank()]
: init_value;
return cg::reduce(warp, x, op);
} }
__device__ T Sum(const T& input) { __device__ T Sum(const T& input) {
@@ -49,6 +51,52 @@ struct BlockBroadcastReduce {
} }
}; };
template <typename T, int BLOCK_DIM, int REDUCE_DIM, int N_READS = 4>
__global__ void rms_norm_small(
const T* x,
const T* w,
T* out,
float eps,
uint32_t axis_size,
uint32_t n_rows,
int64_t w_stride) {
auto grid = cg::this_grid();
auto block = cg::this_thread_block();
using BlockReduceT = BlockBroadcastReduce<float, BLOCK_DIM, REDUCE_DIM>;
__shared__ typename BlockReduceT::TempStorage temp;
auto row =
(grid.block_rank() * block.dim_threads().y) + block.thread_index().y;
if (row >= n_rows) {
return;
}
x += row * axis_size;
out += row * axis_size;
// Normalizer.
float normalizer = 0;
auto index = block.thread_index().x;
auto xn = load_vector<N_READS>(x, index, axis_size, T(0));
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
float t = static_cast<float>(xn[i]);
normalizer += t * t;
}
normalizer = BlockReduceT{block, temp}.Sum(normalizer);
normalizer = rsqrt(normalizer / axis_size + eps);
// Outputs.
auto wn = load_vector<N_READS>(w, index, axis_size, w_stride, T(0));
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
float y = static_cast<float>(xn[i]) * normalizer;
xn[i] = wn[i] * static_cast<T>(y);
}
store_vector<N_READS>(out, index, xn, axis_size);
}
template <typename T, int BLOCK_DIM, int N_READS = 4> template <typename T, int BLOCK_DIM, int N_READS = 4>
__global__ void rms_norm( __global__ void rms_norm(
const T* x, const T* x,
@@ -94,6 +142,74 @@ __global__ void rms_norm(
} }
} }
template <
typename T,
bool HAS_W,
int BLOCK_DIM,
int REDUCE_DIM,
int N_READS = 4>
__global__ void rms_norm_vjp_small(
const T* x,
const T* w,
const T* g,
T* gx,
T* gw,
float eps,
int32_t axis_size,
int32_t n_rows,
int64_t w_stride) {
auto grid = cg::this_grid();
auto block = cg::this_thread_block();
using BlockReduceF2 = BlockBroadcastReduce<float2, BLOCK_DIM, REDUCE_DIM>;
__shared__ typename BlockReduceF2::TempStorage temp;
auto row =
(grid.block_rank() * block.dim_threads().y) + block.thread_index().y;
if (row >= n_rows) {
return;
}
x += row * axis_size;
g += row * axis_size;
gx += row * axis_size;
gw += row * axis_size;
// Normalizer.
float2 factors = {};
auto index = block.thread_index().x;
auto xn = load_vector<N_READS>(x, index, axis_size, T(0));
auto gn = load_vector<N_READS>(g, index, axis_size, T(0));
auto wn = load_vector<N_READS>(w, index, axis_size, w_stride, T(0));
for (int i = 0; i < N_READS; i++) {
float t = static_cast<float>(xn[i]);
float wi = wn[i];
float gi = gn[i];
float wg = wi * gi;
factors = plus_f2(factors, {wg * t, t * t});
}
factors = BlockReduceF2{block, temp}.Reduce(factors, plus_f2, {});
float meangwx = factors.x / axis_size;
float normalizer = rsqrt(factors.y / axis_size + eps);
float normalizer3 = normalizer * normalizer * normalizer;
// Outputs.
for (int i = 0; i < N_READS; i++) {
float xi = xn[i];
float wi = wn[i];
float gi = gn[i];
xn[i] = static_cast<T>(normalizer * wi * gi - xi * meangwx * normalizer3);
if constexpr (HAS_W) {
wn[i] = static_cast<T>(gi * xi * normalizer);
}
}
store_vector<N_READS>(gx, index, xn, axis_size);
if constexpr (HAS_W) {
store_vector<N_READS>(gw, index, wn, axis_size);
}
}
template <typename T, bool HAS_W, int BLOCK_DIM, int N_READS = 4> template <typename T, bool HAS_W, int BLOCK_DIM, int N_READS = 4>
__global__ void rms_norm_vjp( __global__ void rms_norm_vjp(
const T* x, const T* x,
@@ -107,12 +223,8 @@ __global__ void rms_norm_vjp(
auto grid = cg::this_grid(); auto grid = cg::this_grid();
auto block = cg::this_thread_block(); auto block = cg::this_thread_block();
using BlockReduceF = BlockBroadcastReduce<float, BLOCK_DIM>;
using BlockReduceF2 = BlockBroadcastReduce<float2, BLOCK_DIM>; using BlockReduceF2 = BlockBroadcastReduce<float2, BLOCK_DIM>;
__shared__ union { __shared__ typename BlockReduceF2::TempStorage temp;
typename BlockReduceF::TempStorage f;
typename BlockReduceF2::TempStorage f2;
} temp;
x += grid.block_rank() * axis_size; x += grid.block_rank() * axis_size;
g += grid.block_rank() * axis_size; g += grid.block_rank() * axis_size;
@@ -134,7 +246,7 @@ __global__ void rms_norm_vjp(
factors = plus_f2(factors, {wg * t, t * t}); factors = plus_f2(factors, {wg * t, t * t});
} }
} }
factors = BlockReduceF2{block, temp.f2}.Reduce(factors, plus_f2, {}); factors = BlockReduceF2{block, temp}.Reduce(factors, plus_f2, {});
float meangwx = factors.x / axis_size; float meangwx = factors.x / axis_size;
float normalizer = rsqrt(factors.y / axis_size + eps); float normalizer = rsqrt(factors.y / axis_size + eps);
float normalizer3 = normalizer * normalizer * normalizer; float normalizer3 = normalizer * normalizer * normalizer;
@@ -169,6 +281,43 @@ bool RMSNorm::use_fallback(Stream s) {
return s.device == Device::cpu; return s.device == Device::cpu;
} }
template <int n_per_thread, typename F>
void dispatch_group_dim(int axis_size, F&& f) {
if (axis_size <= n_per_thread * 8) {
f(std::integral_constant<int, 8>{},
std::integral_constant<int, 1>(),
std::integral_constant<int, 16>());
} else if (axis_size <= n_per_thread * 16) {
f(std::integral_constant<int, 16>{},
std::integral_constant<int, 1>(),
std::integral_constant<int, 8>());
} else if (axis_size <= n_per_thread * 32) {
f(std::integral_constant<int, 32>{},
std::integral_constant<int, 1>(),
std::integral_constant<int, 4>());
} else if (axis_size <= n_per_thread * 32 * 2) {
f(std::integral_constant<int, 32>{},
std::integral_constant<int, 2>(),
std::integral_constant<int, 2>());
} else if (axis_size <= n_per_thread * 32 * 4) {
f(std::integral_constant<int, 32>{},
std::integral_constant<int, 4>(),
std::integral_constant<int, 1>());
} else if (axis_size <= n_per_thread * 32 * 8) {
f(std::integral_constant<int, 32>{},
std::integral_constant<int, 8>(),
std::integral_constant<int, 1>());
} else if (axis_size <= n_per_thread * 32 * 16) {
f(std::integral_constant<int, 32>{},
std::integral_constant<int, 16>(),
std::integral_constant<int, 1>());
} else {
f(std::integral_constant<int, 32>{},
std::integral_constant<int, 32>(),
std::integral_constant<int, 1>());
}
}
// TODO: There are duplicate code with backend/metal/normalization.cpp // TODO: There are duplicate code with backend/metal/normalization.cpp
void RMSNorm::eval_gpu( void RMSNorm::eval_gpu(
const std::vector<array>& inputs, const std::vector<array>& inputs,
@@ -216,12 +365,33 @@ void RMSNorm::eval_gpu(
dispatch_float_types(out.dtype(), "rms_norm", [&](auto type_tag) { dispatch_float_types(out.dtype(), "rms_norm", [&](auto type_tag) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>; using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
constexpr int N_READS = 16 / sizeof(DataType); constexpr int N_READS = 16 / sizeof(DataType);
dispatch_block_dim(cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) { if (axis_size <= N_READS * 1024) {
auto kernel = cu::rms_norm<DataType, block_dim(), N_READS>; dispatch_group_dim<N_READS>(
axis_size, [&](auto group_dim, auto n_groups, auto groups_per_block) {
constexpr int block_dim = n_groups() * group_dim();
auto kernel =
cu::rms_norm_small<DataType, block_dim, group_dim(), N_READS>;
auto n_blocks =
(n_rows + groups_per_block() - 1) / groups_per_block();
encoder.add_kernel_node(
kernel,
n_blocks,
{block_dim, groups_per_block()},
0,
gpu_ptr<DataType>(x),
gpu_ptr<DataType>(w),
gpu_ptr<DataType>(out),
eps_,
axis_size,
n_rows,
w_stride);
});
} else {
auto kernel = cu::rms_norm<DataType, 1024, N_READS>;
encoder.add_kernel_node( encoder.add_kernel_node(
kernel, kernel,
n_rows, n_rows,
block_dim(), 1024,
0, 0,
gpu_ptr<DataType>(x), gpu_ptr<DataType>(x),
gpu_ptr<DataType>(w), gpu_ptr<DataType>(w),
@@ -229,7 +399,7 @@ void RMSNorm::eval_gpu(
eps_, eps_,
axis_size, axis_size,
w_stride); w_stride);
}); }
}); });
} }
@@ -306,27 +476,51 @@ void RMSNormVJP::eval_gpu(
dispatch_bool(has_w, [&](auto has_w_constant) { dispatch_bool(has_w, [&](auto has_w_constant) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>; using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
constexpr int N_READS = 16 / sizeof(DataType); constexpr int N_READS = 16 / sizeof(DataType);
dispatch_block_dim( if (axis_size <= N_READS * 1024) {
cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) { dispatch_group_dim<N_READS>(
auto kernel = cu::rms_norm_vjp< axis_size,
DataType, [&](auto group_dim, auto n_groups, auto groups_per_block) {
has_w_constant.value, constexpr int block_dim = group_dim() * n_groups();
block_dim(), auto kernel = cu::rms_norm_vjp_small<
N_READS>; DataType,
encoder.add_kernel_node( has_w_constant.value,
kernel, block_dim,
n_rows, group_dim(),
block_dim(), N_READS>;
0, auto n_blocks =
gpu_ptr<DataType>(x), (n_rows + groups_per_block() - 1) / groups_per_block();
gpu_ptr<DataType>(w), encoder.add_kernel_node(
gpu_ptr<DataType>(g), kernel,
gpu_ptr<DataType>(gx), n_blocks,
gpu_ptr<DataType>(gw_temp), {block_dim, groups_per_block()},
eps_, 0,
axis_size, gpu_ptr<DataType>(x),
w_stride); gpu_ptr<DataType>(w),
}); gpu_ptr<DataType>(g),
gpu_ptr<DataType>(gx),
gpu_ptr<DataType>(gw_temp),
eps_,
axis_size,
n_rows,
w_stride);
});
} else {
auto kernel =
cu::rms_norm_vjp<DataType, has_w_constant.value, 1024, N_READS>;
encoder.add_kernel_node(
kernel,
n_rows,
1024,
0,
gpu_ptr<DataType>(x),
gpu_ptr<DataType>(w),
gpu_ptr<DataType>(g),
gpu_ptr<DataType>(gx),
gpu_ptr<DataType>(gw_temp),
eps_,
axis_size,
w_stride);
}
}); });
}); });

View File

@@ -144,13 +144,13 @@ inline BytesKey<SDPACacheKey> build_sdpa_cache_key(
auto& sdpa_cache() { auto& sdpa_cache() {
static LRUBytesKeyCache<SDPACacheKey, fe::graph::Graph> cache( static LRUBytesKeyCache<SDPACacheKey, fe::graph::Graph> cache(
"MLX_CUDA_SDPA_CACHE_SIZE", /* default_capacity */ 16); "MLX_CUDA_SDPA_CACHE_SIZE", /* default_capacity */ 64);
return cache; return cache;
} }
auto& sdpa_backward_cache() { auto& sdpa_backward_cache() {
static LRUBytesKeyCache<SDPACacheKey, fe::graph::Graph> cache( static LRUBytesKeyCache<SDPACacheKey, fe::graph::Graph> cache(
"MLX_CUDA_SDPA_BACKWARD_CACHE_SIZE", /* default_capacity */ 16); "MLX_CUDA_SDPA_BACKWARD_CACHE_SIZE", /* default_capacity */ 64);
return cache; return cache;
} }
@@ -207,8 +207,14 @@ fe::graph::Graph build_sdpa_graph(
auto options = fe::graph::SDPA_attributes() auto options = fe::graph::SDPA_attributes()
.set_name("sdpa_cudnn") .set_name("sdpa_cudnn")
.set_attn_scale(scale) .set_attn_scale(scale)
.set_causal_mask(do_causal)
.set_generate_stats(output_logsumexp); .set_generate_stats(output_logsumexp);
if (do_causal) {
if (q.shape(2) > k.shape(2)) {
options.set_causal_mask(do_causal);
} else {
options.set_causal_mask_bottom_right(do_causal);
}
}
if (mask_arr) { if (mask_arr) {
auto bias_ = graph.tensor(fe::graph::Tensor_attributes().set_name("BIAS")); auto bias_ = graph.tensor(fe::graph::Tensor_attributes().set_name("BIAS"));
set_tensor_attrs(bias_, BIAS, *mask_arr); set_tensor_attrs(bias_, BIAS, *mask_arr);
@@ -282,7 +288,14 @@ fe::graph::Graph build_sdpa_backward_graph(
auto options = fe::graph::SDPA_backward_attributes() auto options = fe::graph::SDPA_backward_attributes()
.set_name("sdpa_backward_cudnn") .set_name("sdpa_backward_cudnn")
.set_attn_scale(scale) .set_attn_scale(scale)
.set_causal_mask(do_causal); .set_attn_scale(scale);
if (do_causal) {
if (q.shape(2) > k.shape(2)) {
options.set_causal_mask(do_causal);
} else {
options.set_causal_mask_bottom_right(do_causal);
}
}
if (mask_arr) { if (mask_arr) {
auto bias_ = graph.tensor(fe::graph::Tensor_attributes().set_name("BIAS")); auto bias_ = graph.tensor(fe::graph::Tensor_attributes().set_name("BIAS"));
set_tensor_attrs(bias_, BIAS, *mask_arr); set_tensor_attrs(bias_, BIAS, *mask_arr);
@@ -340,6 +353,7 @@ bool supports_sdpa_cudnn(
const array& q, const array& q,
const array& k, const array& k,
const array& v, const array& v,
bool do_causal,
Stream s) { Stream s) {
static bool enabled = env::get_var("MLX_CUDA_USE_CUDNN_SPDA", 1); static bool enabled = env::get_var("MLX_CUDA_USE_CUDNN_SPDA", 1);
if (!enabled) { if (!enabled) {
@@ -351,8 +365,8 @@ bool supports_sdpa_cudnn(
return false; return false;
} }
// Only use cuDNN for prefilling and training. // Only use cuDNN for prefilling (T_q > 1) and training (T_q == T_kv).
if (q.shape(2) != k.shape(2)) { if ((q.shape(2) == 1) && (q.shape(2) != k.shape(2))) {
return false; return false;
} }
@@ -520,7 +534,7 @@ bool ScaledDotProductAttention::use_fallback(
return !supports_sdpa_vector( return !supports_sdpa_vector(
q, k, v, has_mask, has_arr_mask, do_causal, output_logsumexp) && q, k, v, has_mask, has_arr_mask, do_causal, output_logsumexp) &&
!supports_sdpa_cudnn(q, k, v, s); !supports_sdpa_cudnn(q, k, v, do_causal, s);
} }
bool ScaledDotProductAttention::supports_bool_mask() { bool ScaledDotProductAttention::supports_bool_mask() {

View File

@@ -407,7 +407,10 @@ class Module(dict):
instance). instance).
Args: Args:
apply_fn (Callable): The function to apply to the modules. apply_fn (Callable): The function to apply to the modules which
takes two parameters. The first parameter is the string path of
the module (e.g. ``"model.layers.0.linear"``). The second
parameter is the module object.
Returns: Returns:
The module instance after updating submodules. The module instance after updating submodules.

View File

@@ -1238,8 +1238,18 @@ void init_transforms(nb::module_& m) {
same in number, shape, and type as the inputs of ``fun`` (i.e. the ``primals``). same in number, shape, and type as the inputs of ``fun`` (i.e. the ``primals``).
Returns: Returns:
list(array): A list of the Jacobian-vector products which tuple(list(array), list(array)): A tuple with the outputs of
is the same in number, shape, and type of the inputs to ``fun``. ``fun`` in the first position and the Jacobian-vector products
in the second position.
Example:
.. code-block:: python
import mlx.core as mx
outs, jvps = mx.jvp(mx.sin, (mx.array(1.0),), (mx.array(1.0),))
)pbdoc"); )pbdoc");
m.def( m.def(
"vjp", "vjp",
@@ -1277,8 +1287,18 @@ void init_transforms(nb::module_& m) {
same in number, shape, and type as the outputs of ``fun``. same in number, shape, and type as the outputs of ``fun``.
Returns: Returns:
list(array): A list of the vector-Jacobian products which tuple(list(array), list(array)): A tuple with the outputs of
is the same in number, shape, and type of the outputs of ``fun``. ``fun`` in the first position and the vector-Jacobian products
in the second position.
Example:
.. code-block:: python
import mlx.core as mx
outs, vjps = mx.vjp(mx.sin, (mx.array(1.0),), (mx.array(1.0),))
)pbdoc"); )pbdoc");
m.def( m.def(
"value_and_grad", "value_and_grad",