Fix dtype_to_cuda_type

This commit is contained in:
Angelos Katharopoulos
2025-06-29 17:35:29 -07:00
parent 4aca09339d
commit 48fa6ae1dd
2 changed files with 33 additions and 10 deletions

View File

@@ -25,16 +25,38 @@ void check_cuda_error(const char* name, cudaError_t err) {
} }
const char* dtype_to_cuda_type(const Dtype& dtype) { const char* dtype_to_cuda_type(const Dtype& dtype) {
if (dtype == float16) { switch (dtype) {
return "__half"; case bool_:
return "bool";
case int8:
return "int8_t";
case int16:
return "int16_t";
case int32:
return "int32_t";
case int64:
return "int64_t";
case uint8:
return "uint8_t";
case uint16:
return "uint16_t";
case uint32:
return "uint32_t";
case uint64:
return "uint64_t";
case float16:
return "__half";
case bfloat16:
return "__nv_bfloat16";
case float32:
return "float";
case float64:
return "double";
case complex64:
return "cuComplex";
default:
return "unknown";
} }
if (dtype == bfloat16) {
return "__nv_bfloat16";
}
if (dtype == complex64) {
return "cuComplex";
}
return dtype_to_string(dtype);
} }
} // namespace mlx::core } // namespace mlx::core

View File

@@ -34,8 +34,9 @@ const char* dtype_to_string(Dtype arg) {
return "float64"; return "float64";
case complex64: case complex64:
return "complex64"; return "complex64";
default:
return "unknown";
} }
return "unknown";
} }
} // namespace mlx::core } // namespace mlx::core