mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Fp64 on the CPU (#1843)
* add fp64 data type * clean build * update docs * fix bug
This commit is contained in:
@@ -54,6 +54,9 @@ inline void PrintFormatter::print(std::ostream& os, bfloat16_t val) {
|
||||
inline void PrintFormatter::print(std::ostream& os, float val) {
|
||||
os << val;
|
||||
}
|
||||
inline void PrintFormatter::print(std::ostream& os, double val) {
|
||||
os << val;
|
||||
}
|
||||
inline void PrintFormatter::print(std::ostream& os, complex64_t val) {
|
||||
os << val;
|
||||
}
|
||||
@@ -234,6 +237,8 @@ std::ostream& operator<<(std::ostream& os, const Dtype& dtype) {
|
||||
return os << "float16";
|
||||
case float32:
|
||||
return os << "float32";
|
||||
case float64:
|
||||
return os << "float64";
|
||||
case bfloat16:
|
||||
return os << "bfloat16";
|
||||
case complex64:
|
||||
@@ -299,6 +304,9 @@ std::ostream& operator<<(std::ostream& os, array a) {
|
||||
case float32:
|
||||
print_array<float>(os, a);
|
||||
break;
|
||||
case float64:
|
||||
print_array<double>(os, a);
|
||||
break;
|
||||
case complex64:
|
||||
print_array<complex64_t>(os, a);
|
||||
break;
|
||||
@@ -337,7 +345,7 @@ int get_var(const char* name, int default_value) {
|
||||
} // namespace env
|
||||
|
||||
template <typename T>
|
||||
void set_finfo_limits(float& min, float& max) {
|
||||
void set_finfo_limits(double& min, double& max) {
|
||||
min = numeric_limits<T>::lowest();
|
||||
max = numeric_limits<T>::max();
|
||||
}
|
||||
@@ -354,6 +362,8 @@ finfo::finfo(Dtype dtype) : dtype(dtype) {
|
||||
set_finfo_limits<float16_t>(min, max);
|
||||
} else if (dtype == bfloat16) {
|
||||
set_finfo_limits<bfloat16_t>(min, max);
|
||||
} else if (dtype == float64) {
|
||||
set_finfo_limits<double>(min, max);
|
||||
} else if (dtype == complex64) {
|
||||
this->dtype = float32;
|
||||
set_finfo_limits<float>(min, max);
|
||||
|
||||
Reference in New Issue
Block a user