mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-14 17:12:49 +08:00
make behaviour of dtype arguments consistent and compliant to numpy (#379)
All functions that take an optional dtype should
* have a default dtype visible in the generated docs (accomplished via `"dtype"_a = std::optional{float32}`)
* behave identical when `dtype=None` or no dtype is passed
This important when passing kw args down from a numpy function like:
```
def f(x, dtype=None):
mx.random.uniform(dtype=dtype)
# ...
```
NumPy functions behave like this.
It also fixes a minor bug in `tri`: #378
Closes #378
This commit is contained in:
@@ -325,6 +325,8 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
for shape in [[4], [4, 4], [2, 10]]:
|
||||
for diag in [-1, 0, 1, -2]:
|
||||
self.assertCmpNumpy(shape, mx.tri, np.tri, k=diag)
|
||||
self.assertEqual(mx.tri(1, 1).dtype, mx.float32)
|
||||
self.assertEqual(mx.tri(1, 1, dtype=mx.bfloat16).dtype, mx.bfloat16)
|
||||
|
||||
def test_tril(self):
|
||||
for diag in [-1, 0, 1, -2]:
|
||||
|
||||
Reference in New Issue
Block a user