Refactor JIT for unary/binary/ternary ops (#1206)

* refactor unary/binary/ternary ops

* get_primitive_string util

---------
This commit is contained in:
Alex Barron
2024-06-12 14:22:12 -07:00
committed by GitHub
parent de2b9e7d0a
commit 934683088e
16 changed files with 379 additions and 935 deletions

View File

@@ -1,9 +1,8 @@
// Copyright © 2024 Apple Inc.
#include <map>
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/metal/jit/arange.h"
#include "mlx/backend/metal/jit/binary.h"
#include "mlx/backend/metal/jit/binary_two.h"
#include "mlx/backend/metal/jit/copy.h"
#include "mlx/backend/metal/jit/includes.h"
#include "mlx/backend/metal/jit/reduce.h"
@@ -12,8 +11,6 @@
#include "mlx/backend/metal/jit/sort.h"
#include "mlx/backend/metal/jit/steel_conv.h"
#include "mlx/backend/metal/jit/steel_gemm.h"
#include "mlx/backend/metal/jit/ternary.h"
#include "mlx/backend/metal/jit/unary.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/utils.h"
@@ -46,38 +43,76 @@ MTL::ComputePipelineState* get_arange_kernel(
MTL::ComputePipelineState* get_unary_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& out) {
Dtype out_type,
const std::string op) {
std::string lib_name = kernel_name.substr(1);
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
auto u_def = get_template_definition(
"v" + lib_name, "unary_v", get_type_string(out_type), op);
auto g_def = get_template_definition(
"g" + lib_name, "unary_g", get_type_string(out_type), op);
kernel_source << metal::utils() << metal::unary_ops() << metal::unary()
<< fmt::format(
unary_kernels,
lib_name,
get_type_string(out.dtype()),
op_name(out));
<< u_def << g_def;
lib = d.get_library(lib_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib);
}
void add_binary_kernels(
const std::string lib_name,
Dtype in_type,
Dtype out_type,
const std::string op,
std::ostringstream& kernel_source) {
const std::map<std::string, std::string> kernel_types = {
{"ss", "binary_ss"},
{"vs", "binary_vs"},
{"sv", "binary_sv"},
{"vv", "binary_vv"},
{"g1", "binary_g_nd1"},
{"g2", "binary_g_nd2"},
{"g3", "binary_g_nd3"},
{"g4", "binary_g_nd"},
{"g5", "binary_g_nd"},
{"gn", "binary_g"},
};
for (auto [name, func] : kernel_types) {
std::string template_def;
if (name == "g4" || name == "g5") {
int dim = std::stoi(name.substr(1));
template_def = get_template_definition(
name + lib_name,
func,
get_type_string(in_type),
get_type_string(out_type),
op,
dim);
} else {
template_def = get_template_definition(
name + lib_name,
func,
get_type_string(in_type),
get_type_string(out_type),
op);
}
kernel_source << template_def;
}
}
MTL::ComputePipelineState* get_binary_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& in,
const array& out) {
Dtype in_type,
Dtype out_type,
const std::string op) {
std::string lib_name = kernel_name.substr(2);
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::binary_ops() << metal::binary()
<< fmt::format(
binary_kernels,
lib_name,
get_type_string(in.dtype()),
get_type_string(out.dtype()),
op_name(out));
kernel_source << metal::utils() << metal::binary_ops() << metal::binary();
add_binary_kernels(lib_name, in_type, out_type, op, kernel_source);
lib = d.get_library(lib_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib);
@@ -86,20 +121,16 @@ MTL::ComputePipelineState* get_binary_kernel(
MTL::ComputePipelineState* get_binary_two_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& in,
const array& out) {
Dtype in_type,
Dtype out_type,
const std::string op) {
std::string lib_name = kernel_name.substr(2);
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::binary_ops()
<< metal::binary_two()
<< fmt::format(
binary_two_kernels,
lib_name,
get_type_string(in.dtype()),
get_type_string(out.dtype()),
op_name(out));
<< metal::binary_two();
add_binary_kernels(lib_name, in_type, out_type, op, kernel_source);
lib = d.get_library(lib_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib);
@@ -108,17 +139,34 @@ MTL::ComputePipelineState* get_binary_two_kernel(
MTL::ComputePipelineState* get_ternary_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& out) {
Dtype type,
const std::string op) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::ternary_ops() << metal::ternary()
<< fmt::format(
ternary_kernels,
lib_name,
get_type_string(out.dtype()),
op_name(out));
const std::map<std::string, std::string> kernel_types = {
{"v", "ternary_v"},
{"g", "ternary_g"},
{"g1", "ternary_g_nd1"},
{"g2", "ternary_g_nd2"},
{"g3", "ternary_g_nd3"},
{"g4", "ternary_g_nd"},
{"g5", "ternary_g_nd"},
};
kernel_source << metal::utils() << metal::ternary_ops() << metal::ternary();
for (auto [name, func] : kernel_types) {
std::string template_def;
if (name == "g4" || name == "g5") {
int dim = std::stoi(name.substr(1));
template_def = get_template_definition(
name + "_" + lib_name, func, get_type_string(type), op, dim);
} else {
template_def = get_template_definition(
name + "_" + lib_name, func, get_type_string(type), op);
}
kernel_source << template_def;
}
lib = d.get_library(lib_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib);