diff --git a/mlx/fast_primitives.h b/mlx/fast_primitives.h index a8000485a..52662158b 100644 --- a/mlx/fast_primitives.h +++ b/mlx/fast_primitives.h @@ -46,8 +46,9 @@ class RMSNorm : public Custom { static bool use_fallback(Stream stream); - void eval_cpu(const std::vector& inputs, std::vector& outputs) - override { + void eval_cpu( + const std::vector& /* inputs */, + std::vector& /* outputs */) override { throw std::runtime_error("NYI"); } void eval_gpu(const std::vector& inputs, std::vector& outputs) @@ -79,8 +80,9 @@ class RMSNormVJP : public Custom { float eps) : Custom(stream, fallback), eps_(eps) {} - void eval_cpu(const std::vector& inputs, std::vector& outputs) - override { + void eval_cpu( + const std::vector& /* inputs */, + std::vector& /* outputs */) override { throw std::runtime_error("NYI"); } void eval_gpu(const std::vector& inputs, std::vector& outputs) @@ -106,8 +108,9 @@ class LayerNorm : public Custom { static bool use_fallback(Stream s); - void eval_cpu(const std::vector& inputs, std::vector& outputs) - override { + void eval_cpu( + const std::vector& /* inputs */, + std::vector& /* outputs */) override { throw std::runtime_error("NYI"); } void eval_gpu(const std::vector& inputs, std::vector& outputs) @@ -138,8 +141,9 @@ class LayerNormVJP : public Custom { float eps) : Custom(stream, fallback), eps_(eps) {} - void eval_cpu(const std::vector& inputs, std::vector& outputs) - override { + void eval_cpu( + const std::vector& /* inputs */, + std::vector& /* outputs */) override { throw std::runtime_error("NYI"); } void eval_gpu(const std::vector& inputs, std::vector& outputs) @@ -174,8 +178,9 @@ class RoPE : public Custom { static bool use_fallback(Stream s); - void eval_cpu(const std::vector& inputs, std::vector& outputs) - override { + void eval_cpu( + const std::vector& /* inputs */, + std::vector& /* outputs */) override { throw std::runtime_error("NYI"); } void eval_gpu(const std::vector& inputs, std::vector& outputs) @@ -225,8 +230,9 @@ class ScaledDotProductAttention : public Custom { bool do_causal, Stream s); - void eval_cpu(const std::vector& inputs, std::vector& outputs) - override { + void eval_cpu( + const std::vector& /* inputs */, + std::vector& /* outputs */) override { throw std::runtime_error("NYI"); } @@ -320,8 +326,9 @@ class CustomKernel : public Primitive { is_precompiled_(is_precompiled), shared_memory_(shared_memory) {} - void eval_cpu(const std::vector& inputs, std::vector& outputs) - override { + void eval_cpu( + const std::vector& /* inputs */, + std::vector& /* outputs */) override { throw std::runtime_error("Custom kernels only run on GPU."); } diff --git a/mlx/transforms.cpp b/mlx/transforms.cpp index 3ec64feea..ec58fd2bd 100644 --- a/mlx/transforms.cpp +++ b/mlx/transforms.cpp @@ -89,7 +89,7 @@ array eval_impl(std::vector outputs, bool async) { auto& [a_ref, idx] = dfs.top(); auto& a = a_ref.get(); - if (idx < a.inputs().size()) { + if (idx < std::ssize(a.inputs())) { // Add an input, and continue auto& in = a.inputs()[idx++]; @@ -146,16 +146,16 @@ array eval_impl(std::vector outputs, bool async) { int max_width = env::bfs_max_width(); dfs = std::stack, int>>(); tape.push_back(synchronizer); - for (int i = 0; !cache.empty() && (i < tape.size() || !dfs.empty());) { - auto& a = (i >= tape.size()) ? dfs.top().first.get() : tape[i]; + for (int i = 0; !cache.empty() && (i < std::ssize(tape) || !dfs.empty());) { + auto& a = (i >= std::ssize(tape)) ? dfs.top().first.get() : tape[i]; int j = 0; - if (i >= tape.size()) { + if (i >= std::ssize(tape)) { j = dfs.top().second; dfs.pop(); } else { i++; } - for (; j < a.inputs().size(); ++j) { + for (; j < std::ssize(a.inputs()); ++j) { auto& in = a.inputs()[j]; if (in.status() != array::Status::unscheduled) { continue; @@ -163,7 +163,7 @@ array eval_impl(std::vector outputs, bool async) { // If the width limit is exceeded, push the array on the stack // and go down a level - if ((tape.size() - i) >= max_width) { + if ((std::ssize(tape) - i) >= max_width) { dfs.emplace(a, j); break; } @@ -343,14 +343,14 @@ std::pair, std::vector> vjp( // that have stop_gradient called on them int cotan_index = 0; std::vector> output_cotan_pairs; - for (int i = 0; i < outputs.size(); ++i) { + for (int i = 0; i < std::ssize(outputs); ++i) { auto& out = outputs[i]; if (out.has_primitive()) { if (auto& p = out.primitive(); typeid(p) == typeid(StopGradient)) { continue; } } - if (cotan_index >= cotans.size()) { + if (cotan_index >= std::ssize(cotans)) { std::ostringstream msg; msg << "[vjp] Number of outputs to compute gradients for (" << outputs.size() << ") does not match number of cotangents (" @@ -374,11 +374,11 @@ std::pair, std::vector> vjp( // to the tape which need a gradient. std::unordered_set cache; std::unordered_set calc_grad; - for (int i = 0, j = 0; i < primals_.size(); ++i) { + for (int i = 0, j = 0; i < std::ssize(primals_); ++i) { auto& primal = primals_[i]; primal.set_tracer(false); cache.insert(primal.id()); - if (j < argnums.size() && argnums[j] == i) { + if (j < std::ssize(argnums) && argnums[j] == i) { j++; calc_grad.insert(primal.id()); } @@ -440,7 +440,7 @@ std::pair, std::vector> vjp( // Get the arguments whose gradients are needed std::vector argnums; - for (int i = 0; i < a.inputs().size(); ++i) { + for (int i = 0; i < std::ssize(a.inputs()); ++i) { if (calc_grad.find(a.inputs()[i].id()) != calc_grad.end()) { argnums.push_back(i); } @@ -473,7 +473,7 @@ std::pair, std::vector> vjp( vjps = a.primitive().vjp(a.inputs(), cotangents, argnums, outputs); } // Accumulate the vector-jacobian products for each input - for (int i = 0; i < argnums.size(); ++i) { + for (int i = 0; i < std::ssize(argnums); ++i) { auto in_id = a.inputs()[argnums[i]].id(); if (auto cotan_it = cotan_map.find(in_id); cotan_it != cotan_map.end()) { cotan_it->second = add(cotan_it->second, vjps[i], s); @@ -528,7 +528,7 @@ std::pair, std::vector> jvp( throw std::invalid_argument( "[jvp] Number of inputs does not match number of tangents."); } - for (int i = 0; i < primals.size(); ++i) { + for (int i = 0; i < std::ssize(primals); ++i) { if (primals[i].shape() != tangents[i].shape()) { throw std::invalid_argument( "[jvp] Input shape does not match shape of tangent."); @@ -597,7 +597,7 @@ std::pair, std::vector> jvp( } std::unordered_map tan_map; - for (int i = 0; i < primals_.size(); ++i) { + for (int i = 0; i < std::ssize(primals_); ++i) { tan_map.insert({primals_[i].id(), tangents[i]}); } @@ -605,7 +605,7 @@ std::pair, std::vector> jvp( // Get the arguments used in the jvp std::vector argnums; std::vector tangents; - for (int i = 0; i < a.inputs().size(); ++i) { + for (int i = 0; i < std::ssize(a.inputs()); ++i) { if (auto it = tan_map.find(a.inputs()[i].id()); it != tan_map.end()) { argnums.push_back(i); tangents.push_back(it->second); @@ -614,7 +614,7 @@ std::pair, std::vector> jvp( auto jvps = a.primitive().jvp(a.inputs(), tangents, argnums); auto outputs = a.outputs(); - for (int i = 0; i < jvps.size(); ++i) { + for (int i = 0; i < std::ssize(jvps); ++i) { tan_map.insert({outputs[i].id(), jvps[i]}); } } @@ -658,7 +658,7 @@ ValueAndGradFn value_and_grad( throw std::invalid_argument( "[grad] Repeat argument number not allowed in grad."); } - if (*args.begin() < 0 || *args.rbegin() >= inputs.size()) { + if (*args.begin() < 0 || *args.rbegin() >= std::ssize(inputs)) { std::ostringstream msg; msg << "[grad] Invalid argument number for function with " << inputs.size() << " inputs."; @@ -668,7 +668,7 @@ ValueAndGradFn value_and_grad( auto gfun = [&fun](const std::vector& inputs) { auto outputs = fun(inputs); - for (int i = 1; i < outputs.size(); i++) { + for (int i = 1; i < std::ssize(outputs); i++) { auto& out = outputs[i]; auto s = out.has_primitive() ? out.primitive().stream() : default_stream(default_device()); @@ -701,7 +701,7 @@ std::pair, std::vector> vmap_trace( // Some error checking and get the vmap axis size size_t vmap_ax_size; - for (int i = 0; i < inputs.size(); ++i) { + for (int i = 0; i < std::ssize(inputs); ++i) { if (in_axes[i] != -1) { if (inputs[i].ndim() == 0) { throw std::invalid_argument( @@ -717,7 +717,7 @@ std::pair, std::vector> vmap_trace( } } // Check that all vmapped axes have the same size - for (int i = 0; i < inputs.size(); ++i) { + for (int i = 0; i < std::ssize(inputs); ++i) { if (in_axes[i] != -1) { if (size_t in_ax = inputs[i].shape(in_axes[i]); vmap_ax_size != in_ax) { std::ostringstream msg; @@ -731,7 +731,7 @@ std::pair, std::vector> vmap_trace( // Run the function on placeholder inputs // to get the original graph std::vector s_inputs; - for (int i = 0; i < inputs.size(); ++i) { + for (int i = 0; i < std::ssize(inputs); ++i) { if (in_axes[i] != -1) { auto shape = inputs[i].shape(); shape.erase(shape.begin() + in_axes[i]); @@ -759,7 +759,7 @@ std::vector vmap_replace( } int vmap_size = -1; - for (int i = 0; i < inputs.size(); ++i) { + for (int i = 0; i < std::ssize(inputs); ++i) { if (in_axes[i] >= 0) { vmap_size = inputs[i].shape(in_axes[i]); break; @@ -772,7 +772,7 @@ std::vector vmap_replace( std::unordered_map> tmap; std::unordered_set needs_vmap; std::unordered_set cache; - for (int i = 0; i < s_inputs.size(); ++i) { + for (int i = 0; i < std::ssize(s_inputs); ++i) { auto in = s_inputs[i]; if (in_axes[i] != -1) { tmap.insert({in.id(), {inputs[i], in_axes[i]}}); @@ -843,7 +843,7 @@ std::vector vmap_replace( // For each primitive's outputs add its id, the vout id and the vax auto outputs = a.outputs(); - for (int i = 0; i < v_outputs.size(); ++i) { + for (int i = 0; i < std::ssize(v_outputs); ++i) { tmap.insert({outputs[i].id(), {v_outputs[i], v_out_axes[i]}}); } } @@ -851,7 +851,7 @@ std::vector vmap_replace( // Populate the outputs and make sure all the output axes are // in the right place std::vector outputs; - for (int i = 0; i < s_outputs.size(); ++i) { + for (int i = 0; i < std::ssize(s_outputs); ++i) { if (auto map_it = tmap.find(s_outputs[i].id()); map_it != tmap.end()) { auto& [out, vdim] = map_it->second; if (vdim != out_axes[i]) { @@ -995,7 +995,7 @@ std::function(const std::vector&)> custom_function( // using `fun` directly because we may not be able to fully reuse // the outputs of the forward pass. fun_vjp.value_or( - [fun](auto primals, auto cotangents, auto outputs) { + [fun](auto primals, auto cotangents, auto /* outputs */) { auto [__, vjps] = vjp(fun, primals, cotangents); return vjps; }), @@ -1009,8 +1009,8 @@ std::function(const std::vector&)> custom_function( // waste computation. fun_jvp.value_or([fun](auto primals, auto tangents, auto argnums) { std::vector all_tangents; - for (int i = 0, j = 0; i < primals.size(); i++) { - if (j < argnums.size() && i == argnums[j]) { + for (int i = 0, j = 0; i < std::ssize(primals); i++) { + if (j < std::ssize(argnums) && i == argnums[j]) { all_tangents.emplace_back(tangents[j++]); } else { all_tangents.emplace_back(zeros_like(primals[i]));