mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-03 22:34:43 +08:00
make behaviour of dtype arguments consistent and compliant to numpy (#379)
All functions that take an optional dtype should * have a default dtype visible in the generated docs (accomplished via `"dtype"_a = std::optional{float32}`) * behave identical when `dtype=None` or no dtype is passed This important when passing kw args down from a numpy function like: ``` def f(x, dtype=None): mx.random.uniform(dtype=dtype) # ... ``` NumPy functions behave like this. It also fixes a minor bug in `tri`: #378 Closes #378
This commit is contained in:
@@ -325,6 +325,8 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
for shape in [[4], [4, 4], [2, 10]]:
|
||||
for diag in [-1, 0, 1, -2]:
|
||||
self.assertCmpNumpy(shape, mx.tri, np.tri, k=diag)
|
||||
self.assertEqual(mx.tri(1, 1).dtype, mx.float32)
|
||||
self.assertEqual(mx.tri(1, 1, dtype=mx.bfloat16).dtype, mx.bfloat16)
|
||||
|
||||
def test_tril(self):
|
||||
for diag in [-1, 0, 1, -2]:
|
||||
|
@@ -61,6 +61,8 @@ class TestRandom(mlx_tests.MLXTestCase):
|
||||
a = mx.random.uniform(low=-0.1, high=0.1, shape=(1,), dtype=mx.bfloat16)
|
||||
self.assertEqual(a.dtype, mx.bfloat16)
|
||||
|
||||
self.assertEqual(mx.random.uniform().dtype, mx.random.uniform(dtype=None).dtype)
|
||||
|
||||
def test_normal(self):
|
||||
key = mx.random.key(0)
|
||||
a = mx.random.normal(key=key)
|
||||
@@ -78,6 +80,8 @@ class TestRandom(mlx_tests.MLXTestCase):
|
||||
a = mx.random.normal(dtype=t)
|
||||
self.assertEqual(a.dtype, t)
|
||||
|
||||
self.assertEqual(mx.random.normal().dtype, mx.random.normal(dtype=None).dtype)
|
||||
|
||||
def test_randint(self):
|
||||
a = mx.random.randint(0, 1, [])
|
||||
self.assertEqual(a.shape, [])
|
||||
@@ -109,6 +113,10 @@ class TestRandom(mlx_tests.MLXTestCase):
|
||||
a = mx.random.randint(10, -10, [1000, 1000])
|
||||
self.assertTrue(mx.all(a == 10).item())
|
||||
|
||||
self.assertEqual(
|
||||
mx.random.randint(0, 1).dtype, mx.random.randint(0, 1, dtype=None).dtype
|
||||
)
|
||||
|
||||
def test_bernoulli(self):
|
||||
a = mx.random.bernoulli()
|
||||
self.assertEqual(a.shape, [])
|
||||
@@ -159,6 +167,11 @@ class TestRandom(mlx_tests.MLXTestCase):
|
||||
with self.assertRaises(ValueError):
|
||||
mx.random.truncated_normal(lower, higher) # Bad shape
|
||||
|
||||
self.assertEqual(
|
||||
mx.random.truncated_normal(0, 1).dtype,
|
||||
mx.random.truncated_normal(0, 1, dtype=None).dtype,
|
||||
)
|
||||
|
||||
def test_gumbel(self):
|
||||
samples = mx.random.gumbel(shape=(100, 100))
|
||||
self.assertEqual(samples.shape, [100, 100])
|
||||
@@ -168,6 +181,10 @@ class TestRandom(mlx_tests.MLXTestCase):
|
||||
# so this test is pretty conservative
|
||||
self.assertTrue(mx.abs(mx.mean(samples) - mean) < 0.2)
|
||||
|
||||
self.assertEqual(
|
||||
mx.random.gumbel((1, 1)).dtype, mx.random.gumbel((1, 1), dtype=None).dtype
|
||||
)
|
||||
|
||||
def test_categorical(self):
|
||||
logits = mx.zeros((10, 20))
|
||||
self.assertEqual(mx.random.categorical(logits, -1).shape, [10])
|
||||
|
Reference in New Issue
Block a user