diff --git a/mlx/backend/common/compiled.cpp b/mlx/backend/common/compiled.cpp index f7b5598ab..98c48cca9 100644 --- a/mlx/backend/common/compiled.cpp +++ b/mlx/backend/common/compiled.cpp @@ -1,8 +1,7 @@ // Copyright © 2023-2024 Apple Inc. #include "mlx/backend/common/compiled.h" -#include "mlx/graph_utils.h" -#include "mlx/primitives.h" +#include "mlx/backend/common/utils.h" #include "mlx/utils.h" namespace mlx::core { @@ -79,55 +78,6 @@ std::string get_type_string(Dtype d) { } } -std::string build_lib_name( - const std::vector& inputs, - const std::vector& outputs, - const std::vector& tape, - const std::unordered_set& constant_ids) { - NodeNamer namer; - std::ostringstream os; - std::ostringstream constant_hasher; - - // Fill the input names. This is not really necessary, I just like having A, - // B, C, ... as the inputs. - for (auto& x : inputs) { - namer.get_name(x); - } - - // The primitives describing the tape. For unary and binary primitives this - // must be enough to describe the full computation. - for (auto& a : tape) { - // name and type of output - os << namer.get_name(a) << kindof(a.dtype()) << a.itemsize(); - // computation performed - a.primitive().print(os); - // name of inputs to the function - for (auto& inp : a.inputs()) { - os << namer.get_name(inp); - } - } - os << "_"; - - for (auto& x : inputs) { - if (constant_ids.find(x.id()) != constant_ids.end()) { - os << "C"; - print_constant(constant_hasher, x); - } else { - os << (is_scalar(x) ? "S" : "V"); - } - } - os << "_"; - for (auto& x : inputs) { - if (constant_ids.find(x.id()) != constant_ids.end()) { - continue; - } - os << kindof(x.dtype()) << x.itemsize(); - } - os << "_" << std::hash{}(constant_hasher.str()); - - return os.str(); -} - bool compiled_check_contiguity( const std::vector& inputs, const Shape& shape) { @@ -159,8 +109,7 @@ bool compiled_check_contiguity( void compiled_allocate_outputs( const std::vector& inputs, std::vector& outputs, - const std::vector& inputs_, - const std::unordered_set& constant_ids_, + const std::function& is_constant, bool contiguous) { if (contiguous) { int o = 0; @@ -175,8 +124,7 @@ void compiled_allocate_outputs( // - Donatable // - Not a constant if (in.itemsize() == outputs[o].itemsize() && !is_scalar(in) && - in.is_donatable() && - constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) { + in.is_donatable() && is_constant(i)) { outputs[o++].copy_shared_buffer(in); } // Get representative input flags to properly set non-donated outputs @@ -204,7 +152,7 @@ void compiled_allocate_outputs( // - Not a constant if (in.flags().row_contiguous && in.size() == outputs[o].size() && in.itemsize() == outputs[o].itemsize() && in.is_donatable() && - constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) { + is_constant(i)) { outputs[o].copy_shared_buffer( in, outputs[o].strides(), in.flags(), in.data_size()); o++; @@ -216,4 +164,74 @@ void compiled_allocate_outputs( } } +std::tuple> compiled_collapse_contiguous_dims( + const std::vector& inputs, + const array& out, + const std::function& is_constant) { + const Shape& shape = out.shape(); + bool contiguous = compiled_check_contiguity(inputs, shape); + if (contiguous) { + return {true, shape, {}}; + } + + std::vector strides_vec{out.strides()}; + for (size_t i = 0; i < inputs.size(); ++i) { + // Skip constants. + if (is_constant(i)) { + continue; + } + + // Skip scalar inputs. + const auto& x = inputs[i]; + if (is_scalar(x)) { + continue; + } + + // Broadcast the inputs to the output shape. + Strides xstrides; + size_t j = 0; + for (; j < shape.size() - x.ndim(); ++j) { + if (shape[j] == 1) { + xstrides.push_back(out.strides()[j]); + } else { + xstrides.push_back(0); + } + } + for (size_t i = 0; i < x.ndim(); ++i, ++j) { + if (x.shape(i) == 1) { + if (shape[j] == 1) { + xstrides.push_back(out.strides()[j]); + } else { + xstrides.push_back(0); + } + } else { + xstrides.push_back(x.strides()[i]); + } + } + strides_vec.push_back(std::move(xstrides)); + } + + auto tup = collapse_contiguous_dims(shape, strides_vec, INT32_MAX); + return {false, std::move(std::get<0>(tup)), std::move(std::get<1>(tup))}; +} + +bool compiled_use_large_index( + const std::vector& inputs, + const std::vector& outputs, + bool contiguous) { + if (contiguous) { + size_t max_size = 0; + for (const auto& in : inputs) { + max_size = std::max(max_size, in.data_size()); + } + return max_size > UINT32_MAX; + } else { + size_t max_size = 0; + for (const auto& o : outputs) { + max_size = std::max(max_size, o.size()); + } + return max_size > UINT32_MAX; + } +} + } // namespace mlx::core diff --git a/mlx/backend/common/compiled.h b/mlx/backend/common/compiled.h index f4d28d6ab..6fccaacd6 100644 --- a/mlx/backend/common/compiled.h +++ b/mlx/backend/common/compiled.h @@ -1,9 +1,8 @@ // Copyright © 2023-2024 Apple Inc. #pragma once +#include #include -#include -#include #include "mlx/array.h" #include "mlx/primitives.h" @@ -14,12 +13,6 @@ inline bool is_static_cast(const Primitive& p) { return (typeid(p) == typeid(Broadcast) || typeid(p) == typeid(AsType)); } -std::string build_lib_name( - const std::vector& inputs, - const std::vector& outputs, - const std::vector& tape, - const std::unordered_set& constant_ids); - std::string get_type_string(Dtype d); template @@ -60,8 +53,19 @@ bool compiled_check_contiguity( void compiled_allocate_outputs( const std::vector& inputs, std::vector& outputs, - const std::vector& inputs_, - const std::unordered_set& constant_ids_, + const std::function& is_constant, + bool contiguous); + +// Collapse contiguous dims ignoring scalars and constants. +std::tuple> compiled_collapse_contiguous_dims( + const std::vector& inputs, + const array& out, + const std::function& is_constant); + +// Return whether the kernel should use large index. +bool compiled_use_large_index( + const std::vector& inputs, + const std::vector& outputs, bool contiguous); } // namespace mlx::core diff --git a/mlx/backend/cpu/compiled.cpp b/mlx/backend/cpu/compiled.cpp index e389e0df5..d0bfb4f45 100644 --- a/mlx/backend/cpu/compiled.cpp +++ b/mlx/backend/cpu/compiled.cpp @@ -146,18 +146,9 @@ inline void build_kernel( const std::vector& inputs, const std::vector& outputs, const std::vector& tape, - const std::unordered_set& constant_ids, + const std::function& is_constant, bool contiguous, int ndim) { - // All outputs should have the exact same shape and will be row contiguous - auto output_shape = outputs[0].shape(); - auto output_strides = outputs[0].strides(); - - // Constants are scalars that are captured by value and cannot change - auto is_constant = [&constant_ids](const array& x) { - return constant_ids.find(x.id()) != constant_ids.end(); - }; - NodeNamer namer; #ifdef _MSC_VER @@ -170,14 +161,15 @@ inline void build_kernel( // Add the input arguments int cnt = 0; - for (auto& x : inputs) { - auto& xname = namer.get_name(x); - + for (size_t i = 0; i < inputs.size(); ++i) { // Skip constants from the input list - if (is_constant(x)) { + if (is_constant(i)) { continue; } + const auto& x = inputs[i]; + auto& xname = namer.get_name(x); + auto tstr = get_type_string(x.dtype()); os << " " << tstr << "* " << xname << " = (" << tstr << "*)args[" << cnt++ << "];" << std::endl; @@ -211,10 +203,11 @@ inline void build_kernel( } // Read the inputs in tmps - for (auto& x : inputs) { + for (size_t i = 0; i < inputs.size(); ++i) { + const auto& x = inputs[i]; auto& xname = namer.get_name(x); - if (is_constant(x)) { + if (is_constant(i)) { os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "; print_constant(os, x); os << ";" << std::endl; @@ -264,8 +257,9 @@ inline void build_kernel( } else { for (int d = ndim - 1; d >= 0; --d) { // Update pointers - for (auto& x : inputs) { - if (is_constant(x) || is_scalar(x)) { + for (size_t i = 0; i < inputs.size(); ++i) { + const auto& x = inputs[i]; + if (is_constant(i) || is_scalar(x)) { continue; } auto& xname = namer.get_name(x); @@ -287,65 +281,37 @@ inline void build_kernel( void Compiled::eval_cpu( const std::vector& inputs, std::vector& outputs) { - if (kernel_lib_.empty()) { - kernel_lib_ = build_lib_name(inputs_, outputs_, tape_, constant_ids_); - } - - // Figure out which kernel we are using - auto& shape = outputs[0].shape(); - auto contiguous = compiled_check_contiguity(inputs, shape); auto& encoder = cpu::get_command_encoder(stream()); - // Handle all broadcasting and collect function input arguments + // Collapse contiguous dims to route to a faster kernel if possible. Also + // handle all broadcasting. + auto [contiguous, shape, strides] = + compiled_collapse_contiguous_dims(inputs, outputs[0], is_constant_); + + // Collect function input arguments. std::vector args; - std::vector> strides; - for (int i = 0; i < inputs.size(); i++) { - // Skip constants. - if (constant_ids_.find(inputs_[i].id()) != constant_ids_.end()) { + int strides_index = 1; + for (size_t i = 0; i < inputs.size(); ++i) { + if (is_constant_(i)) { continue; } - auto& x = inputs[i]; + const auto& x = inputs[i]; encoder.set_input_array(x); args.push_back((void*)x.data()); - - if (contiguous || is_scalar(x)) { - continue; + if (!contiguous && !is_scalar(x)) { + args.push_back(strides[strides_index++].data()); } - - // Broadcast the input to the output shape. - std::vector xstrides; - int j = 0; - for (; j < shape.size() - x.ndim(); j++) { - if (shape[j] == 1) { - xstrides.push_back(outputs[0].strides()[j]); - } else { - xstrides.push_back(0); - } - } - for (int i = 0; i < x.ndim(); i++, j++) { - if (x.shape(i) == 1) { - if (shape[j] == 1) { - xstrides.push_back(outputs[0].strides()[j]); - } else { - xstrides.push_back(0); - } - } else { - xstrides.push_back(x.strides()[i]); - } - } - strides.push_back(std::move(xstrides)); - args.push_back(strides.back().data()); } // Get the kernel name from the lib int ndim = shape.size(); auto kernel_name = kernel_lib_ + (contiguous ? "_contiguous" : "_strided_"); if (!contiguous) { - kernel_name += std::to_string(shape.size()); + kernel_name += std::to_string(ndim); } // Get the function - auto fn_ptr = compile(kernel_name, [&]() { + auto fn_ptr = compile(kernel_name, [&, contiguous = contiguous]() { std::ostringstream kernel; kernel << get_kernel_preamble() << std::endl; kernel << "extern \"C\" {" << std::endl; @@ -355,7 +321,7 @@ void Compiled::eval_cpu( inputs_, outputs_, tape_, - constant_ids_, + is_constant_, contiguous, ndim); // Close extern "C" @@ -363,26 +329,22 @@ void Compiled::eval_cpu( return kernel.str(); }); - compiled_allocate_outputs( - inputs, outputs, inputs_, constant_ids_, contiguous); + compiled_allocate_outputs(inputs, outputs, is_constant_, contiguous); for (auto& x : outputs) { args.push_back(x.data()); encoder.set_output_array(x); } - Shape out_shape; if (!contiguous) { - out_shape = outputs[0].shape(); - args.push_back((void*)out_shape.data()); + args.push_back((void*)shape.data()); } else { args.push_back((void*)outputs[0].data_size()); } auto fun = (void (*)(void**))fn_ptr; - encoder.dispatch( - [fun, - args = std::move(args), - strides = std::move(strides), - out_shape = std::move(out_shape)]() mutable { fun(args.data()); }); + encoder.dispatch([fun, + args = std::move(args), + strides = std::move(strides), + shape = std::move(shape)]() mutable { fun(args.data()); }); } } // namespace mlx::core diff --git a/mlx/backend/metal/compiled.cpp b/mlx/backend/metal/compiled.cpp index db20f938c..6a67b4f57 100644 --- a/mlx/backend/metal/compiled.cpp +++ b/mlx/backend/metal/compiled.cpp @@ -11,8 +11,6 @@ #include "mlx/primitives.h" #include "mlx/utils.h" -using namespace fmt::literals; - namespace mlx::core { inline void build_kernel( @@ -21,21 +19,12 @@ inline void build_kernel( const std::vector& inputs, const std::vector& outputs, const std::vector& tape, - const std::unordered_set& constant_ids, + const std::function& is_constant, bool contiguous, int ndim, bool dynamic_dims, bool use_big_index = false, int work_per_thread = 1) { - // All outputs should have the exact same shape and will be row contiguous - auto output_shape = outputs[0].shape(); - auto output_strides = outputs[0].strides(); - - // Constants are scalars that are captured by value and cannot change - auto is_constant = [&constant_ids](const array& x) { - return constant_ids.find(x.id()) != constant_ids.end(); - }; - NodeNamer namer; bool add_indices = false; int cnt = 0; @@ -45,14 +34,15 @@ inline void build_kernel( "[[host_name(\"{0}\")]]\n[[kernel]] void {0}(\n", kernel_name); // Add the input arguments - for (auto& x : inputs) { - auto& xname = namer.get_name(x); - + for (size_t i = 0; i < inputs.size(); ++i) { // Skip constants from the input list - if (is_constant(x)) { + if (is_constant(i)) { continue; } + const auto& x = inputs[i]; + auto& xname = namer.get_name(x); + // Scalars and contiguous need no strides if (!is_scalar(x) && !contiguous) { add_indices = true; @@ -80,8 +70,6 @@ inline void build_kernel( } // Add output strides and shape to extract the indices. if (!contiguous) { - os += fmt::format( - " constant const int64_t* output_strides [[buffer({0})]],\n", cnt++); os += fmt::format( " constant const int* output_shape [[buffer({0})]],\n", cnt++); } else { @@ -125,7 +113,7 @@ inline void build_kernel( auto& x = inputs[i]; auto& xname = namer.get_name(x); - if (is_constant(x)) { + if (is_constant(i)) { auto type_str = get_type_string(x.dtype()); std::ostringstream ss; print_constant(ss, x); @@ -271,11 +259,6 @@ inline void build_kernel( void Compiled::eval_gpu( const std::vector& inputs, std::vector& outputs) { - // Make the name for the kernel library - if (kernel_lib_.empty()) { - kernel_lib_ = build_lib_name(inputs_, outputs_, tape_, constant_ids_); - } - // Get the kernel if someone else built it already auto& s = stream(); auto& d = metal::device(s.device); @@ -290,7 +273,7 @@ void Compiled::eval_gpu( inputs_, outputs_, tape_, - constant_ids_, + is_constant_, /* contiguous = */ true, /* ndim = */ 0, /* dynamic_dims = */ false, @@ -302,7 +285,7 @@ void Compiled::eval_gpu( inputs_, outputs_, tape_, - constant_ids_, + is_constant_, /* contiguous = */ true, /* ndim = */ 0, /* dynamic_dims = */ false, @@ -315,7 +298,7 @@ void Compiled::eval_gpu( inputs_, outputs_, tape_, - constant_ids_, + is_constant_, /* contiguous = */ false, /* ndim = */ i, /* dynamic_dims = */ false, @@ -328,7 +311,7 @@ void Compiled::eval_gpu( inputs_, outputs_, tape_, - constant_ids_, + is_constant_, /* contiguous = */ false, /* ndim = */ i, /* dynamic_dims = */ false, @@ -342,7 +325,7 @@ void Compiled::eval_gpu( inputs_, outputs_, tape_, - constant_ids_, + is_constant_, /* contiguous = */ false, /* ndim = */ 0, /* dynamic_dims = */ true, @@ -354,7 +337,7 @@ void Compiled::eval_gpu( inputs_, outputs_, tape_, - constant_ids_, + is_constant_, /* contiguous = */ false, /* ndim = */ 0, /* dynamic_dims = */ true, @@ -363,70 +346,13 @@ void Compiled::eval_gpu( return kernel; }); - // Figure out which kernel we are using - auto& output_shape = outputs[0].shape(); - auto contiguous = compiled_check_contiguity(inputs, output_shape); - // Collapse contiguous dims to route to a faster kernel if possible. Also // handle all broadcasting. - std::vector initial_strides; - initial_strides.push_back(outputs[0].strides()); - Shape shape; - std::vector strides; - if (!contiguous) { - for (int i = 0; i < inputs.size(); i++) { - // Skip constants. - if (constant_ids_.find(inputs_[i].id()) != constant_ids_.end()) { - continue; - } - auto& x = inputs[i]; + auto [contiguous, shape, strides] = + compiled_collapse_contiguous_dims(inputs, outputs[0], is_constant_); - // Skip scalar inputs. - if (is_scalar(x)) { - continue; - } - - // Broadcast the inputs to the output shape. - Strides xstrides; - int j = 0; - for (; j < output_shape.size() - x.ndim(); j++) { - if (output_shape[j] == 1) { - xstrides.push_back(outputs[0].strides()[j]); - } else { - xstrides.push_back(0); - } - } - for (int i = 0; i < x.ndim(); i++, j++) { - if (x.shape(i) == 1) { - if (output_shape[j] == 1) { - xstrides.push_back(outputs[0].strides()[j]); - } else { - xstrides.push_back(0); - } - } else { - xstrides.push_back(x.strides()[i]); - } - } - initial_strides.push_back(std::move(xstrides)); - } - std::tie(shape, strides) = - collapse_contiguous_dims(output_shape, initial_strides, INT32_MAX); - } - - bool large; - if (contiguous) { - size_t max_size = 0; - for (auto& in : inputs) { - max_size = std::max(max_size, in.data_size()); - } - large = (max_size > UINT32_MAX); - } else { - size_t max_size = 0; - for (auto& o : outputs) { - max_size = std::max(max_size, o.size()); - } - large = (max_size > UINT32_MAX); - } + // Whether to use large index. + bool large = compiled_use_large_index(inputs, outputs, contiguous); // Get the kernel from the lib int ndim = shape.size(); @@ -451,7 +377,7 @@ void Compiled::eval_gpu( int stride_idx = 1; // idx 0 is the output strides Strides in_strides; for (int i = 0; i < inputs.size(); i++) { - if (constant_ids_.find(inputs_[i].id()) != constant_ids_.end()) { + if (is_constant_(i)) { continue; } auto& x = inputs[i]; @@ -468,8 +394,7 @@ void Compiled::eval_gpu( compute_encoder.set_vector_bytes(in_strides, cnt++); } - compiled_allocate_outputs( - inputs, outputs, inputs_, constant_ids_, contiguous); + compiled_allocate_outputs(inputs, outputs, is_constant_, contiguous); // Put the outputs in for (auto& x : outputs) { @@ -478,7 +403,6 @@ void Compiled::eval_gpu( // Put the output shape and strides in if (!contiguous) { - compute_encoder.set_vector_bytes(strides[0], cnt++); compute_encoder.set_vector_bytes(shape, cnt++); } else { auto size = outputs[0].data_size(); diff --git a/mlx/compile.cpp b/mlx/compile.cpp index 2baeb6fcf..79a55ba8f 100644 --- a/mlx/compile.cpp +++ b/mlx/compile.cpp @@ -1,16 +1,20 @@ // Copyright © 2023-2024 Apple Inc. #include #include +#include #include #include #include "mlx/allocator.h" +#include "mlx/backend/common/compiled.h" #include "mlx/compile.h" #include "mlx/compile_impl.h" #include "mlx/fast_primitives.h" +#include "mlx/graph_utils.h" #include "mlx/primitives.h" #include "mlx/transforms.h" #include "mlx/transforms_impl.h" +#include "mlx/utils.h" namespace mlx::core { @@ -82,7 +86,54 @@ Compiled::Compiled( inputs_(std::move(inputs)), outputs_(std::move(outputs)), tape_(std::move(tape)), - constant_ids_(std::move(constant_ids)) {} + constant_ids_(std::move(constant_ids)), + is_constant_([this](size_t i) { + return constant_ids_.find(inputs_[i].id()) != constant_ids_.end(); + }) { + // Build the kernel name. + NodeNamer namer; + std::ostringstream os; + std::ostringstream constant_hasher; + + // Fill the input names. This is not really necessary, I just like having A, + // B, C, ... as the inputs. + for (const auto& x : inputs_) { + namer.get_name(x); + } + + // The primitives describing the tape. For unary and binary primitives this + // must be enough to describe the full computation. + for (const auto& a : tape_) { + // name and type of output + os << namer.get_name(a) << kindof(a.dtype()) << a.itemsize(); + // computation performed + a.primitive().print(os); + // name of inputs to the function + for (auto& inp : a.inputs()) { + os << namer.get_name(inp); + } + } + os << "_"; + + for (const auto& x : inputs_) { + if (constant_ids_.find(x.id()) != constant_ids_.end()) { + os << "C"; + print_constant(constant_hasher, x); + } else { + os << (is_scalar(x) ? "S" : "V"); + } + } + os << "_"; + for (const auto& x : inputs) { + if (constant_ids.find(x.id()) != constant_ids.end()) { + continue; + } + os << kindof(x.dtype()) << x.itemsize(); + } + os << "_" << std::hash{}(constant_hasher.str()); + + kernel_lib_ = os.str(); +} std::vector Compiled::vjp( const std::vector&, diff --git a/mlx/primitives.h b/mlx/primitives.h index c0fbfc84d..cc60bcfb9 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -627,6 +627,7 @@ class Compiled : public Primitive { const std::vector outputs_; const std::vector tape_; const std::unordered_set constant_ids_; + const std::function is_constant_; std::string kernel_lib_; };