Custom logsumexp (#2028)

* initial custom logsumexp

* more tests

* comments + fix
This commit is contained in:
Awni Hannun
2025-03-31 07:36:55 -07:00
committed by GitHub
parent ec2854b13a
commit de5f38fd48
27 changed files with 590 additions and 255 deletions

View File

@@ -1,8 +1,6 @@
// Copyright © 2024 Apple Inc.
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/metal/jit/arange.h"
#include "mlx/backend/metal/jit/includes.h"
#include "mlx/backend/metal/jit/softmax.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/utils.h"
@@ -21,13 +19,11 @@ MTL::ComputePipelineState* get_arange_kernel(
const std::string& kernel_name,
const array& out) {
auto lib = d.get_library(kernel_name, [&]() {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::arange()
<< fmt::format(
arange_kernels,
kernel_name,
get_type_string(out.dtype()));
return kernel_source.str();
std::string kernel_source = metal::utils();
kernel_source += metal::arange();
kernel_source += get_template_definition(
kernel_name, "arange", get_type_string(out.dtype()));
return kernel_source;
});
return d.get_kernel(kernel_name, lib);
}
@@ -259,14 +255,34 @@ MTL::ComputePipelineState* get_softmax_kernel(
const array& out) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name, [&] {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::softmax()
<< fmt::format(
softmax_kernels,
lib_name,
get_type_string(out.dtype()),
get_type_string(precise ? float32 : out.dtype()));
return kernel_source.str();
std::string kernel_source = metal::utils();
auto in_type = get_type_string(out.dtype());
auto acc_type = get_type_string(precise ? float32 : out.dtype());
kernel_source += metal::softmax();
kernel_source += get_template_definition(
"block_" + lib_name, "softmax_single_row", in_type, acc_type);
kernel_source += get_template_definition(
"looped_" + lib_name, "softmax_looped", in_type, acc_type);
return kernel_source;
});
return d.get_kernel(kernel_name, lib);
}
MTL::ComputePipelineState* get_logsumexp_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& out) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name, [&] {
auto t_str = get_type_string(out.dtype());
std::string kernel_source;
kernel_source = metal::utils();
kernel_source += metal::logsumexp();
kernel_source +=
get_template_definition("block_" + lib_name, "logsumexp", t_str);
kernel_source += get_template_definition(
"looped_" + lib_name, "logsumexp_looped", t_str);
return kernel_source;
});
return d.get_kernel(kernel_name, lib);
}