mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
Shapeless compilation for some graphs (#687)
* shapeless compilation for some graphs * update compile benchmark * default compile a few activations * buffer donation * bugfix * shapeless fix * update tests to work for cpu and gpu fusion * test kwargs * add kwargs to compile * Recompile when python arguments change * no compile for tanh * some constant tests --------- Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
parent
d0fda82595
commit
5798256fcf
109
benchmarks/python/compile_bench.py
Normal file
109
benchmarks/python/compile_bench.py
Normal file
@ -0,0 +1,109 @@
|
|||||||
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import math
|
||||||
|
import random
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
from time_utils import time_fn
|
||||||
|
|
||||||
|
|
||||||
|
def bench_gelu():
|
||||||
|
|
||||||
|
def gelu(x):
|
||||||
|
return x * (1 + mx.erf(x / math.sqrt(2))) / 2
|
||||||
|
|
||||||
|
x = mx.random.uniform(shape=(1000, 1024))
|
||||||
|
|
||||||
|
def gen_fun(fun):
|
||||||
|
def bench_fun(x):
|
||||||
|
for _ in range(10):
|
||||||
|
x = fun(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
return bench_fun
|
||||||
|
|
||||||
|
time_fn(gen_fun(gelu), x, msg="fixed gelu")
|
||||||
|
time_fn(gen_fun(mx.compile(gelu)), x, msg="compiled fixed gelu")
|
||||||
|
|
||||||
|
def randint():
|
||||||
|
return random.randint(1, x.shape[0])
|
||||||
|
|
||||||
|
def gen_fun(fun):
|
||||||
|
def bench_fun(x, y):
|
||||||
|
x = x[: randint()]
|
||||||
|
for _ in range(10):
|
||||||
|
x = fun(x)
|
||||||
|
y = fun(y)
|
||||||
|
return x, y
|
||||||
|
|
||||||
|
return bench_fun
|
||||||
|
|
||||||
|
y = mx.random.uniform(shape=(1000, 1024))
|
||||||
|
time_fn(gen_fun(gelu), x, y, msg="variable gelu")
|
||||||
|
time_fn(gen_fun(mx.compile(gelu)), x, y, msg="compiled variable gelu")
|
||||||
|
time_fn(
|
||||||
|
gen_fun(mx.compile(gelu, shapeless=True)),
|
||||||
|
x,
|
||||||
|
y,
|
||||||
|
msg="shapeless variable gelu",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def bench_layernorm():
|
||||||
|
|
||||||
|
weight = mx.random.uniform(shape=(4096,)).astype(mx.float16)
|
||||||
|
bias = mx.random.uniform(shape=(4096,)).astype(mx.float16)
|
||||||
|
mx.eval(weight, bias)
|
||||||
|
|
||||||
|
def layernorm(x):
|
||||||
|
x = x.astype(mx.float32)
|
||||||
|
means = mx.mean(x, axis=-1, keepdims=True)
|
||||||
|
var = mx.var(x, axis=-1, keepdims=True)
|
||||||
|
x = (x - means) * mx.rsqrt(var + 1e-4)
|
||||||
|
x = x.astype(mx.float16)
|
||||||
|
return weight * x + bias
|
||||||
|
|
||||||
|
x = mx.random.uniform(shape=(1000, 4096)).astype(mx.float16)
|
||||||
|
|
||||||
|
def gen_fun(fun):
|
||||||
|
def bench_fun(x):
|
||||||
|
for _ in range(10):
|
||||||
|
x = fun(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
return bench_fun
|
||||||
|
|
||||||
|
time_fn(gen_fun(layernorm), x, msg="fixed layernorm")
|
||||||
|
time_fn(gen_fun(mx.compile(layernorm)), x, msg="compiled fixed layernorm")
|
||||||
|
|
||||||
|
def randint():
|
||||||
|
return random.randint(1, x.shape[0])
|
||||||
|
|
||||||
|
def gen_fun(fun):
|
||||||
|
def bench_fun(x):
|
||||||
|
x = x[: randint()]
|
||||||
|
for _ in range(10):
|
||||||
|
x = fun(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
return bench_fun
|
||||||
|
|
||||||
|
random.seed(0)
|
||||||
|
time_fn(gen_fun(layernorm), x, msg="variable layernorm")
|
||||||
|
random.seed(0)
|
||||||
|
time_fn(gen_fun(mx.compile(layernorm)), x, msg="compiled variable layernorm")
|
||||||
|
random.seed(0)
|
||||||
|
time_fn(
|
||||||
|
gen_fun(mx.compile(layernorm, shapeless=True)),
|
||||||
|
x,
|
||||||
|
msg="shapeless variable layernorm",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser("Compile benchmarks.")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
bench_gelu()
|
||||||
|
bench_layernorm()
|
@ -6,7 +6,11 @@ import mlx.core as mx
|
|||||||
|
|
||||||
|
|
||||||
def time_fn(fn, *args, **kwargs):
|
def time_fn(fn, *args, **kwargs):
|
||||||
print(f"Timing {fn.__name__} ...", end=" ")
|
msg = kwargs.pop("msg", None)
|
||||||
|
if msg:
|
||||||
|
print(f"Timing {msg} ...", end=" ")
|
||||||
|
else:
|
||||||
|
print(f"Timing {fn.__name__} ...", end=" ")
|
||||||
|
|
||||||
# warmup
|
# warmup
|
||||||
for _ in range(5):
|
for _ in range(5):
|
||||||
|
@ -37,7 +37,7 @@ std::string build_lib_name(
|
|||||||
os << "C";
|
os << "C";
|
||||||
print_constant(constant_hasher, x);
|
print_constant(constant_hasher, x);
|
||||||
} else {
|
} else {
|
||||||
os << ((x.size() == 1) ? "S" : "V");
|
os << (is_scalar(x) ? "S" : "V");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
os << "_";
|
os << "_";
|
||||||
@ -122,10 +122,6 @@ std::string get_type_string(Dtype d) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
inline bool is_scalar(const array& x) {
|
|
||||||
return x.size() == 1;
|
|
||||||
};
|
|
||||||
|
|
||||||
// Return a pointer to a compiled function
|
// Return a pointer to a compiled function
|
||||||
void* compile(
|
void* compile(
|
||||||
const std::string& kernel_name,
|
const std::string& kernel_name,
|
||||||
@ -358,7 +354,7 @@ void Compiled::eval_cpu(
|
|||||||
bool all_col_contig = true;
|
bool all_col_contig = true;
|
||||||
int non_scalar_inputs = 0;
|
int non_scalar_inputs = 0;
|
||||||
for (auto& x : inputs) {
|
for (auto& x : inputs) {
|
||||||
if (x.size() == 1) {
|
if (is_scalar(x)) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
non_scalar_inputs++;
|
non_scalar_inputs++;
|
||||||
@ -385,7 +381,7 @@ void Compiled::eval_cpu(
|
|||||||
auto& x = inputs[i];
|
auto& x = inputs[i];
|
||||||
args.push_back((void*)x.data<void>());
|
args.push_back((void*)x.data<void>());
|
||||||
|
|
||||||
if (contiguous || x.size() <= 1) {
|
if (contiguous || is_scalar(x)) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -458,7 +454,7 @@ void Compiled::eval_cpu(
|
|||||||
// - Donatable
|
// - Donatable
|
||||||
// - Correct size
|
// - Correct size
|
||||||
// - Not a constant
|
// - Not a constant
|
||||||
if (in.flags().contiguous && in.size() > 1 && in.is_donatable() &&
|
if (in.flags().contiguous && !is_scalar(in) && in.is_donatable() &&
|
||||||
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
|
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
|
||||||
outputs[o++].copy_shared_buffer(in);
|
outputs[o++].copy_shared_buffer(in);
|
||||||
}
|
}
|
||||||
|
@ -49,4 +49,8 @@ void print_complex_constant(std::ostream& os, const array& x) {
|
|||||||
|
|
||||||
void print_constant(std::ostream& os, const array& x);
|
void print_constant(std::ostream& os, const array& x);
|
||||||
|
|
||||||
|
inline bool is_scalar(const array& x) {
|
||||||
|
return x.ndim() == 0;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -31,9 +31,6 @@ inline void build_kernel(
|
|||||||
return constant_ids.find(x.id()) != constant_ids.end();
|
return constant_ids.find(x.id()) != constant_ids.end();
|
||||||
};
|
};
|
||||||
|
|
||||||
// For scalar we shouldn't do the indexing things, just read at 0
|
|
||||||
auto is_scalar = [](const array& x) { return x.size() == 1; };
|
|
||||||
|
|
||||||
NodeNamer namer;
|
NodeNamer namer;
|
||||||
bool add_indices = false;
|
bool add_indices = false;
|
||||||
int cnt = 0;
|
int cnt = 0;
|
||||||
@ -226,8 +223,7 @@ void Compiled::eval_gpu(
|
|||||||
/* ndim = */ 0,
|
/* ndim = */ 0,
|
||||||
/* dynamic_dims = */ true);
|
/* dynamic_dims = */ true);
|
||||||
|
|
||||||
kernel_source_ = kernel.str();
|
lib = d.get_library(kernel_lib_, kernel.str());
|
||||||
lib = d.get_library(kernel_lib_, kernel_source_);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Figure out which kernel we are using
|
// Figure out which kernel we are using
|
||||||
@ -235,7 +231,7 @@ void Compiled::eval_gpu(
|
|||||||
bool contiguous = true;
|
bool contiguous = true;
|
||||||
for (auto& x : inputs) {
|
for (auto& x : inputs) {
|
||||||
if ((!x.flags().row_contiguous || x.shape() != output_shape) &&
|
if ((!x.flags().row_contiguous || x.shape() != output_shape) &&
|
||||||
x.size() > 1) {
|
!is_scalar(x)) {
|
||||||
contiguous = false;
|
contiguous = false;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@ -256,7 +252,7 @@ void Compiled::eval_gpu(
|
|||||||
auto& x = inputs[i];
|
auto& x = inputs[i];
|
||||||
|
|
||||||
// Skip scalar inputs.
|
// Skip scalar inputs.
|
||||||
if (x.size() <= 1) {
|
if (is_scalar(x)) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -311,7 +307,7 @@ void Compiled::eval_gpu(
|
|||||||
}
|
}
|
||||||
auto& x = inputs[i];
|
auto& x = inputs[i];
|
||||||
set_array_buffer(compute_encoder, x, cnt++);
|
set_array_buffer(compute_encoder, x, cnt++);
|
||||||
if (!contiguous && x.size() > 1) {
|
if (!contiguous && !is_scalar(x)) {
|
||||||
compute_encoder->setBytes(
|
compute_encoder->setBytes(
|
||||||
strides[stride_idx].data(),
|
strides[stride_idx].data(),
|
||||||
strides[stride_idx].size() * sizeof(size_t),
|
strides[stride_idx].size() * sizeof(size_t),
|
||||||
|
118
mlx/compile.cpp
118
mlx/compile.cpp
@ -13,7 +13,7 @@
|
|||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
constexpr int max_compile_depth = 10;
|
constexpr int max_compile_depth = 11;
|
||||||
|
|
||||||
bool is_unary(const Primitive& p) {
|
bool is_unary(const Primitive& p) {
|
||||||
return (
|
return (
|
||||||
@ -55,19 +55,20 @@ bool is_noop(const Primitive& p) {
|
|||||||
return typeid(p) == typeid(Copy) || typeid(p) == typeid(StopGradient);
|
return typeid(p) == typeid(Copy) || typeid(p) == typeid(StopGradient);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool is_reduction(const Primitive& p) {
|
||||||
|
return typeid(p) == typeid(Reduce) || typeid(p) == typeid(ArgReduce);
|
||||||
|
}
|
||||||
|
|
||||||
bool is_fusable(const Primitive& p) {
|
bool is_fusable(const Primitive& p) {
|
||||||
return is_unary(p) || is_binary(p) || is_broadcast(p) || is_noop(p);
|
return is_unary(p) || is_binary(p) || is_broadcast(p) || is_noop(p);
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace detail {
|
bool allows_shapeless(const Primitive& p) {
|
||||||
|
return typeid(p) == typeid(Compiled) || is_unary(p) || is_binary(p) ||
|
||||||
std::vector<array> compile_replace(
|
is_noop(p) || is_reduction(p) || typeid(p) == typeid(Softmax) ||
|
||||||
const std::vector<array>& tape,
|
typeid(p) == typeid(Sort) || typeid(p) == typeid(ArgSort) ||
|
||||||
const std::vector<array>& trace_inputs,
|
typeid(p) == typeid(ArgPartition) || typeid(p) == typeid(Partition);
|
||||||
const std::vector<array>& trace_outputs,
|
}
|
||||||
const std::vector<array>& inputs);
|
|
||||||
|
|
||||||
} // namespace detail
|
|
||||||
|
|
||||||
Compiled::Compiled(
|
Compiled::Compiled(
|
||||||
Stream stream,
|
Stream stream,
|
||||||
@ -123,6 +124,23 @@ void Compiled::print(std::ostream& os) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<std::vector<int>> Compiled::output_shapes(
|
||||||
|
const std::vector<array>& inputs) {
|
||||||
|
size_t nd = 0;
|
||||||
|
for (auto& in : inputs) {
|
||||||
|
nd = std::max(nd, in.ndim());
|
||||||
|
}
|
||||||
|
std::vector<int> out_shape(nd, 0);
|
||||||
|
for (auto& in : inputs) {
|
||||||
|
auto dd = nd - in.ndim();
|
||||||
|
for (auto i = dd; i < nd; ++i) {
|
||||||
|
out_shape[i] = std::max(out_shape[i], in.shape()[i - dd]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// All outputs have the same shape
|
||||||
|
return std::vector<std::vector<int>>(outputs_.size(), out_shape);
|
||||||
|
}
|
||||||
|
|
||||||
namespace detail {
|
namespace detail {
|
||||||
|
|
||||||
CompileMode& compile_mode() {
|
CompileMode& compile_mode() {
|
||||||
@ -180,21 +198,30 @@ struct CompilerCache {
|
|||||||
std::vector<array> outputs;
|
std::vector<array> outputs;
|
||||||
std::vector<array> tape;
|
std::vector<array> tape;
|
||||||
bool empty{true};
|
bool empty{true};
|
||||||
|
std::vector<uint64_t> constants;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Returns a reference to a CacheEntry which can be updated
|
// Returns a reference to a CacheEntry which can be updated
|
||||||
// by the caller to avoid copying large tapes / inputs / outputs
|
// by the caller to avoid copying large tapes / inputs / outputs
|
||||||
CacheEntry& find(size_t fun_id, const std::vector<array>& inputs) {
|
CacheEntry& find(
|
||||||
|
size_t fun_id,
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
bool shapeless,
|
||||||
|
const std::vector<uint64_t>& constants) {
|
||||||
// Try to find the entry
|
// Try to find the entry
|
||||||
auto [entry_it, inserted] = cache_.insert({fun_id, {}});
|
auto [entry_it, inserted] = cache_.insert({fun_id, {}});
|
||||||
auto& entries = entry_it->second;
|
auto& entries = entry_it->second;
|
||||||
auto is_match = [](const std::vector<array>& in1,
|
auto is_match = [shapeless](
|
||||||
const std::vector<array>& in2) {
|
const std::vector<array>& in1,
|
||||||
|
const std::vector<array>& in2) {
|
||||||
if (in1.size() != in2.size()) {
|
if (in1.size() != in2.size()) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
for (int i = 0; i < in1.size(); ++i) {
|
for (int i = 0; i < in1.size(); ++i) {
|
||||||
if (in1[i].shape() != in2[i].shape()) {
|
if (in1[i].ndim() != in2[i].ndim()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (!shapeless && in1[i].shape() != in2[i].shape()) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
if (in1[i].dtype() != in2[i].dtype()) {
|
if (in1[i].dtype() != in2[i].dtype()) {
|
||||||
@ -210,7 +237,7 @@ struct CompilerCache {
|
|||||||
// more easily searchable structure.
|
// more easily searchable structure.
|
||||||
for (auto& entry : entries) {
|
for (auto& entry : entries) {
|
||||||
// Check the inputs match and return if so
|
// Check the inputs match and return if so
|
||||||
if (is_match(inputs, entry.inputs)) {
|
if (is_match(inputs, entry.inputs) && constants == entry.constants) {
|
||||||
return entry;
|
return entry;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -651,7 +678,8 @@ std::vector<array> compile_replace(
|
|||||||
const std::vector<array>& tape,
|
const std::vector<array>& tape,
|
||||||
const std::vector<array>& trace_inputs,
|
const std::vector<array>& trace_inputs,
|
||||||
const std::vector<array>& trace_outputs,
|
const std::vector<array>& trace_outputs,
|
||||||
const std::vector<array>& inputs) {
|
const std::vector<array>& inputs,
|
||||||
|
bool shapeless) {
|
||||||
std::unordered_map<uintptr_t, array> trace_to_real;
|
std::unordered_map<uintptr_t, array> trace_to_real;
|
||||||
for (int i = 0; i < inputs.size(); ++i) {
|
for (int i = 0; i < inputs.size(); ++i) {
|
||||||
trace_to_real.insert({trace_inputs[i].id(), inputs[i]});
|
trace_to_real.insert({trace_inputs[i].id(), inputs[i]});
|
||||||
@ -669,18 +697,29 @@ std::vector<array> compile_replace(
|
|||||||
real_inputs.push_back(trace_to_real.at(in.id()));
|
real_inputs.push_back(trace_to_real.at(in.id()));
|
||||||
}
|
}
|
||||||
if (a.siblings().empty()) {
|
if (a.siblings().empty()) {
|
||||||
|
auto shape =
|
||||||
|
shapeless ? a.primitive().output_shapes(real_inputs)[0] : a.shape();
|
||||||
auto real_a = array(
|
auto real_a = array(
|
||||||
a.shape(), a.dtype(), a.primitive_ptr(), std::move(real_inputs));
|
std::move(shape),
|
||||||
|
a.dtype(),
|
||||||
|
a.primitive_ptr(),
|
||||||
|
std::move(real_inputs));
|
||||||
trace_to_real.insert({a.id(), std::move(real_a)});
|
trace_to_real.insert({a.id(), std::move(real_a)});
|
||||||
} else {
|
} else {
|
||||||
// Ensure the order is correct for multi-output primitives
|
// Ensure the order is correct for multi-output primitives
|
||||||
std::vector<std::vector<int>> shapes;
|
|
||||||
std::vector<Dtype> types;
|
std::vector<Dtype> types;
|
||||||
auto trace_out = a.outputs();
|
auto trace_out = a.outputs();
|
||||||
for (auto& o : trace_out) {
|
for (auto& o : trace_out) {
|
||||||
shapes.push_back(o.shape());
|
|
||||||
types.push_back(o.dtype());
|
types.push_back(o.dtype());
|
||||||
}
|
}
|
||||||
|
std::vector<std::vector<int>> shapes;
|
||||||
|
if (shapeless) {
|
||||||
|
shapes = a.primitive().output_shapes(real_inputs);
|
||||||
|
} else {
|
||||||
|
for (auto& o : trace_out) {
|
||||||
|
shapes.push_back(o.shape());
|
||||||
|
}
|
||||||
|
}
|
||||||
auto real_out =
|
auto real_out =
|
||||||
array::make_arrays(shapes, types, a.primitive_ptr(), real_inputs);
|
array::make_arrays(shapes, types, a.primitive_ptr(), real_inputs);
|
||||||
for (int i = 0; i < trace_out.size(); ++i) {
|
for (int i = 0; i < trace_out.size(); ++i) {
|
||||||
@ -697,13 +736,34 @@ std::vector<array> compile_replace(
|
|||||||
return outputs;
|
return outputs;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void compile_validate_shapeless(const std::vector<array>& tape) {
|
||||||
|
for (auto& t : tape) {
|
||||||
|
if (!t.has_primitive()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
auto& p = t.primitive();
|
||||||
|
if (allows_shapeless(p)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[compile] Cannot compile primitive ";
|
||||||
|
p.print(msg);
|
||||||
|
msg << " with shapeless enabled.";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
std::function<std::vector<array>(const std::vector<array>&)> compile(
|
std::function<std::vector<array>(const std::vector<array>&)> compile(
|
||||||
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
||||||
size_t fun_id) {
|
size_t fun_id,
|
||||||
|
bool shapeless /* = false */,
|
||||||
|
std::vector<uint64_t> constants /* = {} */) {
|
||||||
if (compile_mode() == CompileMode::disabled) {
|
if (compile_mode() == CompileMode::disabled) {
|
||||||
return fun;
|
return fun;
|
||||||
}
|
}
|
||||||
return [fun, fun_id](const std::vector<array>& inputs) {
|
return [fun, fun_id, shapeless, constants = std::move(constants)](
|
||||||
|
const std::vector<array>& inputs) {
|
||||||
// If the inputs are tracers, trace the original graph
|
// If the inputs are tracers, trace the original graph
|
||||||
if (std::any_of(inputs.begin(), inputs.end(), [](auto& in) {
|
if (std::any_of(inputs.begin(), inputs.end(), [](auto& in) {
|
||||||
return in.is_tracer();
|
return in.is_tracer();
|
||||||
@ -712,12 +772,14 @@ std::function<std::vector<array>(const std::vector<array>&)> compile(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Find a cache entry with the correct inputs
|
// Find a cache entry with the correct inputs
|
||||||
auto& entry = compiler_cache().find(fun_id, inputs);
|
auto& entry = compiler_cache().find(fun_id, inputs, shapeless, constants);
|
||||||
|
|
||||||
// No matching cache entry existed, so compile
|
// No matching cache entry existed, so compile
|
||||||
if (entry.empty) {
|
if (entry.empty) {
|
||||||
// Mark the entry as not empty since we are about to fill it
|
// Mark the entry as not empty since we are about to fill it
|
||||||
entry.empty = false;
|
entry.empty = false;
|
||||||
|
// Set the constants
|
||||||
|
entry.constants = std::move(constants);
|
||||||
// Trace to build the graph
|
// Trace to build the graph
|
||||||
std::tie(entry.inputs, entry.outputs) = compile_trace(fun, inputs);
|
std::tie(entry.inputs, entry.outputs) = compile_trace(fun, inputs);
|
||||||
|
|
||||||
@ -739,11 +801,16 @@ std::function<std::vector<array>(const std::vector<array>&)> compile(
|
|||||||
if (compile_mode() != CompileMode::no_fuse) {
|
if (compile_mode() != CompileMode::no_fuse) {
|
||||||
compile_fuse(entry.tape, parents_map, entry.inputs, entry.outputs);
|
compile_fuse(entry.tape, parents_map, entry.inputs, entry.outputs);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (shapeless) {
|
||||||
|
compile_validate_shapeless(entry.tape);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// At this point we must have a tape, now replace the placeholders
|
// At this point we must have a tape, now replace the placeholders
|
||||||
// with real arrays that can be evaluated
|
// with real arrays that can be evaluated
|
||||||
return compile_replace(entry.tape, entry.inputs, entry.outputs, inputs);
|
return compile_replace(
|
||||||
|
entry.tape, entry.inputs, entry.outputs, inputs, shapeless);
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -754,12 +821,13 @@ void compile_erase(size_t fun_id) {
|
|||||||
} // namespace detail
|
} // namespace detail
|
||||||
|
|
||||||
std::function<std::vector<array>(const std::vector<array>&)> compile(
|
std::function<std::vector<array>(const std::vector<array>&)> compile(
|
||||||
const std::function<std::vector<array>(const std::vector<array>&)>& fun) {
|
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
||||||
|
bool shapeless /* false */) {
|
||||||
if (detail::compile_mode() == CompileMode::disabled) {
|
if (detail::compile_mode() == CompileMode::disabled) {
|
||||||
return fun;
|
return fun;
|
||||||
}
|
}
|
||||||
auto fun_id = detail::getAddress(fun);
|
auto fun_id = detail::getAddress(fun);
|
||||||
return detail::compile(fun, fun_id);
|
return detail::compile(fun, fun_id, shapeless);
|
||||||
}
|
}
|
||||||
|
|
||||||
void disable_compile() {
|
void disable_compile() {
|
||||||
|
@ -8,9 +8,10 @@ namespace mlx::core {
|
|||||||
|
|
||||||
enum class CompileMode { disabled, no_simplify, no_fuse, enabled };
|
enum class CompileMode { disabled, no_simplify, no_fuse, enabled };
|
||||||
|
|
||||||
// Compile takes a function and returns a new function
|
/** Compile takes a function and returns a compiled function. */
|
||||||
std::function<std::vector<array>(const std::vector<array>&)> compile(
|
std::function<std::vector<array>(const std::vector<array>&)> compile(
|
||||||
const std::function<std::vector<array>(const std::vector<array>&)>& fun);
|
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
||||||
|
bool shapeless = false);
|
||||||
|
|
||||||
/** Globally disable compilation.
|
/** Globally disable compilation.
|
||||||
* Setting the environment variable ``MLX_DISABLE_COMPILE`` can also
|
* Setting the environment variable ``MLX_DISABLE_COMPILE`` can also
|
||||||
|
@ -71,6 +71,15 @@ std::pair<std::vector<array>, std::vector<int>> Primitive::vmap(
|
|||||||
throw std::invalid_argument("Primitive's vmap not implemented.");
|
throw std::invalid_argument("Primitive's vmap not implemented.");
|
||||||
};
|
};
|
||||||
|
|
||||||
|
std::vector<std::vector<int>> Primitive::output_shapes(
|
||||||
|
const std::vector<array>&) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[Primitive::output_shapes] ";
|
||||||
|
this->print(msg);
|
||||||
|
msg << " cannot infer output shapes.";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
};
|
||||||
|
|
||||||
std::vector<array> Abs::vjp(
|
std::vector<array> Abs::vjp(
|
||||||
const std::vector<array>& primals,
|
const std::vector<array>& primals,
|
||||||
const std::vector<array>& cotangents,
|
const std::vector<array>& cotangents,
|
||||||
@ -383,6 +392,13 @@ std::pair<std::vector<array>, std::vector<int>> ArgSort::vmap(
|
|||||||
return {{argsort(inputs[0], axis_ + (axes[0] <= axis_), stream())}, axes};
|
return {{argsort(inputs[0], axis_ + (axes[0] <= axis_), stream())}, axes};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<std::vector<int>> ArgReduce::output_shapes(
|
||||||
|
const std::vector<array>& inputs) {
|
||||||
|
auto out_shape = inputs[0].shape();
|
||||||
|
out_shape[axis_] = 1;
|
||||||
|
return {out_shape};
|
||||||
|
}
|
||||||
|
|
||||||
bool ArgSort::is_equivalent(const Primitive& other) const {
|
bool ArgSort::is_equivalent(const Primitive& other) const {
|
||||||
const ArgSort& r_other = static_cast<const ArgSort&>(other);
|
const ArgSort& r_other = static_cast<const ArgSort&>(other);
|
||||||
return axis_ == r_other.axis_;
|
return axis_ == r_other.axis_;
|
||||||
@ -2202,6 +2218,15 @@ bool Reduce::is_equivalent(const Primitive& other) const {
|
|||||||
return reduce_type_ == r_other.reduce_type_ && axes_ == r_other.axes_;
|
return reduce_type_ == r_other.reduce_type_ && axes_ == r_other.axes_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<std::vector<int>> Reduce::output_shapes(
|
||||||
|
const std::vector<array>& inputs) {
|
||||||
|
std::vector<int> out_shape = inputs[0].shape();
|
||||||
|
for (auto i : axes_) {
|
||||||
|
out_shape[i] = 1;
|
||||||
|
}
|
||||||
|
return {out_shape};
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<array> Round::vjp(
|
std::vector<array> Round::vjp(
|
||||||
const std::vector<array>& primals,
|
const std::vector<array>& primals,
|
||||||
const std::vector<array>& cotangents,
|
const std::vector<array>& cotangents,
|
||||||
|
@ -36,6 +36,12 @@
|
|||||||
return true; \
|
return true; \
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#define DEFINE_INPUT_OUTPUT_SHAPE() \
|
||||||
|
std::vector<std::vector<int>> output_shapes( \
|
||||||
|
const std::vector<array>& inputs) override { \
|
||||||
|
return {inputs[0].shape()}; \
|
||||||
|
};
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
// Abstract base class
|
// Abstract base class
|
||||||
@ -102,6 +108,11 @@ class Primitive {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** Get the output shapes of the primitive. This is not required to be
|
||||||
|
* implemented by derived classes, in which case it will throw. */
|
||||||
|
virtual std::vector<std::vector<int>> output_shapes(
|
||||||
|
const std::vector<array>& inputs);
|
||||||
|
|
||||||
virtual ~Primitive() = default;
|
virtual ~Primitive() = default;
|
||||||
Primitive(const Primitive& other) = delete;
|
Primitive(const Primitive& other) = delete;
|
||||||
Primitive(Primitive&& other) = delete;
|
Primitive(Primitive&& other) = delete;
|
||||||
@ -152,6 +163,7 @@ class Abs : public UnaryPrimitive {
|
|||||||
DEFINE_GRADS()
|
DEFINE_GRADS()
|
||||||
DEFINE_PRINT(Abs)
|
DEFINE_PRINT(Abs)
|
||||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||||
|
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void eval(const std::vector<array>& inputs, array& out);
|
void eval(const std::vector<array>& inputs, array& out);
|
||||||
@ -168,6 +180,7 @@ class Add : public UnaryPrimitive {
|
|||||||
DEFINE_GRADS()
|
DEFINE_GRADS()
|
||||||
DEFINE_PRINT(Add)
|
DEFINE_PRINT(Add)
|
||||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||||
|
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void eval(const std::vector<array>& inputs, array& out);
|
void eval(const std::vector<array>& inputs, array& out);
|
||||||
@ -226,6 +239,7 @@ class ArcCos : public UnaryPrimitive {
|
|||||||
DEFINE_GRADS()
|
DEFINE_GRADS()
|
||||||
DEFINE_PRINT(ArcCos)
|
DEFINE_PRINT(ArcCos)
|
||||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||||
|
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void eval(const std::vector<array>& inputs, array& out);
|
void eval(const std::vector<array>& inputs, array& out);
|
||||||
@ -242,6 +256,7 @@ class ArcCosh : public UnaryPrimitive {
|
|||||||
DEFINE_GRADS()
|
DEFINE_GRADS()
|
||||||
DEFINE_PRINT(ArcCosh)
|
DEFINE_PRINT(ArcCosh)
|
||||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||||
|
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void eval(const std::vector<array>& inputs, array& out);
|
void eval(const std::vector<array>& inputs, array& out);
|
||||||
@ -258,6 +273,7 @@ class ArcSin : public UnaryPrimitive {
|
|||||||
DEFINE_GRADS()
|
DEFINE_GRADS()
|
||||||
DEFINE_PRINT(ArcSin)
|
DEFINE_PRINT(ArcSin)
|
||||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||||
|
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void eval(const std::vector<array>& inputs, array& out);
|
void eval(const std::vector<array>& inputs, array& out);
|
||||||
@ -274,6 +290,7 @@ class ArcSinh : public UnaryPrimitive {
|
|||||||
DEFINE_GRADS()
|
DEFINE_GRADS()
|
||||||
DEFINE_PRINT(ArcSinh)
|
DEFINE_PRINT(ArcSinh)
|
||||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||||
|
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void eval(const std::vector<array>& inputs, array& out);
|
void eval(const std::vector<array>& inputs, array& out);
|
||||||
@ -290,6 +307,7 @@ class ArcTan : public UnaryPrimitive {
|
|||||||
DEFINE_GRADS()
|
DEFINE_GRADS()
|
||||||
DEFINE_PRINT(ArcTan)
|
DEFINE_PRINT(ArcTan)
|
||||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||||
|
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void eval(const std::vector<array>& inputs, array& out);
|
void eval(const std::vector<array>& inputs, array& out);
|
||||||
@ -306,6 +324,7 @@ class ArcTanh : public UnaryPrimitive {
|
|||||||
DEFINE_GRADS()
|
DEFINE_GRADS()
|
||||||
DEFINE_PRINT(ArcTanh)
|
DEFINE_PRINT(ArcTanh)
|
||||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||||
|
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void eval(const std::vector<array>& inputs, array& out);
|
void eval(const std::vector<array>& inputs, array& out);
|
||||||
@ -321,6 +340,7 @@ class ArgPartition : public UnaryPrimitive {
|
|||||||
|
|
||||||
DEFINE_VMAP()
|
DEFINE_VMAP()
|
||||||
DEFINE_PRINT(ArgPartition)
|
DEFINE_PRINT(ArgPartition)
|
||||||
|
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||||
bool is_equivalent(const Primitive& other) const override;
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@ -346,6 +366,8 @@ class ArgReduce : public UnaryPrimitive {
|
|||||||
DEFINE_VMAP()
|
DEFINE_VMAP()
|
||||||
DEFINE_PRINT(ArgReduce)
|
DEFINE_PRINT(ArgReduce)
|
||||||
bool is_equivalent(const Primitive& other) const override;
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
|
std::vector<std::vector<int>> output_shapes(
|
||||||
|
const std::vector<array>& inputs) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
ReduceType reduce_type_;
|
ReduceType reduce_type_;
|
||||||
@ -364,6 +386,7 @@ class ArgSort : public UnaryPrimitive {
|
|||||||
|
|
||||||
DEFINE_VMAP()
|
DEFINE_VMAP()
|
||||||
DEFINE_PRINT(ArgSort)
|
DEFINE_PRINT(ArgSort)
|
||||||
|
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||||
bool is_equivalent(const Primitive& other) const override;
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@ -383,6 +406,7 @@ class AsType : public UnaryPrimitive {
|
|||||||
DEFINE_VMAP()
|
DEFINE_VMAP()
|
||||||
DEFINE_GRADS()
|
DEFINE_GRADS()
|
||||||
DEFINE_PRINT(AsType)
|
DEFINE_PRINT(AsType)
|
||||||
|
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||||
bool is_equivalent(const Primitive& other) const override;
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@ -448,6 +472,7 @@ class Ceil : public UnaryPrimitive {
|
|||||||
DEFINE_GRADS()
|
DEFINE_GRADS()
|
||||||
DEFINE_PRINT(Ceil)
|
DEFINE_PRINT(Ceil)
|
||||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||||
|
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void eval(const std::vector<array>& inputs, array& out);
|
void eval(const std::vector<array>& inputs, array& out);
|
||||||
@ -478,15 +503,14 @@ class Compiled : public Primitive {
|
|||||||
|
|
||||||
DEFINE_VMAP()
|
DEFINE_VMAP()
|
||||||
DEFINE_GRADS()
|
DEFINE_GRADS()
|
||||||
|
std::vector<std::vector<int>> output_shapes(
|
||||||
|
const std::vector<array>& inputs) override;
|
||||||
void print(std::ostream& os) override;
|
void print(std::ostream& os) override;
|
||||||
bool is_equivalent(const Primitive& other) const override;
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
|
|
||||||
std::string metal_lib_name() const {
|
std::string lib_name() const {
|
||||||
return kernel_lib_;
|
return kernel_lib_;
|
||||||
}
|
}
|
||||||
std::string metal_lib_source() const {
|
|
||||||
return kernel_source_;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
const std::vector<array> inputs_;
|
const std::vector<array> inputs_;
|
||||||
@ -495,7 +519,6 @@ class Compiled : public Primitive {
|
|||||||
const std::unordered_set<uintptr_t> constant_ids_;
|
const std::unordered_set<uintptr_t> constant_ids_;
|
||||||
|
|
||||||
std::string kernel_lib_;
|
std::string kernel_lib_;
|
||||||
std::string kernel_source_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
class Concatenate : public UnaryPrimitive {
|
class Concatenate : public UnaryPrimitive {
|
||||||
@ -563,6 +586,7 @@ class Copy : public UnaryPrimitive {
|
|||||||
DEFINE_GRADS()
|
DEFINE_GRADS()
|
||||||
DEFINE_PRINT(Copy)
|
DEFINE_PRINT(Copy)
|
||||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||||
|
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void eval(const std::vector<array>& inputs, array& out);
|
void eval(const std::vector<array>& inputs, array& out);
|
||||||
@ -579,6 +603,7 @@ class Cos : public UnaryPrimitive {
|
|||||||
DEFINE_GRADS()
|
DEFINE_GRADS()
|
||||||
DEFINE_PRINT(Cos)
|
DEFINE_PRINT(Cos)
|
||||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||||
|
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void eval(const std::vector<array>& inputs, array& out);
|
void eval(const std::vector<array>& inputs, array& out);
|
||||||
@ -595,6 +620,7 @@ class Cosh : public UnaryPrimitive {
|
|||||||
DEFINE_GRADS()
|
DEFINE_GRADS()
|
||||||
DEFINE_PRINT(Cosh)
|
DEFINE_PRINT(Cosh)
|
||||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||||
|
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void eval(const std::vector<array>& inputs, array& out);
|
void eval(const std::vector<array>& inputs, array& out);
|
||||||
@ -665,6 +691,7 @@ class Divide : public UnaryPrimitive {
|
|||||||
DEFINE_GRADS()
|
DEFINE_GRADS()
|
||||||
DEFINE_PRINT(Divide)
|
DEFINE_PRINT(Divide)
|
||||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||||
|
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void eval(const std::vector<array>& inputs, array& out);
|
void eval(const std::vector<array>& inputs, array& out);
|
||||||
@ -683,6 +710,10 @@ class DivMod : public Primitive {
|
|||||||
DEFINE_GRADS()
|
DEFINE_GRADS()
|
||||||
DEFINE_PRINT(DivMod)
|
DEFINE_PRINT(DivMod)
|
||||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||||
|
std::vector<std::vector<int>> output_shapes(
|
||||||
|
const std::vector<array>& inputs) override {
|
||||||
|
return std::vector{inputs[0].shape(), inputs[0].shape()};
|
||||||
|
};
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
|
void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
|
||||||
@ -699,6 +730,7 @@ class Remainder : public UnaryPrimitive {
|
|||||||
DEFINE_GRADS()
|
DEFINE_GRADS()
|
||||||
DEFINE_PRINT(Remainder)
|
DEFINE_PRINT(Remainder)
|
||||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||||
|
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void eval(const std::vector<array>& inputs, array& out);
|
void eval(const std::vector<array>& inputs, array& out);
|
||||||
@ -715,6 +747,7 @@ class Equal : public UnaryPrimitive {
|
|||||||
DEFINE_VMAP()
|
DEFINE_VMAP()
|
||||||
DEFINE_GRADS()
|
DEFINE_GRADS()
|
||||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||||
|
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||||
|
|
||||||
void print(std::ostream& os) override {
|
void print(std::ostream& os) override {
|
||||||
if (equal_nan_) {
|
if (equal_nan_) {
|
||||||
@ -740,6 +773,7 @@ class Erf : public UnaryPrimitive {
|
|||||||
DEFINE_GRADS()
|
DEFINE_GRADS()
|
||||||
DEFINE_PRINT(Erf)
|
DEFINE_PRINT(Erf)
|
||||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||||
|
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void eval(const std::vector<array>& inputs, array& out);
|
void eval(const std::vector<array>& inputs, array& out);
|
||||||
@ -756,6 +790,7 @@ class ErfInv : public UnaryPrimitive {
|
|||||||
DEFINE_GRADS()
|
DEFINE_GRADS()
|
||||||
DEFINE_PRINT(ErfInv)
|
DEFINE_PRINT(ErfInv)
|
||||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||||
|
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void eval(const std::vector<array>& inputs, array& out);
|
void eval(const std::vector<array>& inputs, array& out);
|
||||||
@ -772,6 +807,7 @@ class Exp : public UnaryPrimitive {
|
|||||||
DEFINE_GRADS()
|
DEFINE_GRADS()
|
||||||
DEFINE_PRINT(Exp)
|
DEFINE_PRINT(Exp)
|
||||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||||
|
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void eval(const std::vector<array>& inputs, array& out);
|
void eval(const std::vector<array>& inputs, array& out);
|
||||||
@ -814,6 +850,7 @@ class Floor : public UnaryPrimitive {
|
|||||||
DEFINE_GRADS()
|
DEFINE_GRADS()
|
||||||
DEFINE_PRINT(Floor)
|
DEFINE_PRINT(Floor)
|
||||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||||
|
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void eval(const std::vector<array>& inputs, array& out);
|
void eval(const std::vector<array>& inputs, array& out);
|
||||||
@ -868,6 +905,7 @@ class Greater : public UnaryPrimitive {
|
|||||||
DEFINE_GRADS()
|
DEFINE_GRADS()
|
||||||
DEFINE_PRINT(Greater)
|
DEFINE_PRINT(Greater)
|
||||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||||
|
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void eval(const std::vector<array>& inputs, array& out);
|
void eval(const std::vector<array>& inputs, array& out);
|
||||||
@ -884,6 +922,7 @@ class GreaterEqual : public UnaryPrimitive {
|
|||||||
DEFINE_GRADS()
|
DEFINE_GRADS()
|
||||||
DEFINE_PRINT(GreaterEqual)
|
DEFINE_PRINT(GreaterEqual)
|
||||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||||
|
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void eval(const std::vector<array>& inputs, array& out);
|
void eval(const std::vector<array>& inputs, array& out);
|
||||||
@ -900,6 +939,7 @@ class Less : public UnaryPrimitive {
|
|||||||
DEFINE_GRADS()
|
DEFINE_GRADS()
|
||||||
DEFINE_PRINT(Less)
|
DEFINE_PRINT(Less)
|
||||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||||
|
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void eval(const std::vector<array>& inputs, array& out);
|
void eval(const std::vector<array>& inputs, array& out);
|
||||||
@ -916,6 +956,7 @@ class LessEqual : public UnaryPrimitive {
|
|||||||
DEFINE_GRADS()
|
DEFINE_GRADS()
|
||||||
DEFINE_PRINT(LessEqual)
|
DEFINE_PRINT(LessEqual)
|
||||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||||
|
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void eval(const std::vector<array>& inputs, array& out);
|
void eval(const std::vector<array>& inputs, array& out);
|
||||||
@ -958,6 +999,7 @@ class Log : public UnaryPrimitive {
|
|||||||
DEFINE_VMAP()
|
DEFINE_VMAP()
|
||||||
DEFINE_GRADS()
|
DEFINE_GRADS()
|
||||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||||
|
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||||
|
|
||||||
void print(std::ostream& os) override {
|
void print(std::ostream& os) override {
|
||||||
switch (base_) {
|
switch (base_) {
|
||||||
@ -988,6 +1030,7 @@ class Log1p : public UnaryPrimitive {
|
|||||||
DEFINE_VMAP()
|
DEFINE_VMAP()
|
||||||
DEFINE_GRADS()
|
DEFINE_GRADS()
|
||||||
DEFINE_PRINT(Log1p)
|
DEFINE_PRINT(Log1p)
|
||||||
|
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void eval(const std::vector<array>& inputs, array& out);
|
void eval(const std::vector<array>& inputs, array& out);
|
||||||
@ -1004,6 +1047,7 @@ class LogicalNot : public UnaryPrimitive {
|
|||||||
DEFINE_GRADS()
|
DEFINE_GRADS()
|
||||||
DEFINE_PRINT(LogicalNot)
|
DEFINE_PRINT(LogicalNot)
|
||||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||||
|
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void eval(const std::vector<array>& inputs, array& out);
|
void eval(const std::vector<array>& inputs, array& out);
|
||||||
@ -1020,6 +1064,7 @@ class LogicalAnd : public UnaryPrimitive {
|
|||||||
DEFINE_GRADS()
|
DEFINE_GRADS()
|
||||||
DEFINE_PRINT(LogicalAnd)
|
DEFINE_PRINT(LogicalAnd)
|
||||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||||
|
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void eval(const std::vector<array>& inputs, array& out);
|
void eval(const std::vector<array>& inputs, array& out);
|
||||||
@ -1036,6 +1081,7 @@ class LogicalOr : public UnaryPrimitive {
|
|||||||
DEFINE_GRADS()
|
DEFINE_GRADS()
|
||||||
DEFINE_PRINT(LogicalOr)
|
DEFINE_PRINT(LogicalOr)
|
||||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||||
|
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void eval(const std::vector<array>& inputs, array& out);
|
void eval(const std::vector<array>& inputs, array& out);
|
||||||
@ -1052,6 +1098,7 @@ class LogAddExp : public UnaryPrimitive {
|
|||||||
DEFINE_GRADS()
|
DEFINE_GRADS()
|
||||||
DEFINE_PRINT(LogAddExp)
|
DEFINE_PRINT(LogAddExp)
|
||||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||||
|
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void eval(const std::vector<array>& inputs, array& out);
|
void eval(const std::vector<array>& inputs, array& out);
|
||||||
@ -1085,6 +1132,7 @@ class Maximum : public UnaryPrimitive {
|
|||||||
DEFINE_GRADS()
|
DEFINE_GRADS()
|
||||||
DEFINE_PRINT(Maximum)
|
DEFINE_PRINT(Maximum)
|
||||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||||
|
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void eval(const std::vector<array>& inputs, array& out);
|
void eval(const std::vector<array>& inputs, array& out);
|
||||||
@ -1101,6 +1149,7 @@ class Minimum : public UnaryPrimitive {
|
|||||||
DEFINE_GRADS()
|
DEFINE_GRADS()
|
||||||
DEFINE_PRINT(Minimum)
|
DEFINE_PRINT(Minimum)
|
||||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||||
|
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void eval(const std::vector<array>& inputs, array& out);
|
void eval(const std::vector<array>& inputs, array& out);
|
||||||
@ -1117,6 +1166,7 @@ class Multiply : public UnaryPrimitive {
|
|||||||
DEFINE_GRADS()
|
DEFINE_GRADS()
|
||||||
DEFINE_PRINT(Multiply)
|
DEFINE_PRINT(Multiply)
|
||||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||||
|
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void eval(const std::vector<array>& inputs, array& out);
|
void eval(const std::vector<array>& inputs, array& out);
|
||||||
@ -1133,6 +1183,7 @@ class Negative : public UnaryPrimitive {
|
|||||||
DEFINE_GRADS()
|
DEFINE_GRADS()
|
||||||
DEFINE_PRINT(Negative)
|
DEFINE_PRINT(Negative)
|
||||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||||
|
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void eval(const std::vector<array>& inputs, array& out);
|
void eval(const std::vector<array>& inputs, array& out);
|
||||||
@ -1149,6 +1200,7 @@ class NotEqual : public UnaryPrimitive {
|
|||||||
DEFINE_GRADS()
|
DEFINE_GRADS()
|
||||||
DEFINE_PRINT(NotEqual)
|
DEFINE_PRINT(NotEqual)
|
||||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||||
|
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void eval(const std::vector<array>& inputs, array& out);
|
void eval(const std::vector<array>& inputs, array& out);
|
||||||
@ -1193,6 +1245,7 @@ class Partition : public UnaryPrimitive {
|
|||||||
DEFINE_VMAP()
|
DEFINE_VMAP()
|
||||||
DEFINE_GRADS()
|
DEFINE_GRADS()
|
||||||
DEFINE_PRINT(Partition)
|
DEFINE_PRINT(Partition)
|
||||||
|
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||||
bool is_equivalent(const Primitive& other) const override;
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@ -1213,6 +1266,7 @@ class Power : public UnaryPrimitive {
|
|||||||
DEFINE_GRADS()
|
DEFINE_GRADS()
|
||||||
DEFINE_PRINT(Power)
|
DEFINE_PRINT(Power)
|
||||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||||
|
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void eval(const std::vector<array>& inputs, array& out);
|
void eval(const std::vector<array>& inputs, array& out);
|
||||||
@ -1305,6 +1359,9 @@ class Reduce : public UnaryPrimitive {
|
|||||||
const std::vector<int>& argnums,
|
const std::vector<int>& argnums,
|
||||||
const std::vector<array>& outputs) override;
|
const std::vector<array>& outputs) override;
|
||||||
|
|
||||||
|
std::vector<std::vector<int>> output_shapes(
|
||||||
|
const std::vector<array>& inputs) override;
|
||||||
|
|
||||||
void print(std::ostream& os) override {
|
void print(std::ostream& os) override {
|
||||||
switch (reduce_type_) {
|
switch (reduce_type_) {
|
||||||
case And:
|
case And:
|
||||||
@ -1347,6 +1404,7 @@ class Round : public UnaryPrimitive {
|
|||||||
DEFINE_GRADS()
|
DEFINE_GRADS()
|
||||||
DEFINE_PRINT(Round)
|
DEFINE_PRINT(Round)
|
||||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||||
|
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void eval(const std::vector<array>& inputs, array& out);
|
void eval(const std::vector<array>& inputs, array& out);
|
||||||
@ -1455,6 +1513,7 @@ class Sigmoid : public UnaryPrimitive {
|
|||||||
DEFINE_GRADS()
|
DEFINE_GRADS()
|
||||||
DEFINE_PRINT(Sigmoid)
|
DEFINE_PRINT(Sigmoid)
|
||||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||||
|
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void eval(const std::vector<array>& inputs, array& out);
|
void eval(const std::vector<array>& inputs, array& out);
|
||||||
@ -1471,6 +1530,7 @@ class Sign : public UnaryPrimitive {
|
|||||||
DEFINE_GRADS()
|
DEFINE_GRADS()
|
||||||
DEFINE_PRINT(Sign)
|
DEFINE_PRINT(Sign)
|
||||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||||
|
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void eval(const std::vector<array>& inputs, array& out);
|
void eval(const std::vector<array>& inputs, array& out);
|
||||||
@ -1487,6 +1547,7 @@ class Sin : public UnaryPrimitive {
|
|||||||
DEFINE_GRADS()
|
DEFINE_GRADS()
|
||||||
DEFINE_PRINT(Sin)
|
DEFINE_PRINT(Sin)
|
||||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||||
|
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void eval(const std::vector<array>& inputs, array& out);
|
void eval(const std::vector<array>& inputs, array& out);
|
||||||
@ -1503,6 +1564,7 @@ class Sinh : public UnaryPrimitive {
|
|||||||
DEFINE_GRADS()
|
DEFINE_GRADS()
|
||||||
DEFINE_PRINT(Sinh)
|
DEFINE_PRINT(Sinh)
|
||||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||||
|
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void eval(const std::vector<array>& inputs, array& out);
|
void eval(const std::vector<array>& inputs, array& out);
|
||||||
@ -1547,6 +1609,7 @@ class Softmax : public UnaryPrimitive {
|
|||||||
DEFINE_GRADS()
|
DEFINE_GRADS()
|
||||||
DEFINE_PRINT(Softmax)
|
DEFINE_PRINT(Softmax)
|
||||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||||
|
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void eval(const std::vector<array>& inputs, array& out);
|
void eval(const std::vector<array>& inputs, array& out);
|
||||||
@ -1563,6 +1626,7 @@ class Sort : public UnaryPrimitive {
|
|||||||
DEFINE_VMAP()
|
DEFINE_VMAP()
|
||||||
DEFINE_GRADS()
|
DEFINE_GRADS()
|
||||||
DEFINE_PRINT(Sort)
|
DEFINE_PRINT(Sort)
|
||||||
|
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||||
bool is_equivalent(const Primitive& other) const override;
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@ -1604,6 +1668,7 @@ class Square : public UnaryPrimitive {
|
|||||||
DEFINE_GRADS()
|
DEFINE_GRADS()
|
||||||
DEFINE_PRINT(Square)
|
DEFINE_PRINT(Square)
|
||||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||||
|
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void eval(const std::vector<array>& inputs, array& out);
|
void eval(const std::vector<array>& inputs, array& out);
|
||||||
@ -1619,6 +1684,7 @@ class Sqrt : public UnaryPrimitive {
|
|||||||
|
|
||||||
DEFINE_VMAP()
|
DEFINE_VMAP()
|
||||||
DEFINE_GRADS()
|
DEFINE_GRADS()
|
||||||
|
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||||
bool is_equivalent(const Primitive& other) const override;
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
|
|
||||||
void print(std::ostream& os) override {
|
void print(std::ostream& os) override {
|
||||||
@ -1644,6 +1710,7 @@ class StopGradient : public UnaryPrimitive {
|
|||||||
DEFINE_VMAP()
|
DEFINE_VMAP()
|
||||||
DEFINE_PRINT(StopGradient)
|
DEFINE_PRINT(StopGradient)
|
||||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||||
|
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void eval(const std::vector<array>& inputs, array& out);
|
void eval(const std::vector<array>& inputs, array& out);
|
||||||
@ -1660,6 +1727,7 @@ class Subtract : public UnaryPrimitive {
|
|||||||
DEFINE_GRADS()
|
DEFINE_GRADS()
|
||||||
DEFINE_PRINT(Subtract)
|
DEFINE_PRINT(Subtract)
|
||||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||||
|
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void eval(const std::vector<array>& inputs, array& out);
|
void eval(const std::vector<array>& inputs, array& out);
|
||||||
@ -1676,6 +1744,7 @@ class Tan : public UnaryPrimitive {
|
|||||||
DEFINE_GRADS()
|
DEFINE_GRADS()
|
||||||
DEFINE_PRINT(Tan)
|
DEFINE_PRINT(Tan)
|
||||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||||
|
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void eval(const std::vector<array>& inputs, array& out);
|
void eval(const std::vector<array>& inputs, array& out);
|
||||||
@ -1692,6 +1761,7 @@ class Tanh : public UnaryPrimitive {
|
|||||||
DEFINE_GRADS()
|
DEFINE_GRADS()
|
||||||
DEFINE_PRINT(Tanh)
|
DEFINE_PRINT(Tanh)
|
||||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||||
|
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void eval(const std::vector<array>& inputs, array& out);
|
void eval(const std::vector<array>& inputs, array& out);
|
||||||
|
@ -18,7 +18,9 @@ std::vector<array> vmap_replace(
|
|||||||
// idea.
|
// idea.
|
||||||
std::function<std::vector<array>(const std::vector<array>&)> compile(
|
std::function<std::vector<array>(const std::vector<array>&)> compile(
|
||||||
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
||||||
size_t fun_id);
|
size_t fun_id,
|
||||||
|
bool shapeless = false,
|
||||||
|
std::vector<uint64_t> constants = {});
|
||||||
|
|
||||||
// Erase cached compile functions
|
// Erase cached compile functions
|
||||||
void compile_erase(size_t fun_id);
|
void compile_erase(size_t fun_id);
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
import math
|
import math
|
||||||
|
from functools import partial
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
@ -9,13 +10,13 @@ from mlx.nn.layers.base import Module
|
|||||||
|
|
||||||
def _make_activation_module(f):
|
def _make_activation_module(f):
|
||||||
def decorator(klass):
|
def decorator(klass):
|
||||||
klass.__doc__ = f.__doc__
|
klass.__call__ = lambda _, x: f(x)
|
||||||
klass.__call__ = lambda self, x: f(x)
|
|
||||||
return klass
|
return klass
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
@partial(mx.compile, shapeless=True)
|
||||||
def sigmoid(x):
|
def sigmoid(x):
|
||||||
r"""Applies the element-wise function:
|
r"""Applies the element-wise function:
|
||||||
|
|
||||||
@ -25,6 +26,7 @@ def sigmoid(x):
|
|||||||
return mx.sigmoid(x)
|
return mx.sigmoid(x)
|
||||||
|
|
||||||
|
|
||||||
|
@partial(mx.compile, shapeless=True)
|
||||||
def relu(x):
|
def relu(x):
|
||||||
r"""Applies the Rectified Linear Unit.
|
r"""Applies the Rectified Linear Unit.
|
||||||
|
|
||||||
@ -33,6 +35,7 @@ def relu(x):
|
|||||||
return mx.maximum(x, 0)
|
return mx.maximum(x, 0)
|
||||||
|
|
||||||
|
|
||||||
|
@partial(mx.compile, shapeless=True)
|
||||||
def leaky_relu(x, negative_slope=0.01):
|
def leaky_relu(x, negative_slope=0.01):
|
||||||
r"""Applies the Leaky Rectified Linear Unit.
|
r"""Applies the Leaky Rectified Linear Unit.
|
||||||
|
|
||||||
@ -41,6 +44,7 @@ def leaky_relu(x, negative_slope=0.01):
|
|||||||
return mx.maximum(negative_slope * x, x)
|
return mx.maximum(negative_slope * x, x)
|
||||||
|
|
||||||
|
|
||||||
|
@partial(mx.compile, shapeless=True)
|
||||||
def log_softmax(x, axis=-1):
|
def log_softmax(x, axis=-1):
|
||||||
r"""Applies the Log Softmax function.
|
r"""Applies the Log Softmax function.
|
||||||
|
|
||||||
@ -49,6 +53,7 @@ def log_softmax(x, axis=-1):
|
|||||||
return x - mx.logsumexp(x, axis=axis, keepdims=True)
|
return x - mx.logsumexp(x, axis=axis, keepdims=True)
|
||||||
|
|
||||||
|
|
||||||
|
@partial(mx.compile, shapeless=True)
|
||||||
def elu(x, alpha=1.0):
|
def elu(x, alpha=1.0):
|
||||||
r"""Applies the Exponential Linear Unit.
|
r"""Applies the Exponential Linear Unit.
|
||||||
|
|
||||||
@ -57,6 +62,7 @@ def elu(x, alpha=1.0):
|
|||||||
return mx.where(x > 0, x, alpha * (mx.exp(x) - 1))
|
return mx.where(x > 0, x, alpha * (mx.exp(x) - 1))
|
||||||
|
|
||||||
|
|
||||||
|
@partial(mx.compile, shapeless=True)
|
||||||
def relu6(x):
|
def relu6(x):
|
||||||
r"""Applies the Rectified Linear Unit 6.
|
r"""Applies the Rectified Linear Unit 6.
|
||||||
|
|
||||||
@ -65,6 +71,7 @@ def relu6(x):
|
|||||||
return mx.minimum(mx.maximum(x, 0), 6.0)
|
return mx.minimum(mx.maximum(x, 0), 6.0)
|
||||||
|
|
||||||
|
|
||||||
|
@partial(mx.compile, shapeless=True)
|
||||||
def softmax(x, axis=-1):
|
def softmax(x, axis=-1):
|
||||||
r"""Applies the Softmax function.
|
r"""Applies the Softmax function.
|
||||||
|
|
||||||
@ -73,6 +80,7 @@ def softmax(x, axis=-1):
|
|||||||
return mx.softmax(x, axis=axis)
|
return mx.softmax(x, axis=axis)
|
||||||
|
|
||||||
|
|
||||||
|
@partial(mx.compile, shapeless=True)
|
||||||
def softplus(x):
|
def softplus(x):
|
||||||
r"""Applies the Softplus function.
|
r"""Applies the Softplus function.
|
||||||
|
|
||||||
@ -81,6 +89,7 @@ def softplus(x):
|
|||||||
return mx.logaddexp(x, 0)
|
return mx.logaddexp(x, 0)
|
||||||
|
|
||||||
|
|
||||||
|
@partial(mx.compile, shapeless=True)
|
||||||
def softsign(x):
|
def softsign(x):
|
||||||
r"""Applies the Softsign function.
|
r"""Applies the Softsign function.
|
||||||
|
|
||||||
@ -89,6 +98,7 @@ def softsign(x):
|
|||||||
return mx.divide(x, 1 + mx.abs(x))
|
return mx.divide(x, 1 + mx.abs(x))
|
||||||
|
|
||||||
|
|
||||||
|
@partial(mx.compile, shapeless=True)
|
||||||
def softshrink(x, lambd: float = 0.5):
|
def softshrink(x, lambd: float = 0.5):
|
||||||
r"""Applies the Softshrink activation function.
|
r"""Applies the Softshrink activation function.
|
||||||
|
|
||||||
@ -102,6 +112,7 @@ def softshrink(x, lambd: float = 0.5):
|
|||||||
return mx.where(mx.abs(x) > lambd, x - mx.sign(x) * lambd, 0)
|
return mx.where(mx.abs(x) > lambd, x - mx.sign(x) * lambd, 0)
|
||||||
|
|
||||||
|
|
||||||
|
@partial(mx.compile, shapeless=True)
|
||||||
def celu(x, alpha=1.0):
|
def celu(x, alpha=1.0):
|
||||||
r"""Applies the Continuously Differentiable Exponential Linear Unit.
|
r"""Applies the Continuously Differentiable Exponential Linear Unit.
|
||||||
|
|
||||||
@ -111,6 +122,7 @@ def celu(x, alpha=1.0):
|
|||||||
return mx.maximum(x, 0.0) + alpha * (mx.exp(mx.minimum(x, 0.0) / alpha) - 1)
|
return mx.maximum(x, 0.0) + alpha * (mx.exp(mx.minimum(x, 0.0) / alpha) - 1)
|
||||||
|
|
||||||
|
|
||||||
|
@partial(mx.compile, shapeless=True)
|
||||||
def silu(x):
|
def silu(x):
|
||||||
r"""Applies the Sigmoid Linear Unit. Also known as Swish.
|
r"""Applies the Sigmoid Linear Unit. Also known as Swish.
|
||||||
|
|
||||||
@ -120,6 +132,7 @@ def silu(x):
|
|||||||
return x * mx.sigmoid(x)
|
return x * mx.sigmoid(x)
|
||||||
|
|
||||||
|
|
||||||
|
@partial(mx.compile, shapeless=True)
|
||||||
def log_sigmoid(x):
|
def log_sigmoid(x):
|
||||||
r"""Applies the Log Sigmoid function.
|
r"""Applies the Log Sigmoid function.
|
||||||
|
|
||||||
@ -128,6 +141,7 @@ def log_sigmoid(x):
|
|||||||
return -softplus(-x)
|
return -softplus(-x)
|
||||||
|
|
||||||
|
|
||||||
|
@partial(mx.compile, shapeless=True)
|
||||||
def gelu(x):
|
def gelu(x):
|
||||||
r"""Applies the Gaussian Error Linear Units function.
|
r"""Applies the Gaussian Error Linear Units function.
|
||||||
|
|
||||||
@ -142,6 +156,7 @@ def gelu(x):
|
|||||||
return x * (1 + mx.erf(x / math.sqrt(2))) / 2
|
return x * (1 + mx.erf(x / math.sqrt(2))) / 2
|
||||||
|
|
||||||
|
|
||||||
|
@partial(mx.compile, shapeless=True)
|
||||||
def gelu_approx(x):
|
def gelu_approx(x):
|
||||||
r"""An approximation to Gaussian Error Linear Unit.
|
r"""An approximation to Gaussian Error Linear Unit.
|
||||||
|
|
||||||
@ -159,6 +174,7 @@ def gelu_approx(x):
|
|||||||
return x * mx.sigmoid(1.60033 * x * (1 + 0.0433603 * x.square()))
|
return x * mx.sigmoid(1.60033 * x * (1 + 0.0433603 * x.square()))
|
||||||
|
|
||||||
|
|
||||||
|
@partial(mx.compile, shapeless=True)
|
||||||
def gelu_fast_approx(x):
|
def gelu_fast_approx(x):
|
||||||
r"""A fast approximation to Gaussian Error Linear Unit.
|
r"""A fast approximation to Gaussian Error Linear Unit.
|
||||||
|
|
||||||
@ -192,27 +208,7 @@ def glu(x: mx.array, axis: int = -1) -> mx.array:
|
|||||||
return a * mx.sigmoid(b)
|
return a * mx.sigmoid(b)
|
||||||
|
|
||||||
|
|
||||||
class GLU(Module):
|
@partial(mx.compile, shapeless=True)
|
||||||
r"""Applies the gated linear unit function.
|
|
||||||
|
|
||||||
This function splits the ``axis`` dimension of the input into two halves
|
|
||||||
(:math:`a` and :math:`b`) and applies :math:`a * \sigma(b)`.
|
|
||||||
|
|
||||||
.. math::
|
|
||||||
textrm{GLU}(x) = a * \sigma(b)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
axis (int): The dimension to split along. Default: ``-1``
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, axis: int = -1):
|
|
||||||
super().__init__()
|
|
||||||
self.axis = axis
|
|
||||||
|
|
||||||
def __call__(self, x) -> Any:
|
|
||||||
return glu(x=x, axis=self.axis)
|
|
||||||
|
|
||||||
|
|
||||||
def step(x: mx.array, threshold: float = 0.0):
|
def step(x: mx.array, threshold: float = 0.0):
|
||||||
r"""Applies the Step Activation Function.
|
r"""Applies the Step Activation Function.
|
||||||
|
|
||||||
@ -232,6 +228,7 @@ def step(x: mx.array, threshold: float = 0.0):
|
|||||||
return mx.where(x > threshold, 1, 0)
|
return mx.where(x > threshold, 1, 0)
|
||||||
|
|
||||||
|
|
||||||
|
@partial(mx.compile, shapeless=True)
|
||||||
def selu(x):
|
def selu(x):
|
||||||
r"""Applies the Scaled Exponential Linear Unit.
|
r"""Applies the Scaled Exponential Linear Unit.
|
||||||
|
|
||||||
@ -248,6 +245,7 @@ def selu(x):
|
|||||||
return elu(x, 1.67326) * 1.0507
|
return elu(x, 1.67326) * 1.0507
|
||||||
|
|
||||||
|
|
||||||
|
@partial(mx.compile, shapeless=True)
|
||||||
def prelu(x: mx.array, alpha: mx.array) -> mx.array:
|
def prelu(x: mx.array, alpha: mx.array) -> mx.array:
|
||||||
r"""Applies the element-wise parametric ReLU.
|
r"""Applies the element-wise parametric ReLU.
|
||||||
|
|
||||||
@ -259,6 +257,7 @@ def prelu(x: mx.array, alpha: mx.array) -> mx.array:
|
|||||||
return mx.maximum(0, x) + alpha * mx.minimum(0, x)
|
return mx.maximum(0, x) + alpha * mx.minimum(0, x)
|
||||||
|
|
||||||
|
|
||||||
|
@partial(mx.compile, shapeless=True)
|
||||||
def mish(x: mx.array) -> mx.array:
|
def mish(x: mx.array) -> mx.array:
|
||||||
r"""Applies the Mish function, element-wise.
|
r"""Applies the Mish function, element-wise.
|
||||||
Mish: A Self Regularized Non-Monotonic Neural Activation Function.
|
Mish: A Self Regularized Non-Monotonic Neural Activation Function.
|
||||||
@ -272,6 +271,7 @@ def mish(x: mx.array) -> mx.array:
|
|||||||
return x * mx.tanh(softplus(x))
|
return x * mx.tanh(softplus(x))
|
||||||
|
|
||||||
|
|
||||||
|
@partial(mx.compile, shapeless=True)
|
||||||
def hardswish(x):
|
def hardswish(x):
|
||||||
r"""Applies the hardswish function, element-wise.
|
r"""Applies the hardswish function, element-wise.
|
||||||
|
|
||||||
@ -282,6 +282,35 @@ def hardswish(x):
|
|||||||
return x * mx.minimum(max_x_3, 6) / 6
|
return x * mx.minimum(max_x_3, 6) / 6
|
||||||
|
|
||||||
|
|
||||||
|
def tanh(x):
|
||||||
|
"""Applies the hyperbolic tangent function.
|
||||||
|
|
||||||
|
Simply ``mx.tanh(x)``.
|
||||||
|
"""
|
||||||
|
return mx.tanh(x)
|
||||||
|
|
||||||
|
|
||||||
|
class GLU(Module):
|
||||||
|
r"""Applies the gated linear unit function.
|
||||||
|
|
||||||
|
This function splits the ``axis`` dimension of the input into two halves
|
||||||
|
(:math:`a` and :math:`b`) and applies :math:`a * \sigma(b)`.
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
textrm{GLU}(x) = a * \sigma(b)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
axis (int): The dimension to split along. Default: ``-1``
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, axis: int = -1):
|
||||||
|
super().__init__()
|
||||||
|
self.axis = axis
|
||||||
|
|
||||||
|
def __call__(self, x) -> Any:
|
||||||
|
return glu(x=x, axis=self.axis)
|
||||||
|
|
||||||
|
|
||||||
@_make_activation_module(mx.sigmoid)
|
@_make_activation_module(mx.sigmoid)
|
||||||
class Sigmoid(Module):
|
class Sigmoid(Module):
|
||||||
r"""Applies the sigmoid function, element-wise.
|
r"""Applies the sigmoid function, element-wise.
|
||||||
@ -500,14 +529,6 @@ class GELU(Module):
|
|||||||
return self._act(x)
|
return self._act(x)
|
||||||
|
|
||||||
|
|
||||||
def tanh(x):
|
|
||||||
"""Applies the hyperbolic tangent function.
|
|
||||||
|
|
||||||
Simply ``mx.tanh(x)``.
|
|
||||||
"""
|
|
||||||
return mx.tanh(x)
|
|
||||||
|
|
||||||
|
|
||||||
@_make_activation_module(tanh)
|
@_make_activation_module(tanh)
|
||||||
class Tanh(Module):
|
class Tanh(Module):
|
||||||
r"""Applies the hyperbolic tangent function.
|
r"""Applies the hyperbolic tangent function.
|
||||||
|
@ -555,13 +555,19 @@ struct PyCompiledFun {
|
|||||||
size_t fun_id;
|
size_t fun_id;
|
||||||
py::object captured_inputs;
|
py::object captured_inputs;
|
||||||
py::object captured_outputs;
|
py::object captured_outputs;
|
||||||
|
bool shapeless;
|
||||||
size_t num_outputs{0};
|
size_t num_outputs{0};
|
||||||
|
|
||||||
PyCompiledFun(const py::function& fun, py::object inputs, py::object outputs)
|
PyCompiledFun(
|
||||||
|
const py::function& fun,
|
||||||
|
py::object inputs,
|
||||||
|
py::object outputs,
|
||||||
|
bool shapeless)
|
||||||
: fun(fun),
|
: fun(fun),
|
||||||
fun_id(reinterpret_cast<size_t>(fun.ptr())),
|
fun_id(reinterpret_cast<size_t>(fun.ptr())),
|
||||||
captured_inputs(inputs),
|
captured_inputs(inputs),
|
||||||
captured_outputs(outputs) {}
|
captured_outputs(outputs),
|
||||||
|
shapeless(shapeless) {}
|
||||||
|
|
||||||
PyCompiledFun(const PyCompiledFun&) = delete;
|
PyCompiledFun(const PyCompiledFun&) = delete;
|
||||||
PyCompiledFun& operator=(const PyCompiledFun&) = delete;
|
PyCompiledFun& operator=(const PyCompiledFun&) = delete;
|
||||||
@ -571,11 +577,15 @@ struct PyCompiledFun {
|
|||||||
other.fun_id = 0;
|
other.fun_id = 0;
|
||||||
captured_inputs = std::move(other.captured_inputs);
|
captured_inputs = std::move(other.captured_inputs);
|
||||||
captured_outputs = std::move(other.captured_outputs);
|
captured_outputs = std::move(other.captured_outputs);
|
||||||
|
shapeless = other.shapeless;
|
||||||
num_outputs = other.num_outputs;
|
num_outputs = other.num_outputs;
|
||||||
};
|
};
|
||||||
|
|
||||||
py::object operator()(const py::args& args) {
|
py::object operator()(const py::args& args, const py::kwargs& kwargs) {
|
||||||
auto compile_fun = [this, &args](const std::vector<array>& a) {
|
auto inputs = tree_flatten(args, false);
|
||||||
|
|
||||||
|
auto compile_fun = [this, &args, &kwargs, num_args = inputs.size()](
|
||||||
|
const std::vector<array>& a) {
|
||||||
// Put tracers into captured inputs
|
// Put tracers into captured inputs
|
||||||
std::vector<array> flat_in_captures;
|
std::vector<array> flat_in_captures;
|
||||||
std::vector<array> trace_captures;
|
std::vector<array> trace_captures;
|
||||||
@ -586,8 +596,10 @@ struct PyCompiledFun {
|
|||||||
tree_fill(captured_inputs, trace_captures);
|
tree_fill(captured_inputs, trace_captures);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto [outputs, py_outputs] = tree_flatten_with_structure(
|
auto tree_outputs =
|
||||||
std::move(fun(*tree_unflatten(args, a))), false);
|
fun(*tree_unflatten(args, a), **tree_unflatten(kwargs, a, num_args));
|
||||||
|
auto [outputs, py_outputs] =
|
||||||
|
tree_flatten_with_structure(std::move(tree_outputs), false);
|
||||||
|
|
||||||
tree_cache().insert({fun_id, py_outputs});
|
tree_cache().insert({fun_id, py_outputs});
|
||||||
|
|
||||||
@ -607,7 +619,14 @@ struct PyCompiledFun {
|
|||||||
return outputs;
|
return outputs;
|
||||||
};
|
};
|
||||||
|
|
||||||
auto inputs = tree_flatten(args, false);
|
{
|
||||||
|
auto flat_kwargs = tree_flatten(kwargs, false);
|
||||||
|
inputs.insert(
|
||||||
|
inputs.end(),
|
||||||
|
std::make_move_iterator(flat_kwargs.begin()),
|
||||||
|
std::make_move_iterator(flat_kwargs.end()));
|
||||||
|
}
|
||||||
|
|
||||||
if (!py::isinstance<py::none>(captured_inputs)) {
|
if (!py::isinstance<py::none>(captured_inputs)) {
|
||||||
auto flat_in_captures = tree_flatten(captured_inputs, false);
|
auto flat_in_captures = tree_flatten(captured_inputs, false);
|
||||||
inputs.insert(
|
inputs.insert(
|
||||||
@ -616,8 +635,39 @@ struct PyCompiledFun {
|
|||||||
std::make_move_iterator(flat_in_captures.end()));
|
std::make_move_iterator(flat_in_captures.end()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Collect the compilation constants
|
||||||
|
std::vector<uint64_t> constants;
|
||||||
|
auto value_hash = [](py::handle o) -> std::optional<uint64_t> {
|
||||||
|
// Consider expanding tuples to their contents including start and end
|
||||||
|
// ids
|
||||||
|
if (py::isinstance<py::tuple>(o) || py::isinstance<py::str>(o)) {
|
||||||
|
auto r = py::hash(o);
|
||||||
|
return *reinterpret_cast<uint64_t*>(&r);
|
||||||
|
} else if (py::isinstance<py::int_>(o)) {
|
||||||
|
auto r = o.cast<int64_t>();
|
||||||
|
return *reinterpret_cast<uint64_t*>(&r);
|
||||||
|
} else if (py::isinstance<py::float_>(o)) {
|
||||||
|
auto r = o.cast<double>();
|
||||||
|
return *reinterpret_cast<uint64_t*>(&r);
|
||||||
|
} else {
|
||||||
|
return std::nullopt;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
for (int i = 0; i < args.size(); i++) {
|
||||||
|
if (auto h = value_hash(args[i]); h.has_value()) {
|
||||||
|
constants.push_back(*h);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (auto& pair : kwargs) {
|
||||||
|
if (auto h = value_hash(pair.second); h.has_value()) {
|
||||||
|
constants.push_back(*value_hash(pair.first));
|
||||||
|
constants.push_back(*h);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Compile and call
|
// Compile and call
|
||||||
auto outputs = detail::compile(compile_fun, fun_id)(inputs);
|
auto outputs =
|
||||||
|
detail::compile(compile_fun, fun_id, shapeless, constants)(inputs);
|
||||||
if (!py::isinstance<py::none>(captured_outputs)) {
|
if (!py::isinstance<py::none>(captured_outputs)) {
|
||||||
std::vector<array> captures(
|
std::vector<array> captures(
|
||||||
std::make_move_iterator(outputs.begin() + num_outputs),
|
std::make_move_iterator(outputs.begin() + num_outputs),
|
||||||
@ -965,12 +1015,14 @@ void init_transforms(py::module_& m) {
|
|||||||
"compile",
|
"compile",
|
||||||
[](const py::function& fun,
|
[](const py::function& fun,
|
||||||
const py::object& inputs,
|
const py::object& inputs,
|
||||||
const py::object& outputs) {
|
const py::object& outputs,
|
||||||
return py::cpp_function(PyCompiledFun{fun, inputs, outputs});
|
bool shapeless) {
|
||||||
|
return py::cpp_function(PyCompiledFun{fun, inputs, outputs, shapeless});
|
||||||
},
|
},
|
||||||
"fun"_a,
|
"fun"_a,
|
||||||
"inputs"_a = std::nullopt,
|
"inputs"_a = std::nullopt,
|
||||||
"outputs"_a = std::nullopt,
|
"outputs"_a = std::nullopt,
|
||||||
|
"shapeless"_a = false,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
compile(fun: function) -> function
|
compile(fun: function) -> function
|
||||||
|
|
||||||
@ -990,6 +1042,12 @@ void init_transforms(py::module_& m) {
|
|||||||
:obj:`list` or a :obj:`dict` containing arbitrarily nested lists,
|
:obj:`list` or a :obj:`dict` containing arbitrarily nested lists,
|
||||||
dictionaries, or arrays. Leaf nodes that are not :obj:`array` are ignored.
|
dictionaries, or arrays. Leaf nodes that are not :obj:`array` are ignored.
|
||||||
Default: ``None``
|
Default: ``None``
|
||||||
|
shapeless (bool, optional): A function compiled with the ``shapeless``
|
||||||
|
option enabled will not be recompiled when the input shape changes. Not all
|
||||||
|
functions can be compiled with ``shapeless`` enabled. Attempting to compile
|
||||||
|
such functions with shapeless enabled will throw. Note, changing the number
|
||||||
|
of dimensions or type of any input will result in a recompilation even with
|
||||||
|
``shapeless`` set to ``True``. Default: ``False``
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
function: A compiled function which has the same input arguments
|
function: A compiled function which has the same input arguments
|
||||||
|
@ -381,6 +381,164 @@ class TestCompile(mlx_tests.MLXTestCase):
|
|||||||
|
|
||||||
self.assertFalse(mx.allclose(fun(), fun(), 1e-2, 1e-2))
|
self.assertFalse(mx.allclose(fun(), fun(), 1e-2, 1e-2))
|
||||||
|
|
||||||
|
def test_compile_kwargs(self):
|
||||||
|
|
||||||
|
@mx.compile
|
||||||
|
def fun(x, y, z):
|
||||||
|
return x + y + z
|
||||||
|
|
||||||
|
x = mx.array(1)
|
||||||
|
y = mx.array(2)
|
||||||
|
z = mx.array(3)
|
||||||
|
out = fun(x, y=y, z=z)
|
||||||
|
self.assertEqual(out.item(), 6)
|
||||||
|
|
||||||
|
def test_shapeless_compile(self):
|
||||||
|
y = 1
|
||||||
|
|
||||||
|
@partial(mx.compile, shapeless=True)
|
||||||
|
def fun(x):
|
||||||
|
return x + y
|
||||||
|
|
||||||
|
x = mx.array([1, 2])
|
||||||
|
self.assertTrue(mx.array_equal(fun(x), mx.array([2, 3])))
|
||||||
|
|
||||||
|
# The function is not recompiled, so the change
|
||||||
|
# to y should not be reflected in the output
|
||||||
|
y = 2
|
||||||
|
x = mx.array([1, 2, 3])
|
||||||
|
self.assertTrue(mx.array_equal(fun(x), mx.array([2, 3, 4])))
|
||||||
|
|
||||||
|
# Type change recompiles
|
||||||
|
x = mx.array([1.0, 2.0, 3.0])
|
||||||
|
self.assertTrue(mx.array_equal(fun(x), mx.array([3.0, 4.0, 5.0])))
|
||||||
|
fun(x, y=y, z=z)
|
||||||
|
|
||||||
|
def test_shapeless_compile(self):
|
||||||
|
y = 1
|
||||||
|
|
||||||
|
@partial(mx.compile, shapeless=True)
|
||||||
|
def fun(x):
|
||||||
|
return x + y
|
||||||
|
|
||||||
|
x = mx.array([1, 2])
|
||||||
|
self.assertTrue(mx.array_equal(fun(x), mx.array([2, 3])))
|
||||||
|
|
||||||
|
# The function is not recompiled, so the change
|
||||||
|
# to y should not be reflected in the output
|
||||||
|
y = 2
|
||||||
|
x = mx.array([1, 2, 3])
|
||||||
|
self.assertTrue(mx.array_equal(fun(x), mx.array([2, 3, 4])))
|
||||||
|
|
||||||
|
# Type change recompiles
|
||||||
|
x = mx.array([1.0, 2.0, 3.0])
|
||||||
|
self.assertTrue(mx.array_equal(fun(x), mx.array([3.0, 4.0, 5.0])))
|
||||||
|
|
||||||
|
# Dim change recompiles
|
||||||
|
x = mx.array([[1, 2, 3]])
|
||||||
|
self.assertTrue(mx.array_equal(fun(x), mx.array([[3, 4, 5]])))
|
||||||
|
|
||||||
|
def test_shapeless_compile_with_broadcasts(self):
|
||||||
|
x = mx.ones((2, 2))
|
||||||
|
y = mx.array([2, 2])
|
||||||
|
|
||||||
|
def fun(x, y):
|
||||||
|
return x * y
|
||||||
|
|
||||||
|
cfun = mx.compile(fun, shapeless=True)
|
||||||
|
self.assertTrue(mx.array_equal(cfun(x, y), fun(x, y)))
|
||||||
|
self.assertTrue(mx.array_equal(cfun(y, x), fun(y, x)))
|
||||||
|
y = mx.array([[3]])
|
||||||
|
self.assertTrue(mx.array_equal(cfun(x, y), fun(x, y)))
|
||||||
|
self.assertTrue(mx.array_equal(cfun(y, x), fun(y, x)))
|
||||||
|
|
||||||
|
def test_shapeless_compile_with_reduction(self):
|
||||||
|
# Test shapeless compile with a reduction
|
||||||
|
z = 1
|
||||||
|
|
||||||
|
@partial(mx.compile, shapeless=True)
|
||||||
|
def fun(x, y):
|
||||||
|
return x + y.sum(0, keepdims=True) + z
|
||||||
|
|
||||||
|
x = mx.ones((2, 2), mx.int32)
|
||||||
|
y = mx.ones((2, 2), mx.int32)
|
||||||
|
self.assertTrue(mx.array_equal(fun(x, y), mx.full(shape=(2, 2), vals=4)))
|
||||||
|
x = mx.ones((3, 3), mx.int32)
|
||||||
|
y = mx.ones((3, 3), mx.int32)
|
||||||
|
z = 2
|
||||||
|
self.assertTrue(mx.array_equal(fun(x, y), mx.full(shape=(3, 3), vals=5)))
|
||||||
|
|
||||||
|
x1 = mx.array([[1, 2], [3, 4], [5, 6]])
|
||||||
|
x2 = mx.array([[1, 2]])
|
||||||
|
|
||||||
|
def fun(x):
|
||||||
|
return x * x.sum(-1, keepdims=True)
|
||||||
|
|
||||||
|
cfun = mx.compile(fun, shapeless=True)
|
||||||
|
mx.eval(cfun(x1))
|
||||||
|
self.assertTrue(mx.array_equal(fun(x2), cfun(x2)))
|
||||||
|
|
||||||
|
def test_compile_with_constant(self):
|
||||||
|
|
||||||
|
# Test float
|
||||||
|
@partial(mx.compile)
|
||||||
|
def fun(x, y):
|
||||||
|
return x + y
|
||||||
|
|
||||||
|
z = fun(mx.array(1.0), 1.0)
|
||||||
|
self.assertEqual(z.item(), 2.0)
|
||||||
|
|
||||||
|
z = fun(mx.array(1.0), 2.0)
|
||||||
|
self.assertEqual(z.item(), 3.0)
|
||||||
|
|
||||||
|
z = fun(mx.array(1.0), y=1.0)
|
||||||
|
self.assertEqual(z.item(), 2.0)
|
||||||
|
|
||||||
|
z = fun(mx.array(1.0), y=3.0)
|
||||||
|
self.assertEqual(z.item(), 4.0)
|
||||||
|
|
||||||
|
# Test tuple
|
||||||
|
@partial(mx.compile)
|
||||||
|
def fun(x, y=(1, 2)):
|
||||||
|
return x + y[0] + y[1]
|
||||||
|
|
||||||
|
z = fun(mx.array(1))
|
||||||
|
self.assertEqual(z.item(), 4)
|
||||||
|
|
||||||
|
z = fun(mx.array(1), (2, 2))
|
||||||
|
self.assertEqual(z.item(), 5)
|
||||||
|
|
||||||
|
z = fun(mx.array(1), (2, 1))
|
||||||
|
self.assertEqual(z.item(), 4)
|
||||||
|
|
||||||
|
# Test bool
|
||||||
|
@partial(mx.compile)
|
||||||
|
def fun(x, y):
|
||||||
|
if y:
|
||||||
|
return x + 1
|
||||||
|
else:
|
||||||
|
return x + 2
|
||||||
|
|
||||||
|
z = fun(mx.array(1), True)
|
||||||
|
self.assertEqual(z.item(), 2)
|
||||||
|
|
||||||
|
z = fun(mx.array(1), False)
|
||||||
|
self.assertEqual(z.item(), 3)
|
||||||
|
|
||||||
|
# Test string
|
||||||
|
@partial(mx.compile)
|
||||||
|
def fun(x, y):
|
||||||
|
if y == "one":
|
||||||
|
return x + 1
|
||||||
|
else:
|
||||||
|
return x + 2
|
||||||
|
|
||||||
|
z = fun(mx.array(1), "one")
|
||||||
|
self.assertEqual(z.item(), 2)
|
||||||
|
|
||||||
|
z = fun(mx.array(1), "two")
|
||||||
|
self.assertEqual(z.item(), 3)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
@ -624,31 +624,23 @@ TEST_CASE("test transform compiled function") {
|
|||||||
CHECK(!outs[0].inputs()[1].has_primitive());
|
CHECK(!outs[0].inputs()[1].has_primitive());
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_CASE("test metal fusion kernel reuse") {
|
TEST_CASE("test fusion kernel reuse") {
|
||||||
if (default_device() != Device::gpu) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto cfun = compile(gelu_1);
|
auto cfun = compile(gelu_1);
|
||||||
auto x = array({2.0f, -2.0f});
|
auto x = array({2.0f, -2.0f});
|
||||||
auto y = cfun({x})[0];
|
auto y = cfun({x})[0];
|
||||||
auto p = std::dynamic_pointer_cast<Compiled>(y.primitive_ptr());
|
auto p = std::dynamic_pointer_cast<Compiled>(y.primitive_ptr());
|
||||||
eval(y);
|
eval(y);
|
||||||
|
|
||||||
std::string lib_name = p->metal_lib_name();
|
std::string lib_name = p->lib_name();
|
||||||
std::string lib_source = p->metal_lib_source();
|
|
||||||
CHECK(!lib_name.empty());
|
CHECK(!lib_name.empty());
|
||||||
CHECK(!lib_source.empty());
|
|
||||||
|
|
||||||
x = astype(reshape(arange(10), {2, 5}), float32);
|
x = astype(reshape(arange(10), {2, 5}), float32);
|
||||||
auto z = cfun({x})[0];
|
auto z = cfun({x})[0];
|
||||||
auto pz = std::dynamic_pointer_cast<Compiled>(z.primitive_ptr());
|
auto pz = std::dynamic_pointer_cast<Compiled>(z.primitive_ptr());
|
||||||
eval(z);
|
eval(z);
|
||||||
|
|
||||||
std::string lib_name_z = pz->metal_lib_name();
|
std::string lib_name_z = pz->lib_name();
|
||||||
std::string lib_source_z = pz->metal_lib_source();
|
|
||||||
CHECK(!lib_name_z.empty());
|
CHECK(!lib_name_z.empty());
|
||||||
CHECK(lib_source_z.empty());
|
|
||||||
|
|
||||||
CHECK_EQ(lib_name, lib_name_z);
|
CHECK_EQ(lib_name, lib_name_z);
|
||||||
}
|
}
|
||||||
@ -657,29 +649,57 @@ auto add3(const std::vector<array>& xs) {
|
|||||||
return std::vector<array>{xs[0] + xs[0] + xs[0]};
|
return std::vector<array>{xs[0] + xs[0] + xs[0]};
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_CASE("test metal fusion types") {
|
TEST_CASE("test fusion types") {
|
||||||
if (default_device() != Device::gpu) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto cfun = compile(add3);
|
auto cfun = compile(add3);
|
||||||
auto x = array({2.0f, -2.0f});
|
auto x = array({2.0f, -2.0f});
|
||||||
auto y = cfun({x})[0];
|
auto y = cfun({x})[0];
|
||||||
auto p = std::dynamic_pointer_cast<Compiled>(y.primitive_ptr());
|
auto p = std::dynamic_pointer_cast<Compiled>(y.primitive_ptr());
|
||||||
eval(y);
|
eval(y);
|
||||||
|
|
||||||
std::string lib_name = p->metal_lib_name();
|
std::string lib_name = p->lib_name();
|
||||||
std::string lib_source = p->metal_lib_source();
|
|
||||||
CHECK(!lib_name.empty());
|
CHECK(!lib_name.empty());
|
||||||
CHECK(!lib_source.empty());
|
|
||||||
|
|
||||||
x = array({2, -2}, int32);
|
x = array({2, -2}, int32);
|
||||||
auto z = cfun({x})[0];
|
auto z = cfun({x})[0];
|
||||||
auto pz = std::dynamic_pointer_cast<Compiled>(z.primitive_ptr());
|
auto pz = std::dynamic_pointer_cast<Compiled>(z.primitive_ptr());
|
||||||
eval(z);
|
eval(z);
|
||||||
|
|
||||||
std::string lib_name_z = pz->metal_lib_name();
|
std::string lib_name_z = pz->lib_name();
|
||||||
std::string lib_source_z = pz->metal_lib_source();
|
|
||||||
CHECK(!lib_name_z.empty());
|
CHECK(!lib_name_z.empty());
|
||||||
CHECK(!lib_source_z.empty());
|
}
|
||||||
|
|
||||||
|
auto compile_shapeless_not_ok(const std::vector<array>& inputs) {
|
||||||
|
auto x = reshape(inputs[0], {2, 2});
|
||||||
|
return std::vector<array>{x};
|
||||||
|
}
|
||||||
|
|
||||||
|
auto compile_shapeless_ok(const std::vector<array>& inputs) {
|
||||||
|
auto x = inputs[0] + array({2});
|
||||||
|
return std::vector<array>{x};
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test shapeless compile") {
|
||||||
|
{
|
||||||
|
auto cfun = compile(compile_shapeless_not_ok, /* shapeless */ true);
|
||||||
|
CHECK_THROWS(cfun({array({1, 2, 3, 4})}));
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
auto cfun = compile(compile_shapeless_ok, /* shapeless */ true);
|
||||||
|
auto out = cfun({array({1, 2})})[0];
|
||||||
|
auto out2 = cfun({array({1, 2, 3, 4})})[0];
|
||||||
|
|
||||||
|
// Not making a new constant array since no recompile,
|
||||||
|
// hence the ids should be the same
|
||||||
|
CHECK_EQ(out.inputs()[1].id(), out2.inputs()[1].id());
|
||||||
|
CHECK(array_equal(out2, array({3, 4, 5, 6})).item<bool>());
|
||||||
|
|
||||||
|
// Recompile since type changes
|
||||||
|
out2 = cfun({array({1.0, 2.0})})[0];
|
||||||
|
CHECK_NE(out.inputs()[1].id(), out2.inputs()[1].id());
|
||||||
|
|
||||||
|
// Recompile since ndim changes
|
||||||
|
out2 = cfun({array({1.0, 2.0}, {1, 2})})[0];
|
||||||
|
CHECK_NE(out.inputs()[1].id(), out2.inputs()[1].id());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user