mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-04 15:04:40 +08:00
make behaviour of dtype arguments consistent and compliant to numpy (#379)
All functions that take an optional dtype should * have a default dtype visible in the generated docs (accomplished via `"dtype"_a = std::optional{float32}`) * behave identical when `dtype=None` or no dtype is passed This important when passing kw args down from a numpy function like: ``` def f(x, dtype=None): mx.random.uniform(dtype=dtype) # ... ``` NumPy functions behave like this. It also fixes a minor bug in `tri`: #378 Closes #378
This commit is contained in:
@@ -325,6 +325,8 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
for shape in [[4], [4, 4], [2, 10]]:
|
||||
for diag in [-1, 0, 1, -2]:
|
||||
self.assertCmpNumpy(shape, mx.tri, np.tri, k=diag)
|
||||
self.assertEqual(mx.tri(1, 1).dtype, mx.float32)
|
||||
self.assertEqual(mx.tri(1, 1, dtype=mx.bfloat16).dtype, mx.bfloat16)
|
||||
|
||||
def test_tril(self):
|
||||
for diag in [-1, 0, 1, -2]:
|
||||
|
Reference in New Issue
Block a user