From 48fa6ae1dd3d7a38cf6aeb3bbf81c172192754a5 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Sun, 29 Jun 2025 17:35:29 -0700 Subject: [PATCH] Fix dtype_to_cuda_type --- mlx/backend/cuda/utils.cpp | 40 +++++++++++++++++++++++++++++--------- mlx/dtype_utils.cpp | 3 ++- 2 files changed, 33 insertions(+), 10 deletions(-) diff --git a/mlx/backend/cuda/utils.cpp b/mlx/backend/cuda/utils.cpp index 34f22ba7f..35731f6eb 100644 --- a/mlx/backend/cuda/utils.cpp +++ b/mlx/backend/cuda/utils.cpp @@ -25,16 +25,38 @@ void check_cuda_error(const char* name, cudaError_t err) { } const char* dtype_to_cuda_type(const Dtype& dtype) { - if (dtype == float16) { - return "__half"; + switch (dtype) { + case bool_: + return "bool"; + case int8: + return "int8_t"; + case int16: + return "int16_t"; + case int32: + return "int32_t"; + case int64: + return "int64_t"; + case uint8: + return "uint8_t"; + case uint16: + return "uint16_t"; + case uint32: + return "uint32_t"; + case uint64: + return "uint64_t"; + case float16: + return "__half"; + case bfloat16: + return "__nv_bfloat16"; + case float32: + return "float"; + case float64: + return "double"; + case complex64: + return "cuComplex"; + default: + return "unknown"; } - if (dtype == bfloat16) { - return "__nv_bfloat16"; - } - if (dtype == complex64) { - return "cuComplex"; - } - return dtype_to_string(dtype); } } // namespace mlx::core diff --git a/mlx/dtype_utils.cpp b/mlx/dtype_utils.cpp index 270949ad6..9f10e6a9a 100644 --- a/mlx/dtype_utils.cpp +++ b/mlx/dtype_utils.cpp @@ -34,8 +34,9 @@ const char* dtype_to_string(Dtype arg) { return "float64"; case complex64: return "complex64"; + default: + return "unknown"; } - return "unknown"; } } // namespace mlx::core