mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Working 64-bit scans (#1506)
This commit is contained in:
committed by
GitHub
parent
32972a5924
commit
c9b41d460f
@@ -4,7 +4,6 @@
|
||||
#include "mlx/backend/metal/jit/arange.h"
|
||||
#include "mlx/backend/metal/jit/gemv_masked.h"
|
||||
#include "mlx/backend/metal/jit/includes.h"
|
||||
#include "mlx/backend/metal/jit/scan.h"
|
||||
#include "mlx/backend/metal/jit/softmax.h"
|
||||
#include "mlx/backend/metal/jit/steel_conv.h"
|
||||
#include "mlx/backend/metal/jit/steel_gemm.h"
|
||||
@@ -224,18 +223,26 @@ MTL::ComputePipelineState* get_scan_kernel(
|
||||
const array& out) {
|
||||
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
||||
auto lib = d.get_library(lib_name, [&]() {
|
||||
std::string op_name = "Cum" + reduce_type;
|
||||
op_name[3] = toupper(op_name[3]);
|
||||
auto out_type = get_type_string(out.dtype());
|
||||
std::string op = "Cum" + reduce_type + "<" + out_type + ">";
|
||||
op[3] = toupper(op[3]);
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::scan()
|
||||
<< fmt::format(
|
||||
scan_kernels,
|
||||
lib_name,
|
||||
get_type_string(in.dtype()),
|
||||
get_type_string(out.dtype()),
|
||||
op_name,
|
||||
inclusive,
|
||||
reverse);
|
||||
kernel_source << metal::utils() << metal::scan();
|
||||
const std::array<std::pair<std::string, std::string>, 2> scan_kernels = {{
|
||||
{"contig_", "contiguous_scan"},
|
||||
{"strided_", "strided_scan"},
|
||||
}};
|
||||
for (auto& [prefix, kernel] : scan_kernels) {
|
||||
kernel_source << get_template_definition(
|
||||
prefix + lib_name,
|
||||
kernel,
|
||||
get_type_string(in.dtype()),
|
||||
get_type_string(out.dtype()),
|
||||
op,
|
||||
in.itemsize() <= 4 ? 4 : 2,
|
||||
inclusive,
|
||||
reverse);
|
||||
}
|
||||
return kernel_source.str();
|
||||
});
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
|
||||
Reference in New Issue
Block a user