mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compare commits
3 Commits
1ba18ff7d9
...
1df9887998
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1df9887998 | ||
|
|
73f22d6226 | ||
|
|
c422050ca7 |
@@ -149,7 +149,7 @@ target_link_libraries(mlx PRIVATE CUDA::nvrtc CUDA::cuda_driver)
|
||||
FetchContent_Declare(
|
||||
cudnn
|
||||
GIT_REPOSITORY https://github.com/NVIDIA/cudnn-frontend.git
|
||||
GIT_TAG v1.12.1
|
||||
GIT_TAG v1.14.0
|
||||
GIT_SHALLOW TRUE
|
||||
EXCLUDE_FROM_ALL)
|
||||
set(CUDNN_FRONTEND_SKIP_JSON_LIB ON)
|
||||
|
||||
@@ -7,9 +7,6 @@
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
// cudnn_frontend.h redefines this macro.
|
||||
#undef CHECK_CUDA_ERROR
|
||||
|
||||
#include <cudnn_frontend.h>
|
||||
#include <cudnn_frontend_find_plan.h>
|
||||
#include <fmt/format.h>
|
||||
|
||||
@@ -262,10 +262,10 @@ struct GEMVKernel {
|
||||
vec_mask_offset += vec_mask_step;
|
||||
}
|
||||
|
||||
if (leftover > 0 &&
|
||||
(!has_operand_mask ||
|
||||
if (leftover > 0) {
|
||||
if (!has_operand_mask ||
|
||||
(bool(mat_mask[mat_mask_offset]) &&
|
||||
bool(vec_mask[vec_mask_offset])))) {
|
||||
bool(vec_mask[vec_mask_offset]))) {
|
||||
T block_scale{1};
|
||||
if (has_mul_operand_mask) {
|
||||
block_scale =
|
||||
@@ -295,6 +295,7 @@ struct GEMVKernel {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Apply out scale
|
||||
if (has_mul_output_mask) {
|
||||
@@ -544,10 +545,10 @@ struct GEMVTKernel {
|
||||
vec_mask_offset += vec_mask_step;
|
||||
}
|
||||
|
||||
if (leftover > 0 &&
|
||||
(!has_operand_mask ||
|
||||
if (leftover > 0) {
|
||||
if (!has_operand_mask ||
|
||||
(bool(mat_mask[mat_mask_offset]) &&
|
||||
bool(vec_mask[vec_mask_offset])))) {
|
||||
bool(vec_mask[vec_mask_offset]))) {
|
||||
T block_scale{1};
|
||||
if (has_mul_operand_mask) {
|
||||
block_scale =
|
||||
@@ -573,6 +574,7 @@ struct GEMVTKernel {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Apply out scale
|
||||
if (has_mul_output_mask) {
|
||||
|
||||
@@ -45,11 +45,13 @@ struct ThreadSort {
|
||||
for (short j = i & 1; j < N_PER_THREAD - 1; j += 2) {
|
||||
if (op(vals[j + 1], vals[j])) {
|
||||
thread_swap(vals[j + 1], vals[j]);
|
||||
if (ARG_SORT) {
|
||||
thread_swap(idxs[j + 1], idxs[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
@@ -111,7 +113,9 @@ struct BlockMergeSort {
|
||||
bool pred = (b_idx < B_sz) && (a_idx >= A_sz || op(b, a));
|
||||
|
||||
vals[i] = pred ? b : a;
|
||||
if (ARG_SORT) {
|
||||
idxs[i] = pred ? Bs_idx[b_idx] : As_idx[a_idx];
|
||||
}
|
||||
|
||||
b_idx += short(pred);
|
||||
a_idx += short(!pred);
|
||||
|
||||
Reference in New Issue
Block a user