More fixes for arrays with large sizes (#1405)

* compile works for big arrays when contiguous

* style

* nits in docs

* a bunch more stuff

* update jit

* update jit

* use constant for shapes and strides and remove elem_to_loc overload

* use kernel instantiation

* docs nits

* update binary and ternary

* comments
This commit is contained in:
Awni Hannun
2024-09-17 12:46:31 -07:00
committed by GitHub
parent c6739ba7f3
commit 4f46e9c997
26 changed files with 325 additions and 611 deletions

View File

@@ -1,9 +1,7 @@
// Copyright © 2024 Apple Inc.
#include <map>
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/metal/jit/arange.h"
#include "mlx/backend/metal/jit/copy.h"
#include "mlx/backend/metal/jit/gemv_masked.h"
#include "mlx/backend/metal/jit/includes.h"
#include "mlx/backend/metal/jit/scan.h"
@@ -67,7 +65,7 @@ void add_binary_kernels(
Dtype out_type,
const std::string op,
std::ostringstream& kernel_source) {
const std::map<std::string, std::string> kernel_types = {
const std::array<std::pair<std::string, std::string>, 11> kernel_types = {{
{"ss", "binary_ss"},
{"vs", "binary_vs"},
{"sv", "binary_sv"},
@@ -78,29 +76,16 @@ void add_binary_kernels(
{"g1", "binary_g_nd1"},
{"g2", "binary_g_nd2"},
{"g3", "binary_g_nd3"},
{"g4", "binary_g_nd"},
{"g5", "binary_g_nd"},
{"gn", "binary_g"},
};
for (auto [name, func] : kernel_types) {
}};
for (auto& [name, func] : kernel_types) {
std::string template_def;
if (name == "g4" || name == "g5") {
int dim = std::stoi(name.substr(1));
template_def = get_template_definition(
name + lib_name,
func,
get_type_string(in_type),
get_type_string(out_type),
op,
dim);
} else {
template_def = get_template_definition(
name + lib_name,
func,
get_type_string(in_type),
get_type_string(out_type),
op);
}
template_def = get_template_definition(
name + lib_name,
func,
get_type_string(in_type),
get_type_string(out_type),
op);
kernel_source << template_def;
}
}
@@ -149,27 +134,19 @@ MTL::ComputePipelineState* get_ternary_kernel(
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
const std::map<std::string, std::string> kernel_types = {
const std::array<std::pair<std::string, std::string>, 6> kernel_types = {{
{"v", "ternary_v"},
{"v2", "ternary_v2"},
{"g", "ternary_g"},
{"g1", "ternary_g_nd1"},
{"g2", "ternary_g_nd2"},
{"g3", "ternary_g_nd3"},
{"g4", "ternary_g_nd"},
{"g5", "ternary_g_nd"},
};
}};
kernel_source << metal::utils() << metal::ternary_ops() << metal::ternary();
for (auto [name, func] : kernel_types) {
for (auto& [name, func] : kernel_types) {
std::string template_def;
if (name == "g4" || name == "g5") {
int dim = std::stoi(name.substr(1));
template_def = get_template_definition(
name + "_" + lib_name, func, get_type_string(type), op, dim);
} else {
template_def = get_template_definition(
name + "_" + lib_name, func, get_type_string(type), op);
}
template_def = get_template_definition(
name + "_" + lib_name, func, get_type_string(type), op);
kernel_source << template_def;
}
lib = d.get_library(lib_name, kernel_source.str());
@@ -186,12 +163,27 @@ MTL::ComputePipelineState* get_copy_kernel(
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::copy()
<< fmt::format(
copy_kernels,
lib_name,
get_type_string(in.dtype()),
get_type_string(out.dtype()));
auto in_type = get_type_string(in.dtype());
auto out_type = get_type_string(out.dtype());
kernel_source
<< metal::utils() << metal::copy()
<< get_template_definition("s_" + lib_name, "copy_s", in_type, out_type)
<< get_template_definition("v_" + lib_name, "copy_v", in_type, out_type)
<< get_template_definition(
"g1_" + lib_name, "copy_g_nd1", in_type, out_type)
<< get_template_definition(
"g2_" + lib_name, "copy_g_nd2", in_type, out_type)
<< get_template_definition(
"g3_" + lib_name, "copy_g_nd3", in_type, out_type)
<< get_template_definition("g_" + lib_name, "copy_g", in_type, out_type)
<< get_template_definition(
"gg1_" + lib_name, "copy_gg_nd1", in_type, out_type)
<< get_template_definition(
"gg2_" + lib_name, "copy_gg_nd2", in_type, out_type)
<< get_template_definition(
"gg3_" + lib_name, "copy_gg_nd3", in_type, out_type)
<< get_template_definition(
"gg_" + lib_name, "copy_gg", in_type, out_type);
lib = d.get_library(lib_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib);
@@ -296,11 +288,11 @@ MTL::ComputePipelineState* get_mb_sort_kernel(
if (lib == nullptr) {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::sort();
std::vector<std::pair<std::string, std::string>> kernel_types = {
{"sort_", "mb_block_sort"},
{"partition_", "mb_block_partition"},
{"merge_", "mb_block_merge"}};
for (auto [name, func] : kernel_types) {
std::array<std::pair<std::string, std::string>, 3> kernel_types = {
{{"sort_", "mb_block_sort"},
{"partition_", "mb_block_partition"},
{"merge_", "mb_block_merge"}}};
for (auto& [name, func] : kernel_types) {
kernel_source << get_template_definition(
name + lib_name,
func,