diff --git a/mlx/backend/cuda/arg_reduce.cu b/mlx/backend/cuda/arg_reduce.cu index 7dbd91e46..c8a5a962a 100644 --- a/mlx/backend/cuda/arg_reduce.cu +++ b/mlx/backend/cuda/arg_reduce.cu @@ -1,5 +1,4 @@ // Copyright © 2025 Apple Inc. - #include "mlx/backend/common/utils.h" #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/iterators/strided_iterator.cuh" @@ -113,7 +112,7 @@ __global__ void arg_reduce_general( for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) { T vals[N_READS]; - auto tid = r * BLOCK_DIM + block.thread_index().z; + auto tid = r * BLOCK_DIM + block.thread_index().x; cub::LoadDirectBlocked( tid, strided_iterator(in + in_idx, axis_stride), vals, axis_size, init); best = op.reduce_many(best, vals, tid * N_READS); @@ -158,7 +157,7 @@ void ArgReduce::eval_gpu(const std::vector& inputs, array& out) { constexpr uint32_t N_READS = 4; MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, { dim3 num_blocks = get_2d_grid_dims(out.shape(), out.strides()); - dim3 block_dims{1, 1, BLOCK_DIM}; + dim3 block_dims{BLOCK_DIM, 1, 1}; auto kernel = &cu::arg_reduce_general< InType, cu::ArgMax, diff --git a/mlx/backend/cuda/matmul.cpp b/mlx/backend/cuda/matmul.cpp index 9930c75b8..5a5e6182e 100644 --- a/mlx/backend/cuda/matmul.cpp +++ b/mlx/backend/cuda/matmul.cpp @@ -5,6 +5,7 @@ #include "mlx/backend/gpu/copy.h" #include "mlx/dtype_utils.h" #include "mlx/primitives.h" +#include "mlx/utils.h" #include #include @@ -45,7 +46,7 @@ class MatMul { heuristic_.state = CUBLAS_STATUS_NOT_INITIALIZED; auto scale_type = dtype_to_cuda_type(dtype); - if (dtype == bfloat16) { + if (dtype == bfloat16 || dtype == float16) { scale_type = CUDA_R_32F; } CHECK_CUBLAS_ERROR(cublasLtMatmulDescCreate( @@ -192,11 +193,12 @@ class MatMul { cublasComputeType_t dtype_to_compute_type(Dtype dtype) { switch (dtype) { case float16: - return CUBLAS_COMPUTE_16F; + return CUBLAS_COMPUTE_32F; case bfloat16: return CUBLAS_COMPUTE_32F; case float32: - return CUBLAS_COMPUTE_32F; + return mlx::core::env::enable_tf32() ? CUBLAS_COMPUTE_32F_FAST_TF32 + : CUBLAS_COMPUTE_32F; case float64: case complex64: return CUBLAS_COMPUTE_64F; diff --git a/mlx/utils.h b/mlx/utils.h index f0aa7c2de..f16bf0468 100644 --- a/mlx/utils.h +++ b/mlx/utils.h @@ -149,6 +149,11 @@ inline bool metal_fast_synch() { return metal_fast_synch; } +inline bool enable_tf32() { + static bool enable_tf32_ = get_var("MLX_ENABLE_TF32", 1); + return enable_tf32_; +} + } // namespace env } // namespace mlx::core