mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-15 17:39:05 +08:00
make behaviour of dtype arguments consistent and compliant to numpy (#379)
All functions that take an optional dtype should
* have a default dtype visible in the generated docs (accomplished via `"dtype"_a = std::optional{float32}`)
* behave identical when `dtype=None` or no dtype is passed
This important when passing kw args down from a numpy function like:
```
def f(x, dtype=None):
mx.random.uniform(dtype=dtype)
# ...
```
NumPy functions behave like this.
It also fixes a minor bug in `tri`: #378
Closes #378
This commit is contained in:
@@ -61,15 +61,21 @@ void init_random(py::module_& parent_module) {
|
||||
[](const ScalarOrArray& low,
|
||||
const ScalarOrArray& high,
|
||||
const std::vector<int>& shape,
|
||||
Dtype type,
|
||||
std::optional<Dtype> type,
|
||||
const std::optional<array>& key,
|
||||
StreamOrDevice s) {
|
||||
return uniform(to_array(low), to_array(high), shape, type, key, s);
|
||||
return uniform(
|
||||
to_array(low),
|
||||
to_array(high),
|
||||
shape,
|
||||
type.value_or(float32),
|
||||
key,
|
||||
s);
|
||||
},
|
||||
"low"_a = 0,
|
||||
"high"_a = 1,
|
||||
"shape"_a = std::vector<int>{},
|
||||
"dtype"_a = float32,
|
||||
"dtype"_a = std::optional{float32},
|
||||
"key"_a = none,
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
@@ -92,12 +98,14 @@ void init_random(py::module_& parent_module) {
|
||||
m.def(
|
||||
"normal",
|
||||
[](const std::vector<int>& shape,
|
||||
Dtype type,
|
||||
std::optional<Dtype> type,
|
||||
const std::optional<array>& key,
|
||||
StreamOrDevice s) { return normal(shape, type, key, s); },
|
||||
StreamOrDevice s) {
|
||||
return normal(shape, type.value_or(float32), key, s);
|
||||
},
|
||||
|
||||
"shape"_a = std::vector<int>{},
|
||||
"dtype"_a = float32,
|
||||
"dtype"_a = std::optional{float32},
|
||||
"key"_a = none,
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
@@ -116,10 +124,11 @@ void init_random(py::module_& parent_module) {
|
||||
[](const ScalarOrArray& low,
|
||||
const ScalarOrArray& high,
|
||||
const std::vector<int>& shape,
|
||||
Dtype type,
|
||||
std::optional<Dtype> type,
|
||||
const std::optional<array>& key,
|
||||
StreamOrDevice s) {
|
||||
return randint(to_array(low), to_array(high), shape, type, key, s);
|
||||
return randint(
|
||||
to_array(low), to_array(high), shape, type.value_or(int32), key, s);
|
||||
},
|
||||
"low"_a,
|
||||
"high"_a,
|
||||
@@ -183,21 +192,22 @@ void init_random(py::module_& parent_module) {
|
||||
[](const ScalarOrArray& lower_,
|
||||
const ScalarOrArray& upper_,
|
||||
const std::optional<std::vector<int>> shape_,
|
||||
Dtype dtype,
|
||||
std::optional<Dtype> type,
|
||||
const std::optional<array>& key,
|
||||
StreamOrDevice s) {
|
||||
auto lower = to_array(lower_);
|
||||
auto upper = to_array(upper_);
|
||||
auto t = type.value_or(float32);
|
||||
if (shape_.has_value()) {
|
||||
return truncated_normal(lower, upper, shape_.value(), dtype, key, s);
|
||||
return truncated_normal(lower, upper, shape_.value(), t, key, s);
|
||||
} else {
|
||||
return truncated_normal(lower, upper, dtype, key, s);
|
||||
return truncated_normal(lower, upper, t, key, s);
|
||||
}
|
||||
},
|
||||
"lower"_a,
|
||||
"upper"_a,
|
||||
"shape"_a = none,
|
||||
"dtype"_a = float32,
|
||||
"dtype"_a = std::optional{float32},
|
||||
"key"_a = none,
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
@@ -221,9 +231,14 @@ void init_random(py::module_& parent_module) {
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"gumbel",
|
||||
&gumbel,
|
||||
[](const std::vector<int>& shape,
|
||||
std::optional<Dtype> type,
|
||||
const std::optional<array>& key,
|
||||
StreamOrDevice s) {
|
||||
return gumbel(shape, type.value_or(float32), key, s);
|
||||
},
|
||||
"shape"_a = std::vector<int>{},
|
||||
"dtype"_a = float32,
|
||||
"dtype"_a = std::optional{float32},
|
||||
"stream"_a = none,
|
||||
"key"_a = none,
|
||||
R"pbdoc(
|
||||
|
||||
Reference in New Issue
Block a user