mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
More jitting (#1132)
* docs + circle min size build * jit scan, arange, softmax * add sort * jit reductions * remove print * fix deps * clean includes / nits
This commit is contained in:
@@ -5,6 +5,7 @@
|
||||
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/kernels.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
@@ -28,30 +29,33 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
in = arr_copy;
|
||||
}
|
||||
|
||||
std::ostringstream kname;
|
||||
if (in.strides()[axis_] == 1) {
|
||||
kname << "contiguous_scan_";
|
||||
if (reverse_) {
|
||||
kname << "reverse_";
|
||||
}
|
||||
kname << ((inclusive_) ? "inclusive_" : "exclusive_");
|
||||
switch (reduce_type_) {
|
||||
case Scan::Sum:
|
||||
kname << "sum_";
|
||||
break;
|
||||
case Scan::Prod:
|
||||
kname << "prod_";
|
||||
break;
|
||||
case Scan::Max:
|
||||
kname << "max_";
|
||||
break;
|
||||
case Scan::Min:
|
||||
kname << "min_";
|
||||
break;
|
||||
}
|
||||
kname << type_to_name(in) << "_" << type_to_name(out);
|
||||
bool contiguous = in.strides()[axis_] == 1;
|
||||
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
std::ostringstream kname;
|
||||
kname << (contiguous ? "contig_" : "strided_");
|
||||
kname << "scan_";
|
||||
if (reverse_) {
|
||||
kname << "reverse_";
|
||||
}
|
||||
kname << ((inclusive_) ? "inclusive_" : "exclusive_");
|
||||
switch (reduce_type_) {
|
||||
case Scan::Sum:
|
||||
kname << "sum_";
|
||||
break;
|
||||
case Scan::Prod:
|
||||
kname << "prod_";
|
||||
break;
|
||||
case Scan::Max:
|
||||
kname << "max_";
|
||||
break;
|
||||
case Scan::Min:
|
||||
kname << "min_";
|
||||
break;
|
||||
}
|
||||
kname << type_to_name(in) << "_" << type_to_name(out);
|
||||
auto kernel = get_scan_kernel(d, kname.str(), reverse_, inclusive_, in, out);
|
||||
|
||||
if (contiguous) {
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
@@ -79,28 +83,6 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
} else {
|
||||
kname << "strided_scan_";
|
||||
if (reverse_) {
|
||||
kname << "reverse_";
|
||||
}
|
||||
kname << ((inclusive_) ? "inclusive_" : "exclusive_");
|
||||
switch (reduce_type_) {
|
||||
case Scan::Sum:
|
||||
kname << "sum_";
|
||||
break;
|
||||
case Scan::Prod:
|
||||
kname << "prod_";
|
||||
break;
|
||||
case Scan::Max:
|
||||
kname << "max_";
|
||||
break;
|
||||
case Scan::Min:
|
||||
kname << "min_";
|
||||
break;
|
||||
}
|
||||
kname << type_to_name(in) << "_" << type_to_name(out);
|
||||
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
|
||||
Reference in New Issue
Block a user