Add mx.finfo and use it when making causal mask (#1726)

* finfo

* fixes

* docs
This commit is contained in:
Awni Hannun
2024-12-19 14:52:41 -08:00
committed by GitHub
parent e03f0372b1
commit c3628eea49
9 changed files with 154 additions and 3 deletions

View File

@@ -4,6 +4,7 @@
#include <sstream>
#include <vector>
#include "mlx/types/limits.h"
#include "mlx/utils.h"
namespace mlx::core {
@@ -326,4 +327,28 @@ int get_var(const char* name, int default_value) {
} // namespace env
template <typename T>
void set_finfo_limits(float& min, float& max) {
min = numeric_limits<T>::lowest();
max = numeric_limits<T>::max();
}
finfo::finfo(Dtype dtype) : dtype(dtype) {
if (!issubdtype(dtype, inexact)) {
std::ostringstream msg;
msg << "[finfo] dtype " << dtype << " is not inexact.";
throw std::invalid_argument(msg.str());
}
if (dtype == float32) {
set_finfo_limits<float>(min, max);
} else if (dtype == float16) {
set_finfo_limits<float16_t>(min, max);
} else if (dtype == bfloat16) {
set_finfo_limits<bfloat16_t>(min, max);
} else if (dtype == complex64) {
this->dtype = float32;
set_finfo_limits<float>(min, max);
}
}
} // namespace mlx::core