Shapeless compilation for some graphs (#687)

* shapeless compilation for some graphs

* update compile benchmark

* default compile a few activations

* buffer donation

* bugfix

* shapeless fix

* update tests to work for cpu and gpu fusion

* test kwargs

* add kwargs to compile

* Recompile when python arguments change

* no compile for tanh

* some constant tests

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
Awni Hannun 2024-02-19 21:43:54 -08:00 committed by GitHub
parent d0fda82595
commit 5798256fcf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 645 additions and 113 deletions

View File

@ -0,0 +1,109 @@
# Copyright © 2023-2024 Apple Inc.
import argparse
import math
import random
import mlx.core as mx
from time_utils import time_fn
def bench_gelu():
def gelu(x):
return x * (1 + mx.erf(x / math.sqrt(2))) / 2
x = mx.random.uniform(shape=(1000, 1024))
def gen_fun(fun):
def bench_fun(x):
for _ in range(10):
x = fun(x)
return x
return bench_fun
time_fn(gen_fun(gelu), x, msg="fixed gelu")
time_fn(gen_fun(mx.compile(gelu)), x, msg="compiled fixed gelu")
def randint():
return random.randint(1, x.shape[0])
def gen_fun(fun):
def bench_fun(x, y):
x = x[: randint()]
for _ in range(10):
x = fun(x)
y = fun(y)
return x, y
return bench_fun
y = mx.random.uniform(shape=(1000, 1024))
time_fn(gen_fun(gelu), x, y, msg="variable gelu")
time_fn(gen_fun(mx.compile(gelu)), x, y, msg="compiled variable gelu")
time_fn(
gen_fun(mx.compile(gelu, shapeless=True)),
x,
y,
msg="shapeless variable gelu",
)
def bench_layernorm():
weight = mx.random.uniform(shape=(4096,)).astype(mx.float16)
bias = mx.random.uniform(shape=(4096,)).astype(mx.float16)
mx.eval(weight, bias)
def layernorm(x):
x = x.astype(mx.float32)
means = mx.mean(x, axis=-1, keepdims=True)
var = mx.var(x, axis=-1, keepdims=True)
x = (x - means) * mx.rsqrt(var + 1e-4)
x = x.astype(mx.float16)
return weight * x + bias
x = mx.random.uniform(shape=(1000, 4096)).astype(mx.float16)
def gen_fun(fun):
def bench_fun(x):
for _ in range(10):
x = fun(x)
return x
return bench_fun
time_fn(gen_fun(layernorm), x, msg="fixed layernorm")
time_fn(gen_fun(mx.compile(layernorm)), x, msg="compiled fixed layernorm")
def randint():
return random.randint(1, x.shape[0])
def gen_fun(fun):
def bench_fun(x):
x = x[: randint()]
for _ in range(10):
x = fun(x)
return x
return bench_fun
random.seed(0)
time_fn(gen_fun(layernorm), x, msg="variable layernorm")
random.seed(0)
time_fn(gen_fun(mx.compile(layernorm)), x, msg="compiled variable layernorm")
random.seed(0)
time_fn(
gen_fun(mx.compile(layernorm, shapeless=True)),
x,
msg="shapeless variable layernorm",
)
if __name__ == "__main__":
parser = argparse.ArgumentParser("Compile benchmarks.")
args = parser.parse_args()
bench_gelu()
bench_layernorm()

View File

@ -6,7 +6,11 @@ import mlx.core as mx
def time_fn(fn, *args, **kwargs): def time_fn(fn, *args, **kwargs):
print(f"Timing {fn.__name__} ...", end=" ") msg = kwargs.pop("msg", None)
if msg:
print(f"Timing {msg} ...", end=" ")
else:
print(f"Timing {fn.__name__} ...", end=" ")
# warmup # warmup
for _ in range(5): for _ in range(5):

View File

@ -37,7 +37,7 @@ std::string build_lib_name(
os << "C"; os << "C";
print_constant(constant_hasher, x); print_constant(constant_hasher, x);
} else { } else {
os << ((x.size() == 1) ? "S" : "V"); os << (is_scalar(x) ? "S" : "V");
} }
} }
os << "_"; os << "_";
@ -122,10 +122,6 @@ std::string get_type_string(Dtype d) {
} }
} }
inline bool is_scalar(const array& x) {
return x.size() == 1;
};
// Return a pointer to a compiled function // Return a pointer to a compiled function
void* compile( void* compile(
const std::string& kernel_name, const std::string& kernel_name,
@ -358,7 +354,7 @@ void Compiled::eval_cpu(
bool all_col_contig = true; bool all_col_contig = true;
int non_scalar_inputs = 0; int non_scalar_inputs = 0;
for (auto& x : inputs) { for (auto& x : inputs) {
if (x.size() == 1) { if (is_scalar(x)) {
continue; continue;
} }
non_scalar_inputs++; non_scalar_inputs++;
@ -385,7 +381,7 @@ void Compiled::eval_cpu(
auto& x = inputs[i]; auto& x = inputs[i];
args.push_back((void*)x.data<void>()); args.push_back((void*)x.data<void>());
if (contiguous || x.size() <= 1) { if (contiguous || is_scalar(x)) {
continue; continue;
} }
@ -458,7 +454,7 @@ void Compiled::eval_cpu(
// - Donatable // - Donatable
// - Correct size // - Correct size
// - Not a constant // - Not a constant
if (in.flags().contiguous && in.size() > 1 && in.is_donatable() && if (in.flags().contiguous && !is_scalar(in) && in.is_donatable() &&
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) { constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
outputs[o++].copy_shared_buffer(in); outputs[o++].copy_shared_buffer(in);
} }

View File

@ -49,4 +49,8 @@ void print_complex_constant(std::ostream& os, const array& x) {
void print_constant(std::ostream& os, const array& x); void print_constant(std::ostream& os, const array& x);
inline bool is_scalar(const array& x) {
return x.ndim() == 0;
}
} // namespace mlx::core } // namespace mlx::core

View File

@ -31,9 +31,6 @@ inline void build_kernel(
return constant_ids.find(x.id()) != constant_ids.end(); return constant_ids.find(x.id()) != constant_ids.end();
}; };
// For scalar we shouldn't do the indexing things, just read at 0
auto is_scalar = [](const array& x) { return x.size() == 1; };
NodeNamer namer; NodeNamer namer;
bool add_indices = false; bool add_indices = false;
int cnt = 0; int cnt = 0;
@ -226,8 +223,7 @@ void Compiled::eval_gpu(
/* ndim = */ 0, /* ndim = */ 0,
/* dynamic_dims = */ true); /* dynamic_dims = */ true);
kernel_source_ = kernel.str(); lib = d.get_library(kernel_lib_, kernel.str());
lib = d.get_library(kernel_lib_, kernel_source_);
} }
// Figure out which kernel we are using // Figure out which kernel we are using
@ -235,7 +231,7 @@ void Compiled::eval_gpu(
bool contiguous = true; bool contiguous = true;
for (auto& x : inputs) { for (auto& x : inputs) {
if ((!x.flags().row_contiguous || x.shape() != output_shape) && if ((!x.flags().row_contiguous || x.shape() != output_shape) &&
x.size() > 1) { !is_scalar(x)) {
contiguous = false; contiguous = false;
break; break;
} }
@ -256,7 +252,7 @@ void Compiled::eval_gpu(
auto& x = inputs[i]; auto& x = inputs[i];
// Skip scalar inputs. // Skip scalar inputs.
if (x.size() <= 1) { if (is_scalar(x)) {
continue; continue;
} }
@ -311,7 +307,7 @@ void Compiled::eval_gpu(
} }
auto& x = inputs[i]; auto& x = inputs[i];
set_array_buffer(compute_encoder, x, cnt++); set_array_buffer(compute_encoder, x, cnt++);
if (!contiguous && x.size() > 1) { if (!contiguous && !is_scalar(x)) {
compute_encoder->setBytes( compute_encoder->setBytes(
strides[stride_idx].data(), strides[stride_idx].data(),
strides[stride_idx].size() * sizeof(size_t), strides[stride_idx].size() * sizeof(size_t),

View File

@ -13,7 +13,7 @@
namespace mlx::core { namespace mlx::core {
constexpr int max_compile_depth = 10; constexpr int max_compile_depth = 11;
bool is_unary(const Primitive& p) { bool is_unary(const Primitive& p) {
return ( return (
@ -55,19 +55,20 @@ bool is_noop(const Primitive& p) {
return typeid(p) == typeid(Copy) || typeid(p) == typeid(StopGradient); return typeid(p) == typeid(Copy) || typeid(p) == typeid(StopGradient);
} }
bool is_reduction(const Primitive& p) {
return typeid(p) == typeid(Reduce) || typeid(p) == typeid(ArgReduce);
}
bool is_fusable(const Primitive& p) { bool is_fusable(const Primitive& p) {
return is_unary(p) || is_binary(p) || is_broadcast(p) || is_noop(p); return is_unary(p) || is_binary(p) || is_broadcast(p) || is_noop(p);
} }
namespace detail { bool allows_shapeless(const Primitive& p) {
return typeid(p) == typeid(Compiled) || is_unary(p) || is_binary(p) ||
std::vector<array> compile_replace( is_noop(p) || is_reduction(p) || typeid(p) == typeid(Softmax) ||
const std::vector<array>& tape, typeid(p) == typeid(Sort) || typeid(p) == typeid(ArgSort) ||
const std::vector<array>& trace_inputs, typeid(p) == typeid(ArgPartition) || typeid(p) == typeid(Partition);
const std::vector<array>& trace_outputs, }
const std::vector<array>& inputs);
} // namespace detail
Compiled::Compiled( Compiled::Compiled(
Stream stream, Stream stream,
@ -123,6 +124,23 @@ void Compiled::print(std::ostream& os) {
} }
} }
std::vector<std::vector<int>> Compiled::output_shapes(
const std::vector<array>& inputs) {
size_t nd = 0;
for (auto& in : inputs) {
nd = std::max(nd, in.ndim());
}
std::vector<int> out_shape(nd, 0);
for (auto& in : inputs) {
auto dd = nd - in.ndim();
for (auto i = dd; i < nd; ++i) {
out_shape[i] = std::max(out_shape[i], in.shape()[i - dd]);
}
}
// All outputs have the same shape
return std::vector<std::vector<int>>(outputs_.size(), out_shape);
}
namespace detail { namespace detail {
CompileMode& compile_mode() { CompileMode& compile_mode() {
@ -180,21 +198,30 @@ struct CompilerCache {
std::vector<array> outputs; std::vector<array> outputs;
std::vector<array> tape; std::vector<array> tape;
bool empty{true}; bool empty{true};
std::vector<uint64_t> constants;
}; };
// Returns a reference to a CacheEntry which can be updated // Returns a reference to a CacheEntry which can be updated
// by the caller to avoid copying large tapes / inputs / outputs // by the caller to avoid copying large tapes / inputs / outputs
CacheEntry& find(size_t fun_id, const std::vector<array>& inputs) { CacheEntry& find(
size_t fun_id,
const std::vector<array>& inputs,
bool shapeless,
const std::vector<uint64_t>& constants) {
// Try to find the entry // Try to find the entry
auto [entry_it, inserted] = cache_.insert({fun_id, {}}); auto [entry_it, inserted] = cache_.insert({fun_id, {}});
auto& entries = entry_it->second; auto& entries = entry_it->second;
auto is_match = [](const std::vector<array>& in1, auto is_match = [shapeless](
const std::vector<array>& in2) { const std::vector<array>& in1,
const std::vector<array>& in2) {
if (in1.size() != in2.size()) { if (in1.size() != in2.size()) {
return false; return false;
} }
for (int i = 0; i < in1.size(); ++i) { for (int i = 0; i < in1.size(); ++i) {
if (in1[i].shape() != in2[i].shape()) { if (in1[i].ndim() != in2[i].ndim()) {
return false;
}
if (!shapeless && in1[i].shape() != in2[i].shape()) {
return false; return false;
} }
if (in1[i].dtype() != in2[i].dtype()) { if (in1[i].dtype() != in2[i].dtype()) {
@ -210,7 +237,7 @@ struct CompilerCache {
// more easily searchable structure. // more easily searchable structure.
for (auto& entry : entries) { for (auto& entry : entries) {
// Check the inputs match and return if so // Check the inputs match and return if so
if (is_match(inputs, entry.inputs)) { if (is_match(inputs, entry.inputs) && constants == entry.constants) {
return entry; return entry;
} }
} }
@ -651,7 +678,8 @@ std::vector<array> compile_replace(
const std::vector<array>& tape, const std::vector<array>& tape,
const std::vector<array>& trace_inputs, const std::vector<array>& trace_inputs,
const std::vector<array>& trace_outputs, const std::vector<array>& trace_outputs,
const std::vector<array>& inputs) { const std::vector<array>& inputs,
bool shapeless) {
std::unordered_map<uintptr_t, array> trace_to_real; std::unordered_map<uintptr_t, array> trace_to_real;
for (int i = 0; i < inputs.size(); ++i) { for (int i = 0; i < inputs.size(); ++i) {
trace_to_real.insert({trace_inputs[i].id(), inputs[i]}); trace_to_real.insert({trace_inputs[i].id(), inputs[i]});
@ -669,18 +697,29 @@ std::vector<array> compile_replace(
real_inputs.push_back(trace_to_real.at(in.id())); real_inputs.push_back(trace_to_real.at(in.id()));
} }
if (a.siblings().empty()) { if (a.siblings().empty()) {
auto shape =
shapeless ? a.primitive().output_shapes(real_inputs)[0] : a.shape();
auto real_a = array( auto real_a = array(
a.shape(), a.dtype(), a.primitive_ptr(), std::move(real_inputs)); std::move(shape),
a.dtype(),
a.primitive_ptr(),
std::move(real_inputs));
trace_to_real.insert({a.id(), std::move(real_a)}); trace_to_real.insert({a.id(), std::move(real_a)});
} else { } else {
// Ensure the order is correct for multi-output primitives // Ensure the order is correct for multi-output primitives
std::vector<std::vector<int>> shapes;
std::vector<Dtype> types; std::vector<Dtype> types;
auto trace_out = a.outputs(); auto trace_out = a.outputs();
for (auto& o : trace_out) { for (auto& o : trace_out) {
shapes.push_back(o.shape());
types.push_back(o.dtype()); types.push_back(o.dtype());
} }
std::vector<std::vector<int>> shapes;
if (shapeless) {
shapes = a.primitive().output_shapes(real_inputs);
} else {
for (auto& o : trace_out) {
shapes.push_back(o.shape());
}
}
auto real_out = auto real_out =
array::make_arrays(shapes, types, a.primitive_ptr(), real_inputs); array::make_arrays(shapes, types, a.primitive_ptr(), real_inputs);
for (int i = 0; i < trace_out.size(); ++i) { for (int i = 0; i < trace_out.size(); ++i) {
@ -697,13 +736,34 @@ std::vector<array> compile_replace(
return outputs; return outputs;
} }
void compile_validate_shapeless(const std::vector<array>& tape) {
for (auto& t : tape) {
if (!t.has_primitive()) {
continue;
}
auto& p = t.primitive();
if (allows_shapeless(p)) {
continue;
}
std::ostringstream msg;
msg << "[compile] Cannot compile primitive ";
p.print(msg);
msg << " with shapeless enabled.";
throw std::invalid_argument(msg.str());
}
}
std::function<std::vector<array>(const std::vector<array>&)> compile( std::function<std::vector<array>(const std::vector<array>&)> compile(
const std::function<std::vector<array>(const std::vector<array>&)>& fun, const std::function<std::vector<array>(const std::vector<array>&)>& fun,
size_t fun_id) { size_t fun_id,
bool shapeless /* = false */,
std::vector<uint64_t> constants /* = {} */) {
if (compile_mode() == CompileMode::disabled) { if (compile_mode() == CompileMode::disabled) {
return fun; return fun;
} }
return [fun, fun_id](const std::vector<array>& inputs) { return [fun, fun_id, shapeless, constants = std::move(constants)](
const std::vector<array>& inputs) {
// If the inputs are tracers, trace the original graph // If the inputs are tracers, trace the original graph
if (std::any_of(inputs.begin(), inputs.end(), [](auto& in) { if (std::any_of(inputs.begin(), inputs.end(), [](auto& in) {
return in.is_tracer(); return in.is_tracer();
@ -712,12 +772,14 @@ std::function<std::vector<array>(const std::vector<array>&)> compile(
} }
// Find a cache entry with the correct inputs // Find a cache entry with the correct inputs
auto& entry = compiler_cache().find(fun_id, inputs); auto& entry = compiler_cache().find(fun_id, inputs, shapeless, constants);
// No matching cache entry existed, so compile // No matching cache entry existed, so compile
if (entry.empty) { if (entry.empty) {
// Mark the entry as not empty since we are about to fill it // Mark the entry as not empty since we are about to fill it
entry.empty = false; entry.empty = false;
// Set the constants
entry.constants = std::move(constants);
// Trace to build the graph // Trace to build the graph
std::tie(entry.inputs, entry.outputs) = compile_trace(fun, inputs); std::tie(entry.inputs, entry.outputs) = compile_trace(fun, inputs);
@ -739,11 +801,16 @@ std::function<std::vector<array>(const std::vector<array>&)> compile(
if (compile_mode() != CompileMode::no_fuse) { if (compile_mode() != CompileMode::no_fuse) {
compile_fuse(entry.tape, parents_map, entry.inputs, entry.outputs); compile_fuse(entry.tape, parents_map, entry.inputs, entry.outputs);
} }
if (shapeless) {
compile_validate_shapeless(entry.tape);
}
} }
// At this point we must have a tape, now replace the placeholders // At this point we must have a tape, now replace the placeholders
// with real arrays that can be evaluated // with real arrays that can be evaluated
return compile_replace(entry.tape, entry.inputs, entry.outputs, inputs); return compile_replace(
entry.tape, entry.inputs, entry.outputs, inputs, shapeless);
}; };
} }
@ -754,12 +821,13 @@ void compile_erase(size_t fun_id) {
} // namespace detail } // namespace detail
std::function<std::vector<array>(const std::vector<array>&)> compile( std::function<std::vector<array>(const std::vector<array>&)> compile(
const std::function<std::vector<array>(const std::vector<array>&)>& fun) { const std::function<std::vector<array>(const std::vector<array>&)>& fun,
bool shapeless /* false */) {
if (detail::compile_mode() == CompileMode::disabled) { if (detail::compile_mode() == CompileMode::disabled) {
return fun; return fun;
} }
auto fun_id = detail::getAddress(fun); auto fun_id = detail::getAddress(fun);
return detail::compile(fun, fun_id); return detail::compile(fun, fun_id, shapeless);
} }
void disable_compile() { void disable_compile() {

View File

@ -8,9 +8,10 @@ namespace mlx::core {
enum class CompileMode { disabled, no_simplify, no_fuse, enabled }; enum class CompileMode { disabled, no_simplify, no_fuse, enabled };
// Compile takes a function and returns a new function /** Compile takes a function and returns a compiled function. */
std::function<std::vector<array>(const std::vector<array>&)> compile( std::function<std::vector<array>(const std::vector<array>&)> compile(
const std::function<std::vector<array>(const std::vector<array>&)>& fun); const std::function<std::vector<array>(const std::vector<array>&)>& fun,
bool shapeless = false);
/** Globally disable compilation. /** Globally disable compilation.
* Setting the environment variable ``MLX_DISABLE_COMPILE`` can also * Setting the environment variable ``MLX_DISABLE_COMPILE`` can also

View File

@ -71,6 +71,15 @@ std::pair<std::vector<array>, std::vector<int>> Primitive::vmap(
throw std::invalid_argument("Primitive's vmap not implemented."); throw std::invalid_argument("Primitive's vmap not implemented.");
}; };
std::vector<std::vector<int>> Primitive::output_shapes(
const std::vector<array>&) {
std::ostringstream msg;
msg << "[Primitive::output_shapes] ";
this->print(msg);
msg << " cannot infer output shapes.";
throw std::invalid_argument(msg.str());
};
std::vector<array> Abs::vjp( std::vector<array> Abs::vjp(
const std::vector<array>& primals, const std::vector<array>& primals,
const std::vector<array>& cotangents, const std::vector<array>& cotangents,
@ -383,6 +392,13 @@ std::pair<std::vector<array>, std::vector<int>> ArgSort::vmap(
return {{argsort(inputs[0], axis_ + (axes[0] <= axis_), stream())}, axes}; return {{argsort(inputs[0], axis_ + (axes[0] <= axis_), stream())}, axes};
} }
std::vector<std::vector<int>> ArgReduce::output_shapes(
const std::vector<array>& inputs) {
auto out_shape = inputs[0].shape();
out_shape[axis_] = 1;
return {out_shape};
}
bool ArgSort::is_equivalent(const Primitive& other) const { bool ArgSort::is_equivalent(const Primitive& other) const {
const ArgSort& r_other = static_cast<const ArgSort&>(other); const ArgSort& r_other = static_cast<const ArgSort&>(other);
return axis_ == r_other.axis_; return axis_ == r_other.axis_;
@ -2202,6 +2218,15 @@ bool Reduce::is_equivalent(const Primitive& other) const {
return reduce_type_ == r_other.reduce_type_ && axes_ == r_other.axes_; return reduce_type_ == r_other.reduce_type_ && axes_ == r_other.axes_;
} }
std::vector<std::vector<int>> Reduce::output_shapes(
const std::vector<array>& inputs) {
std::vector<int> out_shape = inputs[0].shape();
for (auto i : axes_) {
out_shape[i] = 1;
}
return {out_shape};
}
std::vector<array> Round::vjp( std::vector<array> Round::vjp(
const std::vector<array>& primals, const std::vector<array>& primals,
const std::vector<array>& cotangents, const std::vector<array>& cotangents,

View File

@ -36,6 +36,12 @@
return true; \ return true; \
} }
#define DEFINE_INPUT_OUTPUT_SHAPE() \
std::vector<std::vector<int>> output_shapes( \
const std::vector<array>& inputs) override { \
return {inputs[0].shape()}; \
};
namespace mlx::core { namespace mlx::core {
// Abstract base class // Abstract base class
@ -102,6 +108,11 @@ class Primitive {
return false; return false;
} }
/** Get the output shapes of the primitive. This is not required to be
* implemented by derived classes, in which case it will throw. */
virtual std::vector<std::vector<int>> output_shapes(
const std::vector<array>& inputs);
virtual ~Primitive() = default; virtual ~Primitive() = default;
Primitive(const Primitive& other) = delete; Primitive(const Primitive& other) = delete;
Primitive(Primitive&& other) = delete; Primitive(Primitive&& other) = delete;
@ -152,6 +163,7 @@ class Abs : public UnaryPrimitive {
DEFINE_GRADS() DEFINE_GRADS()
DEFINE_PRINT(Abs) DEFINE_PRINT(Abs)
DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE()
private: private:
void eval(const std::vector<array>& inputs, array& out); void eval(const std::vector<array>& inputs, array& out);
@ -168,6 +180,7 @@ class Add : public UnaryPrimitive {
DEFINE_GRADS() DEFINE_GRADS()
DEFINE_PRINT(Add) DEFINE_PRINT(Add)
DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE()
private: private:
void eval(const std::vector<array>& inputs, array& out); void eval(const std::vector<array>& inputs, array& out);
@ -226,6 +239,7 @@ class ArcCos : public UnaryPrimitive {
DEFINE_GRADS() DEFINE_GRADS()
DEFINE_PRINT(ArcCos) DEFINE_PRINT(ArcCos)
DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE()
private: private:
void eval(const std::vector<array>& inputs, array& out); void eval(const std::vector<array>& inputs, array& out);
@ -242,6 +256,7 @@ class ArcCosh : public UnaryPrimitive {
DEFINE_GRADS() DEFINE_GRADS()
DEFINE_PRINT(ArcCosh) DEFINE_PRINT(ArcCosh)
DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE()
private: private:
void eval(const std::vector<array>& inputs, array& out); void eval(const std::vector<array>& inputs, array& out);
@ -258,6 +273,7 @@ class ArcSin : public UnaryPrimitive {
DEFINE_GRADS() DEFINE_GRADS()
DEFINE_PRINT(ArcSin) DEFINE_PRINT(ArcSin)
DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE()
private: private:
void eval(const std::vector<array>& inputs, array& out); void eval(const std::vector<array>& inputs, array& out);
@ -274,6 +290,7 @@ class ArcSinh : public UnaryPrimitive {
DEFINE_GRADS() DEFINE_GRADS()
DEFINE_PRINT(ArcSinh) DEFINE_PRINT(ArcSinh)
DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE()
private: private:
void eval(const std::vector<array>& inputs, array& out); void eval(const std::vector<array>& inputs, array& out);
@ -290,6 +307,7 @@ class ArcTan : public UnaryPrimitive {
DEFINE_GRADS() DEFINE_GRADS()
DEFINE_PRINT(ArcTan) DEFINE_PRINT(ArcTan)
DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE()
private: private:
void eval(const std::vector<array>& inputs, array& out); void eval(const std::vector<array>& inputs, array& out);
@ -306,6 +324,7 @@ class ArcTanh : public UnaryPrimitive {
DEFINE_GRADS() DEFINE_GRADS()
DEFINE_PRINT(ArcTanh) DEFINE_PRINT(ArcTanh)
DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE()
private: private:
void eval(const std::vector<array>& inputs, array& out); void eval(const std::vector<array>& inputs, array& out);
@ -321,6 +340,7 @@ class ArgPartition : public UnaryPrimitive {
DEFINE_VMAP() DEFINE_VMAP()
DEFINE_PRINT(ArgPartition) DEFINE_PRINT(ArgPartition)
DEFINE_INPUT_OUTPUT_SHAPE()
bool is_equivalent(const Primitive& other) const override; bool is_equivalent(const Primitive& other) const override;
private: private:
@ -346,6 +366,8 @@ class ArgReduce : public UnaryPrimitive {
DEFINE_VMAP() DEFINE_VMAP()
DEFINE_PRINT(ArgReduce) DEFINE_PRINT(ArgReduce)
bool is_equivalent(const Primitive& other) const override; bool is_equivalent(const Primitive& other) const override;
std::vector<std::vector<int>> output_shapes(
const std::vector<array>& inputs) override;
private: private:
ReduceType reduce_type_; ReduceType reduce_type_;
@ -364,6 +386,7 @@ class ArgSort : public UnaryPrimitive {
DEFINE_VMAP() DEFINE_VMAP()
DEFINE_PRINT(ArgSort) DEFINE_PRINT(ArgSort)
DEFINE_INPUT_OUTPUT_SHAPE()
bool is_equivalent(const Primitive& other) const override; bool is_equivalent(const Primitive& other) const override;
private: private:
@ -383,6 +406,7 @@ class AsType : public UnaryPrimitive {
DEFINE_VMAP() DEFINE_VMAP()
DEFINE_GRADS() DEFINE_GRADS()
DEFINE_PRINT(AsType) DEFINE_PRINT(AsType)
DEFINE_INPUT_OUTPUT_SHAPE()
bool is_equivalent(const Primitive& other) const override; bool is_equivalent(const Primitive& other) const override;
private: private:
@ -448,6 +472,7 @@ class Ceil : public UnaryPrimitive {
DEFINE_GRADS() DEFINE_GRADS()
DEFINE_PRINT(Ceil) DEFINE_PRINT(Ceil)
DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE()
private: private:
void eval(const std::vector<array>& inputs, array& out); void eval(const std::vector<array>& inputs, array& out);
@ -478,15 +503,14 @@ class Compiled : public Primitive {
DEFINE_VMAP() DEFINE_VMAP()
DEFINE_GRADS() DEFINE_GRADS()
std::vector<std::vector<int>> output_shapes(
const std::vector<array>& inputs) override;
void print(std::ostream& os) override; void print(std::ostream& os) override;
bool is_equivalent(const Primitive& other) const override; bool is_equivalent(const Primitive& other) const override;
std::string metal_lib_name() const { std::string lib_name() const {
return kernel_lib_; return kernel_lib_;
} }
std::string metal_lib_source() const {
return kernel_source_;
}
private: private:
const std::vector<array> inputs_; const std::vector<array> inputs_;
@ -495,7 +519,6 @@ class Compiled : public Primitive {
const std::unordered_set<uintptr_t> constant_ids_; const std::unordered_set<uintptr_t> constant_ids_;
std::string kernel_lib_; std::string kernel_lib_;
std::string kernel_source_;
}; };
class Concatenate : public UnaryPrimitive { class Concatenate : public UnaryPrimitive {
@ -563,6 +586,7 @@ class Copy : public UnaryPrimitive {
DEFINE_GRADS() DEFINE_GRADS()
DEFINE_PRINT(Copy) DEFINE_PRINT(Copy)
DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE()
private: private:
void eval(const std::vector<array>& inputs, array& out); void eval(const std::vector<array>& inputs, array& out);
@ -579,6 +603,7 @@ class Cos : public UnaryPrimitive {
DEFINE_GRADS() DEFINE_GRADS()
DEFINE_PRINT(Cos) DEFINE_PRINT(Cos)
DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE()
private: private:
void eval(const std::vector<array>& inputs, array& out); void eval(const std::vector<array>& inputs, array& out);
@ -595,6 +620,7 @@ class Cosh : public UnaryPrimitive {
DEFINE_GRADS() DEFINE_GRADS()
DEFINE_PRINT(Cosh) DEFINE_PRINT(Cosh)
DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE()
private: private:
void eval(const std::vector<array>& inputs, array& out); void eval(const std::vector<array>& inputs, array& out);
@ -665,6 +691,7 @@ class Divide : public UnaryPrimitive {
DEFINE_GRADS() DEFINE_GRADS()
DEFINE_PRINT(Divide) DEFINE_PRINT(Divide)
DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE()
private: private:
void eval(const std::vector<array>& inputs, array& out); void eval(const std::vector<array>& inputs, array& out);
@ -683,6 +710,10 @@ class DivMod : public Primitive {
DEFINE_GRADS() DEFINE_GRADS()
DEFINE_PRINT(DivMod) DEFINE_PRINT(DivMod)
DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_DEFAULT_IS_EQUIVALENT()
std::vector<std::vector<int>> output_shapes(
const std::vector<array>& inputs) override {
return std::vector{inputs[0].shape(), inputs[0].shape()};
};
private: private:
void eval(const std::vector<array>& inputs, std::vector<array>& outputs); void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
@ -699,6 +730,7 @@ class Remainder : public UnaryPrimitive {
DEFINE_GRADS() DEFINE_GRADS()
DEFINE_PRINT(Remainder) DEFINE_PRINT(Remainder)
DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE()
private: private:
void eval(const std::vector<array>& inputs, array& out); void eval(const std::vector<array>& inputs, array& out);
@ -715,6 +747,7 @@ class Equal : public UnaryPrimitive {
DEFINE_VMAP() DEFINE_VMAP()
DEFINE_GRADS() DEFINE_GRADS()
DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE()
void print(std::ostream& os) override { void print(std::ostream& os) override {
if (equal_nan_) { if (equal_nan_) {
@ -740,6 +773,7 @@ class Erf : public UnaryPrimitive {
DEFINE_GRADS() DEFINE_GRADS()
DEFINE_PRINT(Erf) DEFINE_PRINT(Erf)
DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE()
private: private:
void eval(const std::vector<array>& inputs, array& out); void eval(const std::vector<array>& inputs, array& out);
@ -756,6 +790,7 @@ class ErfInv : public UnaryPrimitive {
DEFINE_GRADS() DEFINE_GRADS()
DEFINE_PRINT(ErfInv) DEFINE_PRINT(ErfInv)
DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE()
private: private:
void eval(const std::vector<array>& inputs, array& out); void eval(const std::vector<array>& inputs, array& out);
@ -772,6 +807,7 @@ class Exp : public UnaryPrimitive {
DEFINE_GRADS() DEFINE_GRADS()
DEFINE_PRINT(Exp) DEFINE_PRINT(Exp)
DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE()
private: private:
void eval(const std::vector<array>& inputs, array& out); void eval(const std::vector<array>& inputs, array& out);
@ -814,6 +850,7 @@ class Floor : public UnaryPrimitive {
DEFINE_GRADS() DEFINE_GRADS()
DEFINE_PRINT(Floor) DEFINE_PRINT(Floor)
DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE()
private: private:
void eval(const std::vector<array>& inputs, array& out); void eval(const std::vector<array>& inputs, array& out);
@ -868,6 +905,7 @@ class Greater : public UnaryPrimitive {
DEFINE_GRADS() DEFINE_GRADS()
DEFINE_PRINT(Greater) DEFINE_PRINT(Greater)
DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE()
private: private:
void eval(const std::vector<array>& inputs, array& out); void eval(const std::vector<array>& inputs, array& out);
@ -884,6 +922,7 @@ class GreaterEqual : public UnaryPrimitive {
DEFINE_GRADS() DEFINE_GRADS()
DEFINE_PRINT(GreaterEqual) DEFINE_PRINT(GreaterEqual)
DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE()
private: private:
void eval(const std::vector<array>& inputs, array& out); void eval(const std::vector<array>& inputs, array& out);
@ -900,6 +939,7 @@ class Less : public UnaryPrimitive {
DEFINE_GRADS() DEFINE_GRADS()
DEFINE_PRINT(Less) DEFINE_PRINT(Less)
DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE()
private: private:
void eval(const std::vector<array>& inputs, array& out); void eval(const std::vector<array>& inputs, array& out);
@ -916,6 +956,7 @@ class LessEqual : public UnaryPrimitive {
DEFINE_GRADS() DEFINE_GRADS()
DEFINE_PRINT(LessEqual) DEFINE_PRINT(LessEqual)
DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE()
private: private:
void eval(const std::vector<array>& inputs, array& out); void eval(const std::vector<array>& inputs, array& out);
@ -958,6 +999,7 @@ class Log : public UnaryPrimitive {
DEFINE_VMAP() DEFINE_VMAP()
DEFINE_GRADS() DEFINE_GRADS()
DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE()
void print(std::ostream& os) override { void print(std::ostream& os) override {
switch (base_) { switch (base_) {
@ -988,6 +1030,7 @@ class Log1p : public UnaryPrimitive {
DEFINE_VMAP() DEFINE_VMAP()
DEFINE_GRADS() DEFINE_GRADS()
DEFINE_PRINT(Log1p) DEFINE_PRINT(Log1p)
DEFINE_INPUT_OUTPUT_SHAPE()
private: private:
void eval(const std::vector<array>& inputs, array& out); void eval(const std::vector<array>& inputs, array& out);
@ -1004,6 +1047,7 @@ class LogicalNot : public UnaryPrimitive {
DEFINE_GRADS() DEFINE_GRADS()
DEFINE_PRINT(LogicalNot) DEFINE_PRINT(LogicalNot)
DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE()
private: private:
void eval(const std::vector<array>& inputs, array& out); void eval(const std::vector<array>& inputs, array& out);
@ -1020,6 +1064,7 @@ class LogicalAnd : public UnaryPrimitive {
DEFINE_GRADS() DEFINE_GRADS()
DEFINE_PRINT(LogicalAnd) DEFINE_PRINT(LogicalAnd)
DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE()
private: private:
void eval(const std::vector<array>& inputs, array& out); void eval(const std::vector<array>& inputs, array& out);
@ -1036,6 +1081,7 @@ class LogicalOr : public UnaryPrimitive {
DEFINE_GRADS() DEFINE_GRADS()
DEFINE_PRINT(LogicalOr) DEFINE_PRINT(LogicalOr)
DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE()
private: private:
void eval(const std::vector<array>& inputs, array& out); void eval(const std::vector<array>& inputs, array& out);
@ -1052,6 +1098,7 @@ class LogAddExp : public UnaryPrimitive {
DEFINE_GRADS() DEFINE_GRADS()
DEFINE_PRINT(LogAddExp) DEFINE_PRINT(LogAddExp)
DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE()
private: private:
void eval(const std::vector<array>& inputs, array& out); void eval(const std::vector<array>& inputs, array& out);
@ -1085,6 +1132,7 @@ class Maximum : public UnaryPrimitive {
DEFINE_GRADS() DEFINE_GRADS()
DEFINE_PRINT(Maximum) DEFINE_PRINT(Maximum)
DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE()
private: private:
void eval(const std::vector<array>& inputs, array& out); void eval(const std::vector<array>& inputs, array& out);
@ -1101,6 +1149,7 @@ class Minimum : public UnaryPrimitive {
DEFINE_GRADS() DEFINE_GRADS()
DEFINE_PRINT(Minimum) DEFINE_PRINT(Minimum)
DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE()
private: private:
void eval(const std::vector<array>& inputs, array& out); void eval(const std::vector<array>& inputs, array& out);
@ -1117,6 +1166,7 @@ class Multiply : public UnaryPrimitive {
DEFINE_GRADS() DEFINE_GRADS()
DEFINE_PRINT(Multiply) DEFINE_PRINT(Multiply)
DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE()
private: private:
void eval(const std::vector<array>& inputs, array& out); void eval(const std::vector<array>& inputs, array& out);
@ -1133,6 +1183,7 @@ class Negative : public UnaryPrimitive {
DEFINE_GRADS() DEFINE_GRADS()
DEFINE_PRINT(Negative) DEFINE_PRINT(Negative)
DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE()
private: private:
void eval(const std::vector<array>& inputs, array& out); void eval(const std::vector<array>& inputs, array& out);
@ -1149,6 +1200,7 @@ class NotEqual : public UnaryPrimitive {
DEFINE_GRADS() DEFINE_GRADS()
DEFINE_PRINT(NotEqual) DEFINE_PRINT(NotEqual)
DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE()
private: private:
void eval(const std::vector<array>& inputs, array& out); void eval(const std::vector<array>& inputs, array& out);
@ -1193,6 +1245,7 @@ class Partition : public UnaryPrimitive {
DEFINE_VMAP() DEFINE_VMAP()
DEFINE_GRADS() DEFINE_GRADS()
DEFINE_PRINT(Partition) DEFINE_PRINT(Partition)
DEFINE_INPUT_OUTPUT_SHAPE()
bool is_equivalent(const Primitive& other) const override; bool is_equivalent(const Primitive& other) const override;
private: private:
@ -1213,6 +1266,7 @@ class Power : public UnaryPrimitive {
DEFINE_GRADS() DEFINE_GRADS()
DEFINE_PRINT(Power) DEFINE_PRINT(Power)
DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE()
private: private:
void eval(const std::vector<array>& inputs, array& out); void eval(const std::vector<array>& inputs, array& out);
@ -1305,6 +1359,9 @@ class Reduce : public UnaryPrimitive {
const std::vector<int>& argnums, const std::vector<int>& argnums,
const std::vector<array>& outputs) override; const std::vector<array>& outputs) override;
std::vector<std::vector<int>> output_shapes(
const std::vector<array>& inputs) override;
void print(std::ostream& os) override { void print(std::ostream& os) override {
switch (reduce_type_) { switch (reduce_type_) {
case And: case And:
@ -1347,6 +1404,7 @@ class Round : public UnaryPrimitive {
DEFINE_GRADS() DEFINE_GRADS()
DEFINE_PRINT(Round) DEFINE_PRINT(Round)
DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE()
private: private:
void eval(const std::vector<array>& inputs, array& out); void eval(const std::vector<array>& inputs, array& out);
@ -1455,6 +1513,7 @@ class Sigmoid : public UnaryPrimitive {
DEFINE_GRADS() DEFINE_GRADS()
DEFINE_PRINT(Sigmoid) DEFINE_PRINT(Sigmoid)
DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE()
private: private:
void eval(const std::vector<array>& inputs, array& out); void eval(const std::vector<array>& inputs, array& out);
@ -1471,6 +1530,7 @@ class Sign : public UnaryPrimitive {
DEFINE_GRADS() DEFINE_GRADS()
DEFINE_PRINT(Sign) DEFINE_PRINT(Sign)
DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE()
private: private:
void eval(const std::vector<array>& inputs, array& out); void eval(const std::vector<array>& inputs, array& out);
@ -1487,6 +1547,7 @@ class Sin : public UnaryPrimitive {
DEFINE_GRADS() DEFINE_GRADS()
DEFINE_PRINT(Sin) DEFINE_PRINT(Sin)
DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE()
private: private:
void eval(const std::vector<array>& inputs, array& out); void eval(const std::vector<array>& inputs, array& out);
@ -1503,6 +1564,7 @@ class Sinh : public UnaryPrimitive {
DEFINE_GRADS() DEFINE_GRADS()
DEFINE_PRINT(Sinh) DEFINE_PRINT(Sinh)
DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE()
private: private:
void eval(const std::vector<array>& inputs, array& out); void eval(const std::vector<array>& inputs, array& out);
@ -1547,6 +1609,7 @@ class Softmax : public UnaryPrimitive {
DEFINE_GRADS() DEFINE_GRADS()
DEFINE_PRINT(Softmax) DEFINE_PRINT(Softmax)
DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE()
private: private:
void eval(const std::vector<array>& inputs, array& out); void eval(const std::vector<array>& inputs, array& out);
@ -1563,6 +1626,7 @@ class Sort : public UnaryPrimitive {
DEFINE_VMAP() DEFINE_VMAP()
DEFINE_GRADS() DEFINE_GRADS()
DEFINE_PRINT(Sort) DEFINE_PRINT(Sort)
DEFINE_INPUT_OUTPUT_SHAPE()
bool is_equivalent(const Primitive& other) const override; bool is_equivalent(const Primitive& other) const override;
private: private:
@ -1604,6 +1668,7 @@ class Square : public UnaryPrimitive {
DEFINE_GRADS() DEFINE_GRADS()
DEFINE_PRINT(Square) DEFINE_PRINT(Square)
DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE()
private: private:
void eval(const std::vector<array>& inputs, array& out); void eval(const std::vector<array>& inputs, array& out);
@ -1619,6 +1684,7 @@ class Sqrt : public UnaryPrimitive {
DEFINE_VMAP() DEFINE_VMAP()
DEFINE_GRADS() DEFINE_GRADS()
DEFINE_INPUT_OUTPUT_SHAPE()
bool is_equivalent(const Primitive& other) const override; bool is_equivalent(const Primitive& other) const override;
void print(std::ostream& os) override { void print(std::ostream& os) override {
@ -1644,6 +1710,7 @@ class StopGradient : public UnaryPrimitive {
DEFINE_VMAP() DEFINE_VMAP()
DEFINE_PRINT(StopGradient) DEFINE_PRINT(StopGradient)
DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE()
private: private:
void eval(const std::vector<array>& inputs, array& out); void eval(const std::vector<array>& inputs, array& out);
@ -1660,6 +1727,7 @@ class Subtract : public UnaryPrimitive {
DEFINE_GRADS() DEFINE_GRADS()
DEFINE_PRINT(Subtract) DEFINE_PRINT(Subtract)
DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE()
private: private:
void eval(const std::vector<array>& inputs, array& out); void eval(const std::vector<array>& inputs, array& out);
@ -1676,6 +1744,7 @@ class Tan : public UnaryPrimitive {
DEFINE_GRADS() DEFINE_GRADS()
DEFINE_PRINT(Tan) DEFINE_PRINT(Tan)
DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE()
private: private:
void eval(const std::vector<array>& inputs, array& out); void eval(const std::vector<array>& inputs, array& out);
@ -1692,6 +1761,7 @@ class Tanh : public UnaryPrimitive {
DEFINE_GRADS() DEFINE_GRADS()
DEFINE_PRINT(Tanh) DEFINE_PRINT(Tanh)
DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE()
private: private:
void eval(const std::vector<array>& inputs, array& out); void eval(const std::vector<array>& inputs, array& out);

View File

@ -18,7 +18,9 @@ std::vector<array> vmap_replace(
// idea. // idea.
std::function<std::vector<array>(const std::vector<array>&)> compile( std::function<std::vector<array>(const std::vector<array>&)> compile(
const std::function<std::vector<array>(const std::vector<array>&)>& fun, const std::function<std::vector<array>(const std::vector<array>&)>& fun,
size_t fun_id); size_t fun_id,
bool shapeless = false,
std::vector<uint64_t> constants = {});
// Erase cached compile functions // Erase cached compile functions
void compile_erase(size_t fun_id); void compile_erase(size_t fun_id);

View File

@ -1,6 +1,7 @@
# Copyright © 2023 Apple Inc. # Copyright © 2023 Apple Inc.
import math import math
from functools import partial
from typing import Any from typing import Any
import mlx.core as mx import mlx.core as mx
@ -9,13 +10,13 @@ from mlx.nn.layers.base import Module
def _make_activation_module(f): def _make_activation_module(f):
def decorator(klass): def decorator(klass):
klass.__doc__ = f.__doc__ klass.__call__ = lambda _, x: f(x)
klass.__call__ = lambda self, x: f(x)
return klass return klass
return decorator return decorator
@partial(mx.compile, shapeless=True)
def sigmoid(x): def sigmoid(x):
r"""Applies the element-wise function: r"""Applies the element-wise function:
@ -25,6 +26,7 @@ def sigmoid(x):
return mx.sigmoid(x) return mx.sigmoid(x)
@partial(mx.compile, shapeless=True)
def relu(x): def relu(x):
r"""Applies the Rectified Linear Unit. r"""Applies the Rectified Linear Unit.
@ -33,6 +35,7 @@ def relu(x):
return mx.maximum(x, 0) return mx.maximum(x, 0)
@partial(mx.compile, shapeless=True)
def leaky_relu(x, negative_slope=0.01): def leaky_relu(x, negative_slope=0.01):
r"""Applies the Leaky Rectified Linear Unit. r"""Applies the Leaky Rectified Linear Unit.
@ -41,6 +44,7 @@ def leaky_relu(x, negative_slope=0.01):
return mx.maximum(negative_slope * x, x) return mx.maximum(negative_slope * x, x)
@partial(mx.compile, shapeless=True)
def log_softmax(x, axis=-1): def log_softmax(x, axis=-1):
r"""Applies the Log Softmax function. r"""Applies the Log Softmax function.
@ -49,6 +53,7 @@ def log_softmax(x, axis=-1):
return x - mx.logsumexp(x, axis=axis, keepdims=True) return x - mx.logsumexp(x, axis=axis, keepdims=True)
@partial(mx.compile, shapeless=True)
def elu(x, alpha=1.0): def elu(x, alpha=1.0):
r"""Applies the Exponential Linear Unit. r"""Applies the Exponential Linear Unit.
@ -57,6 +62,7 @@ def elu(x, alpha=1.0):
return mx.where(x > 0, x, alpha * (mx.exp(x) - 1)) return mx.where(x > 0, x, alpha * (mx.exp(x) - 1))
@partial(mx.compile, shapeless=True)
def relu6(x): def relu6(x):
r"""Applies the Rectified Linear Unit 6. r"""Applies the Rectified Linear Unit 6.
@ -65,6 +71,7 @@ def relu6(x):
return mx.minimum(mx.maximum(x, 0), 6.0) return mx.minimum(mx.maximum(x, 0), 6.0)
@partial(mx.compile, shapeless=True)
def softmax(x, axis=-1): def softmax(x, axis=-1):
r"""Applies the Softmax function. r"""Applies the Softmax function.
@ -73,6 +80,7 @@ def softmax(x, axis=-1):
return mx.softmax(x, axis=axis) return mx.softmax(x, axis=axis)
@partial(mx.compile, shapeless=True)
def softplus(x): def softplus(x):
r"""Applies the Softplus function. r"""Applies the Softplus function.
@ -81,6 +89,7 @@ def softplus(x):
return mx.logaddexp(x, 0) return mx.logaddexp(x, 0)
@partial(mx.compile, shapeless=True)
def softsign(x): def softsign(x):
r"""Applies the Softsign function. r"""Applies the Softsign function.
@ -89,6 +98,7 @@ def softsign(x):
return mx.divide(x, 1 + mx.abs(x)) return mx.divide(x, 1 + mx.abs(x))
@partial(mx.compile, shapeless=True)
def softshrink(x, lambd: float = 0.5): def softshrink(x, lambd: float = 0.5):
r"""Applies the Softshrink activation function. r"""Applies the Softshrink activation function.
@ -102,6 +112,7 @@ def softshrink(x, lambd: float = 0.5):
return mx.where(mx.abs(x) > lambd, x - mx.sign(x) * lambd, 0) return mx.where(mx.abs(x) > lambd, x - mx.sign(x) * lambd, 0)
@partial(mx.compile, shapeless=True)
def celu(x, alpha=1.0): def celu(x, alpha=1.0):
r"""Applies the Continuously Differentiable Exponential Linear Unit. r"""Applies the Continuously Differentiable Exponential Linear Unit.
@ -111,6 +122,7 @@ def celu(x, alpha=1.0):
return mx.maximum(x, 0.0) + alpha * (mx.exp(mx.minimum(x, 0.0) / alpha) - 1) return mx.maximum(x, 0.0) + alpha * (mx.exp(mx.minimum(x, 0.0) / alpha) - 1)
@partial(mx.compile, shapeless=True)
def silu(x): def silu(x):
r"""Applies the Sigmoid Linear Unit. Also known as Swish. r"""Applies the Sigmoid Linear Unit. Also known as Swish.
@ -120,6 +132,7 @@ def silu(x):
return x * mx.sigmoid(x) return x * mx.sigmoid(x)
@partial(mx.compile, shapeless=True)
def log_sigmoid(x): def log_sigmoid(x):
r"""Applies the Log Sigmoid function. r"""Applies the Log Sigmoid function.
@ -128,6 +141,7 @@ def log_sigmoid(x):
return -softplus(-x) return -softplus(-x)
@partial(mx.compile, shapeless=True)
def gelu(x): def gelu(x):
r"""Applies the Gaussian Error Linear Units function. r"""Applies the Gaussian Error Linear Units function.
@ -142,6 +156,7 @@ def gelu(x):
return x * (1 + mx.erf(x / math.sqrt(2))) / 2 return x * (1 + mx.erf(x / math.sqrt(2))) / 2
@partial(mx.compile, shapeless=True)
def gelu_approx(x): def gelu_approx(x):
r"""An approximation to Gaussian Error Linear Unit. r"""An approximation to Gaussian Error Linear Unit.
@ -159,6 +174,7 @@ def gelu_approx(x):
return x * mx.sigmoid(1.60033 * x * (1 + 0.0433603 * x.square())) return x * mx.sigmoid(1.60033 * x * (1 + 0.0433603 * x.square()))
@partial(mx.compile, shapeless=True)
def gelu_fast_approx(x): def gelu_fast_approx(x):
r"""A fast approximation to Gaussian Error Linear Unit. r"""A fast approximation to Gaussian Error Linear Unit.
@ -192,27 +208,7 @@ def glu(x: mx.array, axis: int = -1) -> mx.array:
return a * mx.sigmoid(b) return a * mx.sigmoid(b)
class GLU(Module): @partial(mx.compile, shapeless=True)
r"""Applies the gated linear unit function.
This function splits the ``axis`` dimension of the input into two halves
(:math:`a` and :math:`b`) and applies :math:`a * \sigma(b)`.
.. math::
textrm{GLU}(x) = a * \sigma(b)
Args:
axis (int): The dimension to split along. Default: ``-1``
"""
def __init__(self, axis: int = -1):
super().__init__()
self.axis = axis
def __call__(self, x) -> Any:
return glu(x=x, axis=self.axis)
def step(x: mx.array, threshold: float = 0.0): def step(x: mx.array, threshold: float = 0.0):
r"""Applies the Step Activation Function. r"""Applies the Step Activation Function.
@ -232,6 +228,7 @@ def step(x: mx.array, threshold: float = 0.0):
return mx.where(x > threshold, 1, 0) return mx.where(x > threshold, 1, 0)
@partial(mx.compile, shapeless=True)
def selu(x): def selu(x):
r"""Applies the Scaled Exponential Linear Unit. r"""Applies the Scaled Exponential Linear Unit.
@ -248,6 +245,7 @@ def selu(x):
return elu(x, 1.67326) * 1.0507 return elu(x, 1.67326) * 1.0507
@partial(mx.compile, shapeless=True)
def prelu(x: mx.array, alpha: mx.array) -> mx.array: def prelu(x: mx.array, alpha: mx.array) -> mx.array:
r"""Applies the element-wise parametric ReLU. r"""Applies the element-wise parametric ReLU.
@ -259,6 +257,7 @@ def prelu(x: mx.array, alpha: mx.array) -> mx.array:
return mx.maximum(0, x) + alpha * mx.minimum(0, x) return mx.maximum(0, x) + alpha * mx.minimum(0, x)
@partial(mx.compile, shapeless=True)
def mish(x: mx.array) -> mx.array: def mish(x: mx.array) -> mx.array:
r"""Applies the Mish function, element-wise. r"""Applies the Mish function, element-wise.
Mish: A Self Regularized Non-Monotonic Neural Activation Function. Mish: A Self Regularized Non-Monotonic Neural Activation Function.
@ -272,6 +271,7 @@ def mish(x: mx.array) -> mx.array:
return x * mx.tanh(softplus(x)) return x * mx.tanh(softplus(x))
@partial(mx.compile, shapeless=True)
def hardswish(x): def hardswish(x):
r"""Applies the hardswish function, element-wise. r"""Applies the hardswish function, element-wise.
@ -282,6 +282,35 @@ def hardswish(x):
return x * mx.minimum(max_x_3, 6) / 6 return x * mx.minimum(max_x_3, 6) / 6
def tanh(x):
"""Applies the hyperbolic tangent function.
Simply ``mx.tanh(x)``.
"""
return mx.tanh(x)
class GLU(Module):
r"""Applies the gated linear unit function.
This function splits the ``axis`` dimension of the input into two halves
(:math:`a` and :math:`b`) and applies :math:`a * \sigma(b)`.
.. math::
textrm{GLU}(x) = a * \sigma(b)
Args:
axis (int): The dimension to split along. Default: ``-1``
"""
def __init__(self, axis: int = -1):
super().__init__()
self.axis = axis
def __call__(self, x) -> Any:
return glu(x=x, axis=self.axis)
@_make_activation_module(mx.sigmoid) @_make_activation_module(mx.sigmoid)
class Sigmoid(Module): class Sigmoid(Module):
r"""Applies the sigmoid function, element-wise. r"""Applies the sigmoid function, element-wise.
@ -500,14 +529,6 @@ class GELU(Module):
return self._act(x) return self._act(x)
def tanh(x):
"""Applies the hyperbolic tangent function.
Simply ``mx.tanh(x)``.
"""
return mx.tanh(x)
@_make_activation_module(tanh) @_make_activation_module(tanh)
class Tanh(Module): class Tanh(Module):
r"""Applies the hyperbolic tangent function. r"""Applies the hyperbolic tangent function.

View File

@ -555,13 +555,19 @@ struct PyCompiledFun {
size_t fun_id; size_t fun_id;
py::object captured_inputs; py::object captured_inputs;
py::object captured_outputs; py::object captured_outputs;
bool shapeless;
size_t num_outputs{0}; size_t num_outputs{0};
PyCompiledFun(const py::function& fun, py::object inputs, py::object outputs) PyCompiledFun(
const py::function& fun,
py::object inputs,
py::object outputs,
bool shapeless)
: fun(fun), : fun(fun),
fun_id(reinterpret_cast<size_t>(fun.ptr())), fun_id(reinterpret_cast<size_t>(fun.ptr())),
captured_inputs(inputs), captured_inputs(inputs),
captured_outputs(outputs) {} captured_outputs(outputs),
shapeless(shapeless) {}
PyCompiledFun(const PyCompiledFun&) = delete; PyCompiledFun(const PyCompiledFun&) = delete;
PyCompiledFun& operator=(const PyCompiledFun&) = delete; PyCompiledFun& operator=(const PyCompiledFun&) = delete;
@ -571,11 +577,15 @@ struct PyCompiledFun {
other.fun_id = 0; other.fun_id = 0;
captured_inputs = std::move(other.captured_inputs); captured_inputs = std::move(other.captured_inputs);
captured_outputs = std::move(other.captured_outputs); captured_outputs = std::move(other.captured_outputs);
shapeless = other.shapeless;
num_outputs = other.num_outputs; num_outputs = other.num_outputs;
}; };
py::object operator()(const py::args& args) { py::object operator()(const py::args& args, const py::kwargs& kwargs) {
auto compile_fun = [this, &args](const std::vector<array>& a) { auto inputs = tree_flatten(args, false);
auto compile_fun = [this, &args, &kwargs, num_args = inputs.size()](
const std::vector<array>& a) {
// Put tracers into captured inputs // Put tracers into captured inputs
std::vector<array> flat_in_captures; std::vector<array> flat_in_captures;
std::vector<array> trace_captures; std::vector<array> trace_captures;
@ -586,8 +596,10 @@ struct PyCompiledFun {
tree_fill(captured_inputs, trace_captures); tree_fill(captured_inputs, trace_captures);
} }
auto [outputs, py_outputs] = tree_flatten_with_structure( auto tree_outputs =
std::move(fun(*tree_unflatten(args, a))), false); fun(*tree_unflatten(args, a), **tree_unflatten(kwargs, a, num_args));
auto [outputs, py_outputs] =
tree_flatten_with_structure(std::move(tree_outputs), false);
tree_cache().insert({fun_id, py_outputs}); tree_cache().insert({fun_id, py_outputs});
@ -607,7 +619,14 @@ struct PyCompiledFun {
return outputs; return outputs;
}; };
auto inputs = tree_flatten(args, false); {
auto flat_kwargs = tree_flatten(kwargs, false);
inputs.insert(
inputs.end(),
std::make_move_iterator(flat_kwargs.begin()),
std::make_move_iterator(flat_kwargs.end()));
}
if (!py::isinstance<py::none>(captured_inputs)) { if (!py::isinstance<py::none>(captured_inputs)) {
auto flat_in_captures = tree_flatten(captured_inputs, false); auto flat_in_captures = tree_flatten(captured_inputs, false);
inputs.insert( inputs.insert(
@ -616,8 +635,39 @@ struct PyCompiledFun {
std::make_move_iterator(flat_in_captures.end())); std::make_move_iterator(flat_in_captures.end()));
} }
// Collect the compilation constants
std::vector<uint64_t> constants;
auto value_hash = [](py::handle o) -> std::optional<uint64_t> {
// Consider expanding tuples to their contents including start and end
// ids
if (py::isinstance<py::tuple>(o) || py::isinstance<py::str>(o)) {
auto r = py::hash(o);
return *reinterpret_cast<uint64_t*>(&r);
} else if (py::isinstance<py::int_>(o)) {
auto r = o.cast<int64_t>();
return *reinterpret_cast<uint64_t*>(&r);
} else if (py::isinstance<py::float_>(o)) {
auto r = o.cast<double>();
return *reinterpret_cast<uint64_t*>(&r);
} else {
return std::nullopt;
}
};
for (int i = 0; i < args.size(); i++) {
if (auto h = value_hash(args[i]); h.has_value()) {
constants.push_back(*h);
}
}
for (auto& pair : kwargs) {
if (auto h = value_hash(pair.second); h.has_value()) {
constants.push_back(*value_hash(pair.first));
constants.push_back(*h);
}
}
// Compile and call // Compile and call
auto outputs = detail::compile(compile_fun, fun_id)(inputs); auto outputs =
detail::compile(compile_fun, fun_id, shapeless, constants)(inputs);
if (!py::isinstance<py::none>(captured_outputs)) { if (!py::isinstance<py::none>(captured_outputs)) {
std::vector<array> captures( std::vector<array> captures(
std::make_move_iterator(outputs.begin() + num_outputs), std::make_move_iterator(outputs.begin() + num_outputs),
@ -965,12 +1015,14 @@ void init_transforms(py::module_& m) {
"compile", "compile",
[](const py::function& fun, [](const py::function& fun,
const py::object& inputs, const py::object& inputs,
const py::object& outputs) { const py::object& outputs,
return py::cpp_function(PyCompiledFun{fun, inputs, outputs}); bool shapeless) {
return py::cpp_function(PyCompiledFun{fun, inputs, outputs, shapeless});
}, },
"fun"_a, "fun"_a,
"inputs"_a = std::nullopt, "inputs"_a = std::nullopt,
"outputs"_a = std::nullopt, "outputs"_a = std::nullopt,
"shapeless"_a = false,
R"pbdoc( R"pbdoc(
compile(fun: function) -> function compile(fun: function) -> function
@ -990,6 +1042,12 @@ void init_transforms(py::module_& m) {
:obj:`list` or a :obj:`dict` containing arbitrarily nested lists, :obj:`list` or a :obj:`dict` containing arbitrarily nested lists,
dictionaries, or arrays. Leaf nodes that are not :obj:`array` are ignored. dictionaries, or arrays. Leaf nodes that are not :obj:`array` are ignored.
Default: ``None`` Default: ``None``
shapeless (bool, optional): A function compiled with the ``shapeless``
option enabled will not be recompiled when the input shape changes. Not all
functions can be compiled with ``shapeless`` enabled. Attempting to compile
such functions with shapeless enabled will throw. Note, changing the number
of dimensions or type of any input will result in a recompilation even with
``shapeless`` set to ``True``. Default: ``False``
Returns: Returns:
function: A compiled function which has the same input arguments function: A compiled function which has the same input arguments

View File

@ -381,6 +381,164 @@ class TestCompile(mlx_tests.MLXTestCase):
self.assertFalse(mx.allclose(fun(), fun(), 1e-2, 1e-2)) self.assertFalse(mx.allclose(fun(), fun(), 1e-2, 1e-2))
def test_compile_kwargs(self):
@mx.compile
def fun(x, y, z):
return x + y + z
x = mx.array(1)
y = mx.array(2)
z = mx.array(3)
out = fun(x, y=y, z=z)
self.assertEqual(out.item(), 6)
def test_shapeless_compile(self):
y = 1
@partial(mx.compile, shapeless=True)
def fun(x):
return x + y
x = mx.array([1, 2])
self.assertTrue(mx.array_equal(fun(x), mx.array([2, 3])))
# The function is not recompiled, so the change
# to y should not be reflected in the output
y = 2
x = mx.array([1, 2, 3])
self.assertTrue(mx.array_equal(fun(x), mx.array([2, 3, 4])))
# Type change recompiles
x = mx.array([1.0, 2.0, 3.0])
self.assertTrue(mx.array_equal(fun(x), mx.array([3.0, 4.0, 5.0])))
fun(x, y=y, z=z)
def test_shapeless_compile(self):
y = 1
@partial(mx.compile, shapeless=True)
def fun(x):
return x + y
x = mx.array([1, 2])
self.assertTrue(mx.array_equal(fun(x), mx.array([2, 3])))
# The function is not recompiled, so the change
# to y should not be reflected in the output
y = 2
x = mx.array([1, 2, 3])
self.assertTrue(mx.array_equal(fun(x), mx.array([2, 3, 4])))
# Type change recompiles
x = mx.array([1.0, 2.0, 3.0])
self.assertTrue(mx.array_equal(fun(x), mx.array([3.0, 4.0, 5.0])))
# Dim change recompiles
x = mx.array([[1, 2, 3]])
self.assertTrue(mx.array_equal(fun(x), mx.array([[3, 4, 5]])))
def test_shapeless_compile_with_broadcasts(self):
x = mx.ones((2, 2))
y = mx.array([2, 2])
def fun(x, y):
return x * y
cfun = mx.compile(fun, shapeless=True)
self.assertTrue(mx.array_equal(cfun(x, y), fun(x, y)))
self.assertTrue(mx.array_equal(cfun(y, x), fun(y, x)))
y = mx.array([[3]])
self.assertTrue(mx.array_equal(cfun(x, y), fun(x, y)))
self.assertTrue(mx.array_equal(cfun(y, x), fun(y, x)))
def test_shapeless_compile_with_reduction(self):
# Test shapeless compile with a reduction
z = 1
@partial(mx.compile, shapeless=True)
def fun(x, y):
return x + y.sum(0, keepdims=True) + z
x = mx.ones((2, 2), mx.int32)
y = mx.ones((2, 2), mx.int32)
self.assertTrue(mx.array_equal(fun(x, y), mx.full(shape=(2, 2), vals=4)))
x = mx.ones((3, 3), mx.int32)
y = mx.ones((3, 3), mx.int32)
z = 2
self.assertTrue(mx.array_equal(fun(x, y), mx.full(shape=(3, 3), vals=5)))
x1 = mx.array([[1, 2], [3, 4], [5, 6]])
x2 = mx.array([[1, 2]])
def fun(x):
return x * x.sum(-1, keepdims=True)
cfun = mx.compile(fun, shapeless=True)
mx.eval(cfun(x1))
self.assertTrue(mx.array_equal(fun(x2), cfun(x2)))
def test_compile_with_constant(self):
# Test float
@partial(mx.compile)
def fun(x, y):
return x + y
z = fun(mx.array(1.0), 1.0)
self.assertEqual(z.item(), 2.0)
z = fun(mx.array(1.0), 2.0)
self.assertEqual(z.item(), 3.0)
z = fun(mx.array(1.0), y=1.0)
self.assertEqual(z.item(), 2.0)
z = fun(mx.array(1.0), y=3.0)
self.assertEqual(z.item(), 4.0)
# Test tuple
@partial(mx.compile)
def fun(x, y=(1, 2)):
return x + y[0] + y[1]
z = fun(mx.array(1))
self.assertEqual(z.item(), 4)
z = fun(mx.array(1), (2, 2))
self.assertEqual(z.item(), 5)
z = fun(mx.array(1), (2, 1))
self.assertEqual(z.item(), 4)
# Test bool
@partial(mx.compile)
def fun(x, y):
if y:
return x + 1
else:
return x + 2
z = fun(mx.array(1), True)
self.assertEqual(z.item(), 2)
z = fun(mx.array(1), False)
self.assertEqual(z.item(), 3)
# Test string
@partial(mx.compile)
def fun(x, y):
if y == "one":
return x + 1
else:
return x + 2
z = fun(mx.array(1), "one")
self.assertEqual(z.item(), 2)
z = fun(mx.array(1), "two")
self.assertEqual(z.item(), 3)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View File

@ -624,31 +624,23 @@ TEST_CASE("test transform compiled function") {
CHECK(!outs[0].inputs()[1].has_primitive()); CHECK(!outs[0].inputs()[1].has_primitive());
} }
TEST_CASE("test metal fusion kernel reuse") { TEST_CASE("test fusion kernel reuse") {
if (default_device() != Device::gpu) {
return;
}
auto cfun = compile(gelu_1); auto cfun = compile(gelu_1);
auto x = array({2.0f, -2.0f}); auto x = array({2.0f, -2.0f});
auto y = cfun({x})[0]; auto y = cfun({x})[0];
auto p = std::dynamic_pointer_cast<Compiled>(y.primitive_ptr()); auto p = std::dynamic_pointer_cast<Compiled>(y.primitive_ptr());
eval(y); eval(y);
std::string lib_name = p->metal_lib_name(); std::string lib_name = p->lib_name();
std::string lib_source = p->metal_lib_source();
CHECK(!lib_name.empty()); CHECK(!lib_name.empty());
CHECK(!lib_source.empty());
x = astype(reshape(arange(10), {2, 5}), float32); x = astype(reshape(arange(10), {2, 5}), float32);
auto z = cfun({x})[0]; auto z = cfun({x})[0];
auto pz = std::dynamic_pointer_cast<Compiled>(z.primitive_ptr()); auto pz = std::dynamic_pointer_cast<Compiled>(z.primitive_ptr());
eval(z); eval(z);
std::string lib_name_z = pz->metal_lib_name(); std::string lib_name_z = pz->lib_name();
std::string lib_source_z = pz->metal_lib_source();
CHECK(!lib_name_z.empty()); CHECK(!lib_name_z.empty());
CHECK(lib_source_z.empty());
CHECK_EQ(lib_name, lib_name_z); CHECK_EQ(lib_name, lib_name_z);
} }
@ -657,29 +649,57 @@ auto add3(const std::vector<array>& xs) {
return std::vector<array>{xs[0] + xs[0] + xs[0]}; return std::vector<array>{xs[0] + xs[0] + xs[0]};
} }
TEST_CASE("test metal fusion types") { TEST_CASE("test fusion types") {
if (default_device() != Device::gpu) {
return;
}
auto cfun = compile(add3); auto cfun = compile(add3);
auto x = array({2.0f, -2.0f}); auto x = array({2.0f, -2.0f});
auto y = cfun({x})[0]; auto y = cfun({x})[0];
auto p = std::dynamic_pointer_cast<Compiled>(y.primitive_ptr()); auto p = std::dynamic_pointer_cast<Compiled>(y.primitive_ptr());
eval(y); eval(y);
std::string lib_name = p->metal_lib_name(); std::string lib_name = p->lib_name();
std::string lib_source = p->metal_lib_source();
CHECK(!lib_name.empty()); CHECK(!lib_name.empty());
CHECK(!lib_source.empty());
x = array({2, -2}, int32); x = array({2, -2}, int32);
auto z = cfun({x})[0]; auto z = cfun({x})[0];
auto pz = std::dynamic_pointer_cast<Compiled>(z.primitive_ptr()); auto pz = std::dynamic_pointer_cast<Compiled>(z.primitive_ptr());
eval(z); eval(z);
std::string lib_name_z = pz->metal_lib_name(); std::string lib_name_z = pz->lib_name();
std::string lib_source_z = pz->metal_lib_source();
CHECK(!lib_name_z.empty()); CHECK(!lib_name_z.empty());
CHECK(!lib_source_z.empty()); }
auto compile_shapeless_not_ok(const std::vector<array>& inputs) {
auto x = reshape(inputs[0], {2, 2});
return std::vector<array>{x};
}
auto compile_shapeless_ok(const std::vector<array>& inputs) {
auto x = inputs[0] + array({2});
return std::vector<array>{x};
}
TEST_CASE("test shapeless compile") {
{
auto cfun = compile(compile_shapeless_not_ok, /* shapeless */ true);
CHECK_THROWS(cfun({array({1, 2, 3, 4})}));
}
{
auto cfun = compile(compile_shapeless_ok, /* shapeless */ true);
auto out = cfun({array({1, 2})})[0];
auto out2 = cfun({array({1, 2, 3, 4})})[0];
// Not making a new constant array since no recompile,
// hence the ids should be the same
CHECK_EQ(out.inputs()[1].id(), out2.inputs()[1].id());
CHECK(array_equal(out2, array({3, 4, 5, 6})).item<bool>());
// Recompile since type changes
out2 = cfun({array({1.0, 2.0})})[0];
CHECK_NE(out.inputs()[1].id(), out2.inputs()[1].id());
// Recompile since ndim changes
out2 = cfun({array({1.0, 2.0}, {1, 2})})[0];
CHECK_NE(out.inputs()[1].id(), out2.inputs()[1].id());
}
} }