Compare commits

..

3 Commits

Author SHA1 Message Date
Angelos Katharopoulos
1df9887998 Ensure no oob read in gemv_masked (#2508) 2025-08-17 08:42:33 -07:00
Angelos Katharopoulos
73f22d6226 Ensure small sort doesn't use indices if not argsort (#2506) 2025-08-17 08:42:20 -07:00
Cheng
c422050ca7 Update cuDNN Frontend to v1.14 (#2505) 2025-08-17 19:13:01 +09:00
4 changed files with 56 additions and 53 deletions

View File

@@ -149,7 +149,7 @@ target_link_libraries(mlx PRIVATE CUDA::nvrtc CUDA::cuda_driver)
FetchContent_Declare(
cudnn
GIT_REPOSITORY https://github.com/NVIDIA/cudnn-frontend.git
GIT_TAG v1.12.1
GIT_TAG v1.14.0
GIT_SHALLOW TRUE
EXCLUDE_FROM_ALL)
set(CUDNN_FRONTEND_SKIP_JSON_LIB ON)

View File

@@ -7,9 +7,6 @@
#include "mlx/dtype_utils.h"
#include "mlx/primitives.h"
// cudnn_frontend.h redefines this macro.
#undef CHECK_CUDA_ERROR
#include <cudnn_frontend.h>
#include <cudnn_frontend_find_plan.h>
#include <fmt/format.h>

View File

@@ -262,10 +262,10 @@ struct GEMVKernel {
vec_mask_offset += vec_mask_step;
}
if (leftover > 0 &&
(!has_operand_mask ||
if (leftover > 0) {
if (!has_operand_mask ||
(bool(mat_mask[mat_mask_offset]) &&
bool(vec_mask[vec_mask_offset])))) {
bool(vec_mask[vec_mask_offset]))) {
T block_scale{1};
if (has_mul_operand_mask) {
block_scale =
@@ -295,6 +295,7 @@ struct GEMVKernel {
}
}
}
}
// Apply out scale
if (has_mul_output_mask) {
@@ -544,10 +545,10 @@ struct GEMVTKernel {
vec_mask_offset += vec_mask_step;
}
if (leftover > 0 &&
(!has_operand_mask ||
if (leftover > 0) {
if (!has_operand_mask ||
(bool(mat_mask[mat_mask_offset]) &&
bool(vec_mask[vec_mask_offset])))) {
bool(vec_mask[vec_mask_offset]))) {
T block_scale{1};
if (has_mul_operand_mask) {
block_scale =
@@ -573,6 +574,7 @@ struct GEMVTKernel {
}
}
}
}
// Apply out scale
if (has_mul_output_mask) {

View File

@@ -45,11 +45,13 @@ struct ThreadSort {
for (short j = i & 1; j < N_PER_THREAD - 1; j += 2) {
if (op(vals[j + 1], vals[j])) {
thread_swap(vals[j + 1], vals[j]);
if (ARG_SORT) {
thread_swap(idxs[j + 1], idxs[j]);
}
}
}
}
}
};
///////////////////////////////////////////////////////////////////////////////
@@ -111,7 +113,9 @@ struct BlockMergeSort {
bool pred = (b_idx < B_sz) && (a_idx >= A_sz || op(b, a));
vals[i] = pred ? b : a;
if (ARG_SORT) {
idxs[i] = pred ? Bs_idx[b_idx] : As_idx[a_idx];
}
b_idx += short(pred);
a_idx += short(!pred);