Add Quantized Ops to the JIT (#1204)

* JIT for quantized ops

* remove unused imports

* address comments

* fix imports

* second attempt to fix imports

---------

Co-authored-by: Alex Barron <abarron22@apple.com>
This commit is contained in:
Alex Barron
2024-06-12 09:47:12 -07:00
committed by GitHub
parent df964132fb
commit dd7d8e5e29
13 changed files with 1778 additions and 1948 deletions

View File

@@ -1,12 +1,10 @@
// Copyright © 2024 Apple Inc.
#include <fmt/format.h>
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/metal/jit/arange.h"
#include "mlx/backend/metal/jit/binary.h"
#include "mlx/backend/metal/jit/binary_two.h"
#include "mlx/backend/metal/jit/copy.h"
#include "mlx/backend/metal/jit/fft.h"
#include "mlx/backend/metal/jit/includes.h"
#include "mlx/backend/metal/jit/reduce.h"
#include "mlx/backend/metal/jit/scan.h"
@@ -494,47 +492,32 @@ MTL::ComputePipelineState* get_fft_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& hash_name,
const int tg_mem_size,
const std::string& in_type,
const std::string& out_type,
int step,
bool real,
const metal::MTLFCList& func_consts) {
const metal::MTLFCList& func_consts,
const std::string& template_def) {
const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
std::string kernel_string;
if (lib_name.find("bluestein") != std::string::npos) {
kernel_string = bluestein_fft_kernel;
} else if (lib_name.find("rader") != std::string::npos) {
kernel_string = rader_fft_kernel;
} else if (lib_name.find("four_step") != std::string::npos) {
kernel_string = four_step_fft_kernel;
} else {
kernel_string = fft_kernel;
}
kernel_source << metal::fft();
if (lib_name.find("four_step") != std::string::npos) {
kernel_source << fmt::format(
kernel_string,
"name"_a = lib_name,
"tg_mem_size"_a = tg_mem_size,
"in_T"_a = in_type,
"out_T"_a = out_type,
"step"_a = step,
"real"_a = real);
} else {
kernel_source << fmt::format(
kernel_string,
"name"_a = lib_name,
"tg_mem_size"_a = tg_mem_size,
"in_T"_a = in_type,
"out_T"_a = out_type);
}
kernel_source << metal::fft() << template_def;
lib = d.get_library(lib_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib, hash_name, func_consts);
}
MTL::ComputePipelineState* get_quantized_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& template_def) {
const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::gemm() << metal::quantized()
<< template_def;
lib = d.get_library(lib_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib);
}
} // namespace mlx::core