mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-25 12:48:14 +08:00
Shapeless compilation for some graphs (#687)
* shapeless compilation for some graphs * update compile benchmark * default compile a few activations * buffer donation * bugfix * shapeless fix * update tests to work for cpu and gpu fusion * test kwargs * add kwargs to compile * Recompile when python arguments change * no compile for tanh * some constant tests --------- Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
@@ -37,7 +37,7 @@ std::string build_lib_name(
|
||||
os << "C";
|
||||
print_constant(constant_hasher, x);
|
||||
} else {
|
||||
os << ((x.size() == 1) ? "S" : "V");
|
||||
os << (is_scalar(x) ? "S" : "V");
|
||||
}
|
||||
}
|
||||
os << "_";
|
||||
@@ -122,10 +122,6 @@ std::string get_type_string(Dtype d) {
|
||||
}
|
||||
}
|
||||
|
||||
inline bool is_scalar(const array& x) {
|
||||
return x.size() == 1;
|
||||
};
|
||||
|
||||
// Return a pointer to a compiled function
|
||||
void* compile(
|
||||
const std::string& kernel_name,
|
||||
@@ -358,7 +354,7 @@ void Compiled::eval_cpu(
|
||||
bool all_col_contig = true;
|
||||
int non_scalar_inputs = 0;
|
||||
for (auto& x : inputs) {
|
||||
if (x.size() == 1) {
|
||||
if (is_scalar(x)) {
|
||||
continue;
|
||||
}
|
||||
non_scalar_inputs++;
|
||||
@@ -385,7 +381,7 @@ void Compiled::eval_cpu(
|
||||
auto& x = inputs[i];
|
||||
args.push_back((void*)x.data<void>());
|
||||
|
||||
if (contiguous || x.size() <= 1) {
|
||||
if (contiguous || is_scalar(x)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -458,7 +454,7 @@ void Compiled::eval_cpu(
|
||||
// - Donatable
|
||||
// - Correct size
|
||||
// - Not a constant
|
||||
if (in.flags().contiguous && in.size() > 1 && in.is_donatable() &&
|
||||
if (in.flags().contiguous && !is_scalar(in) && in.is_donatable() &&
|
||||
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
|
||||
outputs[o++].copy_shared_buffer(in);
|
||||
}
|
||||
|
||||
@@ -49,4 +49,8 @@ void print_complex_constant(std::ostream& os, const array& x) {
|
||||
|
||||
void print_constant(std::ostream& os, const array& x);
|
||||
|
||||
inline bool is_scalar(const array& x) {
|
||||
return x.ndim() == 0;
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -31,9 +31,6 @@ inline void build_kernel(
|
||||
return constant_ids.find(x.id()) != constant_ids.end();
|
||||
};
|
||||
|
||||
// For scalar we shouldn't do the indexing things, just read at 0
|
||||
auto is_scalar = [](const array& x) { return x.size() == 1; };
|
||||
|
||||
NodeNamer namer;
|
||||
bool add_indices = false;
|
||||
int cnt = 0;
|
||||
@@ -226,8 +223,7 @@ void Compiled::eval_gpu(
|
||||
/* ndim = */ 0,
|
||||
/* dynamic_dims = */ true);
|
||||
|
||||
kernel_source_ = kernel.str();
|
||||
lib = d.get_library(kernel_lib_, kernel_source_);
|
||||
lib = d.get_library(kernel_lib_, kernel.str());
|
||||
}
|
||||
|
||||
// Figure out which kernel we are using
|
||||
@@ -235,7 +231,7 @@ void Compiled::eval_gpu(
|
||||
bool contiguous = true;
|
||||
for (auto& x : inputs) {
|
||||
if ((!x.flags().row_contiguous || x.shape() != output_shape) &&
|
||||
x.size() > 1) {
|
||||
!is_scalar(x)) {
|
||||
contiguous = false;
|
||||
break;
|
||||
}
|
||||
@@ -256,7 +252,7 @@ void Compiled::eval_gpu(
|
||||
auto& x = inputs[i];
|
||||
|
||||
// Skip scalar inputs.
|
||||
if (x.size() <= 1) {
|
||||
if (is_scalar(x)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -311,7 +307,7 @@ void Compiled::eval_gpu(
|
||||
}
|
||||
auto& x = inputs[i];
|
||||
set_array_buffer(compute_encoder, x, cnt++);
|
||||
if (!contiguous && x.size() > 1) {
|
||||
if (!contiguous && !is_scalar(x)) {
|
||||
compute_encoder->setBytes(
|
||||
strides[stride_idx].data(),
|
||||
strides[stride_idx].size() * sizeof(size_t),
|
||||
|
||||
Reference in New Issue
Block a user