mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
346 lines
10 KiB
C++
346 lines
10 KiB
C++
// Copyright © 2023-2024 Apple Inc.
|
|
|
|
#include <algorithm>
|
|
|
|
#include "mlx/backend/metal/copy.h"
|
|
#include "mlx/backend/metal/device.h"
|
|
#include "mlx/backend/metal/utils.h"
|
|
#include "mlx/primitives.h"
|
|
|
|
namespace mlx::core {
|
|
|
|
namespace {
|
|
|
|
template <bool ARGSORT>
|
|
void single_block_sort(
|
|
const Stream& s,
|
|
metal::Device& d,
|
|
const array& in,
|
|
array& out,
|
|
int axis,
|
|
int bn,
|
|
int tn) {
|
|
// Prepare shapes
|
|
int n_rows = in.size() / in.shape(axis);
|
|
|
|
std::vector<size_t> nc_str = in.strides();
|
|
nc_str.erase(nc_str.begin() + axis);
|
|
|
|
std::vector<int> nc_shape = in.shape();
|
|
nc_shape.erase(nc_shape.begin() + axis);
|
|
|
|
int nc_dim = nc_shape.size();
|
|
|
|
int size_sorted_axis = in.shape(axis);
|
|
int stride_sorted_axis = in.strides()[axis];
|
|
int stride_segment_axis = *std::min_element(nc_str.begin(), nc_str.end());
|
|
|
|
// Check if remaining strides are contiguous
|
|
bool contiguous_write = true;
|
|
if (axis != in.ndim() - 1 && axis != 0) {
|
|
for (int i = 0; i < nc_str.size() - 1; ++i) {
|
|
size_t expected = nc_str[i + 1] * nc_str[i + 1];
|
|
contiguous_write &= (nc_str[i] == expected);
|
|
}
|
|
}
|
|
|
|
// Prepare kernel name
|
|
std::ostringstream kname;
|
|
if (ARGSORT) {
|
|
kname << "arg_";
|
|
}
|
|
kname << "block_merge_sort_" << type_to_name(in) << "_" << type_to_name(out)
|
|
<< "_bn" << bn << "_tn" << tn;
|
|
|
|
if (!contiguous_write) {
|
|
kname << "_nc";
|
|
}
|
|
|
|
// Prepare command encoder
|
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
|
auto kernel = d.get_kernel(kname.str());
|
|
compute_encoder->setComputePipelineState(kernel);
|
|
|
|
// Set inputs
|
|
compute_encoder.set_input_array(in, 0);
|
|
compute_encoder.set_output_array(out, 1);
|
|
compute_encoder->setBytes(&size_sorted_axis, sizeof(int), 2);
|
|
compute_encoder->setBytes(&stride_sorted_axis, sizeof(int), 3);
|
|
|
|
if (contiguous_write) {
|
|
compute_encoder->setBytes(&stride_segment_axis, sizeof(int), 4);
|
|
} else {
|
|
compute_encoder->setBytes(&nc_dim, sizeof(int), 4);
|
|
compute_encoder->setBytes(nc_shape.data(), nc_dim * sizeof(int), 5);
|
|
compute_encoder->setBytes(nc_str.data(), nc_dim * sizeof(size_t), 6);
|
|
}
|
|
|
|
MTL::Size group_dims = MTL::Size(bn, 1, 1);
|
|
MTL::Size grid_dims = MTL::Size(1, n_rows, 1);
|
|
|
|
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
|
}
|
|
|
|
template <bool ARGSORT>
|
|
void multi_block_sort(
|
|
const Stream& s,
|
|
metal::Device& d,
|
|
const array& in,
|
|
array& out,
|
|
int axis,
|
|
int bn,
|
|
int tn,
|
|
int n_blocks) {
|
|
// Prepare shapes
|
|
int n_rows = in.size() / in.shape(axis);
|
|
|
|
std::vector<size_t> nc_str = in.strides();
|
|
nc_str.erase(nc_str.begin() + axis);
|
|
|
|
std::vector<int> nc_shape = in.shape();
|
|
nc_shape.erase(nc_shape.begin() + axis);
|
|
|
|
int nc_dim = nc_shape.size();
|
|
|
|
if (nc_dim == 0) {
|
|
nc_shape = {0};
|
|
nc_str = {1};
|
|
}
|
|
|
|
int size_sorted_axis = in.shape(axis);
|
|
int stride_sorted_axis = in.strides()[axis];
|
|
|
|
// Make temporary copies
|
|
array dev_vals_0({n_rows, size_sorted_axis}, in.dtype(), nullptr, {});
|
|
array dev_vals_1({n_rows, size_sorted_axis}, in.dtype(), nullptr, {});
|
|
|
|
array dev_idxs_0({n_rows, size_sorted_axis}, uint32, nullptr, {});
|
|
array dev_idxs_1({n_rows, size_sorted_axis}, uint32, nullptr, {});
|
|
|
|
array block_partitions({n_rows, n_blocks + 1}, uint32, nullptr, {});
|
|
|
|
// Do allocations
|
|
dev_vals_0.set_data(allocator::malloc_or_wait(dev_vals_0.nbytes()));
|
|
dev_vals_1.set_data(allocator::malloc_or_wait(dev_vals_1.nbytes()));
|
|
dev_idxs_0.set_data(allocator::malloc_or_wait(dev_idxs_0.nbytes()));
|
|
dev_idxs_1.set_data(allocator::malloc_or_wait(dev_idxs_1.nbytes()));
|
|
block_partitions.set_data(
|
|
allocator::malloc_or_wait(block_partitions.nbytes()));
|
|
|
|
std::vector<array> copies = {
|
|
dev_vals_0, dev_vals_1, dev_idxs_0, dev_idxs_1, block_partitions};
|
|
|
|
// Prepare command encoder
|
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
|
|
|
// Do blockwise sort
|
|
{
|
|
std::ostringstream kname;
|
|
kname << "mb_block_sort_" << type_to_name(dev_vals_0) << "_"
|
|
<< type_to_name(dev_idxs_0) << "_bn" << bn << "_tn" << tn;
|
|
|
|
auto kernel = d.get_kernel(kname.str());
|
|
compute_encoder->setComputePipelineState(kernel);
|
|
|
|
compute_encoder.set_input_array(in, 0);
|
|
compute_encoder.set_output_array(dev_vals_0, 1);
|
|
compute_encoder.set_output_array(dev_idxs_0, 2);
|
|
compute_encoder->setBytes(&size_sorted_axis, sizeof(int), 3);
|
|
compute_encoder->setBytes(&stride_sorted_axis, sizeof(int), 4);
|
|
compute_encoder->setBytes(&nc_dim, sizeof(int), 5);
|
|
compute_encoder->setBytes(
|
|
nc_shape.data(), nc_shape.size() * sizeof(int), 6);
|
|
compute_encoder->setBytes(nc_str.data(), nc_str.size() * sizeof(size_t), 7);
|
|
|
|
MTL::Size group_dims = MTL::Size(bn, 1, 1);
|
|
MTL::Size grid_dims = MTL::Size(n_blocks, n_rows, 1);
|
|
|
|
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
|
}
|
|
|
|
// Do merges
|
|
bool ping = false;
|
|
array dev_vals_in = dev_vals_0;
|
|
array dev_idxs_in = dev_idxs_0;
|
|
array dev_vals_out = dev_vals_1;
|
|
array dev_idxs_out = dev_idxs_1;
|
|
|
|
for (int merge_tiles = 2; (merge_tiles / 2) < n_blocks; merge_tiles *= 2) {
|
|
dev_vals_in = ping ? dev_vals_1 : dev_vals_0;
|
|
dev_idxs_in = ping ? dev_idxs_1 : dev_idxs_0;
|
|
dev_vals_out = ping ? dev_vals_0 : dev_vals_1;
|
|
dev_idxs_out = ping ? dev_idxs_0 : dev_idxs_1;
|
|
ping = !ping;
|
|
|
|
// Do partition
|
|
{
|
|
std::ostringstream kname;
|
|
kname << "mb_block_partition_" << type_to_name(dev_vals_in) << "_"
|
|
<< type_to_name(dev_idxs_in) << "_bn" << bn << "_tn" << tn;
|
|
|
|
auto kernel = d.get_kernel(kname.str());
|
|
compute_encoder->setComputePipelineState(kernel);
|
|
|
|
compute_encoder.set_output_array(block_partitions, 0);
|
|
compute_encoder.set_input_array(dev_vals_in, 1);
|
|
compute_encoder.set_input_array(dev_idxs_in, 2);
|
|
compute_encoder->setBytes(&size_sorted_axis, sizeof(int), 3);
|
|
compute_encoder->setBytes(&merge_tiles, sizeof(int), 4);
|
|
|
|
MTL::Size group_dims = MTL::Size(n_blocks + 1, 1, 1);
|
|
MTL::Size grid_dims = MTL::Size(1, n_rows, 1);
|
|
|
|
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
|
}
|
|
|
|
// Do merge
|
|
{
|
|
std::ostringstream kname;
|
|
kname << "mb_block_merge_" << type_to_name(dev_vals_in) << "_"
|
|
<< type_to_name(dev_idxs_in) << "_bn" << bn << "_tn" << tn;
|
|
|
|
auto kernel = d.get_kernel(kname.str());
|
|
compute_encoder->setComputePipelineState(kernel);
|
|
|
|
compute_encoder.set_input_array(block_partitions, 0);
|
|
compute_encoder.set_input_array(dev_vals_in, 1);
|
|
compute_encoder.set_input_array(dev_idxs_in, 2);
|
|
compute_encoder.set_output_array(dev_vals_out, 3);
|
|
compute_encoder.set_output_array(dev_idxs_out, 4);
|
|
compute_encoder->setBytes(&size_sorted_axis, sizeof(int), 5);
|
|
compute_encoder->setBytes(&merge_tiles, sizeof(int), 6);
|
|
compute_encoder->setBytes(&n_blocks, sizeof(int), 7);
|
|
|
|
MTL::Size group_dims = MTL::Size(bn, 1, 1);
|
|
MTL::Size grid_dims = MTL::Size(n_blocks, n_rows, 1);
|
|
|
|
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
|
}
|
|
}
|
|
|
|
// Copy outputs with appropriate strides
|
|
array strided_out_arr = ARGSORT ? dev_idxs_out : dev_vals_out;
|
|
|
|
if (axis == strided_out_arr.ndim() - 1) {
|
|
copy_gpu_inplace(strided_out_arr, out, CopyType::Vector, s);
|
|
} else {
|
|
std::vector<int> strided_out_shape = strided_out_arr.shape();
|
|
std::vector<size_t> strided_out_str = strided_out_arr.strides();
|
|
|
|
int out_axis_shape = strided_out_shape[axis];
|
|
int out_axis_str = strided_out_str[axis];
|
|
|
|
strided_out_shape.erase(strided_out_shape.begin() + axis);
|
|
strided_out_str.erase(strided_out_str.begin() + axis);
|
|
|
|
strided_out_shape.push_back(out_axis_shape);
|
|
strided_out_str.push_back(out_axis_str);
|
|
|
|
array strided_out_slice(strided_out_shape, out.dtype(), nullptr, {});
|
|
strided_out_slice.copy_shared_buffer(
|
|
strided_out_arr,
|
|
strided_out_str,
|
|
strided_out_arr.flags(),
|
|
strided_out_arr.size(),
|
|
0);
|
|
|
|
copy_gpu_inplace(strided_out_slice, out, CopyType::General, s);
|
|
}
|
|
|
|
// Clear copies
|
|
d.get_command_buffer(s.index)->addCompletedHandler(
|
|
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
|
}
|
|
|
|
template <bool ARGSORT>
|
|
void gpu_merge_sort(
|
|
const Stream& s,
|
|
metal::Device& d,
|
|
const array& in,
|
|
array& out,
|
|
int axis_) {
|
|
// Get size info
|
|
int axis = axis_ < 0 ? axis_ + in.ndim() : axis_;
|
|
int size_sorted_axis = in.shape(axis);
|
|
|
|
// Get kernel size
|
|
int tn = 8;
|
|
int bn = 128;
|
|
int potential_bn = (size_sorted_axis + tn - 1) / tn;
|
|
|
|
if (potential_bn > 256) {
|
|
bn = 512;
|
|
} else if (potential_bn > 128) {
|
|
bn = 256;
|
|
} else {
|
|
bn = 128;
|
|
}
|
|
|
|
if (bn == 512 && size_of(in.dtype()) > 4) {
|
|
bn = 256;
|
|
}
|
|
|
|
int n_per_block = bn * tn;
|
|
int n_blocks = (size_sorted_axis + n_per_block - 1) / n_per_block;
|
|
|
|
if (n_blocks > 1) {
|
|
return multi_block_sort<ARGSORT>(s, d, in, out, axis, bn, tn, n_blocks);
|
|
} else {
|
|
return single_block_sort<ARGSORT>(s, d, in, out, axis, bn, tn);
|
|
}
|
|
}
|
|
|
|
} // namespace
|
|
|
|
void ArgSort::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|
assert(inputs.size() == 1);
|
|
|
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
|
|
|
auto& s = stream();
|
|
auto& d = metal::device(s.device);
|
|
auto& in = inputs[0];
|
|
|
|
gpu_merge_sort<true>(s, d, in, out, axis_);
|
|
}
|
|
|
|
void Sort::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|
assert(inputs.size() == 1);
|
|
|
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
|
|
|
auto& s = stream();
|
|
auto& d = metal::device(s.device);
|
|
auto& in = inputs[0];
|
|
|
|
gpu_merge_sort<false>(s, d, in, out, axis_);
|
|
}
|
|
|
|
void ArgPartition::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|
// We direct arg partition to sort for now
|
|
assert(inputs.size() == 1);
|
|
|
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
|
|
|
auto& s = stream();
|
|
auto& d = metal::device(s.device);
|
|
auto& in = inputs[0];
|
|
|
|
gpu_merge_sort<true>(s, d, in, out, axis_);
|
|
}
|
|
|
|
void Partition::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|
// We direct partition to sort for now
|
|
assert(inputs.size() == 1);
|
|
|
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
|
|
|
auto& s = stream();
|
|
auto& d = metal::device(s.device);
|
|
auto& in = inputs[0];
|
|
|
|
gpu_merge_sort<false>(s, d, in, out, axis_);
|
|
}
|
|
|
|
} // namespace mlx::core
|