mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-18 10:26:56 +08:00
Add mx.finfo
and use it when making causal mask (#1726)
* finfo * fixes * docs
This commit is contained in:
61
mlx/types/limits.h
Normal file
61
mlx/types/limits.h
Normal file
@@ -0,0 +1,61 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
#pragma once
|
||||
|
||||
#include <limits>
|
||||
#include "mlx/types/half_types.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
template <typename T>
|
||||
struct numeric_limits;
|
||||
|
||||
template <>
|
||||
struct numeric_limits<float> : public std::numeric_limits<float> {};
|
||||
|
||||
template <>
|
||||
struct numeric_limits<float16_t> {
|
||||
private:
|
||||
union half_or_bits {
|
||||
uint16_t bits;
|
||||
float16_t value;
|
||||
};
|
||||
constexpr static float16_t bits_to_half(uint16_t v) {
|
||||
return half_or_bits{v}.value;
|
||||
}
|
||||
|
||||
public:
|
||||
constexpr static float16_t lowest() {
|
||||
return bits_to_half(0xFBFF);
|
||||
}
|
||||
static constexpr float16_t max() {
|
||||
return bits_to_half(0x7BFF);
|
||||
}
|
||||
static constexpr float16_t infinity() {
|
||||
return bits_to_half(0x7C00);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct numeric_limits<bfloat16_t> {
|
||||
private:
|
||||
union bfloat_or_bits {
|
||||
uint16_t bits;
|
||||
bfloat16_t value;
|
||||
};
|
||||
constexpr static bfloat16_t bits_to_bfloat(uint16_t v) {
|
||||
return bfloat_or_bits{v}.value;
|
||||
}
|
||||
|
||||
public:
|
||||
constexpr static bfloat16_t lowest() {
|
||||
return bits_to_bfloat(0xFF7F);
|
||||
}
|
||||
static constexpr bfloat16_t max() {
|
||||
return bits_to_bfloat(0x7F7F);
|
||||
}
|
||||
static constexpr bfloat16_t infinity() {
|
||||
return bits_to_bfloat(0x7F80);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace mlx::core
|
@@ -4,6 +4,7 @@
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
|
||||
#include "mlx/types/limits.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
@@ -326,4 +327,28 @@ int get_var(const char* name, int default_value) {
|
||||
|
||||
} // namespace env
|
||||
|
||||
template <typename T>
|
||||
void set_finfo_limits(float& min, float& max) {
|
||||
min = numeric_limits<T>::lowest();
|
||||
max = numeric_limits<T>::max();
|
||||
}
|
||||
|
||||
finfo::finfo(Dtype dtype) : dtype(dtype) {
|
||||
if (!issubdtype(dtype, inexact)) {
|
||||
std::ostringstream msg;
|
||||
msg << "[finfo] dtype " << dtype << " is not inexact.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (dtype == float32) {
|
||||
set_finfo_limits<float>(min, max);
|
||||
} else if (dtype == float16) {
|
||||
set_finfo_limits<float16_t>(min, max);
|
||||
} else if (dtype == bfloat16) {
|
||||
set_finfo_limits<bfloat16_t>(min, max);
|
||||
} else if (dtype == complex64) {
|
||||
this->dtype = float32;
|
||||
set_finfo_limits<float>(min, max);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -53,6 +53,14 @@ struct PrintFormatter {
|
||||
|
||||
PrintFormatter& get_global_formatter();
|
||||
|
||||
/** Holds information about floating-point types. */
|
||||
struct finfo {
|
||||
explicit finfo(Dtype dtype);
|
||||
Dtype dtype;
|
||||
float min;
|
||||
float max;
|
||||
};
|
||||
|
||||
/** The type from promoting the arrays' types with one another. */
|
||||
inline Dtype result_type(const array& a, const array& b) {
|
||||
return promote_types(a.dtype(), b.dtype());
|
||||
|
Reference in New Issue
Block a user