Compile with capture (#629)

* Simple kernel generation

* Remove the generate kernel from graph_utils

* fix multi-output with compile

* fuse with stopgrad

* v1 input, output capture in compile

* cleanup tree update with visitor update

* nit

* remove todo

* state for model, optional explicit init and more pure optimizer steps

* move learning rate to state

* add lr to opt state, some fixes in capture

* fix optim

* update tuple of containers as well

* fix stream for compiled output

* rng state for compile

* nit

* updates and comments

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
Awni Hannun
2024-02-07 17:29:22 -08:00
committed by GitHub
parent e5e816a5ef
commit 1b97b2958b
13 changed files with 723 additions and 157 deletions

View File

@@ -2,6 +2,7 @@
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <chrono>
#include "python/src/utils.h"
@@ -13,13 +14,55 @@ using namespace py::literals;
using namespace mlx::core;
using namespace mlx::core::random;
class PyKeySequence {
public:
explicit PyKeySequence(uint64_t seed) {
state_.append(key(seed));
}
void seed(uint64_t seed) {
state_[0] = key(seed);
}
array next() {
auto out = split(py::cast<array>(state_[0]));
state_[0] = out.first;
return out.second;
}
py::list state() {
return state_;
}
void release() {
py::gil_scoped_acquire gil;
state_.release().dec_ref();
}
private:
py::list state_;
};
PyKeySequence& default_key() {
auto get_current_time_seed = []() {
auto now = std::chrono::system_clock::now();
return std::chrono::duration_cast<std::chrono::milliseconds>(
now.time_since_epoch())
.count();
};
static PyKeySequence ks(get_current_time_seed());
return ks;
}
void init_random(py::module_& parent_module) {
auto m = parent_module.def_submodule(
"random",
"mlx.core.random: functionality related to random number generation");
m.attr("state") = default_key().state();
m.def(
"seed",
&seed,
[](uint64_t seed) { default_key().seed(seed); },
"seed"_a,
R"pbdoc(
Seed the global PRNG.
@@ -62,8 +105,9 @@ void init_random(py::module_& parent_module) {
const ScalarOrArray& high,
const std::vector<int>& shape,
std::optional<Dtype> type,
const std::optional<array>& key,
const std::optional<array>& key_,
StreamOrDevice s) {
auto key = key_ ? key_.value() : default_key().next();
return uniform(
to_array(low),
to_array(high),
@@ -101,11 +145,11 @@ void init_random(py::module_& parent_module) {
std::optional<Dtype> type,
float loc,
float scale,
const std::optional<array>& key,
const std::optional<array>& key_,
StreamOrDevice s) {
auto key = key_ ? key_.value() : default_key().next();
return normal(shape, type.value_or(float32), loc, scale, key, s);
},
"shape"_a = std::vector<int>{},
"dtype"_a = std::optional{float32},
"loc"_a = 0.0,
@@ -131,8 +175,9 @@ void init_random(py::module_& parent_module) {
const ScalarOrArray& high,
const std::vector<int>& shape,
std::optional<Dtype> type,
const std::optional<array>& key,
const std::optional<array>& key_,
StreamOrDevice s) {
auto key = key_ ? key_.value() : default_key().next();
return randint(
to_array(low), to_array(high), shape, type.value_or(int32), key, s);
},
@@ -163,8 +208,9 @@ void init_random(py::module_& parent_module) {
"bernoulli",
[](const ScalarOrArray& p_,
const std::optional<std::vector<int>> shape,
const std::optional<array>& key,
const std::optional<array>& key_,
StreamOrDevice s) {
auto key = key_ ? key_.value() : default_key().next();
auto p = to_array(p_);
if (shape.has_value()) {
return bernoulli(p, shape.value(), key, s);
@@ -199,8 +245,9 @@ void init_random(py::module_& parent_module) {
const ScalarOrArray& upper_,
const std::optional<std::vector<int>> shape_,
std::optional<Dtype> type,
const std::optional<array>& key,
const std::optional<array>& key_,
StreamOrDevice s) {
auto key = key_ ? key_.value() : default_key().next();
auto lower = to_array(lower_);
auto upper = to_array(upper_);
auto t = type.value_or(float32);
@@ -239,8 +286,9 @@ void init_random(py::module_& parent_module) {
"gumbel",
[](const std::vector<int>& shape,
std::optional<Dtype> type,
const std::optional<array>& key,
const std::optional<array>& key_,
StreamOrDevice s) {
auto key = key_ ? key_.value() : default_key().next();
return gumbel(shape, type.value_or(float32), key, s);
},
"shape"_a = std::vector<int>{},
@@ -267,8 +315,9 @@ void init_random(py::module_& parent_module) {
int axis,
const std::optional<std::vector<int>> shape,
const std::optional<int> num_samples,
const std::optional<array>& key,
const std::optional<array>& key_,
StreamOrDevice s) {
auto key = key_ ? key_.value() : default_key().next();
if (shape.has_value() && num_samples.has_value()) {
throw std::invalid_argument(
"[categorical] At most one of shape or num_samples can be specified.");
@@ -309,4 +358,7 @@ void init_random(py::module_& parent_module) {
Returns:
array: The ``shape``-sized output array with type ``uint32``.
)pbdoc");
// Register static Python object cleanup before the interpreter exits
auto atexit = py::module_::import("atexit");
atexit.attr("register")(py::cpp_function([]() { default_key().release(); }));
}