mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
@@ -108,7 +108,7 @@ void init_random(nb::module_& parent_module) {
|
||||
"uniform",
|
||||
[](const ScalarOrArray& low,
|
||||
const ScalarOrArray& high,
|
||||
const std::vector<int>& shape,
|
||||
const mx::Shape& shape,
|
||||
std::optional<mx::Dtype> type,
|
||||
const std::optional<mx::array>& key_,
|
||||
mx::StreamOrDevice s) {
|
||||
@@ -123,7 +123,7 @@ void init_random(nb::module_& parent_module) {
|
||||
},
|
||||
"low"_a = 0,
|
||||
"high"_a = 1,
|
||||
"shape"_a = std::vector<int>{},
|
||||
"shape"_a = mx::Shape{},
|
||||
"dtype"_a.none() = mx::float32,
|
||||
"key"_a = nb::none(),
|
||||
"stream"_a = nb::none(),
|
||||
@@ -150,7 +150,7 @@ void init_random(nb::module_& parent_module) {
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"normal",
|
||||
[](const std::vector<int>& shape,
|
||||
[](const mx::Shape& shape,
|
||||
std::optional<mx::Dtype> type,
|
||||
float loc,
|
||||
float scale,
|
||||
@@ -160,7 +160,7 @@ void init_random(nb::module_& parent_module) {
|
||||
return mx::random::normal(
|
||||
shape, type.value_or(mx::float32), loc, scale, key, s);
|
||||
},
|
||||
"shape"_a = std::vector<int>{},
|
||||
"shape"_a = mx::Shape{},
|
||||
"dtype"_a.none() = mx::float32,
|
||||
"loc"_a = 0.0,
|
||||
"scale"_a = 1.0,
|
||||
@@ -185,7 +185,7 @@ void init_random(nb::module_& parent_module) {
|
||||
"multivariate_normal",
|
||||
[](const mx::array& mean,
|
||||
const mx::array& cov,
|
||||
const std::vector<int>& shape,
|
||||
const mx::Shape& shape,
|
||||
std::optional<mx::Dtype> type,
|
||||
const std::optional<mx::array>& key_,
|
||||
mx::StreamOrDevice s) {
|
||||
@@ -195,7 +195,7 @@ void init_random(nb::module_& parent_module) {
|
||||
},
|
||||
"mean"_a,
|
||||
"cov"_a,
|
||||
"shape"_a = std::vector<int>{},
|
||||
"shape"_a = mx::Shape{},
|
||||
"dtype"_a.none() = mx::float32,
|
||||
"key"_a = nb::none(),
|
||||
"stream"_a = nb::none(),
|
||||
@@ -227,7 +227,7 @@ void init_random(nb::module_& parent_module) {
|
||||
"randint",
|
||||
[](const ScalarOrArray& low,
|
||||
const ScalarOrArray& high,
|
||||
const std::vector<int>& shape,
|
||||
const mx::Shape& shape,
|
||||
std::optional<mx::Dtype> type,
|
||||
const std::optional<mx::array>& key_,
|
||||
mx::StreamOrDevice s) {
|
||||
@@ -242,7 +242,7 @@ void init_random(nb::module_& parent_module) {
|
||||
},
|
||||
"low"_a,
|
||||
"high"_a,
|
||||
"shape"_a = std::vector<int>{},
|
||||
"shape"_a = mx::Shape{},
|
||||
"dtype"_a.none() = mx::int32,
|
||||
"key"_a = nb::none(),
|
||||
"stream"_a = nb::none(),
|
||||
@@ -268,7 +268,7 @@ void init_random(nb::module_& parent_module) {
|
||||
m.def(
|
||||
"bernoulli",
|
||||
[](const ScalarOrArray& p_,
|
||||
const std::optional<std::vector<int>> shape,
|
||||
const std::optional<mx::Shape> shape,
|
||||
const std::optional<mx::array>& key_,
|
||||
mx::StreamOrDevice s) {
|
||||
auto key = key_ ? key_.value() : default_key().next();
|
||||
@@ -306,7 +306,7 @@ void init_random(nb::module_& parent_module) {
|
||||
"truncated_normal",
|
||||
[](const ScalarOrArray& lower_,
|
||||
const ScalarOrArray& upper_,
|
||||
const std::optional<std::vector<int>> shape_,
|
||||
const std::optional<mx::Shape> shape_,
|
||||
std::optional<mx::Dtype> type,
|
||||
const std::optional<mx::array>& key_,
|
||||
mx::StreamOrDevice s) {
|
||||
@@ -350,14 +350,14 @@ void init_random(nb::module_& parent_module) {
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"gumbel",
|
||||
[](const std::vector<int>& shape,
|
||||
[](const mx::Shape& shape,
|
||||
std::optional<mx::Dtype> type,
|
||||
const std::optional<mx::array>& key_,
|
||||
mx::StreamOrDevice s) {
|
||||
auto key = key_ ? key_.value() : default_key().next();
|
||||
return mx::random::gumbel(shape, type.value_or(mx::float32), key, s);
|
||||
},
|
||||
"shape"_a = std::vector<int>{},
|
||||
"shape"_a = mx::Shape{},
|
||||
"dtype"_a.none() = mx::float32,
|
||||
"key"_a = nb::none(),
|
||||
"stream"_a = nb::none(),
|
||||
@@ -384,7 +384,7 @@ void init_random(nb::module_& parent_module) {
|
||||
"categorical",
|
||||
[](const mx::array& logits,
|
||||
int axis,
|
||||
const std::optional<std::vector<int>> shape,
|
||||
const std::optional<mx::Shape> shape,
|
||||
const std::optional<int> num_samples,
|
||||
const std::optional<mx::array>& key_,
|
||||
mx::StreamOrDevice s) {
|
||||
@@ -434,7 +434,7 @@ void init_random(nb::module_& parent_module) {
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"laplace",
|
||||
[](const std::vector<int>& shape,
|
||||
[](const mx::Shape& shape,
|
||||
std::optional<mx::Dtype> type,
|
||||
float loc,
|
||||
float scale,
|
||||
@@ -444,7 +444,7 @@ void init_random(nb::module_& parent_module) {
|
||||
return mx::random::laplace(
|
||||
shape, type.value_or(mx::float32), loc, scale, key, s);
|
||||
},
|
||||
"shape"_a = std::vector<int>{},
|
||||
"shape"_a = mx::Shape{},
|
||||
"dtype"_a.none() = mx::float32,
|
||||
"loc"_a = 0.0,
|
||||
"scale"_a = 1.0,
|
||||
@@ -479,7 +479,7 @@ void init_random(nb::module_& parent_module) {
|
||||
return mx::random::permutation(std::get<mx::array>(x), axis, key, s);
|
||||
}
|
||||
},
|
||||
"shape"_a = std::vector<int>{},
|
||||
"x"_a,
|
||||
"axis"_a = 0,
|
||||
"key"_a = nb::none(),
|
||||
"stream"_a = nb::none(),
|
||||
|
||||
Reference in New Issue
Block a user