mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-15 17:39:05 +08:00
Compile with capture (#629)
* Simple kernel generation * Remove the generate kernel from graph_utils * fix multi-output with compile * fuse with stopgrad * v1 input, output capture in compile * cleanup tree update with visitor update * nit * remove todo * state for model, optional explicit init and more pure optimizer steps * move learning rate to state * add lr to opt state, some fixes in capture * fix optim * update tuple of containers as well * fix stream for compiled output * rng state for compile * nit * updates and comments --------- Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
#include <chrono>
|
||||
|
||||
#include "python/src/utils.h"
|
||||
|
||||
@@ -13,13 +14,55 @@ using namespace py::literals;
|
||||
using namespace mlx::core;
|
||||
using namespace mlx::core::random;
|
||||
|
||||
class PyKeySequence {
|
||||
public:
|
||||
explicit PyKeySequence(uint64_t seed) {
|
||||
state_.append(key(seed));
|
||||
}
|
||||
|
||||
void seed(uint64_t seed) {
|
||||
state_[0] = key(seed);
|
||||
}
|
||||
|
||||
array next() {
|
||||
auto out = split(py::cast<array>(state_[0]));
|
||||
state_[0] = out.first;
|
||||
return out.second;
|
||||
}
|
||||
|
||||
py::list state() {
|
||||
return state_;
|
||||
}
|
||||
|
||||
void release() {
|
||||
py::gil_scoped_acquire gil;
|
||||
state_.release().dec_ref();
|
||||
}
|
||||
|
||||
private:
|
||||
py::list state_;
|
||||
};
|
||||
|
||||
PyKeySequence& default_key() {
|
||||
auto get_current_time_seed = []() {
|
||||
auto now = std::chrono::system_clock::now();
|
||||
return std::chrono::duration_cast<std::chrono::milliseconds>(
|
||||
now.time_since_epoch())
|
||||
.count();
|
||||
};
|
||||
static PyKeySequence ks(get_current_time_seed());
|
||||
return ks;
|
||||
}
|
||||
|
||||
void init_random(py::module_& parent_module) {
|
||||
auto m = parent_module.def_submodule(
|
||||
"random",
|
||||
"mlx.core.random: functionality related to random number generation");
|
||||
|
||||
m.attr("state") = default_key().state();
|
||||
m.def(
|
||||
"seed",
|
||||
&seed,
|
||||
[](uint64_t seed) { default_key().seed(seed); },
|
||||
"seed"_a,
|
||||
R"pbdoc(
|
||||
Seed the global PRNG.
|
||||
@@ -62,8 +105,9 @@ void init_random(py::module_& parent_module) {
|
||||
const ScalarOrArray& high,
|
||||
const std::vector<int>& shape,
|
||||
std::optional<Dtype> type,
|
||||
const std::optional<array>& key,
|
||||
const std::optional<array>& key_,
|
||||
StreamOrDevice s) {
|
||||
auto key = key_ ? key_.value() : default_key().next();
|
||||
return uniform(
|
||||
to_array(low),
|
||||
to_array(high),
|
||||
@@ -101,11 +145,11 @@ void init_random(py::module_& parent_module) {
|
||||
std::optional<Dtype> type,
|
||||
float loc,
|
||||
float scale,
|
||||
const std::optional<array>& key,
|
||||
const std::optional<array>& key_,
|
||||
StreamOrDevice s) {
|
||||
auto key = key_ ? key_.value() : default_key().next();
|
||||
return normal(shape, type.value_or(float32), loc, scale, key, s);
|
||||
},
|
||||
|
||||
"shape"_a = std::vector<int>{},
|
||||
"dtype"_a = std::optional{float32},
|
||||
"loc"_a = 0.0,
|
||||
@@ -131,8 +175,9 @@ void init_random(py::module_& parent_module) {
|
||||
const ScalarOrArray& high,
|
||||
const std::vector<int>& shape,
|
||||
std::optional<Dtype> type,
|
||||
const std::optional<array>& key,
|
||||
const std::optional<array>& key_,
|
||||
StreamOrDevice s) {
|
||||
auto key = key_ ? key_.value() : default_key().next();
|
||||
return randint(
|
||||
to_array(low), to_array(high), shape, type.value_or(int32), key, s);
|
||||
},
|
||||
@@ -163,8 +208,9 @@ void init_random(py::module_& parent_module) {
|
||||
"bernoulli",
|
||||
[](const ScalarOrArray& p_,
|
||||
const std::optional<std::vector<int>> shape,
|
||||
const std::optional<array>& key,
|
||||
const std::optional<array>& key_,
|
||||
StreamOrDevice s) {
|
||||
auto key = key_ ? key_.value() : default_key().next();
|
||||
auto p = to_array(p_);
|
||||
if (shape.has_value()) {
|
||||
return bernoulli(p, shape.value(), key, s);
|
||||
@@ -199,8 +245,9 @@ void init_random(py::module_& parent_module) {
|
||||
const ScalarOrArray& upper_,
|
||||
const std::optional<std::vector<int>> shape_,
|
||||
std::optional<Dtype> type,
|
||||
const std::optional<array>& key,
|
||||
const std::optional<array>& key_,
|
||||
StreamOrDevice s) {
|
||||
auto key = key_ ? key_.value() : default_key().next();
|
||||
auto lower = to_array(lower_);
|
||||
auto upper = to_array(upper_);
|
||||
auto t = type.value_or(float32);
|
||||
@@ -239,8 +286,9 @@ void init_random(py::module_& parent_module) {
|
||||
"gumbel",
|
||||
[](const std::vector<int>& shape,
|
||||
std::optional<Dtype> type,
|
||||
const std::optional<array>& key,
|
||||
const std::optional<array>& key_,
|
||||
StreamOrDevice s) {
|
||||
auto key = key_ ? key_.value() : default_key().next();
|
||||
return gumbel(shape, type.value_or(float32), key, s);
|
||||
},
|
||||
"shape"_a = std::vector<int>{},
|
||||
@@ -267,8 +315,9 @@ void init_random(py::module_& parent_module) {
|
||||
int axis,
|
||||
const std::optional<std::vector<int>> shape,
|
||||
const std::optional<int> num_samples,
|
||||
const std::optional<array>& key,
|
||||
const std::optional<array>& key_,
|
||||
StreamOrDevice s) {
|
||||
auto key = key_ ? key_.value() : default_key().next();
|
||||
if (shape.has_value() && num_samples.has_value()) {
|
||||
throw std::invalid_argument(
|
||||
"[categorical] At most one of shape or num_samples can be specified.");
|
||||
@@ -309,4 +358,7 @@ void init_random(py::module_& parent_module) {
|
||||
Returns:
|
||||
array: The ``shape``-sized output array with type ``uint32``.
|
||||
)pbdoc");
|
||||
// Register static Python object cleanup before the interpreter exits
|
||||
auto atexit = py::module_::import("atexit");
|
||||
atexit.attr("register")(py::cpp_function([]() { default_key().release(); }));
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user