mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 00:04:41 +08:00
Add mx.finfo
and use it when making causal mask (#1726)
* finfo * fixes * docs
This commit is contained in:
@@ -102,9 +102,7 @@ class MultiHeadAttention(Module):
|
||||
def create_additive_causal_mask(N: int, dtype: mx.Dtype = mx.float32):
|
||||
indices = mx.arange(N)
|
||||
mask = indices[:, None] < indices[None]
|
||||
# usually inf but 1e9 is as good and softmax(full(1e9)) != nan
|
||||
# TODO: Should replace this with finfo(dtype).min
|
||||
mask = mask.astype(dtype) * -1e9
|
||||
mask = mask.astype(dtype) * mx.finfo(dtype).min
|
||||
return mask
|
||||
|
||||
|
||||
|
@@ -179,6 +179,31 @@ void init_array(nb::module_& m) {
|
||||
.value("number", mx::number)
|
||||
.value("generic", mx::generic)
|
||||
.export_values();
|
||||
|
||||
nb::class_<mx::finfo>(
|
||||
m,
|
||||
"finfo",
|
||||
R"pbdoc(
|
||||
Get information on floating-point types.
|
||||
)pbdoc")
|
||||
.def(nb::init<mx::Dtype>())
|
||||
.def_ro(
|
||||
"min",
|
||||
&mx::finfo::min,
|
||||
R"pbdoc(The smallest representable number.)pbdoc")
|
||||
.def_ro(
|
||||
"max",
|
||||
&mx::finfo::max,
|
||||
R"pbdoc(The largest representable number.)pbdoc")
|
||||
.def_ro("dtype", &mx::finfo::dtype, R"pbdoc(The :obj:`Dtype`.)pbdoc")
|
||||
.def("__repr__", [](const mx::finfo& f) {
|
||||
std::ostringstream os;
|
||||
os << "finfo("
|
||||
<< "min=" << f.min << ", max=" << f.max << ", dtype=" << f.dtype
|
||||
<< ")";
|
||||
return os.str();
|
||||
});
|
||||
|
||||
nb::class_<ArrayAt>(
|
||||
m,
|
||||
"ArrayAt",
|
||||
|
@@ -97,6 +97,18 @@ class TestDtypes(mlx_tests.MLXTestCase):
|
||||
self.assertListEqual(list(z.shape), list(x.shape))
|
||||
self.assertListEqual(list(z.shape), list(y.shape))
|
||||
|
||||
def test_finfo(self):
|
||||
with self.assertRaises(ValueError):
|
||||
mx.finfo(mx.int32)
|
||||
|
||||
self.assertEqual(mx.finfo(mx.float32).min, np.finfo(np.float32).min)
|
||||
self.assertEqual(mx.finfo(mx.float32).max, np.finfo(np.float32).max)
|
||||
self.assertEqual(mx.finfo(mx.float32).dtype, mx.float32)
|
||||
|
||||
self.assertEqual(mx.finfo(mx.float16).min, np.finfo(np.float16).min)
|
||||
self.assertEqual(mx.finfo(mx.float16).max, np.finfo(np.float16).max)
|
||||
self.assertEqual(mx.finfo(mx.float16).dtype, mx.float16)
|
||||
|
||||
|
||||
class TestEquality(mlx_tests.MLXTestCase):
|
||||
def test_array_eq_array(self):
|
||||
|
@@ -1826,6 +1826,15 @@ class TestLayers(mlx_tests.MLXTestCase):
|
||||
|
||||
self.assertGreater(cosine(y, yq).min(), 0.99)
|
||||
|
||||
def test_causal_mask(self):
|
||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(4, mx.float16)
|
||||
self.assertFalse(mx.any(mx.isnan(mask)))
|
||||
self.assertTrue(mask[0, -1].item() < 0)
|
||||
|
||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(4, mx.bfloat16)
|
||||
self.assertFalse(mx.any(mx.isnan(mask)))
|
||||
self.assertTrue(mask[0, -1].item() < 0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user