diff --git a/benchmarks/python/compile_bench.py b/benchmarks/python/compile_bench.py new file mode 100644 index 000000000..0d5d9f61d --- /dev/null +++ b/benchmarks/python/compile_bench.py @@ -0,0 +1,109 @@ +# Copyright © 2023-2024 Apple Inc. + +import argparse +import math +import random + +import mlx.core as mx +from time_utils import time_fn + + +def bench_gelu(): + + def gelu(x): + return x * (1 + mx.erf(x / math.sqrt(2))) / 2 + + x = mx.random.uniform(shape=(1000, 1024)) + + def gen_fun(fun): + def bench_fun(x): + for _ in range(10): + x = fun(x) + return x + + return bench_fun + + time_fn(gen_fun(gelu), x, msg="fixed gelu") + time_fn(gen_fun(mx.compile(gelu)), x, msg="compiled fixed gelu") + + def randint(): + return random.randint(1, x.shape[0]) + + def gen_fun(fun): + def bench_fun(x, y): + x = x[: randint()] + for _ in range(10): + x = fun(x) + y = fun(y) + return x, y + + return bench_fun + + y = mx.random.uniform(shape=(1000, 1024)) + time_fn(gen_fun(gelu), x, y, msg="variable gelu") + time_fn(gen_fun(mx.compile(gelu)), x, y, msg="compiled variable gelu") + time_fn( + gen_fun(mx.compile(gelu, shapeless=True)), + x, + y, + msg="shapeless variable gelu", + ) + + +def bench_layernorm(): + + weight = mx.random.uniform(shape=(4096,)).astype(mx.float16) + bias = mx.random.uniform(shape=(4096,)).astype(mx.float16) + mx.eval(weight, bias) + + def layernorm(x): + x = x.astype(mx.float32) + means = mx.mean(x, axis=-1, keepdims=True) + var = mx.var(x, axis=-1, keepdims=True) + x = (x - means) * mx.rsqrt(var + 1e-4) + x = x.astype(mx.float16) + return weight * x + bias + + x = mx.random.uniform(shape=(1000, 4096)).astype(mx.float16) + + def gen_fun(fun): + def bench_fun(x): + for _ in range(10): + x = fun(x) + return x + + return bench_fun + + time_fn(gen_fun(layernorm), x, msg="fixed layernorm") + time_fn(gen_fun(mx.compile(layernorm)), x, msg="compiled fixed layernorm") + + def randint(): + return random.randint(1, x.shape[0]) + + def gen_fun(fun): + def bench_fun(x): + x = x[: randint()] + for _ in range(10): + x = fun(x) + return x + + return bench_fun + + random.seed(0) + time_fn(gen_fun(layernorm), x, msg="variable layernorm") + random.seed(0) + time_fn(gen_fun(mx.compile(layernorm)), x, msg="compiled variable layernorm") + random.seed(0) + time_fn( + gen_fun(mx.compile(layernorm, shapeless=True)), + x, + msg="shapeless variable layernorm", + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("Compile benchmarks.") + args = parser.parse_args() + + bench_gelu() + bench_layernorm() diff --git a/benchmarks/python/time_utils.py b/benchmarks/python/time_utils.py index f10635ec9..2903c3293 100644 --- a/benchmarks/python/time_utils.py +++ b/benchmarks/python/time_utils.py @@ -6,7 +6,11 @@ import mlx.core as mx def time_fn(fn, *args, **kwargs): - print(f"Timing {fn.__name__} ...", end=" ") + msg = kwargs.pop("msg", None) + if msg: + print(f"Timing {msg} ...", end=" ") + else: + print(f"Timing {fn.__name__} ...", end=" ") # warmup for _ in range(5): diff --git a/mlx/backend/common/compiled.cpp b/mlx/backend/common/compiled.cpp index 52bcac4fa..529ad2fa5 100644 --- a/mlx/backend/common/compiled.cpp +++ b/mlx/backend/common/compiled.cpp @@ -37,7 +37,7 @@ std::string build_lib_name( os << "C"; print_constant(constant_hasher, x); } else { - os << ((x.size() == 1) ? "S" : "V"); + os << (is_scalar(x) ? "S" : "V"); } } os << "_"; @@ -122,10 +122,6 @@ std::string get_type_string(Dtype d) { } } -inline bool is_scalar(const array& x) { - return x.size() == 1; -}; - // Return a pointer to a compiled function void* compile( const std::string& kernel_name, @@ -358,7 +354,7 @@ void Compiled::eval_cpu( bool all_col_contig = true; int non_scalar_inputs = 0; for (auto& x : inputs) { - if (x.size() == 1) { + if (is_scalar(x)) { continue; } non_scalar_inputs++; @@ -385,7 +381,7 @@ void Compiled::eval_cpu( auto& x = inputs[i]; args.push_back((void*)x.data()); - if (contiguous || x.size() <= 1) { + if (contiguous || is_scalar(x)) { continue; } @@ -458,7 +454,7 @@ void Compiled::eval_cpu( // - Donatable // - Correct size // - Not a constant - if (in.flags().contiguous && in.size() > 1 && in.is_donatable() && + if (in.flags().contiguous && !is_scalar(in) && in.is_donatable() && constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) { outputs[o++].copy_shared_buffer(in); } diff --git a/mlx/backend/common/compiled.h b/mlx/backend/common/compiled.h index adbd5399c..d01fe4fdc 100644 --- a/mlx/backend/common/compiled.h +++ b/mlx/backend/common/compiled.h @@ -49,4 +49,8 @@ void print_complex_constant(std::ostream& os, const array& x) { void print_constant(std::ostream& os, const array& x); +inline bool is_scalar(const array& x) { + return x.ndim() == 0; +} + } // namespace mlx::core diff --git a/mlx/backend/metal/compiled.cpp b/mlx/backend/metal/compiled.cpp index 681d635ba..3b1ee116a 100644 --- a/mlx/backend/metal/compiled.cpp +++ b/mlx/backend/metal/compiled.cpp @@ -31,9 +31,6 @@ inline void build_kernel( return constant_ids.find(x.id()) != constant_ids.end(); }; - // For scalar we shouldn't do the indexing things, just read at 0 - auto is_scalar = [](const array& x) { return x.size() == 1; }; - NodeNamer namer; bool add_indices = false; int cnt = 0; @@ -226,8 +223,7 @@ void Compiled::eval_gpu( /* ndim = */ 0, /* dynamic_dims = */ true); - kernel_source_ = kernel.str(); - lib = d.get_library(kernel_lib_, kernel_source_); + lib = d.get_library(kernel_lib_, kernel.str()); } // Figure out which kernel we are using @@ -235,7 +231,7 @@ void Compiled::eval_gpu( bool contiguous = true; for (auto& x : inputs) { if ((!x.flags().row_contiguous || x.shape() != output_shape) && - x.size() > 1) { + !is_scalar(x)) { contiguous = false; break; } @@ -256,7 +252,7 @@ void Compiled::eval_gpu( auto& x = inputs[i]; // Skip scalar inputs. - if (x.size() <= 1) { + if (is_scalar(x)) { continue; } @@ -311,7 +307,7 @@ void Compiled::eval_gpu( } auto& x = inputs[i]; set_array_buffer(compute_encoder, x, cnt++); - if (!contiguous && x.size() > 1) { + if (!contiguous && !is_scalar(x)) { compute_encoder->setBytes( strides[stride_idx].data(), strides[stride_idx].size() * sizeof(size_t), diff --git a/mlx/compile.cpp b/mlx/compile.cpp index a648d191f..700c07ced 100644 --- a/mlx/compile.cpp +++ b/mlx/compile.cpp @@ -13,7 +13,7 @@ namespace mlx::core { -constexpr int max_compile_depth = 10; +constexpr int max_compile_depth = 11; bool is_unary(const Primitive& p) { return ( @@ -55,19 +55,20 @@ bool is_noop(const Primitive& p) { return typeid(p) == typeid(Copy) || typeid(p) == typeid(StopGradient); } +bool is_reduction(const Primitive& p) { + return typeid(p) == typeid(Reduce) || typeid(p) == typeid(ArgReduce); +} + bool is_fusable(const Primitive& p) { return is_unary(p) || is_binary(p) || is_broadcast(p) || is_noop(p); } -namespace detail { - -std::vector compile_replace( - const std::vector& tape, - const std::vector& trace_inputs, - const std::vector& trace_outputs, - const std::vector& inputs); - -} // namespace detail +bool allows_shapeless(const Primitive& p) { + return typeid(p) == typeid(Compiled) || is_unary(p) || is_binary(p) || + is_noop(p) || is_reduction(p) || typeid(p) == typeid(Softmax) || + typeid(p) == typeid(Sort) || typeid(p) == typeid(ArgSort) || + typeid(p) == typeid(ArgPartition) || typeid(p) == typeid(Partition); +} Compiled::Compiled( Stream stream, @@ -123,6 +124,23 @@ void Compiled::print(std::ostream& os) { } } +std::vector> Compiled::output_shapes( + const std::vector& inputs) { + size_t nd = 0; + for (auto& in : inputs) { + nd = std::max(nd, in.ndim()); + } + std::vector out_shape(nd, 0); + for (auto& in : inputs) { + auto dd = nd - in.ndim(); + for (auto i = dd; i < nd; ++i) { + out_shape[i] = std::max(out_shape[i], in.shape()[i - dd]); + } + } + // All outputs have the same shape + return std::vector>(outputs_.size(), out_shape); +} + namespace detail { CompileMode& compile_mode() { @@ -180,21 +198,30 @@ struct CompilerCache { std::vector outputs; std::vector tape; bool empty{true}; + std::vector constants; }; // Returns a reference to a CacheEntry which can be updated // by the caller to avoid copying large tapes / inputs / outputs - CacheEntry& find(size_t fun_id, const std::vector& inputs) { + CacheEntry& find( + size_t fun_id, + const std::vector& inputs, + bool shapeless, + const std::vector& constants) { // Try to find the entry auto [entry_it, inserted] = cache_.insert({fun_id, {}}); auto& entries = entry_it->second; - auto is_match = [](const std::vector& in1, - const std::vector& in2) { + auto is_match = [shapeless]( + const std::vector& in1, + const std::vector& in2) { if (in1.size() != in2.size()) { return false; } for (int i = 0; i < in1.size(); ++i) { - if (in1[i].shape() != in2[i].shape()) { + if (in1[i].ndim() != in2[i].ndim()) { + return false; + } + if (!shapeless && in1[i].shape() != in2[i].shape()) { return false; } if (in1[i].dtype() != in2[i].dtype()) { @@ -210,7 +237,7 @@ struct CompilerCache { // more easily searchable structure. for (auto& entry : entries) { // Check the inputs match and return if so - if (is_match(inputs, entry.inputs)) { + if (is_match(inputs, entry.inputs) && constants == entry.constants) { return entry; } } @@ -651,7 +678,8 @@ std::vector compile_replace( const std::vector& tape, const std::vector& trace_inputs, const std::vector& trace_outputs, - const std::vector& inputs) { + const std::vector& inputs, + bool shapeless) { std::unordered_map trace_to_real; for (int i = 0; i < inputs.size(); ++i) { trace_to_real.insert({trace_inputs[i].id(), inputs[i]}); @@ -669,18 +697,29 @@ std::vector compile_replace( real_inputs.push_back(trace_to_real.at(in.id())); } if (a.siblings().empty()) { + auto shape = + shapeless ? a.primitive().output_shapes(real_inputs)[0] : a.shape(); auto real_a = array( - a.shape(), a.dtype(), a.primitive_ptr(), std::move(real_inputs)); + std::move(shape), + a.dtype(), + a.primitive_ptr(), + std::move(real_inputs)); trace_to_real.insert({a.id(), std::move(real_a)}); } else { // Ensure the order is correct for multi-output primitives - std::vector> shapes; std::vector types; auto trace_out = a.outputs(); for (auto& o : trace_out) { - shapes.push_back(o.shape()); types.push_back(o.dtype()); } + std::vector> shapes; + if (shapeless) { + shapes = a.primitive().output_shapes(real_inputs); + } else { + for (auto& o : trace_out) { + shapes.push_back(o.shape()); + } + } auto real_out = array::make_arrays(shapes, types, a.primitive_ptr(), real_inputs); for (int i = 0; i < trace_out.size(); ++i) { @@ -697,13 +736,34 @@ std::vector compile_replace( return outputs; } +void compile_validate_shapeless(const std::vector& tape) { + for (auto& t : tape) { + if (!t.has_primitive()) { + continue; + } + auto& p = t.primitive(); + if (allows_shapeless(p)) { + continue; + } + + std::ostringstream msg; + msg << "[compile] Cannot compile primitive "; + p.print(msg); + msg << " with shapeless enabled."; + throw std::invalid_argument(msg.str()); + } +} + std::function(const std::vector&)> compile( const std::function(const std::vector&)>& fun, - size_t fun_id) { + size_t fun_id, + bool shapeless /* = false */, + std::vector constants /* = {} */) { if (compile_mode() == CompileMode::disabled) { return fun; } - return [fun, fun_id](const std::vector& inputs) { + return [fun, fun_id, shapeless, constants = std::move(constants)]( + const std::vector& inputs) { // If the inputs are tracers, trace the original graph if (std::any_of(inputs.begin(), inputs.end(), [](auto& in) { return in.is_tracer(); @@ -712,12 +772,14 @@ std::function(const std::vector&)> compile( } // Find a cache entry with the correct inputs - auto& entry = compiler_cache().find(fun_id, inputs); + auto& entry = compiler_cache().find(fun_id, inputs, shapeless, constants); // No matching cache entry existed, so compile if (entry.empty) { // Mark the entry as not empty since we are about to fill it entry.empty = false; + // Set the constants + entry.constants = std::move(constants); // Trace to build the graph std::tie(entry.inputs, entry.outputs) = compile_trace(fun, inputs); @@ -739,11 +801,16 @@ std::function(const std::vector&)> compile( if (compile_mode() != CompileMode::no_fuse) { compile_fuse(entry.tape, parents_map, entry.inputs, entry.outputs); } + + if (shapeless) { + compile_validate_shapeless(entry.tape); + } } // At this point we must have a tape, now replace the placeholders // with real arrays that can be evaluated - return compile_replace(entry.tape, entry.inputs, entry.outputs, inputs); + return compile_replace( + entry.tape, entry.inputs, entry.outputs, inputs, shapeless); }; } @@ -754,12 +821,13 @@ void compile_erase(size_t fun_id) { } // namespace detail std::function(const std::vector&)> compile( - const std::function(const std::vector&)>& fun) { + const std::function(const std::vector&)>& fun, + bool shapeless /* false */) { if (detail::compile_mode() == CompileMode::disabled) { return fun; } auto fun_id = detail::getAddress(fun); - return detail::compile(fun, fun_id); + return detail::compile(fun, fun_id, shapeless); } void disable_compile() { diff --git a/mlx/compile.h b/mlx/compile.h index fb3115d61..1134c20dc 100644 --- a/mlx/compile.h +++ b/mlx/compile.h @@ -8,9 +8,10 @@ namespace mlx::core { enum class CompileMode { disabled, no_simplify, no_fuse, enabled }; -// Compile takes a function and returns a new function +/** Compile takes a function and returns a compiled function. */ std::function(const std::vector&)> compile( - const std::function(const std::vector&)>& fun); + const std::function(const std::vector&)>& fun, + bool shapeless = false); /** Globally disable compilation. * Setting the environment variable ``MLX_DISABLE_COMPILE`` can also diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index b78f8d405..b2daaa2f8 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -71,6 +71,15 @@ std::pair, std::vector> Primitive::vmap( throw std::invalid_argument("Primitive's vmap not implemented."); }; +std::vector> Primitive::output_shapes( + const std::vector&) { + std::ostringstream msg; + msg << "[Primitive::output_shapes] "; + this->print(msg); + msg << " cannot infer output shapes."; + throw std::invalid_argument(msg.str()); +}; + std::vector Abs::vjp( const std::vector& primals, const std::vector& cotangents, @@ -383,6 +392,13 @@ std::pair, std::vector> ArgSort::vmap( return {{argsort(inputs[0], axis_ + (axes[0] <= axis_), stream())}, axes}; } +std::vector> ArgReduce::output_shapes( + const std::vector& inputs) { + auto out_shape = inputs[0].shape(); + out_shape[axis_] = 1; + return {out_shape}; +} + bool ArgSort::is_equivalent(const Primitive& other) const { const ArgSort& r_other = static_cast(other); return axis_ == r_other.axis_; @@ -2202,6 +2218,15 @@ bool Reduce::is_equivalent(const Primitive& other) const { return reduce_type_ == r_other.reduce_type_ && axes_ == r_other.axes_; } +std::vector> Reduce::output_shapes( + const std::vector& inputs) { + std::vector out_shape = inputs[0].shape(); + for (auto i : axes_) { + out_shape[i] = 1; + } + return {out_shape}; +} + std::vector Round::vjp( const std::vector& primals, const std::vector& cotangents, diff --git a/mlx/primitives.h b/mlx/primitives.h index 9d0a9181c..73e4394a5 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -36,6 +36,12 @@ return true; \ } +#define DEFINE_INPUT_OUTPUT_SHAPE() \ + std::vector> output_shapes( \ + const std::vector& inputs) override { \ + return {inputs[0].shape()}; \ + }; + namespace mlx::core { // Abstract base class @@ -102,6 +108,11 @@ class Primitive { return false; } + /** Get the output shapes of the primitive. This is not required to be + * implemented by derived classes, in which case it will throw. */ + virtual std::vector> output_shapes( + const std::vector& inputs); + virtual ~Primitive() = default; Primitive(const Primitive& other) = delete; Primitive(Primitive&& other) = delete; @@ -152,6 +163,7 @@ class Abs : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(Abs) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -168,6 +180,7 @@ class Add : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(Add) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -226,6 +239,7 @@ class ArcCos : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(ArcCos) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -242,6 +256,7 @@ class ArcCosh : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(ArcCosh) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -258,6 +273,7 @@ class ArcSin : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(ArcSin) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -274,6 +290,7 @@ class ArcSinh : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(ArcSinh) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -290,6 +307,7 @@ class ArcTan : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(ArcTan) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -306,6 +324,7 @@ class ArcTanh : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(ArcTanh) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -321,6 +340,7 @@ class ArgPartition : public UnaryPrimitive { DEFINE_VMAP() DEFINE_PRINT(ArgPartition) + DEFINE_INPUT_OUTPUT_SHAPE() bool is_equivalent(const Primitive& other) const override; private: @@ -346,6 +366,8 @@ class ArgReduce : public UnaryPrimitive { DEFINE_VMAP() DEFINE_PRINT(ArgReduce) bool is_equivalent(const Primitive& other) const override; + std::vector> output_shapes( + const std::vector& inputs) override; private: ReduceType reduce_type_; @@ -364,6 +386,7 @@ class ArgSort : public UnaryPrimitive { DEFINE_VMAP() DEFINE_PRINT(ArgSort) + DEFINE_INPUT_OUTPUT_SHAPE() bool is_equivalent(const Primitive& other) const override; private: @@ -383,6 +406,7 @@ class AsType : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() DEFINE_PRINT(AsType) + DEFINE_INPUT_OUTPUT_SHAPE() bool is_equivalent(const Primitive& other) const override; private: @@ -448,6 +472,7 @@ class Ceil : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(Ceil) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -478,15 +503,14 @@ class Compiled : public Primitive { DEFINE_VMAP() DEFINE_GRADS() + std::vector> output_shapes( + const std::vector& inputs) override; void print(std::ostream& os) override; bool is_equivalent(const Primitive& other) const override; - std::string metal_lib_name() const { + std::string lib_name() const { return kernel_lib_; } - std::string metal_lib_source() const { - return kernel_source_; - } private: const std::vector inputs_; @@ -495,7 +519,6 @@ class Compiled : public Primitive { const std::unordered_set constant_ids_; std::string kernel_lib_; - std::string kernel_source_; }; class Concatenate : public UnaryPrimitive { @@ -563,6 +586,7 @@ class Copy : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(Copy) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -579,6 +603,7 @@ class Cos : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(Cos) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -595,6 +620,7 @@ class Cosh : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(Cosh) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -665,6 +691,7 @@ class Divide : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(Divide) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -683,6 +710,10 @@ class DivMod : public Primitive { DEFINE_GRADS() DEFINE_PRINT(DivMod) DEFINE_DEFAULT_IS_EQUIVALENT() + std::vector> output_shapes( + const std::vector& inputs) override { + return std::vector{inputs[0].shape(), inputs[0].shape()}; + }; private: void eval(const std::vector& inputs, std::vector& outputs); @@ -699,6 +730,7 @@ class Remainder : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(Remainder) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -715,6 +747,7 @@ class Equal : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() void print(std::ostream& os) override { if (equal_nan_) { @@ -740,6 +773,7 @@ class Erf : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(Erf) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -756,6 +790,7 @@ class ErfInv : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(ErfInv) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -772,6 +807,7 @@ class Exp : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(Exp) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -814,6 +850,7 @@ class Floor : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(Floor) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -868,6 +905,7 @@ class Greater : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(Greater) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -884,6 +922,7 @@ class GreaterEqual : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(GreaterEqual) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -900,6 +939,7 @@ class Less : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(Less) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -916,6 +956,7 @@ class LessEqual : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(LessEqual) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -958,6 +999,7 @@ class Log : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() void print(std::ostream& os) override { switch (base_) { @@ -988,6 +1030,7 @@ class Log1p : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() DEFINE_PRINT(Log1p) + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -1004,6 +1047,7 @@ class LogicalNot : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(LogicalNot) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -1020,6 +1064,7 @@ class LogicalAnd : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(LogicalAnd) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -1036,6 +1081,7 @@ class LogicalOr : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(LogicalOr) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -1052,6 +1098,7 @@ class LogAddExp : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(LogAddExp) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -1085,6 +1132,7 @@ class Maximum : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(Maximum) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -1101,6 +1149,7 @@ class Minimum : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(Minimum) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -1117,6 +1166,7 @@ class Multiply : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(Multiply) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -1133,6 +1183,7 @@ class Negative : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(Negative) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -1149,6 +1200,7 @@ class NotEqual : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(NotEqual) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -1193,6 +1245,7 @@ class Partition : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() DEFINE_PRINT(Partition) + DEFINE_INPUT_OUTPUT_SHAPE() bool is_equivalent(const Primitive& other) const override; private: @@ -1213,6 +1266,7 @@ class Power : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(Power) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -1305,6 +1359,9 @@ class Reduce : public UnaryPrimitive { const std::vector& argnums, const std::vector& outputs) override; + std::vector> output_shapes( + const std::vector& inputs) override; + void print(std::ostream& os) override { switch (reduce_type_) { case And: @@ -1347,6 +1404,7 @@ class Round : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(Round) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -1455,6 +1513,7 @@ class Sigmoid : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(Sigmoid) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -1471,6 +1530,7 @@ class Sign : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(Sign) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -1487,6 +1547,7 @@ class Sin : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(Sin) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -1503,6 +1564,7 @@ class Sinh : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(Sinh) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -1547,6 +1609,7 @@ class Softmax : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(Softmax) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -1563,6 +1626,7 @@ class Sort : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() DEFINE_PRINT(Sort) + DEFINE_INPUT_OUTPUT_SHAPE() bool is_equivalent(const Primitive& other) const override; private: @@ -1604,6 +1668,7 @@ class Square : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(Square) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -1619,6 +1684,7 @@ class Sqrt : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() + DEFINE_INPUT_OUTPUT_SHAPE() bool is_equivalent(const Primitive& other) const override; void print(std::ostream& os) override { @@ -1644,6 +1710,7 @@ class StopGradient : public UnaryPrimitive { DEFINE_VMAP() DEFINE_PRINT(StopGradient) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -1660,6 +1727,7 @@ class Subtract : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(Subtract) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -1676,6 +1744,7 @@ class Tan : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(Tan) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -1692,6 +1761,7 @@ class Tanh : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(Tanh) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); diff --git a/mlx/transforms_impl.h b/mlx/transforms_impl.h index a1b3461ab..6c7959426 100644 --- a/mlx/transforms_impl.h +++ b/mlx/transforms_impl.h @@ -18,7 +18,9 @@ std::vector vmap_replace( // idea. std::function(const std::vector&)> compile( const std::function(const std::vector&)>& fun, - size_t fun_id); + size_t fun_id, + bool shapeless = false, + std::vector constants = {}); // Erase cached compile functions void compile_erase(size_t fun_id); diff --git a/python/mlx/nn/layers/activations.py b/python/mlx/nn/layers/activations.py index db07ce190..dfd435cfd 100644 --- a/python/mlx/nn/layers/activations.py +++ b/python/mlx/nn/layers/activations.py @@ -1,6 +1,7 @@ # Copyright © 2023 Apple Inc. import math +from functools import partial from typing import Any import mlx.core as mx @@ -9,13 +10,13 @@ from mlx.nn.layers.base import Module def _make_activation_module(f): def decorator(klass): - klass.__doc__ = f.__doc__ - klass.__call__ = lambda self, x: f(x) + klass.__call__ = lambda _, x: f(x) return klass return decorator +@partial(mx.compile, shapeless=True) def sigmoid(x): r"""Applies the element-wise function: @@ -25,6 +26,7 @@ def sigmoid(x): return mx.sigmoid(x) +@partial(mx.compile, shapeless=True) def relu(x): r"""Applies the Rectified Linear Unit. @@ -33,6 +35,7 @@ def relu(x): return mx.maximum(x, 0) +@partial(mx.compile, shapeless=True) def leaky_relu(x, negative_slope=0.01): r"""Applies the Leaky Rectified Linear Unit. @@ -41,6 +44,7 @@ def leaky_relu(x, negative_slope=0.01): return mx.maximum(negative_slope * x, x) +@partial(mx.compile, shapeless=True) def log_softmax(x, axis=-1): r"""Applies the Log Softmax function. @@ -49,6 +53,7 @@ def log_softmax(x, axis=-1): return x - mx.logsumexp(x, axis=axis, keepdims=True) +@partial(mx.compile, shapeless=True) def elu(x, alpha=1.0): r"""Applies the Exponential Linear Unit. @@ -57,6 +62,7 @@ def elu(x, alpha=1.0): return mx.where(x > 0, x, alpha * (mx.exp(x) - 1)) +@partial(mx.compile, shapeless=True) def relu6(x): r"""Applies the Rectified Linear Unit 6. @@ -65,6 +71,7 @@ def relu6(x): return mx.minimum(mx.maximum(x, 0), 6.0) +@partial(mx.compile, shapeless=True) def softmax(x, axis=-1): r"""Applies the Softmax function. @@ -73,6 +80,7 @@ def softmax(x, axis=-1): return mx.softmax(x, axis=axis) +@partial(mx.compile, shapeless=True) def softplus(x): r"""Applies the Softplus function. @@ -81,6 +89,7 @@ def softplus(x): return mx.logaddexp(x, 0) +@partial(mx.compile, shapeless=True) def softsign(x): r"""Applies the Softsign function. @@ -89,6 +98,7 @@ def softsign(x): return mx.divide(x, 1 + mx.abs(x)) +@partial(mx.compile, shapeless=True) def softshrink(x, lambd: float = 0.5): r"""Applies the Softshrink activation function. @@ -102,6 +112,7 @@ def softshrink(x, lambd: float = 0.5): return mx.where(mx.abs(x) > lambd, x - mx.sign(x) * lambd, 0) +@partial(mx.compile, shapeless=True) def celu(x, alpha=1.0): r"""Applies the Continuously Differentiable Exponential Linear Unit. @@ -111,6 +122,7 @@ def celu(x, alpha=1.0): return mx.maximum(x, 0.0) + alpha * (mx.exp(mx.minimum(x, 0.0) / alpha) - 1) +@partial(mx.compile, shapeless=True) def silu(x): r"""Applies the Sigmoid Linear Unit. Also known as Swish. @@ -120,6 +132,7 @@ def silu(x): return x * mx.sigmoid(x) +@partial(mx.compile, shapeless=True) def log_sigmoid(x): r"""Applies the Log Sigmoid function. @@ -128,6 +141,7 @@ def log_sigmoid(x): return -softplus(-x) +@partial(mx.compile, shapeless=True) def gelu(x): r"""Applies the Gaussian Error Linear Units function. @@ -142,6 +156,7 @@ def gelu(x): return x * (1 + mx.erf(x / math.sqrt(2))) / 2 +@partial(mx.compile, shapeless=True) def gelu_approx(x): r"""An approximation to Gaussian Error Linear Unit. @@ -159,6 +174,7 @@ def gelu_approx(x): return x * mx.sigmoid(1.60033 * x * (1 + 0.0433603 * x.square())) +@partial(mx.compile, shapeless=True) def gelu_fast_approx(x): r"""A fast approximation to Gaussian Error Linear Unit. @@ -192,27 +208,7 @@ def glu(x: mx.array, axis: int = -1) -> mx.array: return a * mx.sigmoid(b) -class GLU(Module): - r"""Applies the gated linear unit function. - - This function splits the ``axis`` dimension of the input into two halves - (:math:`a` and :math:`b`) and applies :math:`a * \sigma(b)`. - - .. math:: - textrm{GLU}(x) = a * \sigma(b) - - Args: - axis (int): The dimension to split along. Default: ``-1`` - """ - - def __init__(self, axis: int = -1): - super().__init__() - self.axis = axis - - def __call__(self, x) -> Any: - return glu(x=x, axis=self.axis) - - +@partial(mx.compile, shapeless=True) def step(x: mx.array, threshold: float = 0.0): r"""Applies the Step Activation Function. @@ -232,6 +228,7 @@ def step(x: mx.array, threshold: float = 0.0): return mx.where(x > threshold, 1, 0) +@partial(mx.compile, shapeless=True) def selu(x): r"""Applies the Scaled Exponential Linear Unit. @@ -248,6 +245,7 @@ def selu(x): return elu(x, 1.67326) * 1.0507 +@partial(mx.compile, shapeless=True) def prelu(x: mx.array, alpha: mx.array) -> mx.array: r"""Applies the element-wise parametric ReLU. @@ -259,6 +257,7 @@ def prelu(x: mx.array, alpha: mx.array) -> mx.array: return mx.maximum(0, x) + alpha * mx.minimum(0, x) +@partial(mx.compile, shapeless=True) def mish(x: mx.array) -> mx.array: r"""Applies the Mish function, element-wise. Mish: A Self Regularized Non-Monotonic Neural Activation Function. @@ -272,6 +271,7 @@ def mish(x: mx.array) -> mx.array: return x * mx.tanh(softplus(x)) +@partial(mx.compile, shapeless=True) def hardswish(x): r"""Applies the hardswish function, element-wise. @@ -282,6 +282,35 @@ def hardswish(x): return x * mx.minimum(max_x_3, 6) / 6 +def tanh(x): + """Applies the hyperbolic tangent function. + + Simply ``mx.tanh(x)``. + """ + return mx.tanh(x) + + +class GLU(Module): + r"""Applies the gated linear unit function. + + This function splits the ``axis`` dimension of the input into two halves + (:math:`a` and :math:`b`) and applies :math:`a * \sigma(b)`. + + .. math:: + textrm{GLU}(x) = a * \sigma(b) + + Args: + axis (int): The dimension to split along. Default: ``-1`` + """ + + def __init__(self, axis: int = -1): + super().__init__() + self.axis = axis + + def __call__(self, x) -> Any: + return glu(x=x, axis=self.axis) + + @_make_activation_module(mx.sigmoid) class Sigmoid(Module): r"""Applies the sigmoid function, element-wise. @@ -500,14 +529,6 @@ class GELU(Module): return self._act(x) -def tanh(x): - """Applies the hyperbolic tangent function. - - Simply ``mx.tanh(x)``. - """ - return mx.tanh(x) - - @_make_activation_module(tanh) class Tanh(Module): r"""Applies the hyperbolic tangent function. diff --git a/python/src/transforms.cpp b/python/src/transforms.cpp index f081fdedd..cda1d6316 100644 --- a/python/src/transforms.cpp +++ b/python/src/transforms.cpp @@ -555,13 +555,19 @@ struct PyCompiledFun { size_t fun_id; py::object captured_inputs; py::object captured_outputs; + bool shapeless; size_t num_outputs{0}; - PyCompiledFun(const py::function& fun, py::object inputs, py::object outputs) + PyCompiledFun( + const py::function& fun, + py::object inputs, + py::object outputs, + bool shapeless) : fun(fun), fun_id(reinterpret_cast(fun.ptr())), captured_inputs(inputs), - captured_outputs(outputs) {} + captured_outputs(outputs), + shapeless(shapeless) {} PyCompiledFun(const PyCompiledFun&) = delete; PyCompiledFun& operator=(const PyCompiledFun&) = delete; @@ -571,11 +577,15 @@ struct PyCompiledFun { other.fun_id = 0; captured_inputs = std::move(other.captured_inputs); captured_outputs = std::move(other.captured_outputs); + shapeless = other.shapeless; num_outputs = other.num_outputs; }; - py::object operator()(const py::args& args) { - auto compile_fun = [this, &args](const std::vector& a) { + py::object operator()(const py::args& args, const py::kwargs& kwargs) { + auto inputs = tree_flatten(args, false); + + auto compile_fun = [this, &args, &kwargs, num_args = inputs.size()]( + const std::vector& a) { // Put tracers into captured inputs std::vector flat_in_captures; std::vector trace_captures; @@ -586,8 +596,10 @@ struct PyCompiledFun { tree_fill(captured_inputs, trace_captures); } - auto [outputs, py_outputs] = tree_flatten_with_structure( - std::move(fun(*tree_unflatten(args, a))), false); + auto tree_outputs = + fun(*tree_unflatten(args, a), **tree_unflatten(kwargs, a, num_args)); + auto [outputs, py_outputs] = + tree_flatten_with_structure(std::move(tree_outputs), false); tree_cache().insert({fun_id, py_outputs}); @@ -607,7 +619,14 @@ struct PyCompiledFun { return outputs; }; - auto inputs = tree_flatten(args, false); + { + auto flat_kwargs = tree_flatten(kwargs, false); + inputs.insert( + inputs.end(), + std::make_move_iterator(flat_kwargs.begin()), + std::make_move_iterator(flat_kwargs.end())); + } + if (!py::isinstance(captured_inputs)) { auto flat_in_captures = tree_flatten(captured_inputs, false); inputs.insert( @@ -616,8 +635,39 @@ struct PyCompiledFun { std::make_move_iterator(flat_in_captures.end())); } + // Collect the compilation constants + std::vector constants; + auto value_hash = [](py::handle o) -> std::optional { + // Consider expanding tuples to their contents including start and end + // ids + if (py::isinstance(o) || py::isinstance(o)) { + auto r = py::hash(o); + return *reinterpret_cast(&r); + } else if (py::isinstance(o)) { + auto r = o.cast(); + return *reinterpret_cast(&r); + } else if (py::isinstance(o)) { + auto r = o.cast(); + return *reinterpret_cast(&r); + } else { + return std::nullopt; + } + }; + for (int i = 0; i < args.size(); i++) { + if (auto h = value_hash(args[i]); h.has_value()) { + constants.push_back(*h); + } + } + for (auto& pair : kwargs) { + if (auto h = value_hash(pair.second); h.has_value()) { + constants.push_back(*value_hash(pair.first)); + constants.push_back(*h); + } + } + // Compile and call - auto outputs = detail::compile(compile_fun, fun_id)(inputs); + auto outputs = + detail::compile(compile_fun, fun_id, shapeless, constants)(inputs); if (!py::isinstance(captured_outputs)) { std::vector captures( std::make_move_iterator(outputs.begin() + num_outputs), @@ -965,12 +1015,14 @@ void init_transforms(py::module_& m) { "compile", [](const py::function& fun, const py::object& inputs, - const py::object& outputs) { - return py::cpp_function(PyCompiledFun{fun, inputs, outputs}); + const py::object& outputs, + bool shapeless) { + return py::cpp_function(PyCompiledFun{fun, inputs, outputs, shapeless}); }, "fun"_a, "inputs"_a = std::nullopt, "outputs"_a = std::nullopt, + "shapeless"_a = false, R"pbdoc( compile(fun: function) -> function @@ -990,6 +1042,12 @@ void init_transforms(py::module_& m) { :obj:`list` or a :obj:`dict` containing arbitrarily nested lists, dictionaries, or arrays. Leaf nodes that are not :obj:`array` are ignored. Default: ``None`` + shapeless (bool, optional): A function compiled with the ``shapeless`` + option enabled will not be recompiled when the input shape changes. Not all + functions can be compiled with ``shapeless`` enabled. Attempting to compile + such functions with shapeless enabled will throw. Note, changing the number + of dimensions or type of any input will result in a recompilation even with + ``shapeless`` set to ``True``. Default: ``False`` Returns: function: A compiled function which has the same input arguments diff --git a/python/tests/test_compile.py b/python/tests/test_compile.py index 2e0bb1d7f..e53134482 100644 --- a/python/tests/test_compile.py +++ b/python/tests/test_compile.py @@ -381,6 +381,164 @@ class TestCompile(mlx_tests.MLXTestCase): self.assertFalse(mx.allclose(fun(), fun(), 1e-2, 1e-2)) + def test_compile_kwargs(self): + + @mx.compile + def fun(x, y, z): + return x + y + z + + x = mx.array(1) + y = mx.array(2) + z = mx.array(3) + out = fun(x, y=y, z=z) + self.assertEqual(out.item(), 6) + + def test_shapeless_compile(self): + y = 1 + + @partial(mx.compile, shapeless=True) + def fun(x): + return x + y + + x = mx.array([1, 2]) + self.assertTrue(mx.array_equal(fun(x), mx.array([2, 3]))) + + # The function is not recompiled, so the change + # to y should not be reflected in the output + y = 2 + x = mx.array([1, 2, 3]) + self.assertTrue(mx.array_equal(fun(x), mx.array([2, 3, 4]))) + + # Type change recompiles + x = mx.array([1.0, 2.0, 3.0]) + self.assertTrue(mx.array_equal(fun(x), mx.array([3.0, 4.0, 5.0]))) + fun(x, y=y, z=z) + + def test_shapeless_compile(self): + y = 1 + + @partial(mx.compile, shapeless=True) + def fun(x): + return x + y + + x = mx.array([1, 2]) + self.assertTrue(mx.array_equal(fun(x), mx.array([2, 3]))) + + # The function is not recompiled, so the change + # to y should not be reflected in the output + y = 2 + x = mx.array([1, 2, 3]) + self.assertTrue(mx.array_equal(fun(x), mx.array([2, 3, 4]))) + + # Type change recompiles + x = mx.array([1.0, 2.0, 3.0]) + self.assertTrue(mx.array_equal(fun(x), mx.array([3.0, 4.0, 5.0]))) + + # Dim change recompiles + x = mx.array([[1, 2, 3]]) + self.assertTrue(mx.array_equal(fun(x), mx.array([[3, 4, 5]]))) + + def test_shapeless_compile_with_broadcasts(self): + x = mx.ones((2, 2)) + y = mx.array([2, 2]) + + def fun(x, y): + return x * y + + cfun = mx.compile(fun, shapeless=True) + self.assertTrue(mx.array_equal(cfun(x, y), fun(x, y))) + self.assertTrue(mx.array_equal(cfun(y, x), fun(y, x))) + y = mx.array([[3]]) + self.assertTrue(mx.array_equal(cfun(x, y), fun(x, y))) + self.assertTrue(mx.array_equal(cfun(y, x), fun(y, x))) + + def test_shapeless_compile_with_reduction(self): + # Test shapeless compile with a reduction + z = 1 + + @partial(mx.compile, shapeless=True) + def fun(x, y): + return x + y.sum(0, keepdims=True) + z + + x = mx.ones((2, 2), mx.int32) + y = mx.ones((2, 2), mx.int32) + self.assertTrue(mx.array_equal(fun(x, y), mx.full(shape=(2, 2), vals=4))) + x = mx.ones((3, 3), mx.int32) + y = mx.ones((3, 3), mx.int32) + z = 2 + self.assertTrue(mx.array_equal(fun(x, y), mx.full(shape=(3, 3), vals=5))) + + x1 = mx.array([[1, 2], [3, 4], [5, 6]]) + x2 = mx.array([[1, 2]]) + + def fun(x): + return x * x.sum(-1, keepdims=True) + + cfun = mx.compile(fun, shapeless=True) + mx.eval(cfun(x1)) + self.assertTrue(mx.array_equal(fun(x2), cfun(x2))) + + def test_compile_with_constant(self): + + # Test float + @partial(mx.compile) + def fun(x, y): + return x + y + + z = fun(mx.array(1.0), 1.0) + self.assertEqual(z.item(), 2.0) + + z = fun(mx.array(1.0), 2.0) + self.assertEqual(z.item(), 3.0) + + z = fun(mx.array(1.0), y=1.0) + self.assertEqual(z.item(), 2.0) + + z = fun(mx.array(1.0), y=3.0) + self.assertEqual(z.item(), 4.0) + + # Test tuple + @partial(mx.compile) + def fun(x, y=(1, 2)): + return x + y[0] + y[1] + + z = fun(mx.array(1)) + self.assertEqual(z.item(), 4) + + z = fun(mx.array(1), (2, 2)) + self.assertEqual(z.item(), 5) + + z = fun(mx.array(1), (2, 1)) + self.assertEqual(z.item(), 4) + + # Test bool + @partial(mx.compile) + def fun(x, y): + if y: + return x + 1 + else: + return x + 2 + + z = fun(mx.array(1), True) + self.assertEqual(z.item(), 2) + + z = fun(mx.array(1), False) + self.assertEqual(z.item(), 3) + + # Test string + @partial(mx.compile) + def fun(x, y): + if y == "one": + return x + 1 + else: + return x + 2 + + z = fun(mx.array(1), "one") + self.assertEqual(z.item(), 2) + + z = fun(mx.array(1), "two") + self.assertEqual(z.item(), 3) + if __name__ == "__main__": unittest.main() diff --git a/tests/compile_tests.cpp b/tests/compile_tests.cpp index 8ad67a1ed..569ab0913 100644 --- a/tests/compile_tests.cpp +++ b/tests/compile_tests.cpp @@ -624,31 +624,23 @@ TEST_CASE("test transform compiled function") { CHECK(!outs[0].inputs()[1].has_primitive()); } -TEST_CASE("test metal fusion kernel reuse") { - if (default_device() != Device::gpu) { - return; - } - +TEST_CASE("test fusion kernel reuse") { auto cfun = compile(gelu_1); auto x = array({2.0f, -2.0f}); auto y = cfun({x})[0]; auto p = std::dynamic_pointer_cast(y.primitive_ptr()); eval(y); - std::string lib_name = p->metal_lib_name(); - std::string lib_source = p->metal_lib_source(); + std::string lib_name = p->lib_name(); CHECK(!lib_name.empty()); - CHECK(!lib_source.empty()); x = astype(reshape(arange(10), {2, 5}), float32); auto z = cfun({x})[0]; auto pz = std::dynamic_pointer_cast(z.primitive_ptr()); eval(z); - std::string lib_name_z = pz->metal_lib_name(); - std::string lib_source_z = pz->metal_lib_source(); + std::string lib_name_z = pz->lib_name(); CHECK(!lib_name_z.empty()); - CHECK(lib_source_z.empty()); CHECK_EQ(lib_name, lib_name_z); } @@ -657,29 +649,57 @@ auto add3(const std::vector& xs) { return std::vector{xs[0] + xs[0] + xs[0]}; } -TEST_CASE("test metal fusion types") { - if (default_device() != Device::gpu) { - return; - } - +TEST_CASE("test fusion types") { auto cfun = compile(add3); auto x = array({2.0f, -2.0f}); auto y = cfun({x})[0]; auto p = std::dynamic_pointer_cast(y.primitive_ptr()); eval(y); - std::string lib_name = p->metal_lib_name(); - std::string lib_source = p->metal_lib_source(); + std::string lib_name = p->lib_name(); CHECK(!lib_name.empty()); - CHECK(!lib_source.empty()); x = array({2, -2}, int32); auto z = cfun({x})[0]; auto pz = std::dynamic_pointer_cast(z.primitive_ptr()); eval(z); - std::string lib_name_z = pz->metal_lib_name(); - std::string lib_source_z = pz->metal_lib_source(); + std::string lib_name_z = pz->lib_name(); CHECK(!lib_name_z.empty()); - CHECK(!lib_source_z.empty()); +} + +auto compile_shapeless_not_ok(const std::vector& inputs) { + auto x = reshape(inputs[0], {2, 2}); + return std::vector{x}; +} + +auto compile_shapeless_ok(const std::vector& inputs) { + auto x = inputs[0] + array({2}); + return std::vector{x}; +} + +TEST_CASE("test shapeless compile") { + { + auto cfun = compile(compile_shapeless_not_ok, /* shapeless */ true); + CHECK_THROWS(cfun({array({1, 2, 3, 4})})); + } + + { + auto cfun = compile(compile_shapeless_ok, /* shapeless */ true); + auto out = cfun({array({1, 2})})[0]; + auto out2 = cfun({array({1, 2, 3, 4})})[0]; + + // Not making a new constant array since no recompile, + // hence the ids should be the same + CHECK_EQ(out.inputs()[1].id(), out2.inputs()[1].id()); + CHECK(array_equal(out2, array({3, 4, 5, 6})).item()); + + // Recompile since type changes + out2 = cfun({array({1.0, 2.0})})[0]; + CHECK_NE(out.inputs()[1].id(), out2.inputs()[1].id()); + + // Recompile since ndim changes + out2 = cfun({array({1.0, 2.0}, {1, 2})})[0]; + CHECK_NE(out.inputs()[1].id(), out2.inputs()[1].id()); + } }