make behaviour of dtype arguments consistent and compliant to numpy (#379)

All functions that take an optional dtype should

* have a default dtype visible in the generated docs (accomplished via `"dtype"_a = std::optional{float32}`)
* behave identical when `dtype=None` or no dtype is passed

This important when passing kw args down from a numpy function like:

```
def f(x, dtype=None):
  mx.random.uniform(dtype=dtype)
  # ...
```

NumPy functions behave like this.

It also fixes a minor bug in `tri`: #378

Closes #378
This commit is contained in:
Daniel Strobusch
2024-01-05 18:37:46 +01:00
committed by GitHub
parent d8f41a5c0f
commit dfdb284e16
4 changed files with 75 additions and 29 deletions

View File

@@ -325,6 +325,8 @@ class TestOps(mlx_tests.MLXTestCase):
for shape in [[4], [4, 4], [2, 10]]:
for diag in [-1, 0, 1, -2]:
self.assertCmpNumpy(shape, mx.tri, np.tri, k=diag)
self.assertEqual(mx.tri(1, 1).dtype, mx.float32)
self.assertEqual(mx.tri(1, 1, dtype=mx.bfloat16).dtype, mx.bfloat16)
def test_tril(self):
for diag in [-1, 0, 1, -2]:

View File

@@ -61,6 +61,8 @@ class TestRandom(mlx_tests.MLXTestCase):
a = mx.random.uniform(low=-0.1, high=0.1, shape=(1,), dtype=mx.bfloat16)
self.assertEqual(a.dtype, mx.bfloat16)
self.assertEqual(mx.random.uniform().dtype, mx.random.uniform(dtype=None).dtype)
def test_normal(self):
key = mx.random.key(0)
a = mx.random.normal(key=key)
@@ -78,6 +80,8 @@ class TestRandom(mlx_tests.MLXTestCase):
a = mx.random.normal(dtype=t)
self.assertEqual(a.dtype, t)
self.assertEqual(mx.random.normal().dtype, mx.random.normal(dtype=None).dtype)
def test_randint(self):
a = mx.random.randint(0, 1, [])
self.assertEqual(a.shape, [])
@@ -109,6 +113,10 @@ class TestRandom(mlx_tests.MLXTestCase):
a = mx.random.randint(10, -10, [1000, 1000])
self.assertTrue(mx.all(a == 10).item())
self.assertEqual(
mx.random.randint(0, 1).dtype, mx.random.randint(0, 1, dtype=None).dtype
)
def test_bernoulli(self):
a = mx.random.bernoulli()
self.assertEqual(a.shape, [])
@@ -159,6 +167,11 @@ class TestRandom(mlx_tests.MLXTestCase):
with self.assertRaises(ValueError):
mx.random.truncated_normal(lower, higher) # Bad shape
self.assertEqual(
mx.random.truncated_normal(0, 1).dtype,
mx.random.truncated_normal(0, 1, dtype=None).dtype,
)
def test_gumbel(self):
samples = mx.random.gumbel(shape=(100, 100))
self.assertEqual(samples.shape, [100, 100])
@@ -168,6 +181,10 @@ class TestRandom(mlx_tests.MLXTestCase):
# so this test is pretty conservative
self.assertTrue(mx.abs(mx.mean(samples) - mean) < 0.2)
self.assertEqual(
mx.random.gumbel((1, 1)).dtype, mx.random.gumbel((1, 1), dtype=None).dtype
)
def test_categorical(self):
logits = mx.zeros((10, 20))
self.assertEqual(mx.random.categorical(logits, -1).shape, [10])