mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
More jitting (#1132)
* docs + circle min size build * jit scan, arange, softmax * add sort * jit reductions * remove print * fix deps * clean includes / nits
This commit is contained in:
@@ -4,6 +4,7 @@
|
||||
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/kernels.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
@@ -11,7 +12,6 @@ namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
template <bool ARGSORT>
|
||||
void single_block_sort(
|
||||
const Stream& s,
|
||||
metal::Device& d,
|
||||
@@ -19,7 +19,8 @@ void single_block_sort(
|
||||
array& out,
|
||||
int axis,
|
||||
int bn,
|
||||
int tn) {
|
||||
int tn,
|
||||
bool argsort) {
|
||||
// Prepare shapes
|
||||
int n_rows = in.size() / in.shape(axis);
|
||||
|
||||
@@ -46,19 +47,17 @@ void single_block_sort(
|
||||
|
||||
// Prepare kernel name
|
||||
std::ostringstream kname;
|
||||
if (ARGSORT) {
|
||||
kname << "arg_";
|
||||
kname << (contiguous_write ? "c" : "nc");
|
||||
if (argsort) {
|
||||
kname << "arg";
|
||||
}
|
||||
kname << "block_merge_sort_" << type_to_name(in) << "_" << type_to_name(out)
|
||||
<< "_bn" << bn << "_tn" << tn;
|
||||
|
||||
if (!contiguous_write) {
|
||||
kname << "_nc";
|
||||
}
|
||||
kname << "_block_sort_" << type_to_name(in) << "_" << type_to_name(out)
|
||||
<< "_bn" << bn << "_tn" << tn;
|
||||
auto kernel = get_sort_kernel(d, kname.str(), in, out, bn, tn);
|
||||
|
||||
// Prepare command encoder
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
// Set inputs
|
||||
@@ -81,7 +80,6 @@ void single_block_sort(
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
template <bool ARGSORT>
|
||||
void multi_block_sort(
|
||||
const Stream& s,
|
||||
metal::Device& d,
|
||||
@@ -90,7 +88,8 @@ void multi_block_sort(
|
||||
int axis,
|
||||
int bn,
|
||||
int tn,
|
||||
int n_blocks) {
|
||||
int n_blocks,
|
||||
bool argsort) {
|
||||
// Prepare shapes
|
||||
int n_rows = in.size() / in.shape(axis);
|
||||
|
||||
@@ -136,10 +135,10 @@ void multi_block_sort(
|
||||
// Do blockwise sort
|
||||
{
|
||||
std::ostringstream kname;
|
||||
kname << "mb_block_sort_" << type_to_name(dev_vals_0) << "_"
|
||||
kname << "sort_mbsort_" << type_to_name(dev_vals_0) << "_"
|
||||
<< type_to_name(dev_idxs_0) << "_bn" << bn << "_tn" << tn;
|
||||
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
auto kernel =
|
||||
get_mb_sort_kernel(d, kname.str(), dev_vals_0, dev_idxs_0, bn, tn);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
@@ -175,10 +174,11 @@ void multi_block_sort(
|
||||
// Do partition
|
||||
{
|
||||
std::ostringstream kname;
|
||||
kname << "mb_block_partition_" << type_to_name(dev_vals_in) << "_"
|
||||
kname << "partition_mbsort_" << type_to_name(dev_vals_in) << "_"
|
||||
<< type_to_name(dev_idxs_in) << "_bn" << bn << "_tn" << tn;
|
||||
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
auto kernel =
|
||||
get_mb_sort_kernel(d, kname.str(), dev_vals_0, dev_idxs_0, bn, tn);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
compute_encoder.set_output_array(block_partitions, 0);
|
||||
@@ -196,10 +196,11 @@ void multi_block_sort(
|
||||
// Do merge
|
||||
{
|
||||
std::ostringstream kname;
|
||||
kname << "mb_block_merge_" << type_to_name(dev_vals_in) << "_"
|
||||
kname << "merge_mbsort_" << type_to_name(dev_vals_in) << "_"
|
||||
<< type_to_name(dev_idxs_in) << "_bn" << bn << "_tn" << tn;
|
||||
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
auto kernel =
|
||||
get_mb_sort_kernel(d, kname.str(), dev_vals_0, dev_idxs_0, bn, tn);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
compute_encoder.set_input_array(block_partitions, 0);
|
||||
@@ -219,7 +220,7 @@ void multi_block_sort(
|
||||
}
|
||||
|
||||
// Copy outputs with appropriate strides
|
||||
array strided_out_arr = ARGSORT ? dev_idxs_out : dev_vals_out;
|
||||
array strided_out_arr = argsort ? dev_idxs_out : dev_vals_out;
|
||||
|
||||
if (axis == strided_out_arr.ndim() - 1) {
|
||||
copy_gpu_inplace(strided_out_arr, out, CopyType::Vector, s);
|
||||
@@ -252,13 +253,13 @@ void multi_block_sort(
|
||||
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||
}
|
||||
|
||||
template <bool ARGSORT>
|
||||
void gpu_merge_sort(
|
||||
const Stream& s,
|
||||
metal::Device& d,
|
||||
const array& in,
|
||||
array& out,
|
||||
int axis_) {
|
||||
int axis_,
|
||||
bool argsort) {
|
||||
// Get size info
|
||||
int axis = axis_ < 0 ? axis_ + in.ndim() : axis_;
|
||||
int size_sorted_axis = in.shape(axis);
|
||||
@@ -284,9 +285,9 @@ void gpu_merge_sort(
|
||||
int n_blocks = (size_sorted_axis + n_per_block - 1) / n_per_block;
|
||||
|
||||
if (n_blocks > 1) {
|
||||
return multi_block_sort<ARGSORT>(s, d, in, out, axis, bn, tn, n_blocks);
|
||||
return multi_block_sort(s, d, in, out, axis, bn, tn, n_blocks, argsort);
|
||||
} else {
|
||||
return single_block_sort<ARGSORT>(s, d, in, out, axis, bn, tn);
|
||||
return single_block_sort(s, d, in, out, axis, bn, tn, argsort);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -301,7 +302,7 @@ void ArgSort::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& d = metal::device(s.device);
|
||||
auto& in = inputs[0];
|
||||
|
||||
gpu_merge_sort<true>(s, d, in, out, axis_);
|
||||
gpu_merge_sort(s, d, in, out, axis_, true);
|
||||
}
|
||||
|
||||
void Sort::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
@@ -313,7 +314,7 @@ void Sort::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& d = metal::device(s.device);
|
||||
auto& in = inputs[0];
|
||||
|
||||
gpu_merge_sort<false>(s, d, in, out, axis_);
|
||||
gpu_merge_sort(s, d, in, out, axis_, false);
|
||||
}
|
||||
|
||||
void ArgPartition::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
@@ -326,7 +327,7 @@ void ArgPartition::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& d = metal::device(s.device);
|
||||
auto& in = inputs[0];
|
||||
|
||||
gpu_merge_sort<true>(s, d, in, out, axis_);
|
||||
gpu_merge_sort(s, d, in, out, axis_, true);
|
||||
}
|
||||
|
||||
void Partition::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
@@ -339,7 +340,7 @@ void Partition::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& d = metal::device(s.device);
|
||||
auto& in = inputs[0];
|
||||
|
||||
gpu_merge_sort<false>(s, d, in, out, axis_);
|
||||
gpu_merge_sort(s, d, in, out, axis_, false);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
Reference in New Issue
Block a user