Compare commits

..

1 Commits

Author SHA1 Message Date
Awni Hannun
c5be966863
Merge 688e421184 into 6871e2eeb7 2025-06-13 21:55:29 -07:00
3 changed files with 23 additions and 18 deletions

View File

@ -1,4 +1,5 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/iterators/strided_iterator.cuh"
@ -112,7 +113,7 @@ __global__ void arg_reduce_general(
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
T vals[N_READS];
auto tid = r * BLOCK_DIM + block.thread_index().x;
auto tid = r * BLOCK_DIM + block.thread_index().z;
cub::LoadDirectBlocked(
tid, strided_iterator(in + in_idx, axis_stride), vals, axis_size, init);
best = op.reduce_many(best, vals, tid * N_READS);
@ -157,7 +158,7 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
constexpr uint32_t N_READS = 4;
MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, {
dim3 num_blocks = get_2d_grid_dims(out.shape(), out.strides());
dim3 block_dims{BLOCK_DIM, 1, 1};
dim3 block_dims{1, 1, BLOCK_DIM};
auto kernel = &cu::arg_reduce_general<
InType,
cu::ArgMax<InType>,

View File

@ -5,7 +5,6 @@
#include "mlx/backend/gpu/copy.h"
#include "mlx/dtype_utils.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"
#include <cublasLt.h>
#include <fmt/format.h>
@ -45,12 +44,9 @@ class MatMul {
int64_t b_batch_stride) {
heuristic_.state = CUBLAS_STATUS_NOT_INITIALIZED;
auto scale_type = dtype_to_cuda_type(dtype);
if (dtype == bfloat16 || dtype == float16) {
scale_type = CUDA_R_32F;
}
auto type = dtype_to_cuda_type(dtype);
CHECK_CUBLAS_ERROR(cublasLtMatmulDescCreate(
&matmul_desc_, dtype_to_compute_type(dtype), scale_type));
&matmul_desc_, dtype_to_compute_type(dtype), type));
int32_t pointer_mode = CUBLASLT_POINTER_MODE_HOST;
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
matmul_desc_,
@ -69,7 +65,6 @@ class MatMul {
&op,
sizeof(cublasOperation_t)));
auto type = dtype_to_cuda_type(dtype);
a_desc_ = create_matrix_layout(
type, a_rows, a_cols, a_transposed, lda, batch_count, a_batch_stride);
b_desc_ = create_matrix_layout(
@ -192,13 +187,17 @@ class MatMul {
private:
cublasComputeType_t dtype_to_compute_type(Dtype dtype) {
switch (dtype) {
case uint8:
case uint16:
case int8:
case int16:
case int32:
return CUBLAS_COMPUTE_32I;
case float16:
return CUBLAS_COMPUTE_32F;
case bfloat16:
return CUBLAS_COMPUTE_32F;
return CUBLAS_COMPUTE_16F;
case float32:
return mlx::core::env::enable_tf32() ? CUBLAS_COMPUTE_32F_FAST_TF32
: CUBLAS_COMPUTE_32F;
return CUBLAS_COMPUTE_32F;
case float64:
case complex64:
return CUBLAS_COMPUTE_64F;
@ -210,6 +209,16 @@ class MatMul {
cudaDataType_t dtype_to_cuda_type(Dtype dtype) {
switch (dtype) {
case uint8:
return CUDA_R_8U;
case uint16:
return CUDA_R_16U;
case int8:
return CUDA_R_8I;
case int16:
return CUDA_R_16I;
case int32:
return CUDA_R_32I;
case float16:
return CUDA_R_16F;
case bfloat16:

View File

@ -149,11 +149,6 @@ inline bool metal_fast_synch() {
return metal_fast_synch;
}
inline bool enable_tf32() {
static bool enable_tf32_ = get_var("MLX_ENABLE_TF32", 1);
return enable_tf32_;
}
} // namespace env
} // namespace mlx::core