More shape type (#1705)

* more shape type

* fix
This commit is contained in:
Awni Hannun
2024-12-19 08:08:20 -08:00
committed by GitHub
parent f17536af9c
commit e03f0372b1
38 changed files with 260 additions and 258 deletions

View File

@@ -108,7 +108,7 @@ void init_random(nb::module_& parent_module) {
"uniform",
[](const ScalarOrArray& low,
const ScalarOrArray& high,
const std::vector<int>& shape,
const mx::Shape& shape,
std::optional<mx::Dtype> type,
const std::optional<mx::array>& key_,
mx::StreamOrDevice s) {
@@ -123,7 +123,7 @@ void init_random(nb::module_& parent_module) {
},
"low"_a = 0,
"high"_a = 1,
"shape"_a = std::vector<int>{},
"shape"_a = mx::Shape{},
"dtype"_a.none() = mx::float32,
"key"_a = nb::none(),
"stream"_a = nb::none(),
@@ -150,7 +150,7 @@ void init_random(nb::module_& parent_module) {
)pbdoc");
m.def(
"normal",
[](const std::vector<int>& shape,
[](const mx::Shape& shape,
std::optional<mx::Dtype> type,
float loc,
float scale,
@@ -160,7 +160,7 @@ void init_random(nb::module_& parent_module) {
return mx::random::normal(
shape, type.value_or(mx::float32), loc, scale, key, s);
},
"shape"_a = std::vector<int>{},
"shape"_a = mx::Shape{},
"dtype"_a.none() = mx::float32,
"loc"_a = 0.0,
"scale"_a = 1.0,
@@ -185,7 +185,7 @@ void init_random(nb::module_& parent_module) {
"multivariate_normal",
[](const mx::array& mean,
const mx::array& cov,
const std::vector<int>& shape,
const mx::Shape& shape,
std::optional<mx::Dtype> type,
const std::optional<mx::array>& key_,
mx::StreamOrDevice s) {
@@ -195,7 +195,7 @@ void init_random(nb::module_& parent_module) {
},
"mean"_a,
"cov"_a,
"shape"_a = std::vector<int>{},
"shape"_a = mx::Shape{},
"dtype"_a.none() = mx::float32,
"key"_a = nb::none(),
"stream"_a = nb::none(),
@@ -227,7 +227,7 @@ void init_random(nb::module_& parent_module) {
"randint",
[](const ScalarOrArray& low,
const ScalarOrArray& high,
const std::vector<int>& shape,
const mx::Shape& shape,
std::optional<mx::Dtype> type,
const std::optional<mx::array>& key_,
mx::StreamOrDevice s) {
@@ -242,7 +242,7 @@ void init_random(nb::module_& parent_module) {
},
"low"_a,
"high"_a,
"shape"_a = std::vector<int>{},
"shape"_a = mx::Shape{},
"dtype"_a.none() = mx::int32,
"key"_a = nb::none(),
"stream"_a = nb::none(),
@@ -268,7 +268,7 @@ void init_random(nb::module_& parent_module) {
m.def(
"bernoulli",
[](const ScalarOrArray& p_,
const std::optional<std::vector<int>> shape,
const std::optional<mx::Shape> shape,
const std::optional<mx::array>& key_,
mx::StreamOrDevice s) {
auto key = key_ ? key_.value() : default_key().next();
@@ -306,7 +306,7 @@ void init_random(nb::module_& parent_module) {
"truncated_normal",
[](const ScalarOrArray& lower_,
const ScalarOrArray& upper_,
const std::optional<std::vector<int>> shape_,
const std::optional<mx::Shape> shape_,
std::optional<mx::Dtype> type,
const std::optional<mx::array>& key_,
mx::StreamOrDevice s) {
@@ -350,14 +350,14 @@ void init_random(nb::module_& parent_module) {
)pbdoc");
m.def(
"gumbel",
[](const std::vector<int>& shape,
[](const mx::Shape& shape,
std::optional<mx::Dtype> type,
const std::optional<mx::array>& key_,
mx::StreamOrDevice s) {
auto key = key_ ? key_.value() : default_key().next();
return mx::random::gumbel(shape, type.value_or(mx::float32), key, s);
},
"shape"_a = std::vector<int>{},
"shape"_a = mx::Shape{},
"dtype"_a.none() = mx::float32,
"key"_a = nb::none(),
"stream"_a = nb::none(),
@@ -384,7 +384,7 @@ void init_random(nb::module_& parent_module) {
"categorical",
[](const mx::array& logits,
int axis,
const std::optional<std::vector<int>> shape,
const std::optional<mx::Shape> shape,
const std::optional<int> num_samples,
const std::optional<mx::array>& key_,
mx::StreamOrDevice s) {
@@ -434,7 +434,7 @@ void init_random(nb::module_& parent_module) {
)pbdoc");
m.def(
"laplace",
[](const std::vector<int>& shape,
[](const mx::Shape& shape,
std::optional<mx::Dtype> type,
float loc,
float scale,
@@ -444,7 +444,7 @@ void init_random(nb::module_& parent_module) {
return mx::random::laplace(
shape, type.value_or(mx::float32), loc, scale, key, s);
},
"shape"_a = std::vector<int>{},
"shape"_a = mx::Shape{},
"dtype"_a.none() = mx::float32,
"loc"_a = 0.0,
"scale"_a = 1.0,
@@ -479,7 +479,7 @@ void init_random(nb::module_& parent_module) {
return mx::random::permutation(std::get<mx::array>(x), axis, key, s);
}
},
"shape"_a = std::vector<int>{},
"x"_a,
"axis"_a = 0,
"key"_a = nb::none(),
"stream"_a = nb::none(),