Compare commits

..

1 Commits

Author SHA1 Message Date
Eric Buehler
777bce6e25
Merge 4d68bd3250 into a6d780154f 2025-06-14 08:14:58 +02:00
3 changed files with 6 additions and 12 deletions

View File

@ -1,4 +1,5 @@
// Copyright © 2025 Apple Inc. // Copyright © 2025 Apple Inc.
#include "mlx/backend/common/utils.h" #include "mlx/backend/common/utils.h"
#include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/iterators/strided_iterator.cuh" #include "mlx/backend/cuda/iterators/strided_iterator.cuh"
@ -112,7 +113,7 @@ __global__ void arg_reduce_general(
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) { for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
T vals[N_READS]; T vals[N_READS];
auto tid = r * BLOCK_DIM + block.thread_index().x; auto tid = r * BLOCK_DIM + block.thread_index().z;
cub::LoadDirectBlocked( cub::LoadDirectBlocked(
tid, strided_iterator(in + in_idx, axis_stride), vals, axis_size, init); tid, strided_iterator(in + in_idx, axis_stride), vals, axis_size, init);
best = op.reduce_many(best, vals, tid * N_READS); best = op.reduce_many(best, vals, tid * N_READS);
@ -157,7 +158,7 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
constexpr uint32_t N_READS = 4; constexpr uint32_t N_READS = 4;
MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, { MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, {
dim3 num_blocks = get_2d_grid_dims(out.shape(), out.strides()); dim3 num_blocks = get_2d_grid_dims(out.shape(), out.strides());
dim3 block_dims{BLOCK_DIM, 1, 1}; dim3 block_dims{1, 1, BLOCK_DIM};
auto kernel = &cu::arg_reduce_general< auto kernel = &cu::arg_reduce_general<
InType, InType,
cu::ArgMax<InType>, cu::ArgMax<InType>,

View File

@ -5,7 +5,6 @@
#include "mlx/backend/gpu/copy.h" #include "mlx/backend/gpu/copy.h"
#include "mlx/dtype_utils.h" #include "mlx/dtype_utils.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
#include "mlx/utils.h"
#include <cublasLt.h> #include <cublasLt.h>
#include <fmt/format.h> #include <fmt/format.h>
@ -46,7 +45,7 @@ class MatMul {
heuristic_.state = CUBLAS_STATUS_NOT_INITIALIZED; heuristic_.state = CUBLAS_STATUS_NOT_INITIALIZED;
auto scale_type = dtype_to_cuda_type(dtype); auto scale_type = dtype_to_cuda_type(dtype);
if (dtype == bfloat16 || dtype == float16) { if (dtype == bfloat16) {
scale_type = CUDA_R_32F; scale_type = CUDA_R_32F;
} }
CHECK_CUBLAS_ERROR(cublasLtMatmulDescCreate( CHECK_CUBLAS_ERROR(cublasLtMatmulDescCreate(
@ -193,12 +192,11 @@ class MatMul {
cublasComputeType_t dtype_to_compute_type(Dtype dtype) { cublasComputeType_t dtype_to_compute_type(Dtype dtype) {
switch (dtype) { switch (dtype) {
case float16: case float16:
return CUBLAS_COMPUTE_32F; return CUBLAS_COMPUTE_16F;
case bfloat16: case bfloat16:
return CUBLAS_COMPUTE_32F; return CUBLAS_COMPUTE_32F;
case float32: case float32:
return mlx::core::env::enable_tf32() ? CUBLAS_COMPUTE_32F_FAST_TF32 return CUBLAS_COMPUTE_32F;
: CUBLAS_COMPUTE_32F;
case float64: case float64:
case complex64: case complex64:
return CUBLAS_COMPUTE_64F; return CUBLAS_COMPUTE_64F;

View File

@ -149,11 +149,6 @@ inline bool metal_fast_synch() {
return metal_fast_synch; return metal_fast_synch;
} }
inline bool enable_tf32() {
static bool enable_tf32_ = get_var("MLX_ENABLE_TF32", 1);
return enable_tf32_;
}
} // namespace env } // namespace env
} // namespace mlx::core } // namespace mlx::core