This commit is contained in:
Ronan Collobert
2025-10-29 16:51:05 -07:00
parent 3d67b717a0
commit 53525cba23
2 changed files with 49 additions and 42 deletions

View File

@@ -46,8 +46,9 @@ class RMSNorm : public Custom {
static bool use_fallback(Stream stream);
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override {
void eval_cpu(
const std::vector<array>& /* inputs */,
std::vector<array>& /* outputs */) override {
throw std::runtime_error("NYI");
}
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
@@ -79,8 +80,9 @@ class RMSNormVJP : public Custom {
float eps)
: Custom(stream, fallback), eps_(eps) {}
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override {
void eval_cpu(
const std::vector<array>& /* inputs */,
std::vector<array>& /* outputs */) override {
throw std::runtime_error("NYI");
}
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
@@ -106,8 +108,9 @@ class LayerNorm : public Custom {
static bool use_fallback(Stream s);
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override {
void eval_cpu(
const std::vector<array>& /* inputs */,
std::vector<array>& /* outputs */) override {
throw std::runtime_error("NYI");
}
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
@@ -138,8 +141,9 @@ class LayerNormVJP : public Custom {
float eps)
: Custom(stream, fallback), eps_(eps) {}
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override {
void eval_cpu(
const std::vector<array>& /* inputs */,
std::vector<array>& /* outputs */) override {
throw std::runtime_error("NYI");
}
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
@@ -174,8 +178,9 @@ class RoPE : public Custom {
static bool use_fallback(Stream s);
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override {
void eval_cpu(
const std::vector<array>& /* inputs */,
std::vector<array>& /* outputs */) override {
throw std::runtime_error("NYI");
}
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
@@ -225,8 +230,9 @@ class ScaledDotProductAttention : public Custom {
bool do_causal,
Stream s);
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override {
void eval_cpu(
const std::vector<array>& /* inputs */,
std::vector<array>& /* outputs */) override {
throw std::runtime_error("NYI");
}
@@ -320,8 +326,9 @@ class CustomKernel : public Primitive {
is_precompiled_(is_precompiled),
shared_memory_(shared_memory) {}
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override {
void eval_cpu(
const std::vector<array>& /* inputs */,
std::vector<array>& /* outputs */) override {
throw std::runtime_error("Custom kernels only run on GPU.");
}

View File

@@ -89,7 +89,7 @@ array eval_impl(std::vector<array> outputs, bool async) {
auto& [a_ref, idx] = dfs.top();
auto& a = a_ref.get();
if (idx < a.inputs().size()) {
if (idx < std::ssize(a.inputs())) {
// Add an input, and continue
auto& in = a.inputs()[idx++];
@@ -146,16 +146,16 @@ array eval_impl(std::vector<array> outputs, bool async) {
int max_width = env::bfs_max_width();
dfs = std::stack<std::pair<std::reference_wrapper<array>, int>>();
tape.push_back(synchronizer);
for (int i = 0; !cache.empty() && (i < tape.size() || !dfs.empty());) {
auto& a = (i >= tape.size()) ? dfs.top().first.get() : tape[i];
for (int i = 0; !cache.empty() && (i < std::ssize(tape) || !dfs.empty());) {
auto& a = (i >= std::ssize(tape)) ? dfs.top().first.get() : tape[i];
int j = 0;
if (i >= tape.size()) {
if (i >= std::ssize(tape)) {
j = dfs.top().second;
dfs.pop();
} else {
i++;
}
for (; j < a.inputs().size(); ++j) {
for (; j < std::ssize(a.inputs()); ++j) {
auto& in = a.inputs()[j];
if (in.status() != array::Status::unscheduled) {
continue;
@@ -163,7 +163,7 @@ array eval_impl(std::vector<array> outputs, bool async) {
// If the width limit is exceeded, push the array on the stack
// and go down a level
if ((tape.size() - i) >= max_width) {
if ((std::ssize(tape) - i) >= max_width) {
dfs.emplace(a, j);
break;
}
@@ -343,14 +343,14 @@ std::pair<std::vector<array>, std::vector<array>> vjp(
// that have stop_gradient called on them
int cotan_index = 0;
std::vector<std::pair<int, int>> output_cotan_pairs;
for (int i = 0; i < outputs.size(); ++i) {
for (int i = 0; i < std::ssize(outputs); ++i) {
auto& out = outputs[i];
if (out.has_primitive()) {
if (auto& p = out.primitive(); typeid(p) == typeid(StopGradient)) {
continue;
}
}
if (cotan_index >= cotans.size()) {
if (cotan_index >= std::ssize(cotans)) {
std::ostringstream msg;
msg << "[vjp] Number of outputs to compute gradients for ("
<< outputs.size() << ") does not match number of cotangents ("
@@ -374,11 +374,11 @@ std::pair<std::vector<array>, std::vector<array>> vjp(
// to the tape which need a gradient.
std::unordered_set<std::uintptr_t> cache;
std::unordered_set<std::uintptr_t> calc_grad;
for (int i = 0, j = 0; i < primals_.size(); ++i) {
for (int i = 0, j = 0; i < std::ssize(primals_); ++i) {
auto& primal = primals_[i];
primal.set_tracer(false);
cache.insert(primal.id());
if (j < argnums.size() && argnums[j] == i) {
if (j < std::ssize(argnums) && argnums[j] == i) {
j++;
calc_grad.insert(primal.id());
}
@@ -440,7 +440,7 @@ std::pair<std::vector<array>, std::vector<array>> vjp(
// Get the arguments whose gradients are needed
std::vector<int> argnums;
for (int i = 0; i < a.inputs().size(); ++i) {
for (int i = 0; i < std::ssize(a.inputs()); ++i) {
if (calc_grad.find(a.inputs()[i].id()) != calc_grad.end()) {
argnums.push_back(i);
}
@@ -473,7 +473,7 @@ std::pair<std::vector<array>, std::vector<array>> vjp(
vjps = a.primitive().vjp(a.inputs(), cotangents, argnums, outputs);
}
// Accumulate the vector-jacobian products for each input
for (int i = 0; i < argnums.size(); ++i) {
for (int i = 0; i < std::ssize(argnums); ++i) {
auto in_id = a.inputs()[argnums[i]].id();
if (auto cotan_it = cotan_map.find(in_id); cotan_it != cotan_map.end()) {
cotan_it->second = add(cotan_it->second, vjps[i], s);
@@ -528,7 +528,7 @@ std::pair<std::vector<array>, std::vector<array>> jvp(
throw std::invalid_argument(
"[jvp] Number of inputs does not match number of tangents.");
}
for (int i = 0; i < primals.size(); ++i) {
for (int i = 0; i < std::ssize(primals); ++i) {
if (primals[i].shape() != tangents[i].shape()) {
throw std::invalid_argument(
"[jvp] Input shape does not match shape of tangent.");
@@ -597,7 +597,7 @@ std::pair<std::vector<array>, std::vector<array>> jvp(
}
std::unordered_map<std::uintptr_t, array> tan_map;
for (int i = 0; i < primals_.size(); ++i) {
for (int i = 0; i < std::ssize(primals_); ++i) {
tan_map.insert({primals_[i].id(), tangents[i]});
}
@@ -605,7 +605,7 @@ std::pair<std::vector<array>, std::vector<array>> jvp(
// Get the arguments used in the jvp
std::vector<int> argnums;
std::vector<array> tangents;
for (int i = 0; i < a.inputs().size(); ++i) {
for (int i = 0; i < std::ssize(a.inputs()); ++i) {
if (auto it = tan_map.find(a.inputs()[i].id()); it != tan_map.end()) {
argnums.push_back(i);
tangents.push_back(it->second);
@@ -614,7 +614,7 @@ std::pair<std::vector<array>, std::vector<array>> jvp(
auto jvps = a.primitive().jvp(a.inputs(), tangents, argnums);
auto outputs = a.outputs();
for (int i = 0; i < jvps.size(); ++i) {
for (int i = 0; i < std::ssize(jvps); ++i) {
tan_map.insert({outputs[i].id(), jvps[i]});
}
}
@@ -658,7 +658,7 @@ ValueAndGradFn value_and_grad(
throw std::invalid_argument(
"[grad] Repeat argument number not allowed in grad.");
}
if (*args.begin() < 0 || *args.rbegin() >= inputs.size()) {
if (*args.begin() < 0 || *args.rbegin() >= std::ssize(inputs)) {
std::ostringstream msg;
msg << "[grad] Invalid argument number for function with "
<< inputs.size() << " inputs.";
@@ -668,7 +668,7 @@ ValueAndGradFn value_and_grad(
auto gfun = [&fun](const std::vector<array>& inputs) {
auto outputs = fun(inputs);
for (int i = 1; i < outputs.size(); i++) {
for (int i = 1; i < std::ssize(outputs); i++) {
auto& out = outputs[i];
auto s = out.has_primitive() ? out.primitive().stream()
: default_stream(default_device());
@@ -701,7 +701,7 @@ std::pair<std::vector<array>, std::vector<array>> vmap_trace(
// Some error checking and get the vmap axis size
size_t vmap_ax_size;
for (int i = 0; i < inputs.size(); ++i) {
for (int i = 0; i < std::ssize(inputs); ++i) {
if (in_axes[i] != -1) {
if (inputs[i].ndim() == 0) {
throw std::invalid_argument(
@@ -717,7 +717,7 @@ std::pair<std::vector<array>, std::vector<array>> vmap_trace(
}
}
// Check that all vmapped axes have the same size
for (int i = 0; i < inputs.size(); ++i) {
for (int i = 0; i < std::ssize(inputs); ++i) {
if (in_axes[i] != -1) {
if (size_t in_ax = inputs[i].shape(in_axes[i]); vmap_ax_size != in_ax) {
std::ostringstream msg;
@@ -731,7 +731,7 @@ std::pair<std::vector<array>, std::vector<array>> vmap_trace(
// Run the function on placeholder inputs
// to get the original graph
std::vector<array> s_inputs;
for (int i = 0; i < inputs.size(); ++i) {
for (int i = 0; i < std::ssize(inputs); ++i) {
if (in_axes[i] != -1) {
auto shape = inputs[i].shape();
shape.erase(shape.begin() + in_axes[i]);
@@ -759,7 +759,7 @@ std::vector<array> vmap_replace(
}
int vmap_size = -1;
for (int i = 0; i < inputs.size(); ++i) {
for (int i = 0; i < std::ssize(inputs); ++i) {
if (in_axes[i] >= 0) {
vmap_size = inputs[i].shape(in_axes[i]);
break;
@@ -772,7 +772,7 @@ std::vector<array> vmap_replace(
std::unordered_map<std::uintptr_t, std::pair<array, int>> tmap;
std::unordered_set<std::uintptr_t> needs_vmap;
std::unordered_set<std::uintptr_t> cache;
for (int i = 0; i < s_inputs.size(); ++i) {
for (int i = 0; i < std::ssize(s_inputs); ++i) {
auto in = s_inputs[i];
if (in_axes[i] != -1) {
tmap.insert({in.id(), {inputs[i], in_axes[i]}});
@@ -843,7 +843,7 @@ std::vector<array> vmap_replace(
// For each primitive's outputs add its id, the vout id and the vax
auto outputs = a.outputs();
for (int i = 0; i < v_outputs.size(); ++i) {
for (int i = 0; i < std::ssize(v_outputs); ++i) {
tmap.insert({outputs[i].id(), {v_outputs[i], v_out_axes[i]}});
}
}
@@ -851,7 +851,7 @@ std::vector<array> vmap_replace(
// Populate the outputs and make sure all the output axes are
// in the right place
std::vector<array> outputs;
for (int i = 0; i < s_outputs.size(); ++i) {
for (int i = 0; i < std::ssize(s_outputs); ++i) {
if (auto map_it = tmap.find(s_outputs[i].id()); map_it != tmap.end()) {
auto& [out, vdim] = map_it->second;
if (vdim != out_axes[i]) {
@@ -995,7 +995,7 @@ std::function<std::vector<array>(const std::vector<array>&)> custom_function(
// using `fun` directly because we may not be able to fully reuse
// the outputs of the forward pass.
fun_vjp.value_or(
[fun](auto primals, auto cotangents, auto outputs) {
[fun](auto primals, auto cotangents, auto /* outputs */) {
auto [__, vjps] = vjp(fun, primals, cotangents);
return vjps;
}),
@@ -1009,8 +1009,8 @@ std::function<std::vector<array>(const std::vector<array>&)> custom_function(
// waste computation.
fun_jvp.value_or([fun](auto primals, auto tangents, auto argnums) {
std::vector<array> all_tangents;
for (int i = 0, j = 0; i < primals.size(); i++) {
if (j < argnums.size() && i == argnums[j]) {
for (int i = 0, j = 0; i < std::ssize(primals); i++) {
if (j < std::ssize(argnums) && i == argnums[j]) {
all_tangents.emplace_back(tangents[j++]);
} else {
all_tangents.emplace_back(zeros_like(primals[i]));