mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Fix strided sort bug (#1236)
* Use output strides in sort kernel * fix zero strides bug
This commit is contained in:
@@ -8,7 +8,6 @@
|
||||
#include "mlx/backend/metal/jit/reduce.h"
|
||||
#include "mlx/backend/metal/jit/scan.h"
|
||||
#include "mlx/backend/metal/jit/softmax.h"
|
||||
#include "mlx/backend/metal/jit/sort.h"
|
||||
#include "mlx/backend/metal/jit/steel_conv.h"
|
||||
#include "mlx/backend/metal/jit/steel_gemm.h"
|
||||
#include "mlx/backend/metal/kernels.h"
|
||||
@@ -251,14 +250,29 @@ MTL::ComputePipelineState* get_sort_kernel(
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::sort()
|
||||
<< fmt::format(
|
||||
block_sort_kernels,
|
||||
lib_name,
|
||||
get_type_string(in.dtype()),
|
||||
get_type_string(out.dtype()),
|
||||
bn,
|
||||
tn);
|
||||
auto in_type = get_type_string(in.dtype());
|
||||
auto out_type = get_type_string(out.dtype());
|
||||
kernel_source << metal::utils() << metal::sort();
|
||||
for (bool is_argsort : {true, false}) {
|
||||
std::string bool_string = is_argsort ? "true" : "false";
|
||||
std::string func_string = is_argsort ? "carg_" : "c_";
|
||||
kernel_source << get_template_definition(
|
||||
func_string + lib_name,
|
||||
"block_sort",
|
||||
in_type,
|
||||
out_type,
|
||||
bool_string,
|
||||
bn,
|
||||
tn);
|
||||
kernel_source << get_template_definition(
|
||||
"n" + func_string + lib_name,
|
||||
"block_sort_nc",
|
||||
in_type,
|
||||
out_type,
|
||||
bool_string,
|
||||
bn,
|
||||
tn);
|
||||
}
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
@@ -275,14 +289,21 @@ MTL::ComputePipelineState* get_mb_sort_kernel(
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::sort()
|
||||
<< fmt::format(
|
||||
multiblock_sort_kernels,
|
||||
lib_name,
|
||||
get_type_string(in.dtype()),
|
||||
get_type_string(idx.dtype()),
|
||||
bn,
|
||||
tn);
|
||||
kernel_source << metal::utils() << metal::sort();
|
||||
std::vector<std::pair<std::string, std::string>> kernel_types = {
|
||||
{"sort_", "mb_block_sort"},
|
||||
{"partition_", "mb_block_partition"},
|
||||
{"merge_", "mb_block_merge"}};
|
||||
for (auto [name, func] : kernel_types) {
|
||||
kernel_source << get_template_definition(
|
||||
name + lib_name,
|
||||
func,
|
||||
get_type_string(in.dtype()),
|
||||
get_type_string(idx.dtype()),
|
||||
"true",
|
||||
bn,
|
||||
tn);
|
||||
}
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
|
||||
Reference in New Issue
Block a user