diff --git a/mlx/backend/cuda/utils.cpp b/mlx/backend/cuda/utils.cpp index 34f22ba7f..35731f6eb 100644 --- a/mlx/backend/cuda/utils.cpp +++ b/mlx/backend/cuda/utils.cpp @@ -25,16 +25,38 @@ void check_cuda_error(const char* name, cudaError_t err) { } const char* dtype_to_cuda_type(const Dtype& dtype) { - if (dtype == float16) { - return "__half"; + switch (dtype) { + case bool_: + return "bool"; + case int8: + return "int8_t"; + case int16: + return "int16_t"; + case int32: + return "int32_t"; + case int64: + return "int64_t"; + case uint8: + return "uint8_t"; + case uint16: + return "uint16_t"; + case uint32: + return "uint32_t"; + case uint64: + return "uint64_t"; + case float16: + return "__half"; + case bfloat16: + return "__nv_bfloat16"; + case float32: + return "float"; + case float64: + return "double"; + case complex64: + return "cuComplex"; + default: + return "unknown"; } - if (dtype == bfloat16) { - return "__nv_bfloat16"; - } - if (dtype == complex64) { - return "cuComplex"; - } - return dtype_to_string(dtype); } } // namespace mlx::core diff --git a/mlx/dtype_utils.cpp b/mlx/dtype_utils.cpp index 270949ad6..9f10e6a9a 100644 --- a/mlx/dtype_utils.cpp +++ b/mlx/dtype_utils.cpp @@ -34,8 +34,9 @@ const char* dtype_to_string(Dtype arg) { return "float64"; case complex64: return "complex64"; + default: + return "unknown"; } - return "unknown"; } } // namespace mlx::core