Shapeless compilation for some graphs (#687)

* shapeless compilation for some graphs

* update compile benchmark

* default compile a few activations

* buffer donation

* bugfix

* shapeless fix

* update tests to work for cpu and gpu fusion

* test kwargs

* add kwargs to compile

* Recompile when python arguments change

* no compile for tanh

* some constant tests

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
Awni Hannun
2024-02-19 21:43:54 -08:00
committed by GitHub
parent d0fda82595
commit 5798256fcf
14 changed files with 645 additions and 113 deletions

View File

@@ -37,7 +37,7 @@ std::string build_lib_name(
os << "C";
print_constant(constant_hasher, x);
} else {
os << ((x.size() == 1) ? "S" : "V");
os << (is_scalar(x) ? "S" : "V");
}
}
os << "_";
@@ -122,10 +122,6 @@ std::string get_type_string(Dtype d) {
}
}
inline bool is_scalar(const array& x) {
return x.size() == 1;
};
// Return a pointer to a compiled function
void* compile(
const std::string& kernel_name,
@@ -358,7 +354,7 @@ void Compiled::eval_cpu(
bool all_col_contig = true;
int non_scalar_inputs = 0;
for (auto& x : inputs) {
if (x.size() == 1) {
if (is_scalar(x)) {
continue;
}
non_scalar_inputs++;
@@ -385,7 +381,7 @@ void Compiled::eval_cpu(
auto& x = inputs[i];
args.push_back((void*)x.data<void>());
if (contiguous || x.size() <= 1) {
if (contiguous || is_scalar(x)) {
continue;
}
@@ -458,7 +454,7 @@ void Compiled::eval_cpu(
// - Donatable
// - Correct size
// - Not a constant
if (in.flags().contiguous && in.size() > 1 && in.is_donatable() &&
if (in.flags().contiguous && !is_scalar(in) && in.is_donatable() &&
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
outputs[o++].copy_shared_buffer(in);
}

View File

@@ -49,4 +49,8 @@ void print_complex_constant(std::ostream& os, const array& x) {
void print_constant(std::ostream& os, const array& x);
inline bool is_scalar(const array& x) {
return x.ndim() == 0;
}
} // namespace mlx::core

View File

@@ -31,9 +31,6 @@ inline void build_kernel(
return constant_ids.find(x.id()) != constant_ids.end();
};
// For scalar we shouldn't do the indexing things, just read at 0
auto is_scalar = [](const array& x) { return x.size() == 1; };
NodeNamer namer;
bool add_indices = false;
int cnt = 0;
@@ -226,8 +223,7 @@ void Compiled::eval_gpu(
/* ndim = */ 0,
/* dynamic_dims = */ true);
kernel_source_ = kernel.str();
lib = d.get_library(kernel_lib_, kernel_source_);
lib = d.get_library(kernel_lib_, kernel.str());
}
// Figure out which kernel we are using
@@ -235,7 +231,7 @@ void Compiled::eval_gpu(
bool contiguous = true;
for (auto& x : inputs) {
if ((!x.flags().row_contiguous || x.shape() != output_shape) &&
x.size() > 1) {
!is_scalar(x)) {
contiguous = false;
break;
}
@@ -256,7 +252,7 @@ void Compiled::eval_gpu(
auto& x = inputs[i];
// Skip scalar inputs.
if (x.size() <= 1) {
if (is_scalar(x)) {
continue;
}
@@ -311,7 +307,7 @@ void Compiled::eval_gpu(
}
auto& x = inputs[i];
set_array_buffer(compute_encoder, x, cnt++);
if (!contiguous && x.size() > 1) {
if (!contiguous && !is_scalar(x)) {
compute_encoder->setBytes(
strides[stride_idx].data(),
strides[stride_idx].size() * sizeof(size_t),