More jitting (#1132)

* docs + circle min size build

* jit scan, arange, softmax

* add sort

* jit reductions

* remove print

* fix deps

* clean includes / nits
This commit is contained in:
Awni Hannun
2024-05-23 16:23:44 -07:00
committed by GitHub
parent 9401507336
commit 0189ab6ab6
41 changed files with 2377 additions and 1846 deletions

View File

@@ -4,6 +4,7 @@
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/primitives.h"
@@ -11,7 +12,6 @@ namespace mlx::core {
namespace {
template <bool ARGSORT>
void single_block_sort(
const Stream& s,
metal::Device& d,
@@ -19,7 +19,8 @@ void single_block_sort(
array& out,
int axis,
int bn,
int tn) {
int tn,
bool argsort) {
// Prepare shapes
int n_rows = in.size() / in.shape(axis);
@@ -46,19 +47,17 @@ void single_block_sort(
// Prepare kernel name
std::ostringstream kname;
if (ARGSORT) {
kname << "arg_";
kname << (contiguous_write ? "c" : "nc");
if (argsort) {
kname << "arg";
}
kname << "block_merge_sort_" << type_to_name(in) << "_" << type_to_name(out)
<< "_bn" << bn << "_tn" << tn;
if (!contiguous_write) {
kname << "_nc";
}
kname << "_block_sort_" << type_to_name(in) << "_" << type_to_name(out)
<< "_bn" << bn << "_tn" << tn;
auto kernel = get_sort_kernel(d, kname.str(), in, out, bn, tn);
// Prepare command encoder
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
// Set inputs
@@ -81,7 +80,6 @@ void single_block_sort(
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
}
template <bool ARGSORT>
void multi_block_sort(
const Stream& s,
metal::Device& d,
@@ -90,7 +88,8 @@ void multi_block_sort(
int axis,
int bn,
int tn,
int n_blocks) {
int n_blocks,
bool argsort) {
// Prepare shapes
int n_rows = in.size() / in.shape(axis);
@@ -136,10 +135,10 @@ void multi_block_sort(
// Do blockwise sort
{
std::ostringstream kname;
kname << "mb_block_sort_" << type_to_name(dev_vals_0) << "_"
kname << "sort_mbsort_" << type_to_name(dev_vals_0) << "_"
<< type_to_name(dev_idxs_0) << "_bn" << bn << "_tn" << tn;
auto kernel = d.get_kernel(kname.str());
auto kernel =
get_mb_sort_kernel(d, kname.str(), dev_vals_0, dev_idxs_0, bn, tn);
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(in, 0);
@@ -175,10 +174,11 @@ void multi_block_sort(
// Do partition
{
std::ostringstream kname;
kname << "mb_block_partition_" << type_to_name(dev_vals_in) << "_"
kname << "partition_mbsort_" << type_to_name(dev_vals_in) << "_"
<< type_to_name(dev_idxs_in) << "_bn" << bn << "_tn" << tn;
auto kernel = d.get_kernel(kname.str());
auto kernel =
get_mb_sort_kernel(d, kname.str(), dev_vals_0, dev_idxs_0, bn, tn);
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_output_array(block_partitions, 0);
@@ -196,10 +196,11 @@ void multi_block_sort(
// Do merge
{
std::ostringstream kname;
kname << "mb_block_merge_" << type_to_name(dev_vals_in) << "_"
kname << "merge_mbsort_" << type_to_name(dev_vals_in) << "_"
<< type_to_name(dev_idxs_in) << "_bn" << bn << "_tn" << tn;
auto kernel = d.get_kernel(kname.str());
auto kernel =
get_mb_sort_kernel(d, kname.str(), dev_vals_0, dev_idxs_0, bn, tn);
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(block_partitions, 0);
@@ -219,7 +220,7 @@ void multi_block_sort(
}
// Copy outputs with appropriate strides
array strided_out_arr = ARGSORT ? dev_idxs_out : dev_vals_out;
array strided_out_arr = argsort ? dev_idxs_out : dev_vals_out;
if (axis == strided_out_arr.ndim() - 1) {
copy_gpu_inplace(strided_out_arr, out, CopyType::Vector, s);
@@ -252,13 +253,13 @@ void multi_block_sort(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
}
template <bool ARGSORT>
void gpu_merge_sort(
const Stream& s,
metal::Device& d,
const array& in,
array& out,
int axis_) {
int axis_,
bool argsort) {
// Get size info
int axis = axis_ < 0 ? axis_ + in.ndim() : axis_;
int size_sorted_axis = in.shape(axis);
@@ -284,9 +285,9 @@ void gpu_merge_sort(
int n_blocks = (size_sorted_axis + n_per_block - 1) / n_per_block;
if (n_blocks > 1) {
return multi_block_sort<ARGSORT>(s, d, in, out, axis, bn, tn, n_blocks);
return multi_block_sort(s, d, in, out, axis, bn, tn, n_blocks, argsort);
} else {
return single_block_sort<ARGSORT>(s, d, in, out, axis, bn, tn);
return single_block_sort(s, d, in, out, axis, bn, tn, argsort);
}
}
@@ -301,7 +302,7 @@ void ArgSort::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& d = metal::device(s.device);
auto& in = inputs[0];
gpu_merge_sort<true>(s, d, in, out, axis_);
gpu_merge_sort(s, d, in, out, axis_, true);
}
void Sort::eval_gpu(const std::vector<array>& inputs, array& out) {
@@ -313,7 +314,7 @@ void Sort::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& d = metal::device(s.device);
auto& in = inputs[0];
gpu_merge_sort<false>(s, d, in, out, axis_);
gpu_merge_sort(s, d, in, out, axis_, false);
}
void ArgPartition::eval_gpu(const std::vector<array>& inputs, array& out) {
@@ -326,7 +327,7 @@ void ArgPartition::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& d = metal::device(s.device);
auto& in = inputs[0];
gpu_merge_sort<true>(s, d, in, out, axis_);
gpu_merge_sort(s, d, in, out, axis_, true);
}
void Partition::eval_gpu(const std::vector<array>& inputs, array& out) {
@@ -339,7 +340,7 @@ void Partition::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& d = metal::device(s.device);
auto& in = inputs[0];
gpu_merge_sort<false>(s, d, in, out, axis_);
gpu_merge_sort(s, d, in, out, axis_, false);
}
} // namespace mlx::core