mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Add mx.finfo
and use it when making causal mask (#1726)
* finfo * fixes * docs
This commit is contained in:
parent
e03f0372b1
commit
c3628eea49
@ -66,3 +66,4 @@ documentation for more information. Use :func:`issubdtype` to determine if one
|
|||||||
Dtype
|
Dtype
|
||||||
DtypeCategory
|
DtypeCategory
|
||||||
issubdtype
|
issubdtype
|
||||||
|
finfo
|
||||||
|
61
mlx/types/limits.h
Normal file
61
mlx/types/limits.h
Normal file
@ -0,0 +1,61 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <limits>
|
||||||
|
#include "mlx/types/half_types.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct numeric_limits;
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct numeric_limits<float> : public std::numeric_limits<float> {};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct numeric_limits<float16_t> {
|
||||||
|
private:
|
||||||
|
union half_or_bits {
|
||||||
|
uint16_t bits;
|
||||||
|
float16_t value;
|
||||||
|
};
|
||||||
|
constexpr static float16_t bits_to_half(uint16_t v) {
|
||||||
|
return half_or_bits{v}.value;
|
||||||
|
}
|
||||||
|
|
||||||
|
public:
|
||||||
|
constexpr static float16_t lowest() {
|
||||||
|
return bits_to_half(0xFBFF);
|
||||||
|
}
|
||||||
|
static constexpr float16_t max() {
|
||||||
|
return bits_to_half(0x7BFF);
|
||||||
|
}
|
||||||
|
static constexpr float16_t infinity() {
|
||||||
|
return bits_to_half(0x7C00);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct numeric_limits<bfloat16_t> {
|
||||||
|
private:
|
||||||
|
union bfloat_or_bits {
|
||||||
|
uint16_t bits;
|
||||||
|
bfloat16_t value;
|
||||||
|
};
|
||||||
|
constexpr static bfloat16_t bits_to_bfloat(uint16_t v) {
|
||||||
|
return bfloat_or_bits{v}.value;
|
||||||
|
}
|
||||||
|
|
||||||
|
public:
|
||||||
|
constexpr static bfloat16_t lowest() {
|
||||||
|
return bits_to_bfloat(0xFF7F);
|
||||||
|
}
|
||||||
|
static constexpr bfloat16_t max() {
|
||||||
|
return bits_to_bfloat(0x7F7F);
|
||||||
|
}
|
||||||
|
static constexpr bfloat16_t infinity() {
|
||||||
|
return bits_to_bfloat(0x7F80);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
@ -4,6 +4,7 @@
|
|||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "mlx/types/limits.h"
|
||||||
#include "mlx/utils.h"
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
@ -326,4 +327,28 @@ int get_var(const char* name, int default_value) {
|
|||||||
|
|
||||||
} // namespace env
|
} // namespace env
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void set_finfo_limits(float& min, float& max) {
|
||||||
|
min = numeric_limits<T>::lowest();
|
||||||
|
max = numeric_limits<T>::max();
|
||||||
|
}
|
||||||
|
|
||||||
|
finfo::finfo(Dtype dtype) : dtype(dtype) {
|
||||||
|
if (!issubdtype(dtype, inexact)) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[finfo] dtype " << dtype << " is not inexact.";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
if (dtype == float32) {
|
||||||
|
set_finfo_limits<float>(min, max);
|
||||||
|
} else if (dtype == float16) {
|
||||||
|
set_finfo_limits<float16_t>(min, max);
|
||||||
|
} else if (dtype == bfloat16) {
|
||||||
|
set_finfo_limits<bfloat16_t>(min, max);
|
||||||
|
} else if (dtype == complex64) {
|
||||||
|
this->dtype = float32;
|
||||||
|
set_finfo_limits<float>(min, max);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -53,6 +53,14 @@ struct PrintFormatter {
|
|||||||
|
|
||||||
PrintFormatter& get_global_formatter();
|
PrintFormatter& get_global_formatter();
|
||||||
|
|
||||||
|
/** Holds information about floating-point types. */
|
||||||
|
struct finfo {
|
||||||
|
explicit finfo(Dtype dtype);
|
||||||
|
Dtype dtype;
|
||||||
|
float min;
|
||||||
|
float max;
|
||||||
|
};
|
||||||
|
|
||||||
/** The type from promoting the arrays' types with one another. */
|
/** The type from promoting the arrays' types with one another. */
|
||||||
inline Dtype result_type(const array& a, const array& b) {
|
inline Dtype result_type(const array& a, const array& b) {
|
||||||
return promote_types(a.dtype(), b.dtype());
|
return promote_types(a.dtype(), b.dtype());
|
||||||
|
@ -102,9 +102,7 @@ class MultiHeadAttention(Module):
|
|||||||
def create_additive_causal_mask(N: int, dtype: mx.Dtype = mx.float32):
|
def create_additive_causal_mask(N: int, dtype: mx.Dtype = mx.float32):
|
||||||
indices = mx.arange(N)
|
indices = mx.arange(N)
|
||||||
mask = indices[:, None] < indices[None]
|
mask = indices[:, None] < indices[None]
|
||||||
# usually inf but 1e9 is as good and softmax(full(1e9)) != nan
|
mask = mask.astype(dtype) * mx.finfo(dtype).min
|
||||||
# TODO: Should replace this with finfo(dtype).min
|
|
||||||
mask = mask.astype(dtype) * -1e9
|
|
||||||
return mask
|
return mask
|
||||||
|
|
||||||
|
|
||||||
|
@ -179,6 +179,31 @@ void init_array(nb::module_& m) {
|
|||||||
.value("number", mx::number)
|
.value("number", mx::number)
|
||||||
.value("generic", mx::generic)
|
.value("generic", mx::generic)
|
||||||
.export_values();
|
.export_values();
|
||||||
|
|
||||||
|
nb::class_<mx::finfo>(
|
||||||
|
m,
|
||||||
|
"finfo",
|
||||||
|
R"pbdoc(
|
||||||
|
Get information on floating-point types.
|
||||||
|
)pbdoc")
|
||||||
|
.def(nb::init<mx::Dtype>())
|
||||||
|
.def_ro(
|
||||||
|
"min",
|
||||||
|
&mx::finfo::min,
|
||||||
|
R"pbdoc(The smallest representable number.)pbdoc")
|
||||||
|
.def_ro(
|
||||||
|
"max",
|
||||||
|
&mx::finfo::max,
|
||||||
|
R"pbdoc(The largest representable number.)pbdoc")
|
||||||
|
.def_ro("dtype", &mx::finfo::dtype, R"pbdoc(The :obj:`Dtype`.)pbdoc")
|
||||||
|
.def("__repr__", [](const mx::finfo& f) {
|
||||||
|
std::ostringstream os;
|
||||||
|
os << "finfo("
|
||||||
|
<< "min=" << f.min << ", max=" << f.max << ", dtype=" << f.dtype
|
||||||
|
<< ")";
|
||||||
|
return os.str();
|
||||||
|
});
|
||||||
|
|
||||||
nb::class_<ArrayAt>(
|
nb::class_<ArrayAt>(
|
||||||
m,
|
m,
|
||||||
"ArrayAt",
|
"ArrayAt",
|
||||||
|
@ -97,6 +97,18 @@ class TestDtypes(mlx_tests.MLXTestCase):
|
|||||||
self.assertListEqual(list(z.shape), list(x.shape))
|
self.assertListEqual(list(z.shape), list(x.shape))
|
||||||
self.assertListEqual(list(z.shape), list(y.shape))
|
self.assertListEqual(list(z.shape), list(y.shape))
|
||||||
|
|
||||||
|
def test_finfo(self):
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
mx.finfo(mx.int32)
|
||||||
|
|
||||||
|
self.assertEqual(mx.finfo(mx.float32).min, np.finfo(np.float32).min)
|
||||||
|
self.assertEqual(mx.finfo(mx.float32).max, np.finfo(np.float32).max)
|
||||||
|
self.assertEqual(mx.finfo(mx.float32).dtype, mx.float32)
|
||||||
|
|
||||||
|
self.assertEqual(mx.finfo(mx.float16).min, np.finfo(np.float16).min)
|
||||||
|
self.assertEqual(mx.finfo(mx.float16).max, np.finfo(np.float16).max)
|
||||||
|
self.assertEqual(mx.finfo(mx.float16).dtype, mx.float16)
|
||||||
|
|
||||||
|
|
||||||
class TestEquality(mlx_tests.MLXTestCase):
|
class TestEquality(mlx_tests.MLXTestCase):
|
||||||
def test_array_eq_array(self):
|
def test_array_eq_array(self):
|
||||||
|
@ -1826,6 +1826,15 @@ class TestLayers(mlx_tests.MLXTestCase):
|
|||||||
|
|
||||||
self.assertGreater(cosine(y, yq).min(), 0.99)
|
self.assertGreater(cosine(y, yq).min(), 0.99)
|
||||||
|
|
||||||
|
def test_causal_mask(self):
|
||||||
|
mask = nn.MultiHeadAttention.create_additive_causal_mask(4, mx.float16)
|
||||||
|
self.assertFalse(mx.any(mx.isnan(mask)))
|
||||||
|
self.assertTrue(mask[0, -1].item() < 0)
|
||||||
|
|
||||||
|
mask = nn.MultiHeadAttention.create_additive_causal_mask(4, mx.bfloat16)
|
||||||
|
self.assertFalse(mx.any(mx.isnan(mask)))
|
||||||
|
self.assertTrue(mask[0, -1].item() < 0)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
@ -43,3 +43,15 @@ TEST_CASE("test normalize axis") {
|
|||||||
CHECK_THROWS(normalize_axis_index(3, 3));
|
CHECK_THROWS(normalize_axis_index(3, 3));
|
||||||
CHECK_THROWS(normalize_axis_index(-4, 3));
|
CHECK_THROWS(normalize_axis_index(-4, 3));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test finfo") {
|
||||||
|
CHECK_EQ(finfo(float32).dtype, float32);
|
||||||
|
CHECK_EQ(finfo(complex64).dtype, float32);
|
||||||
|
CHECK_EQ(finfo(float16).dtype, float16);
|
||||||
|
CHECK_EQ(finfo(float32).min, std::numeric_limits<float>::lowest());
|
||||||
|
CHECK_EQ(finfo(float32).max, std::numeric_limits<float>::max());
|
||||||
|
CHECK_EQ(finfo(complex64).min, std::numeric_limits<float>::lowest());
|
||||||
|
CHECK_EQ(finfo(complex64).max, std::numeric_limits<float>::max());
|
||||||
|
CHECK_EQ(finfo(float16).min, -65504);
|
||||||
|
CHECK_EQ(finfo(float16).max, 65504);
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user