make behaviour of dtype arguments consistent and compliant to numpy (#379)

All functions that take an optional dtype should

* have a default dtype visible in the generated docs (accomplished via `"dtype"_a = std::optional{float32}`)
* behave identical when `dtype=None` or no dtype is passed

This important when passing kw args down from a numpy function like:

```
def f(x, dtype=None):
  mx.random.uniform(dtype=dtype)
  # ...
```

NumPy functions behave like this.

It also fixes a minor bug in `tri`: #378

Closes #378
This commit is contained in:
Daniel Strobusch 2024-01-05 18:37:46 +01:00 committed by GitHub
parent d8f41a5c0f
commit dfdb284e16
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 75 additions and 29 deletions

View File

@ -1212,14 +1212,22 @@ void init_ops(py::module_& m) {
)pbdoc"); )pbdoc");
m.def( m.def(
"linspace", "linspace",
[](Scalar start, Scalar stop, int num, Dtype dtype, StreamOrDevice s) { [](Scalar start,
Scalar stop,
int num,
std::optional<Dtype> dtype,
StreamOrDevice s) {
return linspace( return linspace(
scalar_to_double(start), scalar_to_double(stop), num, dtype, s); scalar_to_double(start),
scalar_to_double(stop),
num,
dtype.value_or(float32),
s);
}, },
"start"_a, "start"_a,
"stop"_a, "stop"_a,
"num"_a = 50, "num"_a = 50,
"dtype"_a = float32, "dtype"_a = std::optional{float32},
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc(
linspace(start, stop, num: Optional[int] = 50, dtype: Optional[Dtype] = float32, stream: Union[None, Stream, Device] = None) -> array linspace(start, stop, num: Optional[int] = 50, dtype: Optional[Dtype] = float32, stream: Union[None, Stream, Device] = None) -> array
@ -1356,11 +1364,11 @@ void init_ops(py::module_& m) {
} }
}, },
"shape"_a, "shape"_a,
"dtype"_a = std::nullopt, "dtype"_a = std::optional{float32},
py::kw_only(), py::kw_only(),
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc(
zeros(shape: Union[int, List[int]], dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array zeros(shape: Union[int, List[int]], dtype: Optional[Dtype] = float32, *, stream: Union[None, Stream, Device] = None) -> array
Construct an array of zeros. Construct an array of zeros.
@ -1403,11 +1411,11 @@ void init_ops(py::module_& m) {
} }
}, },
"shape"_a, "shape"_a,
"dtype"_a = std::nullopt, "dtype"_a = std::optional{float32},
py::kw_only(), py::kw_only(),
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc(
ones(shape: Union[int, List[int]], dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array ones(shape: Union[int, List[int]], dtype: Optional[Dtype] = float32, *, stream: Union[None, Stream, Device] = None) -> array
Construct an array of ones. Construct an array of ones.
@ -1449,11 +1457,11 @@ void init_ops(py::module_& m) {
"n"_a, "n"_a,
"m"_a = py::none(), "m"_a = py::none(),
"k"_a = 0, "k"_a = 0,
"dtype"_a = std::nullopt, "dtype"_a = std::optional{float32},
py::kw_only(), py::kw_only(),
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc(
eye(n: int, m: Optional[int] = None, k: int = 0, dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array eye(n: int, m: Optional[int] = None, k: int = 0, dtype: Optional[Dtype] = float32, *, stream: Union[None, Stream, Device] = None) -> array
Create an identity matrix or a general diagonal matrix. Create an identity matrix or a general diagonal matrix.
@ -1473,11 +1481,11 @@ void init_ops(py::module_& m) {
return identity(n, dtype.value_or(float32), s); return identity(n, dtype.value_or(float32), s);
}, },
"n"_a, "n"_a,
"dtype"_a = std::nullopt, "dtype"_a = std::optional{float32},
py::kw_only(), py::kw_only(),
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc(
identity(n: int, dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array identity(n: int, dtype: Optional[Dtype] = float32, *, stream: Union[None, Stream, Device] = None) -> array
Create a square identity matrix. Create a square identity matrix.
@ -1491,13 +1499,17 @@ void init_ops(py::module_& m) {
)pbdoc"); )pbdoc");
m.def( m.def(
"tri", "tri",
[](int n, std::optional<int> m, int k, Dtype dtype, StreamOrDevice s) { [](int n,
return tri(n, m.value_or(n), k, float32, s); std::optional<int> m,
int k,
std::optional<Dtype> type,
StreamOrDevice s) {
return tri(n, m.value_or(n), k, type.value_or(float32), s);
}, },
"n"_a, "n"_a,
"m"_a = none, "m"_a = none,
"k"_a = 0, "k"_a = 0,
"dtype"_a = float32, "dtype"_a = std::optional{float32},
py::kw_only(), py::kw_only(),
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc(

View File

@ -61,15 +61,21 @@ void init_random(py::module_& parent_module) {
[](const ScalarOrArray& low, [](const ScalarOrArray& low,
const ScalarOrArray& high, const ScalarOrArray& high,
const std::vector<int>& shape, const std::vector<int>& shape,
Dtype type, std::optional<Dtype> type,
const std::optional<array>& key, const std::optional<array>& key,
StreamOrDevice s) { StreamOrDevice s) {
return uniform(to_array(low), to_array(high), shape, type, key, s); return uniform(
to_array(low),
to_array(high),
shape,
type.value_or(float32),
key,
s);
}, },
"low"_a = 0, "low"_a = 0,
"high"_a = 1, "high"_a = 1,
"shape"_a = std::vector<int>{}, "shape"_a = std::vector<int>{},
"dtype"_a = float32, "dtype"_a = std::optional{float32},
"key"_a = none, "key"_a = none,
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc(
@ -92,12 +98,14 @@ void init_random(py::module_& parent_module) {
m.def( m.def(
"normal", "normal",
[](const std::vector<int>& shape, [](const std::vector<int>& shape,
Dtype type, std::optional<Dtype> type,
const std::optional<array>& key, const std::optional<array>& key,
StreamOrDevice s) { return normal(shape, type, key, s); }, StreamOrDevice s) {
return normal(shape, type.value_or(float32), key, s);
},
"shape"_a = std::vector<int>{}, "shape"_a = std::vector<int>{},
"dtype"_a = float32, "dtype"_a = std::optional{float32},
"key"_a = none, "key"_a = none,
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc(
@ -116,10 +124,11 @@ void init_random(py::module_& parent_module) {
[](const ScalarOrArray& low, [](const ScalarOrArray& low,
const ScalarOrArray& high, const ScalarOrArray& high,
const std::vector<int>& shape, const std::vector<int>& shape,
Dtype type, std::optional<Dtype> type,
const std::optional<array>& key, const std::optional<array>& key,
StreamOrDevice s) { StreamOrDevice s) {
return randint(to_array(low), to_array(high), shape, type, key, s); return randint(
to_array(low), to_array(high), shape, type.value_or(int32), key, s);
}, },
"low"_a, "low"_a,
"high"_a, "high"_a,
@ -183,21 +192,22 @@ void init_random(py::module_& parent_module) {
[](const ScalarOrArray& lower_, [](const ScalarOrArray& lower_,
const ScalarOrArray& upper_, const ScalarOrArray& upper_,
const std::optional<std::vector<int>> shape_, const std::optional<std::vector<int>> shape_,
Dtype dtype, std::optional<Dtype> type,
const std::optional<array>& key, const std::optional<array>& key,
StreamOrDevice s) { StreamOrDevice s) {
auto lower = to_array(lower_); auto lower = to_array(lower_);
auto upper = to_array(upper_); auto upper = to_array(upper_);
auto t = type.value_or(float32);
if (shape_.has_value()) { if (shape_.has_value()) {
return truncated_normal(lower, upper, shape_.value(), dtype, key, s); return truncated_normal(lower, upper, shape_.value(), t, key, s);
} else { } else {
return truncated_normal(lower, upper, dtype, key, s); return truncated_normal(lower, upper, t, key, s);
} }
}, },
"lower"_a, "lower"_a,
"upper"_a, "upper"_a,
"shape"_a = none, "shape"_a = none,
"dtype"_a = float32, "dtype"_a = std::optional{float32},
"key"_a = none, "key"_a = none,
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc(
@ -221,9 +231,14 @@ void init_random(py::module_& parent_module) {
)pbdoc"); )pbdoc");
m.def( m.def(
"gumbel", "gumbel",
&gumbel, [](const std::vector<int>& shape,
std::optional<Dtype> type,
const std::optional<array>& key,
StreamOrDevice s) {
return gumbel(shape, type.value_or(float32), key, s);
},
"shape"_a = std::vector<int>{}, "shape"_a = std::vector<int>{},
"dtype"_a = float32, "dtype"_a = std::optional{float32},
"stream"_a = none, "stream"_a = none,
"key"_a = none, "key"_a = none,
R"pbdoc( R"pbdoc(

View File

@ -325,6 +325,8 @@ class TestOps(mlx_tests.MLXTestCase):
for shape in [[4], [4, 4], [2, 10]]: for shape in [[4], [4, 4], [2, 10]]:
for diag in [-1, 0, 1, -2]: for diag in [-1, 0, 1, -2]:
self.assertCmpNumpy(shape, mx.tri, np.tri, k=diag) self.assertCmpNumpy(shape, mx.tri, np.tri, k=diag)
self.assertEqual(mx.tri(1, 1).dtype, mx.float32)
self.assertEqual(mx.tri(1, 1, dtype=mx.bfloat16).dtype, mx.bfloat16)
def test_tril(self): def test_tril(self):
for diag in [-1, 0, 1, -2]: for diag in [-1, 0, 1, -2]:

View File

@ -61,6 +61,8 @@ class TestRandom(mlx_tests.MLXTestCase):
a = mx.random.uniform(low=-0.1, high=0.1, shape=(1,), dtype=mx.bfloat16) a = mx.random.uniform(low=-0.1, high=0.1, shape=(1,), dtype=mx.bfloat16)
self.assertEqual(a.dtype, mx.bfloat16) self.assertEqual(a.dtype, mx.bfloat16)
self.assertEqual(mx.random.uniform().dtype, mx.random.uniform(dtype=None).dtype)
def test_normal(self): def test_normal(self):
key = mx.random.key(0) key = mx.random.key(0)
a = mx.random.normal(key=key) a = mx.random.normal(key=key)
@ -78,6 +80,8 @@ class TestRandom(mlx_tests.MLXTestCase):
a = mx.random.normal(dtype=t) a = mx.random.normal(dtype=t)
self.assertEqual(a.dtype, t) self.assertEqual(a.dtype, t)
self.assertEqual(mx.random.normal().dtype, mx.random.normal(dtype=None).dtype)
def test_randint(self): def test_randint(self):
a = mx.random.randint(0, 1, []) a = mx.random.randint(0, 1, [])
self.assertEqual(a.shape, []) self.assertEqual(a.shape, [])
@ -109,6 +113,10 @@ class TestRandom(mlx_tests.MLXTestCase):
a = mx.random.randint(10, -10, [1000, 1000]) a = mx.random.randint(10, -10, [1000, 1000])
self.assertTrue(mx.all(a == 10).item()) self.assertTrue(mx.all(a == 10).item())
self.assertEqual(
mx.random.randint(0, 1).dtype, mx.random.randint(0, 1, dtype=None).dtype
)
def test_bernoulli(self): def test_bernoulli(self):
a = mx.random.bernoulli() a = mx.random.bernoulli()
self.assertEqual(a.shape, []) self.assertEqual(a.shape, [])
@ -159,6 +167,11 @@ class TestRandom(mlx_tests.MLXTestCase):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
mx.random.truncated_normal(lower, higher) # Bad shape mx.random.truncated_normal(lower, higher) # Bad shape
self.assertEqual(
mx.random.truncated_normal(0, 1).dtype,
mx.random.truncated_normal(0, 1, dtype=None).dtype,
)
def test_gumbel(self): def test_gumbel(self):
samples = mx.random.gumbel(shape=(100, 100)) samples = mx.random.gumbel(shape=(100, 100))
self.assertEqual(samples.shape, [100, 100]) self.assertEqual(samples.shape, [100, 100])
@ -168,6 +181,10 @@ class TestRandom(mlx_tests.MLXTestCase):
# so this test is pretty conservative # so this test is pretty conservative
self.assertTrue(mx.abs(mx.mean(samples) - mean) < 0.2) self.assertTrue(mx.abs(mx.mean(samples) - mean) < 0.2)
self.assertEqual(
mx.random.gumbel((1, 1)).dtype, mx.random.gumbel((1, 1), dtype=None).dtype
)
def test_categorical(self): def test_categorical(self):
logits = mx.zeros((10, 20)) logits = mx.zeros((10, 20))
self.assertEqual(mx.random.categorical(logits, -1).shape, [10]) self.assertEqual(mx.random.categorical(logits, -1).shape, [10])