mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
WIP
This commit is contained in:
@@ -46,8 +46,9 @@ class RMSNorm : public Custom {
|
||||
|
||||
static bool use_fallback(Stream stream);
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override {
|
||||
void eval_cpu(
|
||||
const std::vector<array>& /* inputs */,
|
||||
std::vector<array>& /* outputs */) override {
|
||||
throw std::runtime_error("NYI");
|
||||
}
|
||||
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
@@ -79,8 +80,9 @@ class RMSNormVJP : public Custom {
|
||||
float eps)
|
||||
: Custom(stream, fallback), eps_(eps) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override {
|
||||
void eval_cpu(
|
||||
const std::vector<array>& /* inputs */,
|
||||
std::vector<array>& /* outputs */) override {
|
||||
throw std::runtime_error("NYI");
|
||||
}
|
||||
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
@@ -106,8 +108,9 @@ class LayerNorm : public Custom {
|
||||
|
||||
static bool use_fallback(Stream s);
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override {
|
||||
void eval_cpu(
|
||||
const std::vector<array>& /* inputs */,
|
||||
std::vector<array>& /* outputs */) override {
|
||||
throw std::runtime_error("NYI");
|
||||
}
|
||||
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
@@ -138,8 +141,9 @@ class LayerNormVJP : public Custom {
|
||||
float eps)
|
||||
: Custom(stream, fallback), eps_(eps) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override {
|
||||
void eval_cpu(
|
||||
const std::vector<array>& /* inputs */,
|
||||
std::vector<array>& /* outputs */) override {
|
||||
throw std::runtime_error("NYI");
|
||||
}
|
||||
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
@@ -174,8 +178,9 @@ class RoPE : public Custom {
|
||||
|
||||
static bool use_fallback(Stream s);
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override {
|
||||
void eval_cpu(
|
||||
const std::vector<array>& /* inputs */,
|
||||
std::vector<array>& /* outputs */) override {
|
||||
throw std::runtime_error("NYI");
|
||||
}
|
||||
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
@@ -225,8 +230,9 @@ class ScaledDotProductAttention : public Custom {
|
||||
bool do_causal,
|
||||
Stream s);
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override {
|
||||
void eval_cpu(
|
||||
const std::vector<array>& /* inputs */,
|
||||
std::vector<array>& /* outputs */) override {
|
||||
throw std::runtime_error("NYI");
|
||||
}
|
||||
|
||||
@@ -320,8 +326,9 @@ class CustomKernel : public Primitive {
|
||||
is_precompiled_(is_precompiled),
|
||||
shared_memory_(shared_memory) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override {
|
||||
void eval_cpu(
|
||||
const std::vector<array>& /* inputs */,
|
||||
std::vector<array>& /* outputs */) override {
|
||||
throw std::runtime_error("Custom kernels only run on GPU.");
|
||||
}
|
||||
|
||||
|
||||
@@ -89,7 +89,7 @@ array eval_impl(std::vector<array> outputs, bool async) {
|
||||
auto& [a_ref, idx] = dfs.top();
|
||||
auto& a = a_ref.get();
|
||||
|
||||
if (idx < a.inputs().size()) {
|
||||
if (idx < std::ssize(a.inputs())) {
|
||||
// Add an input, and continue
|
||||
auto& in = a.inputs()[idx++];
|
||||
|
||||
@@ -146,16 +146,16 @@ array eval_impl(std::vector<array> outputs, bool async) {
|
||||
int max_width = env::bfs_max_width();
|
||||
dfs = std::stack<std::pair<std::reference_wrapper<array>, int>>();
|
||||
tape.push_back(synchronizer);
|
||||
for (int i = 0; !cache.empty() && (i < tape.size() || !dfs.empty());) {
|
||||
auto& a = (i >= tape.size()) ? dfs.top().first.get() : tape[i];
|
||||
for (int i = 0; !cache.empty() && (i < std::ssize(tape) || !dfs.empty());) {
|
||||
auto& a = (i >= std::ssize(tape)) ? dfs.top().first.get() : tape[i];
|
||||
int j = 0;
|
||||
if (i >= tape.size()) {
|
||||
if (i >= std::ssize(tape)) {
|
||||
j = dfs.top().second;
|
||||
dfs.pop();
|
||||
} else {
|
||||
i++;
|
||||
}
|
||||
for (; j < a.inputs().size(); ++j) {
|
||||
for (; j < std::ssize(a.inputs()); ++j) {
|
||||
auto& in = a.inputs()[j];
|
||||
if (in.status() != array::Status::unscheduled) {
|
||||
continue;
|
||||
@@ -163,7 +163,7 @@ array eval_impl(std::vector<array> outputs, bool async) {
|
||||
|
||||
// If the width limit is exceeded, push the array on the stack
|
||||
// and go down a level
|
||||
if ((tape.size() - i) >= max_width) {
|
||||
if ((std::ssize(tape) - i) >= max_width) {
|
||||
dfs.emplace(a, j);
|
||||
break;
|
||||
}
|
||||
@@ -343,14 +343,14 @@ std::pair<std::vector<array>, std::vector<array>> vjp(
|
||||
// that have stop_gradient called on them
|
||||
int cotan_index = 0;
|
||||
std::vector<std::pair<int, int>> output_cotan_pairs;
|
||||
for (int i = 0; i < outputs.size(); ++i) {
|
||||
for (int i = 0; i < std::ssize(outputs); ++i) {
|
||||
auto& out = outputs[i];
|
||||
if (out.has_primitive()) {
|
||||
if (auto& p = out.primitive(); typeid(p) == typeid(StopGradient)) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
if (cotan_index >= cotans.size()) {
|
||||
if (cotan_index >= std::ssize(cotans)) {
|
||||
std::ostringstream msg;
|
||||
msg << "[vjp] Number of outputs to compute gradients for ("
|
||||
<< outputs.size() << ") does not match number of cotangents ("
|
||||
@@ -374,11 +374,11 @@ std::pair<std::vector<array>, std::vector<array>> vjp(
|
||||
// to the tape which need a gradient.
|
||||
std::unordered_set<std::uintptr_t> cache;
|
||||
std::unordered_set<std::uintptr_t> calc_grad;
|
||||
for (int i = 0, j = 0; i < primals_.size(); ++i) {
|
||||
for (int i = 0, j = 0; i < std::ssize(primals_); ++i) {
|
||||
auto& primal = primals_[i];
|
||||
primal.set_tracer(false);
|
||||
cache.insert(primal.id());
|
||||
if (j < argnums.size() && argnums[j] == i) {
|
||||
if (j < std::ssize(argnums) && argnums[j] == i) {
|
||||
j++;
|
||||
calc_grad.insert(primal.id());
|
||||
}
|
||||
@@ -440,7 +440,7 @@ std::pair<std::vector<array>, std::vector<array>> vjp(
|
||||
|
||||
// Get the arguments whose gradients are needed
|
||||
std::vector<int> argnums;
|
||||
for (int i = 0; i < a.inputs().size(); ++i) {
|
||||
for (int i = 0; i < std::ssize(a.inputs()); ++i) {
|
||||
if (calc_grad.find(a.inputs()[i].id()) != calc_grad.end()) {
|
||||
argnums.push_back(i);
|
||||
}
|
||||
@@ -473,7 +473,7 @@ std::pair<std::vector<array>, std::vector<array>> vjp(
|
||||
vjps = a.primitive().vjp(a.inputs(), cotangents, argnums, outputs);
|
||||
}
|
||||
// Accumulate the vector-jacobian products for each input
|
||||
for (int i = 0; i < argnums.size(); ++i) {
|
||||
for (int i = 0; i < std::ssize(argnums); ++i) {
|
||||
auto in_id = a.inputs()[argnums[i]].id();
|
||||
if (auto cotan_it = cotan_map.find(in_id); cotan_it != cotan_map.end()) {
|
||||
cotan_it->second = add(cotan_it->second, vjps[i], s);
|
||||
@@ -528,7 +528,7 @@ std::pair<std::vector<array>, std::vector<array>> jvp(
|
||||
throw std::invalid_argument(
|
||||
"[jvp] Number of inputs does not match number of tangents.");
|
||||
}
|
||||
for (int i = 0; i < primals.size(); ++i) {
|
||||
for (int i = 0; i < std::ssize(primals); ++i) {
|
||||
if (primals[i].shape() != tangents[i].shape()) {
|
||||
throw std::invalid_argument(
|
||||
"[jvp] Input shape does not match shape of tangent.");
|
||||
@@ -597,7 +597,7 @@ std::pair<std::vector<array>, std::vector<array>> jvp(
|
||||
}
|
||||
|
||||
std::unordered_map<std::uintptr_t, array> tan_map;
|
||||
for (int i = 0; i < primals_.size(); ++i) {
|
||||
for (int i = 0; i < std::ssize(primals_); ++i) {
|
||||
tan_map.insert({primals_[i].id(), tangents[i]});
|
||||
}
|
||||
|
||||
@@ -605,7 +605,7 @@ std::pair<std::vector<array>, std::vector<array>> jvp(
|
||||
// Get the arguments used in the jvp
|
||||
std::vector<int> argnums;
|
||||
std::vector<array> tangents;
|
||||
for (int i = 0; i < a.inputs().size(); ++i) {
|
||||
for (int i = 0; i < std::ssize(a.inputs()); ++i) {
|
||||
if (auto it = tan_map.find(a.inputs()[i].id()); it != tan_map.end()) {
|
||||
argnums.push_back(i);
|
||||
tangents.push_back(it->second);
|
||||
@@ -614,7 +614,7 @@ std::pair<std::vector<array>, std::vector<array>> jvp(
|
||||
|
||||
auto jvps = a.primitive().jvp(a.inputs(), tangents, argnums);
|
||||
auto outputs = a.outputs();
|
||||
for (int i = 0; i < jvps.size(); ++i) {
|
||||
for (int i = 0; i < std::ssize(jvps); ++i) {
|
||||
tan_map.insert({outputs[i].id(), jvps[i]});
|
||||
}
|
||||
}
|
||||
@@ -658,7 +658,7 @@ ValueAndGradFn value_and_grad(
|
||||
throw std::invalid_argument(
|
||||
"[grad] Repeat argument number not allowed in grad.");
|
||||
}
|
||||
if (*args.begin() < 0 || *args.rbegin() >= inputs.size()) {
|
||||
if (*args.begin() < 0 || *args.rbegin() >= std::ssize(inputs)) {
|
||||
std::ostringstream msg;
|
||||
msg << "[grad] Invalid argument number for function with "
|
||||
<< inputs.size() << " inputs.";
|
||||
@@ -668,7 +668,7 @@ ValueAndGradFn value_and_grad(
|
||||
|
||||
auto gfun = [&fun](const std::vector<array>& inputs) {
|
||||
auto outputs = fun(inputs);
|
||||
for (int i = 1; i < outputs.size(); i++) {
|
||||
for (int i = 1; i < std::ssize(outputs); i++) {
|
||||
auto& out = outputs[i];
|
||||
auto s = out.has_primitive() ? out.primitive().stream()
|
||||
: default_stream(default_device());
|
||||
@@ -701,7 +701,7 @@ std::pair<std::vector<array>, std::vector<array>> vmap_trace(
|
||||
|
||||
// Some error checking and get the vmap axis size
|
||||
size_t vmap_ax_size;
|
||||
for (int i = 0; i < inputs.size(); ++i) {
|
||||
for (int i = 0; i < std::ssize(inputs); ++i) {
|
||||
if (in_axes[i] != -1) {
|
||||
if (inputs[i].ndim() == 0) {
|
||||
throw std::invalid_argument(
|
||||
@@ -717,7 +717,7 @@ std::pair<std::vector<array>, std::vector<array>> vmap_trace(
|
||||
}
|
||||
}
|
||||
// Check that all vmapped axes have the same size
|
||||
for (int i = 0; i < inputs.size(); ++i) {
|
||||
for (int i = 0; i < std::ssize(inputs); ++i) {
|
||||
if (in_axes[i] != -1) {
|
||||
if (size_t in_ax = inputs[i].shape(in_axes[i]); vmap_ax_size != in_ax) {
|
||||
std::ostringstream msg;
|
||||
@@ -731,7 +731,7 @@ std::pair<std::vector<array>, std::vector<array>> vmap_trace(
|
||||
// Run the function on placeholder inputs
|
||||
// to get the original graph
|
||||
std::vector<array> s_inputs;
|
||||
for (int i = 0; i < inputs.size(); ++i) {
|
||||
for (int i = 0; i < std::ssize(inputs); ++i) {
|
||||
if (in_axes[i] != -1) {
|
||||
auto shape = inputs[i].shape();
|
||||
shape.erase(shape.begin() + in_axes[i]);
|
||||
@@ -759,7 +759,7 @@ std::vector<array> vmap_replace(
|
||||
}
|
||||
|
||||
int vmap_size = -1;
|
||||
for (int i = 0; i < inputs.size(); ++i) {
|
||||
for (int i = 0; i < std::ssize(inputs); ++i) {
|
||||
if (in_axes[i] >= 0) {
|
||||
vmap_size = inputs[i].shape(in_axes[i]);
|
||||
break;
|
||||
@@ -772,7 +772,7 @@ std::vector<array> vmap_replace(
|
||||
std::unordered_map<std::uintptr_t, std::pair<array, int>> tmap;
|
||||
std::unordered_set<std::uintptr_t> needs_vmap;
|
||||
std::unordered_set<std::uintptr_t> cache;
|
||||
for (int i = 0; i < s_inputs.size(); ++i) {
|
||||
for (int i = 0; i < std::ssize(s_inputs); ++i) {
|
||||
auto in = s_inputs[i];
|
||||
if (in_axes[i] != -1) {
|
||||
tmap.insert({in.id(), {inputs[i], in_axes[i]}});
|
||||
@@ -843,7 +843,7 @@ std::vector<array> vmap_replace(
|
||||
|
||||
// For each primitive's outputs add its id, the vout id and the vax
|
||||
auto outputs = a.outputs();
|
||||
for (int i = 0; i < v_outputs.size(); ++i) {
|
||||
for (int i = 0; i < std::ssize(v_outputs); ++i) {
|
||||
tmap.insert({outputs[i].id(), {v_outputs[i], v_out_axes[i]}});
|
||||
}
|
||||
}
|
||||
@@ -851,7 +851,7 @@ std::vector<array> vmap_replace(
|
||||
// Populate the outputs and make sure all the output axes are
|
||||
// in the right place
|
||||
std::vector<array> outputs;
|
||||
for (int i = 0; i < s_outputs.size(); ++i) {
|
||||
for (int i = 0; i < std::ssize(s_outputs); ++i) {
|
||||
if (auto map_it = tmap.find(s_outputs[i].id()); map_it != tmap.end()) {
|
||||
auto& [out, vdim] = map_it->second;
|
||||
if (vdim != out_axes[i]) {
|
||||
@@ -995,7 +995,7 @@ std::function<std::vector<array>(const std::vector<array>&)> custom_function(
|
||||
// using `fun` directly because we may not be able to fully reuse
|
||||
// the outputs of the forward pass.
|
||||
fun_vjp.value_or(
|
||||
[fun](auto primals, auto cotangents, auto outputs) {
|
||||
[fun](auto primals, auto cotangents, auto /* outputs */) {
|
||||
auto [__, vjps] = vjp(fun, primals, cotangents);
|
||||
return vjps;
|
||||
}),
|
||||
@@ -1009,8 +1009,8 @@ std::function<std::vector<array>(const std::vector<array>&)> custom_function(
|
||||
// waste computation.
|
||||
fun_jvp.value_or([fun](auto primals, auto tangents, auto argnums) {
|
||||
std::vector<array> all_tangents;
|
||||
for (int i = 0, j = 0; i < primals.size(); i++) {
|
||||
if (j < argnums.size() && i == argnums[j]) {
|
||||
for (int i = 0, j = 0; i < std::ssize(primals); i++) {
|
||||
if (j < std::ssize(argnums) && i == argnums[j]) {
|
||||
all_tangents.emplace_back(tangents[j++]);
|
||||
} else {
|
||||
all_tangents.emplace_back(zeros_like(primals[i]));
|
||||
|
||||
Reference in New Issue
Block a user