make behaviour of dtype arguments consistent and compliant to numpy (#379)

All functions that take an optional dtype should

* have a default dtype visible in the generated docs (accomplished via `"dtype"_a = std::optional{float32}`)
* behave identical when `dtype=None` or no dtype is passed

This important when passing kw args down from a numpy function like:

```
def f(x, dtype=None):
  mx.random.uniform(dtype=dtype)
  # ...
```

NumPy functions behave like this.

It also fixes a minor bug in `tri`: #378

Closes #378
This commit is contained in:
Daniel Strobusch 2024-01-05 18:37:46 +01:00 committed by GitHub
parent d8f41a5c0f
commit dfdb284e16
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 75 additions and 29 deletions

View File

@ -1212,14 +1212,22 @@ void init_ops(py::module_& m) {
)pbdoc"); )pbdoc");
m.def( m.def(
"linspace", "linspace",
[](Scalar start, Scalar stop, int num, Dtype dtype, StreamOrDevice s) { [](Scalar start,
Scalar stop,
int num,
std::optional<Dtype> dtype,
StreamOrDevice s) {
return linspace( return linspace(
scalar_to_double(start), scalar_to_double(stop), num, dtype, s); scalar_to_double(start),
scalar_to_double(stop),
num,
dtype.value_or(float32),
s);
}, },
"start"_a, "start"_a,
"stop"_a, "stop"_a,
"num"_a = 50, "num"_a = 50,
"dtype"_a = float32, "dtype"_a = std::optional{float32},
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc(
linspace(start, stop, num: Optional[int] = 50, dtype: Optional[Dtype] = float32, stream: Union[None, Stream, Device] = None) -> array linspace(start, stop, num: Optional[int] = 50, dtype: Optional[Dtype] = float32, stream: Union[None, Stream, Device] = None) -> array
@ -1356,11 +1364,11 @@ void init_ops(py::module_& m) {
} }
}, },
"shape"_a, "shape"_a,
"dtype"_a = std::nullopt, "dtype"_a = std::optional{float32},
py::kw_only(), py::kw_only(),
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc(
zeros(shape: Union[int, List[int]], dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array zeros(shape: Union[int, List[int]], dtype: Optional[Dtype] = float32, *, stream: Union[None, Stream, Device] = None) -> array
Construct an array of zeros. Construct an array of zeros.
@ -1403,11 +1411,11 @@ void init_ops(py::module_& m) {
} }
}, },
"shape"_a, "shape"_a,
"dtype"_a = std::nullopt, "dtype"_a = std::optional{float32},
py::kw_only(), py::kw_only(),
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc(
ones(shape: Union[int, List[int]], dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array ones(shape: Union[int, List[int]], dtype: Optional[Dtype] = float32, *, stream: Union[None, Stream, Device] = None) -> array
Construct an array of ones. Construct an array of ones.
@ -1449,11 +1457,11 @@ void init_ops(py::module_& m) {
"n"_a, "n"_a,
"m"_a = py::none(), "m"_a = py::none(),
"k"_a = 0, "k"_a = 0,
"dtype"_a = std::nullopt, "dtype"_a = std::optional{float32},
py::kw_only(), py::kw_only(),
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc(
eye(n: int, m: Optional[int] = None, k: int = 0, dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array eye(n: int, m: Optional[int] = None, k: int = 0, dtype: Optional[Dtype] = float32, *, stream: Union[None, Stream, Device] = None) -> array
Create an identity matrix or a general diagonal matrix. Create an identity matrix or a general diagonal matrix.
@ -1473,11 +1481,11 @@ void init_ops(py::module_& m) {
return identity(n, dtype.value_or(float32), s); return identity(n, dtype.value_or(float32), s);
}, },
"n"_a, "n"_a,
"dtype"_a = std::nullopt, "dtype"_a = std::optional{float32},
py::kw_only(), py::kw_only(),
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc(
identity(n: int, dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array identity(n: int, dtype: Optional[Dtype] = float32, *, stream: Union[None, Stream, Device] = None) -> array
Create a square identity matrix. Create a square identity matrix.
@ -1491,13 +1499,17 @@ void init_ops(py::module_& m) {
)pbdoc"); )pbdoc");
m.def( m.def(
"tri", "tri",
[](int n, std::optional<int> m, int k, Dtype dtype, StreamOrDevice s) { [](int n,
return tri(n, m.value_or(n), k, float32, s); std::optional<int> m,
int k,
std::optional<Dtype> type,
StreamOrDevice s) {
return tri(n, m.value_or(n), k, type.value_or(float32), s);
}, },
"n"_a, "n"_a,
"m"_a = none, "m"_a = none,
"k"_a = 0, "k"_a = 0,
"dtype"_a = float32, "dtype"_a = std::optional{float32},
py::kw_only(), py::kw_only(),
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc(
@ -2432,7 +2444,7 @@ void init_ops(py::module_& m) {
array (array): Input array. array (array): Input array.
repeats (int): The number of repetitions for each element. repeats (int): The number of repetitions for each element.
axis (int, optional): The axis in which to repeat the array along. If axis (int, optional): The axis in which to repeat the array along. If
unspecified it uses the flattened array of the input and repeats unspecified it uses the flattened array of the input and repeats
along axis 0. along axis 0.
stream (Stream, optional): Stream or device. Defaults to ``None``. stream (Stream, optional): Stream or device. Defaults to ``None``.

View File

@ -61,15 +61,21 @@ void init_random(py::module_& parent_module) {
[](const ScalarOrArray& low, [](const ScalarOrArray& low,
const ScalarOrArray& high, const ScalarOrArray& high,
const std::vector<int>& shape, const std::vector<int>& shape,
Dtype type, std::optional<Dtype> type,
const std::optional<array>& key, const std::optional<array>& key,
StreamOrDevice s) { StreamOrDevice s) {
return uniform(to_array(low), to_array(high), shape, type, key, s); return uniform(
to_array(low),
to_array(high),
shape,
type.value_or(float32),
key,
s);
}, },
"low"_a = 0, "low"_a = 0,
"high"_a = 1, "high"_a = 1,
"shape"_a = std::vector<int>{}, "shape"_a = std::vector<int>{},
"dtype"_a = float32, "dtype"_a = std::optional{float32},
"key"_a = none, "key"_a = none,
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc(
@ -92,12 +98,14 @@ void init_random(py::module_& parent_module) {
m.def( m.def(
"normal", "normal",
[](const std::vector<int>& shape, [](const std::vector<int>& shape,
Dtype type, std::optional<Dtype> type,
const std::optional<array>& key, const std::optional<array>& key,
StreamOrDevice s) { return normal(shape, type, key, s); }, StreamOrDevice s) {
return normal(shape, type.value_or(float32), key, s);
},
"shape"_a = std::vector<int>{}, "shape"_a = std::vector<int>{},
"dtype"_a = float32, "dtype"_a = std::optional{float32},
"key"_a = none, "key"_a = none,
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc(
@ -116,10 +124,11 @@ void init_random(py::module_& parent_module) {
[](const ScalarOrArray& low, [](const ScalarOrArray& low,
const ScalarOrArray& high, const ScalarOrArray& high,
const std::vector<int>& shape, const std::vector<int>& shape,
Dtype type, std::optional<Dtype> type,
const std::optional<array>& key, const std::optional<array>& key,
StreamOrDevice s) { StreamOrDevice s) {
return randint(to_array(low), to_array(high), shape, type, key, s); return randint(
to_array(low), to_array(high), shape, type.value_or(int32), key, s);
}, },
"low"_a, "low"_a,
"high"_a, "high"_a,
@ -183,21 +192,22 @@ void init_random(py::module_& parent_module) {
[](const ScalarOrArray& lower_, [](const ScalarOrArray& lower_,
const ScalarOrArray& upper_, const ScalarOrArray& upper_,
const std::optional<std::vector<int>> shape_, const std::optional<std::vector<int>> shape_,
Dtype dtype, std::optional<Dtype> type,
const std::optional<array>& key, const std::optional<array>& key,
StreamOrDevice s) { StreamOrDevice s) {
auto lower = to_array(lower_); auto lower = to_array(lower_);
auto upper = to_array(upper_); auto upper = to_array(upper_);
auto t = type.value_or(float32);
if (shape_.has_value()) { if (shape_.has_value()) {
return truncated_normal(lower, upper, shape_.value(), dtype, key, s); return truncated_normal(lower, upper, shape_.value(), t, key, s);
} else { } else {
return truncated_normal(lower, upper, dtype, key, s); return truncated_normal(lower, upper, t, key, s);
} }
}, },
"lower"_a, "lower"_a,
"upper"_a, "upper"_a,
"shape"_a = none, "shape"_a = none,
"dtype"_a = float32, "dtype"_a = std::optional{float32},
"key"_a = none, "key"_a = none,
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc(
@ -221,9 +231,14 @@ void init_random(py::module_& parent_module) {
)pbdoc"); )pbdoc");
m.def( m.def(
"gumbel", "gumbel",
&gumbel, [](const std::vector<int>& shape,
std::optional<Dtype> type,
const std::optional<array>& key,
StreamOrDevice s) {
return gumbel(shape, type.value_or(float32), key, s);
},
"shape"_a = std::vector<int>{}, "shape"_a = std::vector<int>{},
"dtype"_a = float32, "dtype"_a = std::optional{float32},
"stream"_a = none, "stream"_a = none,
"key"_a = none, "key"_a = none,
R"pbdoc( R"pbdoc(

View File

@ -325,6 +325,8 @@ class TestOps(mlx_tests.MLXTestCase):
for shape in [[4], [4, 4], [2, 10]]: for shape in [[4], [4, 4], [2, 10]]:
for diag in [-1, 0, 1, -2]: for diag in [-1, 0, 1, -2]:
self.assertCmpNumpy(shape, mx.tri, np.tri, k=diag) self.assertCmpNumpy(shape, mx.tri, np.tri, k=diag)
self.assertEqual(mx.tri(1, 1).dtype, mx.float32)
self.assertEqual(mx.tri(1, 1, dtype=mx.bfloat16).dtype, mx.bfloat16)
def test_tril(self): def test_tril(self):
for diag in [-1, 0, 1, -2]: for diag in [-1, 0, 1, -2]:

View File

@ -61,6 +61,8 @@ class TestRandom(mlx_tests.MLXTestCase):
a = mx.random.uniform(low=-0.1, high=0.1, shape=(1,), dtype=mx.bfloat16) a = mx.random.uniform(low=-0.1, high=0.1, shape=(1,), dtype=mx.bfloat16)
self.assertEqual(a.dtype, mx.bfloat16) self.assertEqual(a.dtype, mx.bfloat16)
self.assertEqual(mx.random.uniform().dtype, mx.random.uniform(dtype=None).dtype)
def test_normal(self): def test_normal(self):
key = mx.random.key(0) key = mx.random.key(0)
a = mx.random.normal(key=key) a = mx.random.normal(key=key)
@ -78,6 +80,8 @@ class TestRandom(mlx_tests.MLXTestCase):
a = mx.random.normal(dtype=t) a = mx.random.normal(dtype=t)
self.assertEqual(a.dtype, t) self.assertEqual(a.dtype, t)
self.assertEqual(mx.random.normal().dtype, mx.random.normal(dtype=None).dtype)
def test_randint(self): def test_randint(self):
a = mx.random.randint(0, 1, []) a = mx.random.randint(0, 1, [])
self.assertEqual(a.shape, []) self.assertEqual(a.shape, [])
@ -109,6 +113,10 @@ class TestRandom(mlx_tests.MLXTestCase):
a = mx.random.randint(10, -10, [1000, 1000]) a = mx.random.randint(10, -10, [1000, 1000])
self.assertTrue(mx.all(a == 10).item()) self.assertTrue(mx.all(a == 10).item())
self.assertEqual(
mx.random.randint(0, 1).dtype, mx.random.randint(0, 1, dtype=None).dtype
)
def test_bernoulli(self): def test_bernoulli(self):
a = mx.random.bernoulli() a = mx.random.bernoulli()
self.assertEqual(a.shape, []) self.assertEqual(a.shape, [])
@ -159,6 +167,11 @@ class TestRandom(mlx_tests.MLXTestCase):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
mx.random.truncated_normal(lower, higher) # Bad shape mx.random.truncated_normal(lower, higher) # Bad shape
self.assertEqual(
mx.random.truncated_normal(0, 1).dtype,
mx.random.truncated_normal(0, 1, dtype=None).dtype,
)
def test_gumbel(self): def test_gumbel(self):
samples = mx.random.gumbel(shape=(100, 100)) samples = mx.random.gumbel(shape=(100, 100))
self.assertEqual(samples.shape, [100, 100]) self.assertEqual(samples.shape, [100, 100])
@ -168,6 +181,10 @@ class TestRandom(mlx_tests.MLXTestCase):
# so this test is pretty conservative # so this test is pretty conservative
self.assertTrue(mx.abs(mx.mean(samples) - mean) < 0.2) self.assertTrue(mx.abs(mx.mean(samples) - mean) < 0.2)
self.assertEqual(
mx.random.gumbel((1, 1)).dtype, mx.random.gumbel((1, 1), dtype=None).dtype
)
def test_categorical(self): def test_categorical(self):
logits = mx.zeros((10, 20)) logits = mx.zeros((10, 20))
self.assertEqual(mx.random.categorical(logits, -1).shape, [10]) self.assertEqual(mx.random.categorical(logits, -1).shape, [10])