make behaviour of dtype arguments consistent and compliant to numpy (#379)

All functions that take an optional dtype should

* have a default dtype visible in the generated docs (accomplished via `"dtype"_a = std::optional{float32}`)
* behave identical when `dtype=None` or no dtype is passed

This important when passing kw args down from a numpy function like:

```
def f(x, dtype=None):
  mx.random.uniform(dtype=dtype)
  # ...
```

NumPy functions behave like this.

It also fixes a minor bug in `tri`: #378

Closes #378
This commit is contained in:
Daniel Strobusch
2024-01-05 18:37:46 +01:00
committed by GitHub
parent d8f41a5c0f
commit dfdb284e16
4 changed files with 75 additions and 29 deletions

View File

@@ -325,6 +325,8 @@ class TestOps(mlx_tests.MLXTestCase):
for shape in [[4], [4, 4], [2, 10]]:
for diag in [-1, 0, 1, -2]:
self.assertCmpNumpy(shape, mx.tri, np.tri, k=diag)
self.assertEqual(mx.tri(1, 1).dtype, mx.float32)
self.assertEqual(mx.tri(1, 1, dtype=mx.bfloat16).dtype, mx.bfloat16)
def test_tril(self):
for diag in [-1, 0, 1, -2]: