make behaviour of dtype arguments consistent and compliant to numpy (#379)

All functions that take an optional dtype should

* have a default dtype visible in the generated docs (accomplished via `"dtype"_a = std::optional{float32}`)
* behave identical when `dtype=None` or no dtype is passed

This important when passing kw args down from a numpy function like:

```
def f(x, dtype=None):
  mx.random.uniform(dtype=dtype)
  # ...
```

NumPy functions behave like this.

It also fixes a minor bug in `tri`: #378

Closes #378
This commit is contained in:
Daniel Strobusch 2024-01-05 18:37:46 +01:00 committed by GitHub
parent d8f41a5c0f
commit dfdb284e16
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 75 additions and 29 deletions

View File

@ -1212,14 +1212,22 @@ void init_ops(py::module_& m) {
)pbdoc");
m.def(
"linspace",
[](Scalar start, Scalar stop, int num, Dtype dtype, StreamOrDevice s) {
[](Scalar start,
Scalar stop,
int num,
std::optional<Dtype> dtype,
StreamOrDevice s) {
return linspace(
scalar_to_double(start), scalar_to_double(stop), num, dtype, s);
scalar_to_double(start),
scalar_to_double(stop),
num,
dtype.value_or(float32),
s);
},
"start"_a,
"stop"_a,
"num"_a = 50,
"dtype"_a = float32,
"dtype"_a = std::optional{float32},
"stream"_a = none,
R"pbdoc(
linspace(start, stop, num: Optional[int] = 50, dtype: Optional[Dtype] = float32, stream: Union[None, Stream, Device] = None) -> array
@ -1356,11 +1364,11 @@ void init_ops(py::module_& m) {
}
},
"shape"_a,
"dtype"_a = std::nullopt,
"dtype"_a = std::optional{float32},
py::kw_only(),
"stream"_a = none,
R"pbdoc(
zeros(shape: Union[int, List[int]], dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array
zeros(shape: Union[int, List[int]], dtype: Optional[Dtype] = float32, *, stream: Union[None, Stream, Device] = None) -> array
Construct an array of zeros.
@ -1403,11 +1411,11 @@ void init_ops(py::module_& m) {
}
},
"shape"_a,
"dtype"_a = std::nullopt,
"dtype"_a = std::optional{float32},
py::kw_only(),
"stream"_a = none,
R"pbdoc(
ones(shape: Union[int, List[int]], dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array
ones(shape: Union[int, List[int]], dtype: Optional[Dtype] = float32, *, stream: Union[None, Stream, Device] = None) -> array
Construct an array of ones.
@ -1449,11 +1457,11 @@ void init_ops(py::module_& m) {
"n"_a,
"m"_a = py::none(),
"k"_a = 0,
"dtype"_a = std::nullopt,
"dtype"_a = std::optional{float32},
py::kw_only(),
"stream"_a = none,
R"pbdoc(
eye(n: int, m: Optional[int] = None, k: int = 0, dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array
eye(n: int, m: Optional[int] = None, k: int = 0, dtype: Optional[Dtype] = float32, *, stream: Union[None, Stream, Device] = None) -> array
Create an identity matrix or a general diagonal matrix.
@ -1473,11 +1481,11 @@ void init_ops(py::module_& m) {
return identity(n, dtype.value_or(float32), s);
},
"n"_a,
"dtype"_a = std::nullopt,
"dtype"_a = std::optional{float32},
py::kw_only(),
"stream"_a = none,
R"pbdoc(
identity(n: int, dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array
identity(n: int, dtype: Optional[Dtype] = float32, *, stream: Union[None, Stream, Device] = None) -> array
Create a square identity matrix.
@ -1491,13 +1499,17 @@ void init_ops(py::module_& m) {
)pbdoc");
m.def(
"tri",
[](int n, std::optional<int> m, int k, Dtype dtype, StreamOrDevice s) {
return tri(n, m.value_or(n), k, float32, s);
[](int n,
std::optional<int> m,
int k,
std::optional<Dtype> type,
StreamOrDevice s) {
return tri(n, m.value_or(n), k, type.value_or(float32), s);
},
"n"_a,
"m"_a = none,
"k"_a = 0,
"dtype"_a = float32,
"dtype"_a = std::optional{float32},
py::kw_only(),
"stream"_a = none,
R"pbdoc(
@ -2432,7 +2444,7 @@ void init_ops(py::module_& m) {
array (array): Input array.
repeats (int): The number of repetitions for each element.
axis (int, optional): The axis in which to repeat the array along. If
unspecified it uses the flattened array of the input and repeats
unspecified it uses the flattened array of the input and repeats
along axis 0.
stream (Stream, optional): Stream or device. Defaults to ``None``.

View File

@ -61,15 +61,21 @@ void init_random(py::module_& parent_module) {
[](const ScalarOrArray& low,
const ScalarOrArray& high,
const std::vector<int>& shape,
Dtype type,
std::optional<Dtype> type,
const std::optional<array>& key,
StreamOrDevice s) {
return uniform(to_array(low), to_array(high), shape, type, key, s);
return uniform(
to_array(low),
to_array(high),
shape,
type.value_or(float32),
key,
s);
},
"low"_a = 0,
"high"_a = 1,
"shape"_a = std::vector<int>{},
"dtype"_a = float32,
"dtype"_a = std::optional{float32},
"key"_a = none,
"stream"_a = none,
R"pbdoc(
@ -92,12 +98,14 @@ void init_random(py::module_& parent_module) {
m.def(
"normal",
[](const std::vector<int>& shape,
Dtype type,
std::optional<Dtype> type,
const std::optional<array>& key,
StreamOrDevice s) { return normal(shape, type, key, s); },
StreamOrDevice s) {
return normal(shape, type.value_or(float32), key, s);
},
"shape"_a = std::vector<int>{},
"dtype"_a = float32,
"dtype"_a = std::optional{float32},
"key"_a = none,
"stream"_a = none,
R"pbdoc(
@ -116,10 +124,11 @@ void init_random(py::module_& parent_module) {
[](const ScalarOrArray& low,
const ScalarOrArray& high,
const std::vector<int>& shape,
Dtype type,
std::optional<Dtype> type,
const std::optional<array>& key,
StreamOrDevice s) {
return randint(to_array(low), to_array(high), shape, type, key, s);
return randint(
to_array(low), to_array(high), shape, type.value_or(int32), key, s);
},
"low"_a,
"high"_a,
@ -183,21 +192,22 @@ void init_random(py::module_& parent_module) {
[](const ScalarOrArray& lower_,
const ScalarOrArray& upper_,
const std::optional<std::vector<int>> shape_,
Dtype dtype,
std::optional<Dtype> type,
const std::optional<array>& key,
StreamOrDevice s) {
auto lower = to_array(lower_);
auto upper = to_array(upper_);
auto t = type.value_or(float32);
if (shape_.has_value()) {
return truncated_normal(lower, upper, shape_.value(), dtype, key, s);
return truncated_normal(lower, upper, shape_.value(), t, key, s);
} else {
return truncated_normal(lower, upper, dtype, key, s);
return truncated_normal(lower, upper, t, key, s);
}
},
"lower"_a,
"upper"_a,
"shape"_a = none,
"dtype"_a = float32,
"dtype"_a = std::optional{float32},
"key"_a = none,
"stream"_a = none,
R"pbdoc(
@ -221,9 +231,14 @@ void init_random(py::module_& parent_module) {
)pbdoc");
m.def(
"gumbel",
&gumbel,
[](const std::vector<int>& shape,
std::optional<Dtype> type,
const std::optional<array>& key,
StreamOrDevice s) {
return gumbel(shape, type.value_or(float32), key, s);
},
"shape"_a = std::vector<int>{},
"dtype"_a = float32,
"dtype"_a = std::optional{float32},
"stream"_a = none,
"key"_a = none,
R"pbdoc(

View File

@ -325,6 +325,8 @@ class TestOps(mlx_tests.MLXTestCase):
for shape in [[4], [4, 4], [2, 10]]:
for diag in [-1, 0, 1, -2]:
self.assertCmpNumpy(shape, mx.tri, np.tri, k=diag)
self.assertEqual(mx.tri(1, 1).dtype, mx.float32)
self.assertEqual(mx.tri(1, 1, dtype=mx.bfloat16).dtype, mx.bfloat16)
def test_tril(self):
for diag in [-1, 0, 1, -2]:

View File

@ -61,6 +61,8 @@ class TestRandom(mlx_tests.MLXTestCase):
a = mx.random.uniform(low=-0.1, high=0.1, shape=(1,), dtype=mx.bfloat16)
self.assertEqual(a.dtype, mx.bfloat16)
self.assertEqual(mx.random.uniform().dtype, mx.random.uniform(dtype=None).dtype)
def test_normal(self):
key = mx.random.key(0)
a = mx.random.normal(key=key)
@ -78,6 +80,8 @@ class TestRandom(mlx_tests.MLXTestCase):
a = mx.random.normal(dtype=t)
self.assertEqual(a.dtype, t)
self.assertEqual(mx.random.normal().dtype, mx.random.normal(dtype=None).dtype)
def test_randint(self):
a = mx.random.randint(0, 1, [])
self.assertEqual(a.shape, [])
@ -109,6 +113,10 @@ class TestRandom(mlx_tests.MLXTestCase):
a = mx.random.randint(10, -10, [1000, 1000])
self.assertTrue(mx.all(a == 10).item())
self.assertEqual(
mx.random.randint(0, 1).dtype, mx.random.randint(0, 1, dtype=None).dtype
)
def test_bernoulli(self):
a = mx.random.bernoulli()
self.assertEqual(a.shape, [])
@ -159,6 +167,11 @@ class TestRandom(mlx_tests.MLXTestCase):
with self.assertRaises(ValueError):
mx.random.truncated_normal(lower, higher) # Bad shape
self.assertEqual(
mx.random.truncated_normal(0, 1).dtype,
mx.random.truncated_normal(0, 1, dtype=None).dtype,
)
def test_gumbel(self):
samples = mx.random.gumbel(shape=(100, 100))
self.assertEqual(samples.shape, [100, 100])
@ -168,6 +181,10 @@ class TestRandom(mlx_tests.MLXTestCase):
# so this test is pretty conservative
self.assertTrue(mx.abs(mx.mean(samples) - mean) < 0.2)
self.assertEqual(
mx.random.gumbel((1, 1)).dtype, mx.random.gumbel((1, 1), dtype=None).dtype
)
def test_categorical(self):
logits = mx.zeros((10, 20))
self.assertEqual(mx.random.categorical(logits, -1).shape, [10])