make behaviour of dtype arguments consistent and compliant to numpy (#379)

All functions that take an optional dtype should

* have a default dtype visible in the generated docs (accomplished via `"dtype"_a = std::optional{float32}`)
* behave identical when `dtype=None` or no dtype is passed

This important when passing kw args down from a numpy function like:

```
def f(x, dtype=None):
  mx.random.uniform(dtype=dtype)
  # ...
```

NumPy functions behave like this.

It also fixes a minor bug in `tri`: #378

Closes #378
This commit is contained in:
Daniel Strobusch
2024-01-05 18:37:46 +01:00
committed by GitHub
parent d8f41a5c0f
commit dfdb284e16
4 changed files with 75 additions and 29 deletions

View File

@@ -61,15 +61,21 @@ void init_random(py::module_& parent_module) {
[](const ScalarOrArray& low,
const ScalarOrArray& high,
const std::vector<int>& shape,
Dtype type,
std::optional<Dtype> type,
const std::optional<array>& key,
StreamOrDevice s) {
return uniform(to_array(low), to_array(high), shape, type, key, s);
return uniform(
to_array(low),
to_array(high),
shape,
type.value_or(float32),
key,
s);
},
"low"_a = 0,
"high"_a = 1,
"shape"_a = std::vector<int>{},
"dtype"_a = float32,
"dtype"_a = std::optional{float32},
"key"_a = none,
"stream"_a = none,
R"pbdoc(
@@ -92,12 +98,14 @@ void init_random(py::module_& parent_module) {
m.def(
"normal",
[](const std::vector<int>& shape,
Dtype type,
std::optional<Dtype> type,
const std::optional<array>& key,
StreamOrDevice s) { return normal(shape, type, key, s); },
StreamOrDevice s) {
return normal(shape, type.value_or(float32), key, s);
},
"shape"_a = std::vector<int>{},
"dtype"_a = float32,
"dtype"_a = std::optional{float32},
"key"_a = none,
"stream"_a = none,
R"pbdoc(
@@ -116,10 +124,11 @@ void init_random(py::module_& parent_module) {
[](const ScalarOrArray& low,
const ScalarOrArray& high,
const std::vector<int>& shape,
Dtype type,
std::optional<Dtype> type,
const std::optional<array>& key,
StreamOrDevice s) {
return randint(to_array(low), to_array(high), shape, type, key, s);
return randint(
to_array(low), to_array(high), shape, type.value_or(int32), key, s);
},
"low"_a,
"high"_a,
@@ -183,21 +192,22 @@ void init_random(py::module_& parent_module) {
[](const ScalarOrArray& lower_,
const ScalarOrArray& upper_,
const std::optional<std::vector<int>> shape_,
Dtype dtype,
std::optional<Dtype> type,
const std::optional<array>& key,
StreamOrDevice s) {
auto lower = to_array(lower_);
auto upper = to_array(upper_);
auto t = type.value_or(float32);
if (shape_.has_value()) {
return truncated_normal(lower, upper, shape_.value(), dtype, key, s);
return truncated_normal(lower, upper, shape_.value(), t, key, s);
} else {
return truncated_normal(lower, upper, dtype, key, s);
return truncated_normal(lower, upper, t, key, s);
}
},
"lower"_a,
"upper"_a,
"shape"_a = none,
"dtype"_a = float32,
"dtype"_a = std::optional{float32},
"key"_a = none,
"stream"_a = none,
R"pbdoc(
@@ -221,9 +231,14 @@ void init_random(py::module_& parent_module) {
)pbdoc");
m.def(
"gumbel",
&gumbel,
[](const std::vector<int>& shape,
std::optional<Dtype> type,
const std::optional<array>& key,
StreamOrDevice s) {
return gumbel(shape, type.value_or(float32), key, s);
},
"shape"_a = std::vector<int>{},
"dtype"_a = float32,
"dtype"_a = std::optional{float32},
"stream"_a = none,
"key"_a = none,
R"pbdoc(