mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00
Fix cuda arg reduce (#2291)
This commit is contained in:
parent
a6d780154f
commit
a14aaa7c9d
@ -1,5 +1,4 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
#include "mlx/backend/common/utils.h"
|
#include "mlx/backend/common/utils.h"
|
||||||
#include "mlx/backend/cuda/device.h"
|
#include "mlx/backend/cuda/device.h"
|
||||||
#include "mlx/backend/cuda/iterators/strided_iterator.cuh"
|
#include "mlx/backend/cuda/iterators/strided_iterator.cuh"
|
||||||
@ -113,7 +112,7 @@ __global__ void arg_reduce_general(
|
|||||||
|
|
||||||
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||||
T vals[N_READS];
|
T vals[N_READS];
|
||||||
auto tid = r * BLOCK_DIM + block.thread_index().z;
|
auto tid = r * BLOCK_DIM + block.thread_index().x;
|
||||||
cub::LoadDirectBlocked(
|
cub::LoadDirectBlocked(
|
||||||
tid, strided_iterator(in + in_idx, axis_stride), vals, axis_size, init);
|
tid, strided_iterator(in + in_idx, axis_stride), vals, axis_size, init);
|
||||||
best = op.reduce_many(best, vals, tid * N_READS);
|
best = op.reduce_many(best, vals, tid * N_READS);
|
||||||
@ -158,7 +157,7 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
constexpr uint32_t N_READS = 4;
|
constexpr uint32_t N_READS = 4;
|
||||||
MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, {
|
MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, {
|
||||||
dim3 num_blocks = get_2d_grid_dims(out.shape(), out.strides());
|
dim3 num_blocks = get_2d_grid_dims(out.shape(), out.strides());
|
||||||
dim3 block_dims{1, 1, BLOCK_DIM};
|
dim3 block_dims{BLOCK_DIM, 1, 1};
|
||||||
auto kernel = &cu::arg_reduce_general<
|
auto kernel = &cu::arg_reduce_general<
|
||||||
InType,
|
InType,
|
||||||
cu::ArgMax<InType>,
|
cu::ArgMax<InType>,
|
||||||
|
@ -5,6 +5,7 @@
|
|||||||
#include "mlx/backend/gpu/copy.h"
|
#include "mlx/backend/gpu/copy.h"
|
||||||
#include "mlx/dtype_utils.h"
|
#include "mlx/dtype_utils.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
#include <cublasLt.h>
|
#include <cublasLt.h>
|
||||||
#include <fmt/format.h>
|
#include <fmt/format.h>
|
||||||
@ -45,7 +46,7 @@ class MatMul {
|
|||||||
heuristic_.state = CUBLAS_STATUS_NOT_INITIALIZED;
|
heuristic_.state = CUBLAS_STATUS_NOT_INITIALIZED;
|
||||||
|
|
||||||
auto scale_type = dtype_to_cuda_type(dtype);
|
auto scale_type = dtype_to_cuda_type(dtype);
|
||||||
if (dtype == bfloat16) {
|
if (dtype == bfloat16 || dtype == float16) {
|
||||||
scale_type = CUDA_R_32F;
|
scale_type = CUDA_R_32F;
|
||||||
}
|
}
|
||||||
CHECK_CUBLAS_ERROR(cublasLtMatmulDescCreate(
|
CHECK_CUBLAS_ERROR(cublasLtMatmulDescCreate(
|
||||||
@ -192,11 +193,12 @@ class MatMul {
|
|||||||
cublasComputeType_t dtype_to_compute_type(Dtype dtype) {
|
cublasComputeType_t dtype_to_compute_type(Dtype dtype) {
|
||||||
switch (dtype) {
|
switch (dtype) {
|
||||||
case float16:
|
case float16:
|
||||||
return CUBLAS_COMPUTE_16F;
|
return CUBLAS_COMPUTE_32F;
|
||||||
case bfloat16:
|
case bfloat16:
|
||||||
return CUBLAS_COMPUTE_32F;
|
return CUBLAS_COMPUTE_32F;
|
||||||
case float32:
|
case float32:
|
||||||
return CUBLAS_COMPUTE_32F;
|
return mlx::core::env::enable_tf32() ? CUBLAS_COMPUTE_32F_FAST_TF32
|
||||||
|
: CUBLAS_COMPUTE_32F;
|
||||||
case float64:
|
case float64:
|
||||||
case complex64:
|
case complex64:
|
||||||
return CUBLAS_COMPUTE_64F;
|
return CUBLAS_COMPUTE_64F;
|
||||||
|
@ -149,6 +149,11 @@ inline bool metal_fast_synch() {
|
|||||||
return metal_fast_synch;
|
return metal_fast_synch;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline bool enable_tf32() {
|
||||||
|
static bool enable_tf32_ = get_var("MLX_ENABLE_TF32", 1);
|
||||||
|
return enable_tf32_;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace env
|
} // namespace env
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
Loading…
Reference in New Issue
Block a user