mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Remove "using namespace mlx::core" in python/src (#1689)
This commit is contained in:
@@ -12,23 +12,22 @@
|
||||
#include "mlx/ops.h"
|
||||
#include "mlx/random.h"
|
||||
|
||||
namespace mx = mlx::core;
|
||||
namespace nb = nanobind;
|
||||
using namespace nb::literals;
|
||||
using namespace mlx::core;
|
||||
using namespace mlx::core::random;
|
||||
|
||||
class PyKeySequence {
|
||||
public:
|
||||
explicit PyKeySequence(uint64_t seed) {
|
||||
state_.append(key(seed));
|
||||
state_.append(mx::random::key(seed));
|
||||
}
|
||||
|
||||
void seed(uint64_t seed) {
|
||||
state_[0] = key(seed);
|
||||
state_[0] = mx::random::key(seed);
|
||||
}
|
||||
|
||||
array next() {
|
||||
auto out = split(nb::cast<array>(state_[0]));
|
||||
mx::array next() {
|
||||
auto out = mx::random::split(nb::cast<mx::array>(state_[0]));
|
||||
state_[0] = out.first;
|
||||
return out.second;
|
||||
}
|
||||
@@ -75,7 +74,7 @@ void init_random(nb::module_& parent_module) {
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"key",
|
||||
&key,
|
||||
&mx::random::key,
|
||||
"seed"_a,
|
||||
R"pbdoc(
|
||||
Get a PRNG key from a seed.
|
||||
@@ -88,7 +87,8 @@ void init_random(nb::module_& parent_module) {
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"split",
|
||||
nb::overload_cast<const array&, int, StreamOrDevice>(&random::split),
|
||||
nb::overload_cast<const mx::array&, int, mx::StreamOrDevice>(
|
||||
&mx::random::split),
|
||||
"key"_a,
|
||||
"num"_a = 2,
|
||||
"stream"_a = nb::none(),
|
||||
@@ -109,22 +109,22 @@ void init_random(nb::module_& parent_module) {
|
||||
[](const ScalarOrArray& low,
|
||||
const ScalarOrArray& high,
|
||||
const std::vector<int>& shape,
|
||||
std::optional<Dtype> type,
|
||||
const std::optional<array>& key_,
|
||||
StreamOrDevice s) {
|
||||
std::optional<mx::Dtype> type,
|
||||
const std::optional<mx::array>& key_,
|
||||
mx::StreamOrDevice s) {
|
||||
auto key = key_ ? key_.value() : default_key().next();
|
||||
return uniform(
|
||||
return mx::random::uniform(
|
||||
to_array(low),
|
||||
to_array(high),
|
||||
shape,
|
||||
type.value_or(float32),
|
||||
type.value_or(mx::float32),
|
||||
key,
|
||||
s);
|
||||
},
|
||||
"low"_a = 0,
|
||||
"high"_a = 1,
|
||||
"shape"_a = std::vector<int>{},
|
||||
"dtype"_a.none() = float32,
|
||||
"dtype"_a.none() = mx::float32,
|
||||
"key"_a = nb::none(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
@@ -151,16 +151,17 @@ void init_random(nb::module_& parent_module) {
|
||||
m.def(
|
||||
"normal",
|
||||
[](const std::vector<int>& shape,
|
||||
std::optional<Dtype> type,
|
||||
std::optional<mx::Dtype> type,
|
||||
float loc,
|
||||
float scale,
|
||||
const std::optional<array>& key_,
|
||||
StreamOrDevice s) {
|
||||
const std::optional<mx::array>& key_,
|
||||
mx::StreamOrDevice s) {
|
||||
auto key = key_ ? key_.value() : default_key().next();
|
||||
return normal(shape, type.value_or(float32), loc, scale, key, s);
|
||||
return mx::random::normal(
|
||||
shape, type.value_or(mx::float32), loc, scale, key, s);
|
||||
},
|
||||
"shape"_a = std::vector<int>{},
|
||||
"dtype"_a.none() = float32,
|
||||
"dtype"_a.none() = mx::float32,
|
||||
"loc"_a = 0.0,
|
||||
"scale"_a = 1.0,
|
||||
"key"_a = nb::none(),
|
||||
@@ -182,20 +183,20 @@ void init_random(nb::module_& parent_module) {
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"multivariate_normal",
|
||||
[](const array& mean,
|
||||
const array& cov,
|
||||
[](const mx::array& mean,
|
||||
const mx::array& cov,
|
||||
const std::vector<int>& shape,
|
||||
std::optional<Dtype> type,
|
||||
const std::optional<array>& key_,
|
||||
StreamOrDevice s) {
|
||||
std::optional<mx::Dtype> type,
|
||||
const std::optional<mx::array>& key_,
|
||||
mx::StreamOrDevice s) {
|
||||
auto key = key_ ? key_.value() : default_key().next();
|
||||
return multivariate_normal(
|
||||
mean, cov, shape, type.value_or(float32), key, s);
|
||||
return mx::random::multivariate_normal(
|
||||
mean, cov, shape, type.value_or(mx::float32), key, s);
|
||||
},
|
||||
"mean"_a,
|
||||
"cov"_a,
|
||||
"shape"_a = std::vector<int>{},
|
||||
"dtype"_a.none() = float32,
|
||||
"dtype"_a.none() = mx::float32,
|
||||
"key"_a = nb::none(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
@@ -227,17 +228,22 @@ void init_random(nb::module_& parent_module) {
|
||||
[](const ScalarOrArray& low,
|
||||
const ScalarOrArray& high,
|
||||
const std::vector<int>& shape,
|
||||
std::optional<Dtype> type,
|
||||
const std::optional<array>& key_,
|
||||
StreamOrDevice s) {
|
||||
std::optional<mx::Dtype> type,
|
||||
const std::optional<mx::array>& key_,
|
||||
mx::StreamOrDevice s) {
|
||||
auto key = key_ ? key_.value() : default_key().next();
|
||||
return randint(
|
||||
to_array(low), to_array(high), shape, type.value_or(int32), key, s);
|
||||
return mx::random::randint(
|
||||
to_array(low),
|
||||
to_array(high),
|
||||
shape,
|
||||
type.value_or(mx::int32),
|
||||
key,
|
||||
s);
|
||||
},
|
||||
"low"_a,
|
||||
"high"_a,
|
||||
"shape"_a = std::vector<int>{},
|
||||
"dtype"_a.none() = int32,
|
||||
"dtype"_a.none() = mx::int32,
|
||||
"key"_a = nb::none(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
@@ -263,14 +269,14 @@ void init_random(nb::module_& parent_module) {
|
||||
"bernoulli",
|
||||
[](const ScalarOrArray& p_,
|
||||
const std::optional<std::vector<int>> shape,
|
||||
const std::optional<array>& key_,
|
||||
StreamOrDevice s) {
|
||||
const std::optional<mx::array>& key_,
|
||||
mx::StreamOrDevice s) {
|
||||
auto key = key_ ? key_.value() : default_key().next();
|
||||
auto p = to_array(p_);
|
||||
if (shape.has_value()) {
|
||||
return bernoulli(p, shape.value(), key, s);
|
||||
return mx::random::bernoulli(p, shape.value(), key, s);
|
||||
} else {
|
||||
return bernoulli(p, key, s);
|
||||
return mx::random::bernoulli(p, key, s);
|
||||
}
|
||||
},
|
||||
"p"_a = 0.5,
|
||||
@@ -301,23 +307,24 @@ void init_random(nb::module_& parent_module) {
|
||||
[](const ScalarOrArray& lower_,
|
||||
const ScalarOrArray& upper_,
|
||||
const std::optional<std::vector<int>> shape_,
|
||||
std::optional<Dtype> type,
|
||||
const std::optional<array>& key_,
|
||||
StreamOrDevice s) {
|
||||
std::optional<mx::Dtype> type,
|
||||
const std::optional<mx::array>& key_,
|
||||
mx::StreamOrDevice s) {
|
||||
auto key = key_ ? key_.value() : default_key().next();
|
||||
auto lower = to_array(lower_);
|
||||
auto upper = to_array(upper_);
|
||||
auto t = type.value_or(float32);
|
||||
auto t = type.value_or(mx::float32);
|
||||
if (shape_.has_value()) {
|
||||
return truncated_normal(lower, upper, shape_.value(), t, key, s);
|
||||
return mx::random::truncated_normal(
|
||||
lower, upper, shape_.value(), t, key, s);
|
||||
} else {
|
||||
return truncated_normal(lower, upper, t, key, s);
|
||||
return mx::random::truncated_normal(lower, upper, t, key, s);
|
||||
}
|
||||
},
|
||||
"lower"_a,
|
||||
"upper"_a,
|
||||
"shape"_a = nb::none(),
|
||||
"dtype"_a.none() = float32,
|
||||
"dtype"_a.none() = mx::float32,
|
||||
"key"_a = nb::none(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
@@ -344,14 +351,14 @@ void init_random(nb::module_& parent_module) {
|
||||
m.def(
|
||||
"gumbel",
|
||||
[](const std::vector<int>& shape,
|
||||
std::optional<Dtype> type,
|
||||
const std::optional<array>& key_,
|
||||
StreamOrDevice s) {
|
||||
std::optional<mx::Dtype> type,
|
||||
const std::optional<mx::array>& key_,
|
||||
mx::StreamOrDevice s) {
|
||||
auto key = key_ ? key_.value() : default_key().next();
|
||||
return gumbel(shape, type.value_or(float32), key, s);
|
||||
return mx::random::gumbel(shape, type.value_or(mx::float32), key, s);
|
||||
},
|
||||
"shape"_a = std::vector<int>{},
|
||||
"dtype"_a.none() = float32,
|
||||
"dtype"_a.none() = mx::float32,
|
||||
"key"_a = nb::none(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
@@ -375,22 +382,23 @@ void init_random(nb::module_& parent_module) {
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"categorical",
|
||||
[](const array& logits,
|
||||
[](const mx::array& logits,
|
||||
int axis,
|
||||
const std::optional<std::vector<int>> shape,
|
||||
const std::optional<int> num_samples,
|
||||
const std::optional<array>& key_,
|
||||
StreamOrDevice s) {
|
||||
const std::optional<mx::array>& key_,
|
||||
mx::StreamOrDevice s) {
|
||||
auto key = key_ ? key_.value() : default_key().next();
|
||||
if (shape.has_value() && num_samples.has_value()) {
|
||||
throw std::invalid_argument(
|
||||
"[categorical] At most one of shape or num_samples can be specified.");
|
||||
} else if (shape.has_value()) {
|
||||
return categorical(logits, axis, shape.value(), key, s);
|
||||
return mx::random::categorical(logits, axis, shape.value(), key, s);
|
||||
} else if (num_samples.has_value()) {
|
||||
return categorical(logits, axis, num_samples.value(), key, s);
|
||||
return mx::random::categorical(
|
||||
logits, axis, num_samples.value(), key, s);
|
||||
} else {
|
||||
return categorical(logits, axis, key, s);
|
||||
return mx::random::categorical(logits, axis, key, s);
|
||||
}
|
||||
},
|
||||
"logits"_a,
|
||||
@@ -427,16 +435,17 @@ void init_random(nb::module_& parent_module) {
|
||||
m.def(
|
||||
"laplace",
|
||||
[](const std::vector<int>& shape,
|
||||
std::optional<Dtype> type,
|
||||
std::optional<mx::Dtype> type,
|
||||
float loc,
|
||||
float scale,
|
||||
const std::optional<array>& key_,
|
||||
StreamOrDevice s) {
|
||||
const std::optional<mx::array>& key_,
|
||||
mx::StreamOrDevice s) {
|
||||
auto key = key_ ? key_.value() : default_key().next();
|
||||
return laplace(shape, type.value_or(float32), loc, scale, key, s);
|
||||
return mx::random::laplace(
|
||||
shape, type.value_or(mx::float32), loc, scale, key, s);
|
||||
},
|
||||
"shape"_a = std::vector<int>{},
|
||||
"dtype"_a.none() = float32,
|
||||
"dtype"_a.none() = mx::float32,
|
||||
"loc"_a = 0.0,
|
||||
"scale"_a = 1.0,
|
||||
"key"_a = nb::none(),
|
||||
@@ -459,15 +468,15 @@ void init_random(nb::module_& parent_module) {
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"permuation",
|
||||
[](const std::variant<nb::int_, array>& x,
|
||||
[](const std::variant<nb::int_, mx::array>& x,
|
||||
int axis,
|
||||
const std::optional<array>& key_,
|
||||
StreamOrDevice s) {
|
||||
const std::optional<mx::array>& key_,
|
||||
mx::StreamOrDevice s) {
|
||||
auto key = key_ ? key_.value() : default_key().next();
|
||||
if (auto pv = std::get_if<nb::int_>(&x); pv) {
|
||||
return permutation(nb::cast<int>(*pv), key, s);
|
||||
return mx::random::permutation(nb::cast<int>(*pv), key, s);
|
||||
} else {
|
||||
return permutation(std::get<array>(x), axis, key, s);
|
||||
return mx::random::permutation(std::get<mx::array>(x), axis, key, s);
|
||||
}
|
||||
},
|
||||
"shape"_a = std::vector<int>{},
|
||||
|
||||
Reference in New Issue
Block a user