mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compare commits
3 Commits
1ba18ff7d9
...
1df9887998
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1df9887998 | ||
|
|
73f22d6226 | ||
|
|
c422050ca7 |
@@ -149,7 +149,7 @@ target_link_libraries(mlx PRIVATE CUDA::nvrtc CUDA::cuda_driver)
|
|||||||
FetchContent_Declare(
|
FetchContent_Declare(
|
||||||
cudnn
|
cudnn
|
||||||
GIT_REPOSITORY https://github.com/NVIDIA/cudnn-frontend.git
|
GIT_REPOSITORY https://github.com/NVIDIA/cudnn-frontend.git
|
||||||
GIT_TAG v1.12.1
|
GIT_TAG v1.14.0
|
||||||
GIT_SHALLOW TRUE
|
GIT_SHALLOW TRUE
|
||||||
EXCLUDE_FROM_ALL)
|
EXCLUDE_FROM_ALL)
|
||||||
set(CUDNN_FRONTEND_SKIP_JSON_LIB ON)
|
set(CUDNN_FRONTEND_SKIP_JSON_LIB ON)
|
||||||
|
|||||||
@@ -7,9 +7,6 @@
|
|||||||
#include "mlx/dtype_utils.h"
|
#include "mlx/dtype_utils.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
// cudnn_frontend.h redefines this macro.
|
|
||||||
#undef CHECK_CUDA_ERROR
|
|
||||||
|
|
||||||
#include <cudnn_frontend.h>
|
#include <cudnn_frontend.h>
|
||||||
#include <cudnn_frontend_find_plan.h>
|
#include <cudnn_frontend_find_plan.h>
|
||||||
#include <fmt/format.h>
|
#include <fmt/format.h>
|
||||||
|
|||||||
@@ -262,36 +262,37 @@ struct GEMVKernel {
|
|||||||
vec_mask_offset += vec_mask_step;
|
vec_mask_offset += vec_mask_step;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (leftover > 0 &&
|
if (leftover > 0) {
|
||||||
(!has_operand_mask ||
|
if (!has_operand_mask ||
|
||||||
(bool(mat_mask[mat_mask_offset]) &&
|
(bool(mat_mask[mat_mask_offset]) &&
|
||||||
bool(vec_mask[vec_mask_offset])))) {
|
bool(vec_mask[vec_mask_offset]))) {
|
||||||
T block_scale{1};
|
T block_scale{1};
|
||||||
if (has_mul_operand_mask) {
|
if (has_mul_operand_mask) {
|
||||||
block_scale =
|
block_scale =
|
||||||
T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]);
|
T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]);
|
||||||
}
|
|
||||||
|
|
||||||
load_safe<AccT>(in_vec, v_coeff, bn, in_size);
|
|
||||||
|
|
||||||
// Apply scale
|
|
||||||
if (has_mul_operand_mask) {
|
|
||||||
MLX_MTL_PRAGMA_UNROLL
|
|
||||||
for (int tn = 0; tn < TN; tn++) {
|
|
||||||
v_coeff[tn] *= block_scale;
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
// Per thread work loop
|
load_safe<AccT>(in_vec, v_coeff, bn, in_size);
|
||||||
MLX_MTL_PRAGMA_UNROLL
|
|
||||||
for (int tm = 0; tm < TM; tm++) {
|
|
||||||
// Load for the row
|
|
||||||
load_safe(&mat[tm * matrix_ld], inter, bn, in_size);
|
|
||||||
|
|
||||||
// Accumulate results
|
// Apply scale
|
||||||
|
if (has_mul_operand_mask) {
|
||||||
|
MLX_MTL_PRAGMA_UNROLL
|
||||||
|
for (int tn = 0; tn < TN; tn++) {
|
||||||
|
v_coeff[tn] *= block_scale;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Per thread work loop
|
||||||
MLX_MTL_PRAGMA_UNROLL
|
MLX_MTL_PRAGMA_UNROLL
|
||||||
for (int tn = 0; tn < TN; tn++) {
|
for (int tm = 0; tm < TM; tm++) {
|
||||||
result[tm] += inter[tn] * v_coeff[tn];
|
// Load for the row
|
||||||
|
load_safe(&mat[tm * matrix_ld], inter, bn, in_size);
|
||||||
|
|
||||||
|
// Accumulate results
|
||||||
|
MLX_MTL_PRAGMA_UNROLL
|
||||||
|
for (int tn = 0; tn < TN; tn++) {
|
||||||
|
result[tm] += inter[tn] * v_coeff[tn];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -544,31 +545,32 @@ struct GEMVTKernel {
|
|||||||
vec_mask_offset += vec_mask_step;
|
vec_mask_offset += vec_mask_step;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (leftover > 0 &&
|
if (leftover > 0) {
|
||||||
(!has_operand_mask ||
|
if (!has_operand_mask ||
|
||||||
(bool(mat_mask[mat_mask_offset]) &&
|
(bool(mat_mask[mat_mask_offset]) &&
|
||||||
bool(vec_mask[vec_mask_offset])))) {
|
bool(vec_mask[vec_mask_offset]))) {
|
||||||
T block_scale{1};
|
T block_scale{1};
|
||||||
if (has_mul_operand_mask) {
|
|
||||||
block_scale =
|
|
||||||
T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]);
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int tm = 0; tm < TM && bm + tm < in_vec_size; tm++) {
|
|
||||||
v_coeff[tm] = static_cast<AccT>(in_vec[bm + tm]);
|
|
||||||
|
|
||||||
if (has_mul_operand_mask) {
|
if (has_mul_operand_mask) {
|
||||||
v_coeff[tm] *= block_scale;
|
block_scale =
|
||||||
|
T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]);
|
||||||
}
|
}
|
||||||
|
|
||||||
MLX_MTL_PRAGMA_UNROLL
|
for (int tm = 0; tm < TM && bm + tm < in_vec_size; tm++) {
|
||||||
for (int tn = 0; tn < TN; tn++) {
|
v_coeff[tm] = static_cast<AccT>(in_vec[bm + tm]);
|
||||||
inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn];
|
|
||||||
}
|
|
||||||
|
|
||||||
MLX_MTL_PRAGMA_UNROLL
|
if (has_mul_operand_mask) {
|
||||||
for (int tn = 0; tn < TN; tn++) {
|
v_coeff[tm] *= block_scale;
|
||||||
result[tn] += v_coeff[tm] * inter[tn];
|
}
|
||||||
|
|
||||||
|
MLX_MTL_PRAGMA_UNROLL
|
||||||
|
for (int tn = 0; tn < TN; tn++) {
|
||||||
|
inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn];
|
||||||
|
}
|
||||||
|
|
||||||
|
MLX_MTL_PRAGMA_UNROLL
|
||||||
|
for (int tn = 0; tn < TN; tn++) {
|
||||||
|
result[tn] += v_coeff[tm] * inter[tn];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -45,7 +45,9 @@ struct ThreadSort {
|
|||||||
for (short j = i & 1; j < N_PER_THREAD - 1; j += 2) {
|
for (short j = i & 1; j < N_PER_THREAD - 1; j += 2) {
|
||||||
if (op(vals[j + 1], vals[j])) {
|
if (op(vals[j + 1], vals[j])) {
|
||||||
thread_swap(vals[j + 1], vals[j]);
|
thread_swap(vals[j + 1], vals[j]);
|
||||||
thread_swap(idxs[j + 1], idxs[j]);
|
if (ARG_SORT) {
|
||||||
|
thread_swap(idxs[j + 1], idxs[j]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -111,7 +113,9 @@ struct BlockMergeSort {
|
|||||||
bool pred = (b_idx < B_sz) && (a_idx >= A_sz || op(b, a));
|
bool pred = (b_idx < B_sz) && (a_idx >= A_sz || op(b, a));
|
||||||
|
|
||||||
vals[i] = pred ? b : a;
|
vals[i] = pred ? b : a;
|
||||||
idxs[i] = pred ? Bs_idx[b_idx] : As_idx[a_idx];
|
if (ARG_SORT) {
|
||||||
|
idxs[i] = pred ? Bs_idx[b_idx] : As_idx[a_idx];
|
||||||
|
}
|
||||||
|
|
||||||
b_idx += short(pred);
|
b_idx += short(pred);
|
||||||
a_idx += short(!pred);
|
a_idx += short(!pred);
|
||||||
|
|||||||
Reference in New Issue
Block a user