More jitting (#1132)

* docs + circle min size build

* jit scan, arange, softmax

* add sort

* jit reductions

* remove print

* fix deps

* clean includes / nits
This commit is contained in:
Awni Hannun
2024-05-23 16:23:44 -07:00
committed by GitHub
parent 9401507336
commit 0189ab6ab6
41 changed files with 2377 additions and 1846 deletions

View File

@@ -5,6 +5,7 @@
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/primitives.h"
@@ -28,30 +29,33 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
in = arr_copy;
}
std::ostringstream kname;
if (in.strides()[axis_] == 1) {
kname << "contiguous_scan_";
if (reverse_) {
kname << "reverse_";
}
kname << ((inclusive_) ? "inclusive_" : "exclusive_");
switch (reduce_type_) {
case Scan::Sum:
kname << "sum_";
break;
case Scan::Prod:
kname << "prod_";
break;
case Scan::Max:
kname << "max_";
break;
case Scan::Min:
kname << "min_";
break;
}
kname << type_to_name(in) << "_" << type_to_name(out);
bool contiguous = in.strides()[axis_] == 1;
auto kernel = d.get_kernel(kname.str());
std::ostringstream kname;
kname << (contiguous ? "contig_" : "strided_");
kname << "scan_";
if (reverse_) {
kname << "reverse_";
}
kname << ((inclusive_) ? "inclusive_" : "exclusive_");
switch (reduce_type_) {
case Scan::Sum:
kname << "sum_";
break;
case Scan::Prod:
kname << "prod_";
break;
case Scan::Max:
kname << "max_";
break;
case Scan::Min:
kname << "min_";
break;
}
kname << type_to_name(in) << "_" << type_to_name(out);
auto kernel = get_scan_kernel(d, kname.str(), reverse_, inclusive_, in, out);
if (contiguous) {
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(in, 0);
@@ -79,28 +83,6 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
compute_encoder.dispatchThreads(grid_dims, group_dims);
} else {
kname << "strided_scan_";
if (reverse_) {
kname << "reverse_";
}
kname << ((inclusive_) ? "inclusive_" : "exclusive_");
switch (reduce_type_) {
case Scan::Sum:
kname << "sum_";
break;
case Scan::Prod:
kname << "prod_";
break;
case Scan::Max:
kname << "max_";
break;
case Scan::Min:
kname << "min_";
break;
}
kname << type_to_name(in) << "_" << type_to_name(out);
auto kernel = d.get_kernel(kname.str());
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(in, 0);