diff --git a/docs/src/python/data_types.rst b/docs/src/python/data_types.rst index 549446447..c75bfcb9d 100644 --- a/docs/src/python/data_types.rst +++ b/docs/src/python/data_types.rst @@ -66,3 +66,4 @@ documentation for more information. Use :func:`issubdtype` to determine if one Dtype DtypeCategory issubdtype + finfo diff --git a/mlx/types/limits.h b/mlx/types/limits.h new file mode 100644 index 000000000..6f2668a5f --- /dev/null +++ b/mlx/types/limits.h @@ -0,0 +1,61 @@ +// Copyright © 2024 Apple Inc. +#pragma once + +#include +#include "mlx/types/half_types.h" + +namespace mlx::core { + +template +struct numeric_limits; + +template <> +struct numeric_limits : public std::numeric_limits {}; + +template <> +struct numeric_limits { + private: + union half_or_bits { + uint16_t bits; + float16_t value; + }; + constexpr static float16_t bits_to_half(uint16_t v) { + return half_or_bits{v}.value; + } + + public: + constexpr static float16_t lowest() { + return bits_to_half(0xFBFF); + } + static constexpr float16_t max() { + return bits_to_half(0x7BFF); + } + static constexpr float16_t infinity() { + return bits_to_half(0x7C00); + } +}; + +template <> +struct numeric_limits { + private: + union bfloat_or_bits { + uint16_t bits; + bfloat16_t value; + }; + constexpr static bfloat16_t bits_to_bfloat(uint16_t v) { + return bfloat_or_bits{v}.value; + } + + public: + constexpr static bfloat16_t lowest() { + return bits_to_bfloat(0xFF7F); + } + static constexpr bfloat16_t max() { + return bits_to_bfloat(0x7F7F); + } + static constexpr bfloat16_t infinity() { + return bits_to_bfloat(0x7F80); + } +}; + +} // namespace mlx::core diff --git a/mlx/utils.cpp b/mlx/utils.cpp index 3471ef566..6a840172f 100644 --- a/mlx/utils.cpp +++ b/mlx/utils.cpp @@ -4,6 +4,7 @@ #include #include +#include "mlx/types/limits.h" #include "mlx/utils.h" namespace mlx::core { @@ -326,4 +327,28 @@ int get_var(const char* name, int default_value) { } // namespace env +template +void set_finfo_limits(float& min, float& max) { + min = numeric_limits::lowest(); + max = numeric_limits::max(); +} + +finfo::finfo(Dtype dtype) : dtype(dtype) { + if (!issubdtype(dtype, inexact)) { + std::ostringstream msg; + msg << "[finfo] dtype " << dtype << " is not inexact."; + throw std::invalid_argument(msg.str()); + } + if (dtype == float32) { + set_finfo_limits(min, max); + } else if (dtype == float16) { + set_finfo_limits(min, max); + } else if (dtype == bfloat16) { + set_finfo_limits(min, max); + } else if (dtype == complex64) { + this->dtype = float32; + set_finfo_limits(min, max); + } +} + } // namespace mlx::core diff --git a/mlx/utils.h b/mlx/utils.h index 730bf0315..28134bd20 100644 --- a/mlx/utils.h +++ b/mlx/utils.h @@ -53,6 +53,14 @@ struct PrintFormatter { PrintFormatter& get_global_formatter(); +/** Holds information about floating-point types. */ +struct finfo { + explicit finfo(Dtype dtype); + Dtype dtype; + float min; + float max; +}; + /** The type from promoting the arrays' types with one another. */ inline Dtype result_type(const array& a, const array& b) { return promote_types(a.dtype(), b.dtype()); diff --git a/python/mlx/nn/layers/transformer.py b/python/mlx/nn/layers/transformer.py index 79e87dc1d..f3df7986c 100644 --- a/python/mlx/nn/layers/transformer.py +++ b/python/mlx/nn/layers/transformer.py @@ -102,9 +102,7 @@ class MultiHeadAttention(Module): def create_additive_causal_mask(N: int, dtype: mx.Dtype = mx.float32): indices = mx.arange(N) mask = indices[:, None] < indices[None] - # usually inf but 1e9 is as good and softmax(full(1e9)) != nan - # TODO: Should replace this with finfo(dtype).min - mask = mask.astype(dtype) * -1e9 + mask = mask.astype(dtype) * mx.finfo(dtype).min return mask diff --git a/python/src/array.cpp b/python/src/array.cpp index f35236ede..ff358cbe4 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -179,6 +179,31 @@ void init_array(nb::module_& m) { .value("number", mx::number) .value("generic", mx::generic) .export_values(); + + nb::class_( + m, + "finfo", + R"pbdoc( + Get information on floating-point types. + )pbdoc") + .def(nb::init()) + .def_ro( + "min", + &mx::finfo::min, + R"pbdoc(The smallest representable number.)pbdoc") + .def_ro( + "max", + &mx::finfo::max, + R"pbdoc(The largest representable number.)pbdoc") + .def_ro("dtype", &mx::finfo::dtype, R"pbdoc(The :obj:`Dtype`.)pbdoc") + .def("__repr__", [](const mx::finfo& f) { + std::ostringstream os; + os << "finfo(" + << "min=" << f.min << ", max=" << f.max << ", dtype=" << f.dtype + << ")"; + return os.str(); + }); + nb::class_( m, "ArrayAt", diff --git a/python/tests/test_array.py b/python/tests/test_array.py index ef3c6dd2e..1f4515b6b 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -97,6 +97,18 @@ class TestDtypes(mlx_tests.MLXTestCase): self.assertListEqual(list(z.shape), list(x.shape)) self.assertListEqual(list(z.shape), list(y.shape)) + def test_finfo(self): + with self.assertRaises(ValueError): + mx.finfo(mx.int32) + + self.assertEqual(mx.finfo(mx.float32).min, np.finfo(np.float32).min) + self.assertEqual(mx.finfo(mx.float32).max, np.finfo(np.float32).max) + self.assertEqual(mx.finfo(mx.float32).dtype, mx.float32) + + self.assertEqual(mx.finfo(mx.float16).min, np.finfo(np.float16).min) + self.assertEqual(mx.finfo(mx.float16).max, np.finfo(np.float16).max) + self.assertEqual(mx.finfo(mx.float16).dtype, mx.float16) + class TestEquality(mlx_tests.MLXTestCase): def test_array_eq_array(self): diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index 445ef3df4..7ca8ba272 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -1826,6 +1826,15 @@ class TestLayers(mlx_tests.MLXTestCase): self.assertGreater(cosine(y, yq).min(), 0.99) + def test_causal_mask(self): + mask = nn.MultiHeadAttention.create_additive_causal_mask(4, mx.float16) + self.assertFalse(mx.any(mx.isnan(mask))) + self.assertTrue(mask[0, -1].item() < 0) + + mask = nn.MultiHeadAttention.create_additive_causal_mask(4, mx.bfloat16) + self.assertFalse(mx.any(mx.isnan(mask))) + self.assertTrue(mask[0, -1].item() < 0) + if __name__ == "__main__": unittest.main() diff --git a/tests/utils_tests.cpp b/tests/utils_tests.cpp index 666a6b749..a17f12e33 100644 --- a/tests/utils_tests.cpp +++ b/tests/utils_tests.cpp @@ -43,3 +43,15 @@ TEST_CASE("test normalize axis") { CHECK_THROWS(normalize_axis_index(3, 3)); CHECK_THROWS(normalize_axis_index(-4, 3)); } + +TEST_CASE("test finfo") { + CHECK_EQ(finfo(float32).dtype, float32); + CHECK_EQ(finfo(complex64).dtype, float32); + CHECK_EQ(finfo(float16).dtype, float16); + CHECK_EQ(finfo(float32).min, std::numeric_limits::lowest()); + CHECK_EQ(finfo(float32).max, std::numeric_limits::max()); + CHECK_EQ(finfo(complex64).min, std::numeric_limits::lowest()); + CHECK_EQ(finfo(complex64).max, std::numeric_limits::max()); + CHECK_EQ(finfo(float16).min, -65504); + CHECK_EQ(finfo(float16).max, 65504); +}