mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Introduce macros for dispatching dynamic dtypes as static types (#2073)
This commit is contained in:
110
mlx/utils.cpp
110
mlx/utils.cpp
@@ -5,6 +5,7 @@
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/types/limits.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
@@ -224,37 +225,7 @@ void print_array(std::ostream& os, const array& a) {
|
||||
} // namespace
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, const Dtype& dtype) {
|
||||
switch (dtype) {
|
||||
case bool_:
|
||||
return os << "bool";
|
||||
case uint8:
|
||||
return os << "uint8";
|
||||
case uint16:
|
||||
return os << "uint16";
|
||||
case uint32:
|
||||
return os << "uint32";
|
||||
case uint64:
|
||||
return os << "uint64";
|
||||
case int8:
|
||||
return os << "int8";
|
||||
case int16:
|
||||
return os << "int16";
|
||||
case int32:
|
||||
return os << "int32";
|
||||
case int64:
|
||||
return os << "int64";
|
||||
case float16:
|
||||
return os << "float16";
|
||||
case float32:
|
||||
return os << "float32";
|
||||
case float64:
|
||||
return os << "float64";
|
||||
case bfloat16:
|
||||
return os << "bfloat16";
|
||||
case complex64:
|
||||
return os << "complex64";
|
||||
}
|
||||
return os;
|
||||
return os << dtype_to_string(dtype);
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, const Dtype::Kind& k) {
|
||||
@@ -277,50 +248,7 @@ std::ostream& operator<<(std::ostream& os, const Dtype::Kind& k) {
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, array a) {
|
||||
a.eval();
|
||||
switch (a.dtype()) {
|
||||
case bool_:
|
||||
print_array<bool>(os, a);
|
||||
break;
|
||||
case uint8:
|
||||
print_array<uint8_t>(os, a);
|
||||
break;
|
||||
case uint16:
|
||||
print_array<uint16_t>(os, a);
|
||||
break;
|
||||
case uint32:
|
||||
print_array<uint32_t>(os, a);
|
||||
break;
|
||||
case uint64:
|
||||
print_array<uint64_t>(os, a);
|
||||
break;
|
||||
case int8:
|
||||
print_array<int8_t>(os, a);
|
||||
break;
|
||||
case int16:
|
||||
print_array<int16_t>(os, a);
|
||||
break;
|
||||
case int32:
|
||||
print_array<int32_t>(os, a);
|
||||
break;
|
||||
case int64:
|
||||
print_array<int64_t>(os, a);
|
||||
break;
|
||||
case float16:
|
||||
print_array<float16_t>(os, a);
|
||||
break;
|
||||
case bfloat16:
|
||||
print_array<bfloat16_t>(os, a);
|
||||
break;
|
||||
case float32:
|
||||
print_array<float>(os, a);
|
||||
break;
|
||||
case float64:
|
||||
print_array<double>(os, a);
|
||||
break;
|
||||
case complex64:
|
||||
print_array<complex64_t>(os, a);
|
||||
break;
|
||||
}
|
||||
MLX_SWITCH_ALL_TYPES(a.dtype(), CTYPE, print_array<CTYPE>(os, a));
|
||||
return os;
|
||||
}
|
||||
|
||||
@@ -387,36 +315,8 @@ void set_iinfo_limits(int64_t& min, uint64_t& max) {
|
||||
}
|
||||
|
||||
iinfo::iinfo(Dtype dtype) : dtype(dtype) {
|
||||
switch (dtype) {
|
||||
case int8:
|
||||
set_iinfo_limits<int8_t>(min, max);
|
||||
break;
|
||||
case uint8:
|
||||
set_iinfo_limits<uint8_t>(min, max);
|
||||
break;
|
||||
case int16:
|
||||
set_iinfo_limits<int16_t>(min, max);
|
||||
break;
|
||||
case uint16:
|
||||
set_iinfo_limits<uint16_t>(min, max);
|
||||
break;
|
||||
case int32:
|
||||
set_iinfo_limits<int32_t>(min, max);
|
||||
break;
|
||||
case uint32:
|
||||
set_iinfo_limits<uint32_t>(min, max);
|
||||
break;
|
||||
case int64:
|
||||
set_iinfo_limits<int64_t>(min, max);
|
||||
break;
|
||||
case uint64:
|
||||
set_iinfo_limits<uint64_t>(min, max);
|
||||
break;
|
||||
default:
|
||||
std::ostringstream msg;
|
||||
msg << "[iinfo] dtype " << dtype << " is not integral.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
MLX_SWITCH_INT_TYPES_CHECKED(
|
||||
dtype, "[iinfo]", CTYPE, set_iinfo_limits<CTYPE>(min, max));
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
Reference in New Issue
Block a user