Introduce macros for dispatching dynamic dtypes as static types (#2073)

This commit is contained in:
Cheng
2025-04-19 21:16:30 +08:00
committed by GitHub
parent 5f04c0f818
commit b13f2aed16
4 changed files with 233 additions and 105 deletions

View File

@@ -5,6 +5,7 @@
#include <sstream>
#include <vector>
#include "mlx/dtype_utils.h"
#include "mlx/types/limits.h"
#include "mlx/utils.h"
@@ -224,37 +225,7 @@ void print_array(std::ostream& os, const array& a) {
} // namespace
std::ostream& operator<<(std::ostream& os, const Dtype& dtype) {
switch (dtype) {
case bool_:
return os << "bool";
case uint8:
return os << "uint8";
case uint16:
return os << "uint16";
case uint32:
return os << "uint32";
case uint64:
return os << "uint64";
case int8:
return os << "int8";
case int16:
return os << "int16";
case int32:
return os << "int32";
case int64:
return os << "int64";
case float16:
return os << "float16";
case float32:
return os << "float32";
case float64:
return os << "float64";
case bfloat16:
return os << "bfloat16";
case complex64:
return os << "complex64";
}
return os;
return os << dtype_to_string(dtype);
}
std::ostream& operator<<(std::ostream& os, const Dtype::Kind& k) {
@@ -277,50 +248,7 @@ std::ostream& operator<<(std::ostream& os, const Dtype::Kind& k) {
std::ostream& operator<<(std::ostream& os, array a) {
a.eval();
switch (a.dtype()) {
case bool_:
print_array<bool>(os, a);
break;
case uint8:
print_array<uint8_t>(os, a);
break;
case uint16:
print_array<uint16_t>(os, a);
break;
case uint32:
print_array<uint32_t>(os, a);
break;
case uint64:
print_array<uint64_t>(os, a);
break;
case int8:
print_array<int8_t>(os, a);
break;
case int16:
print_array<int16_t>(os, a);
break;
case int32:
print_array<int32_t>(os, a);
break;
case int64:
print_array<int64_t>(os, a);
break;
case float16:
print_array<float16_t>(os, a);
break;
case bfloat16:
print_array<bfloat16_t>(os, a);
break;
case float32:
print_array<float>(os, a);
break;
case float64:
print_array<double>(os, a);
break;
case complex64:
print_array<complex64_t>(os, a);
break;
}
MLX_SWITCH_ALL_TYPES(a.dtype(), CTYPE, print_array<CTYPE>(os, a));
return os;
}
@@ -387,36 +315,8 @@ void set_iinfo_limits(int64_t& min, uint64_t& max) {
}
iinfo::iinfo(Dtype dtype) : dtype(dtype) {
switch (dtype) {
case int8:
set_iinfo_limits<int8_t>(min, max);
break;
case uint8:
set_iinfo_limits<uint8_t>(min, max);
break;
case int16:
set_iinfo_limits<int16_t>(min, max);
break;
case uint16:
set_iinfo_limits<uint16_t>(min, max);
break;
case int32:
set_iinfo_limits<int32_t>(min, max);
break;
case uint32:
set_iinfo_limits<uint32_t>(min, max);
break;
case int64:
set_iinfo_limits<int64_t>(min, max);
break;
case uint64:
set_iinfo_limits<uint64_t>(min, max);
break;
default:
std::ostringstream msg;
msg << "[iinfo] dtype " << dtype << " is not integral.";
throw std::invalid_argument(msg.str());
}
MLX_SWITCH_INT_TYPES_CHECKED(
dtype, "[iinfo]", CTYPE, set_iinfo_limits<CTYPE>(min, max));
}
} // namespace mlx::core