mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Add Quantized Ops to the JIT (#1204)
* JIT for quantized ops * remove unused imports * address comments * fix imports * second attempt to fix imports --------- Co-authored-by: Alex Barron <abarron22@apple.com>
This commit is contained in:
@@ -1,12 +1,10 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
#include <fmt/format.h>
|
||||
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
#include "mlx/backend/metal/jit/arange.h"
|
||||
#include "mlx/backend/metal/jit/binary.h"
|
||||
#include "mlx/backend/metal/jit/binary_two.h"
|
||||
#include "mlx/backend/metal/jit/copy.h"
|
||||
#include "mlx/backend/metal/jit/fft.h"
|
||||
#include "mlx/backend/metal/jit/includes.h"
|
||||
#include "mlx/backend/metal/jit/reduce.h"
|
||||
#include "mlx/backend/metal/jit/scan.h"
|
||||
@@ -494,47 +492,32 @@ MTL::ComputePipelineState* get_fft_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const std::string& hash_name,
|
||||
const int tg_mem_size,
|
||||
const std::string& in_type,
|
||||
const std::string& out_type,
|
||||
int step,
|
||||
bool real,
|
||||
const metal::MTLFCList& func_consts) {
|
||||
const metal::MTLFCList& func_consts,
|
||||
const std::string& template_def) {
|
||||
const auto& lib_name = kernel_name;
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
std::string kernel_string;
|
||||
if (lib_name.find("bluestein") != std::string::npos) {
|
||||
kernel_string = bluestein_fft_kernel;
|
||||
} else if (lib_name.find("rader") != std::string::npos) {
|
||||
kernel_string = rader_fft_kernel;
|
||||
} else if (lib_name.find("four_step") != std::string::npos) {
|
||||
kernel_string = four_step_fft_kernel;
|
||||
} else {
|
||||
kernel_string = fft_kernel;
|
||||
}
|
||||
kernel_source << metal::fft();
|
||||
if (lib_name.find("four_step") != std::string::npos) {
|
||||
kernel_source << fmt::format(
|
||||
kernel_string,
|
||||
"name"_a = lib_name,
|
||||
"tg_mem_size"_a = tg_mem_size,
|
||||
"in_T"_a = in_type,
|
||||
"out_T"_a = out_type,
|
||||
"step"_a = step,
|
||||
"real"_a = real);
|
||||
} else {
|
||||
kernel_source << fmt::format(
|
||||
kernel_string,
|
||||
"name"_a = lib_name,
|
||||
"tg_mem_size"_a = tg_mem_size,
|
||||
"in_T"_a = in_type,
|
||||
"out_T"_a = out_type);
|
||||
}
|
||||
kernel_source << metal::fft() << template_def;
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return d.get_kernel(kernel_name, lib, hash_name, func_consts);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_quantized_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const std::string& template_def) {
|
||||
const auto& lib_name = kernel_name;
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::gemm() << metal::quantized()
|
||||
<< template_def;
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
Reference in New Issue
Block a user