Remove "using namespace mlx::core" in python/src (#1689)

This commit is contained in:
Cheng
2024-12-12 08:45:39 +09:00
committed by GitHub
parent f3dfa36a3a
commit 0bf19037ca
22 changed files with 1423 additions and 1302 deletions

View File

@@ -12,23 +12,22 @@
#include "mlx/ops.h"
#include "mlx/random.h"
namespace mx = mlx::core;
namespace nb = nanobind;
using namespace nb::literals;
using namespace mlx::core;
using namespace mlx::core::random;
class PyKeySequence {
public:
explicit PyKeySequence(uint64_t seed) {
state_.append(key(seed));
state_.append(mx::random::key(seed));
}
void seed(uint64_t seed) {
state_[0] = key(seed);
state_[0] = mx::random::key(seed);
}
array next() {
auto out = split(nb::cast<array>(state_[0]));
mx::array next() {
auto out = mx::random::split(nb::cast<mx::array>(state_[0]));
state_[0] = out.first;
return out.second;
}
@@ -75,7 +74,7 @@ void init_random(nb::module_& parent_module) {
)pbdoc");
m.def(
"key",
&key,
&mx::random::key,
"seed"_a,
R"pbdoc(
Get a PRNG key from a seed.
@@ -88,7 +87,8 @@ void init_random(nb::module_& parent_module) {
)pbdoc");
m.def(
"split",
nb::overload_cast<const array&, int, StreamOrDevice>(&random::split),
nb::overload_cast<const mx::array&, int, mx::StreamOrDevice>(
&mx::random::split),
"key"_a,
"num"_a = 2,
"stream"_a = nb::none(),
@@ -109,22 +109,22 @@ void init_random(nb::module_& parent_module) {
[](const ScalarOrArray& low,
const ScalarOrArray& high,
const std::vector<int>& shape,
std::optional<Dtype> type,
const std::optional<array>& key_,
StreamOrDevice s) {
std::optional<mx::Dtype> type,
const std::optional<mx::array>& key_,
mx::StreamOrDevice s) {
auto key = key_ ? key_.value() : default_key().next();
return uniform(
return mx::random::uniform(
to_array(low),
to_array(high),
shape,
type.value_or(float32),
type.value_or(mx::float32),
key,
s);
},
"low"_a = 0,
"high"_a = 1,
"shape"_a = std::vector<int>{},
"dtype"_a.none() = float32,
"dtype"_a.none() = mx::float32,
"key"_a = nb::none(),
"stream"_a = nb::none(),
nb::sig(
@@ -151,16 +151,17 @@ void init_random(nb::module_& parent_module) {
m.def(
"normal",
[](const std::vector<int>& shape,
std::optional<Dtype> type,
std::optional<mx::Dtype> type,
float loc,
float scale,
const std::optional<array>& key_,
StreamOrDevice s) {
const std::optional<mx::array>& key_,
mx::StreamOrDevice s) {
auto key = key_ ? key_.value() : default_key().next();
return normal(shape, type.value_or(float32), loc, scale, key, s);
return mx::random::normal(
shape, type.value_or(mx::float32), loc, scale, key, s);
},
"shape"_a = std::vector<int>{},
"dtype"_a.none() = float32,
"dtype"_a.none() = mx::float32,
"loc"_a = 0.0,
"scale"_a = 1.0,
"key"_a = nb::none(),
@@ -182,20 +183,20 @@ void init_random(nb::module_& parent_module) {
)pbdoc");
m.def(
"multivariate_normal",
[](const array& mean,
const array& cov,
[](const mx::array& mean,
const mx::array& cov,
const std::vector<int>& shape,
std::optional<Dtype> type,
const std::optional<array>& key_,
StreamOrDevice s) {
std::optional<mx::Dtype> type,
const std::optional<mx::array>& key_,
mx::StreamOrDevice s) {
auto key = key_ ? key_.value() : default_key().next();
return multivariate_normal(
mean, cov, shape, type.value_or(float32), key, s);
return mx::random::multivariate_normal(
mean, cov, shape, type.value_or(mx::float32), key, s);
},
"mean"_a,
"cov"_a,
"shape"_a = std::vector<int>{},
"dtype"_a.none() = float32,
"dtype"_a.none() = mx::float32,
"key"_a = nb::none(),
"stream"_a = nb::none(),
nb::sig(
@@ -227,17 +228,22 @@ void init_random(nb::module_& parent_module) {
[](const ScalarOrArray& low,
const ScalarOrArray& high,
const std::vector<int>& shape,
std::optional<Dtype> type,
const std::optional<array>& key_,
StreamOrDevice s) {
std::optional<mx::Dtype> type,
const std::optional<mx::array>& key_,
mx::StreamOrDevice s) {
auto key = key_ ? key_.value() : default_key().next();
return randint(
to_array(low), to_array(high), shape, type.value_or(int32), key, s);
return mx::random::randint(
to_array(low),
to_array(high),
shape,
type.value_or(mx::int32),
key,
s);
},
"low"_a,
"high"_a,
"shape"_a = std::vector<int>{},
"dtype"_a.none() = int32,
"dtype"_a.none() = mx::int32,
"key"_a = nb::none(),
"stream"_a = nb::none(),
nb::sig(
@@ -263,14 +269,14 @@ void init_random(nb::module_& parent_module) {
"bernoulli",
[](const ScalarOrArray& p_,
const std::optional<std::vector<int>> shape,
const std::optional<array>& key_,
StreamOrDevice s) {
const std::optional<mx::array>& key_,
mx::StreamOrDevice s) {
auto key = key_ ? key_.value() : default_key().next();
auto p = to_array(p_);
if (shape.has_value()) {
return bernoulli(p, shape.value(), key, s);
return mx::random::bernoulli(p, shape.value(), key, s);
} else {
return bernoulli(p, key, s);
return mx::random::bernoulli(p, key, s);
}
},
"p"_a = 0.5,
@@ -301,23 +307,24 @@ void init_random(nb::module_& parent_module) {
[](const ScalarOrArray& lower_,
const ScalarOrArray& upper_,
const std::optional<std::vector<int>> shape_,
std::optional<Dtype> type,
const std::optional<array>& key_,
StreamOrDevice s) {
std::optional<mx::Dtype> type,
const std::optional<mx::array>& key_,
mx::StreamOrDevice s) {
auto key = key_ ? key_.value() : default_key().next();
auto lower = to_array(lower_);
auto upper = to_array(upper_);
auto t = type.value_or(float32);
auto t = type.value_or(mx::float32);
if (shape_.has_value()) {
return truncated_normal(lower, upper, shape_.value(), t, key, s);
return mx::random::truncated_normal(
lower, upper, shape_.value(), t, key, s);
} else {
return truncated_normal(lower, upper, t, key, s);
return mx::random::truncated_normal(lower, upper, t, key, s);
}
},
"lower"_a,
"upper"_a,
"shape"_a = nb::none(),
"dtype"_a.none() = float32,
"dtype"_a.none() = mx::float32,
"key"_a = nb::none(),
"stream"_a = nb::none(),
nb::sig(
@@ -344,14 +351,14 @@ void init_random(nb::module_& parent_module) {
m.def(
"gumbel",
[](const std::vector<int>& shape,
std::optional<Dtype> type,
const std::optional<array>& key_,
StreamOrDevice s) {
std::optional<mx::Dtype> type,
const std::optional<mx::array>& key_,
mx::StreamOrDevice s) {
auto key = key_ ? key_.value() : default_key().next();
return gumbel(shape, type.value_or(float32), key, s);
return mx::random::gumbel(shape, type.value_or(mx::float32), key, s);
},
"shape"_a = std::vector<int>{},
"dtype"_a.none() = float32,
"dtype"_a.none() = mx::float32,
"key"_a = nb::none(),
"stream"_a = nb::none(),
nb::sig(
@@ -375,22 +382,23 @@ void init_random(nb::module_& parent_module) {
)pbdoc");
m.def(
"categorical",
[](const array& logits,
[](const mx::array& logits,
int axis,
const std::optional<std::vector<int>> shape,
const std::optional<int> num_samples,
const std::optional<array>& key_,
StreamOrDevice s) {
const std::optional<mx::array>& key_,
mx::StreamOrDevice s) {
auto key = key_ ? key_.value() : default_key().next();
if (shape.has_value() && num_samples.has_value()) {
throw std::invalid_argument(
"[categorical] At most one of shape or num_samples can be specified.");
} else if (shape.has_value()) {
return categorical(logits, axis, shape.value(), key, s);
return mx::random::categorical(logits, axis, shape.value(), key, s);
} else if (num_samples.has_value()) {
return categorical(logits, axis, num_samples.value(), key, s);
return mx::random::categorical(
logits, axis, num_samples.value(), key, s);
} else {
return categorical(logits, axis, key, s);
return mx::random::categorical(logits, axis, key, s);
}
},
"logits"_a,
@@ -427,16 +435,17 @@ void init_random(nb::module_& parent_module) {
m.def(
"laplace",
[](const std::vector<int>& shape,
std::optional<Dtype> type,
std::optional<mx::Dtype> type,
float loc,
float scale,
const std::optional<array>& key_,
StreamOrDevice s) {
const std::optional<mx::array>& key_,
mx::StreamOrDevice s) {
auto key = key_ ? key_.value() : default_key().next();
return laplace(shape, type.value_or(float32), loc, scale, key, s);
return mx::random::laplace(
shape, type.value_or(mx::float32), loc, scale, key, s);
},
"shape"_a = std::vector<int>{},
"dtype"_a.none() = float32,
"dtype"_a.none() = mx::float32,
"loc"_a = 0.0,
"scale"_a = 1.0,
"key"_a = nb::none(),
@@ -459,15 +468,15 @@ void init_random(nb::module_& parent_module) {
)pbdoc");
m.def(
"permuation",
[](const std::variant<nb::int_, array>& x,
[](const std::variant<nb::int_, mx::array>& x,
int axis,
const std::optional<array>& key_,
StreamOrDevice s) {
const std::optional<mx::array>& key_,
mx::StreamOrDevice s) {
auto key = key_ ? key_.value() : default_key().next();
if (auto pv = std::get_if<nb::int_>(&x); pv) {
return permutation(nb::cast<int>(*pv), key, s);
return mx::random::permutation(nb::cast<int>(*pv), key, s);
} else {
return permutation(std::get<array>(x), axis, key, s);
return mx::random::permutation(std::get<mx::array>(x), axis, key, s);
}
},
"shape"_a = std::vector<int>{},