mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-22 19:28:14 +08:00
Shapeless compilation for some graphs (#687)
* shapeless compilation for some graphs * update compile benchmark * default compile a few activations * buffer donation * bugfix * shapeless fix * update tests to work for cpu and gpu fusion * test kwargs * add kwargs to compile * Recompile when python arguments change * no compile for tanh * some constant tests --------- Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
@@ -31,9 +31,6 @@ inline void build_kernel(
|
||||
return constant_ids.find(x.id()) != constant_ids.end();
|
||||
};
|
||||
|
||||
// For scalar we shouldn't do the indexing things, just read at 0
|
||||
auto is_scalar = [](const array& x) { return x.size() == 1; };
|
||||
|
||||
NodeNamer namer;
|
||||
bool add_indices = false;
|
||||
int cnt = 0;
|
||||
@@ -226,8 +223,7 @@ void Compiled::eval_gpu(
|
||||
/* ndim = */ 0,
|
||||
/* dynamic_dims = */ true);
|
||||
|
||||
kernel_source_ = kernel.str();
|
||||
lib = d.get_library(kernel_lib_, kernel_source_);
|
||||
lib = d.get_library(kernel_lib_, kernel.str());
|
||||
}
|
||||
|
||||
// Figure out which kernel we are using
|
||||
@@ -235,7 +231,7 @@ void Compiled::eval_gpu(
|
||||
bool contiguous = true;
|
||||
for (auto& x : inputs) {
|
||||
if ((!x.flags().row_contiguous || x.shape() != output_shape) &&
|
||||
x.size() > 1) {
|
||||
!is_scalar(x)) {
|
||||
contiguous = false;
|
||||
break;
|
||||
}
|
||||
@@ -256,7 +252,7 @@ void Compiled::eval_gpu(
|
||||
auto& x = inputs[i];
|
||||
|
||||
// Skip scalar inputs.
|
||||
if (x.size() <= 1) {
|
||||
if (is_scalar(x)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -311,7 +307,7 @@ void Compiled::eval_gpu(
|
||||
}
|
||||
auto& x = inputs[i];
|
||||
set_array_buffer(compute_encoder, x, cnt++);
|
||||
if (!contiguous && x.size() > 1) {
|
||||
if (!contiguous && !is_scalar(x)) {
|
||||
compute_encoder->setBytes(
|
||||
strides[stride_idx].data(),
|
||||
strides[stride_idx].size() * sizeof(size_t),
|
||||
|
Reference in New Issue
Block a user