mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-23 18:11:17 +08:00
Faster metal compiled kernels + some fixes (#1486)
* bump mac tests to use py39 * work per thread for compiled kernels * fixe for large arrays * fix
This commit is contained in:
parent
0eef4febfd
commit
881615b072
@ -84,8 +84,7 @@ void binary_op_gpu_inplace(
|
|||||||
|
|
||||||
bool use_2d = out.data_size() > UINT32_MAX;
|
bool use_2d = out.data_size() > UINT32_MAX;
|
||||||
auto ndim = shape.size();
|
auto ndim = shape.size();
|
||||||
int work_per_thread =
|
int work_per_thread = (bopt == BinaryOpType::General) ? 4 : 1;
|
||||||
(bopt == BinaryOpType::General && shape[ndim - 1] > 4) ? 4 : 1;
|
|
||||||
std::string kernel_name =
|
std::string kernel_name =
|
||||||
get_kernel_name(bopt, op, a, use_2d, shape.size(), work_per_thread);
|
get_kernel_name(bopt, op, a, use_2d, shape.size(), work_per_thread);
|
||||||
auto& d = metal::device(s.device);
|
auto& d = metal::device(s.device);
|
||||||
|
@ -13,6 +13,8 @@
|
|||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
|
constexpr int WORK_PER_THREAD = 4;
|
||||||
|
|
||||||
inline void build_kernel(
|
inline void build_kernel(
|
||||||
std::ostream& os,
|
std::ostream& os,
|
||||||
const std::string& kernel_name,
|
const std::string& kernel_name,
|
||||||
@ -23,7 +25,8 @@ inline void build_kernel(
|
|||||||
bool contiguous,
|
bool contiguous,
|
||||||
int ndim,
|
int ndim,
|
||||||
bool dynamic_dims,
|
bool dynamic_dims,
|
||||||
bool use_big_index = false) {
|
bool use_big_index = false,
|
||||||
|
int work_per_thread = 1) {
|
||||||
// All outputs should have the exact same shape and will be row contiguous
|
// All outputs should have the exact same shape and will be row contiguous
|
||||||
auto output_shape = outputs[0].shape();
|
auto output_shape = outputs[0].shape();
|
||||||
auto output_strides = outputs[0].strides();
|
auto output_strides = outputs[0].strides();
|
||||||
@ -38,8 +41,8 @@ inline void build_kernel(
|
|||||||
int cnt = 0;
|
int cnt = 0;
|
||||||
|
|
||||||
// Start the kernel
|
// Start the kernel
|
||||||
os << "[[host_name(\"" << kernel_name << "\")]]" << std::endl
|
os << "[[host_name(\"" << kernel_name << "\")]]\n"
|
||||||
<< "[[kernel]] void " << kernel_name << "(" << std::endl;
|
<< "[[kernel]] void " << kernel_name << "(\n";
|
||||||
|
|
||||||
// Add the input arguments
|
// Add the input arguments
|
||||||
for (auto& x : inputs) {
|
for (auto& x : inputs) {
|
||||||
@ -53,11 +56,11 @@ inline void build_kernel(
|
|||||||
// Scalars and contiguous need no strides
|
// Scalars and contiguous need no strides
|
||||||
if (is_scalar(x) || contiguous) {
|
if (is_scalar(x) || contiguous) {
|
||||||
os << " device const " << get_type_string(x.dtype()) << "* " << xname
|
os << " device const " << get_type_string(x.dtype()) << "* " << xname
|
||||||
<< " [[buffer(" << cnt++ << ")]]," << std::endl;
|
<< " [[buffer(" << cnt++ << ")]],\n";
|
||||||
} else {
|
} else {
|
||||||
add_indices = true;
|
add_indices = true;
|
||||||
os << " device const " << get_type_string(x.dtype()) << "* " << xname
|
os << " device const " << get_type_string(x.dtype()) << "* " << xname
|
||||||
<< " [[buffer(" << cnt++ << ")]]," << std::endl;
|
<< " [[buffer(" << cnt++ << ")]],\n";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -69,58 +72,37 @@ inline void build_kernel(
|
|||||||
// Add the output arguments
|
// Add the output arguments
|
||||||
for (auto& x : outputs) {
|
for (auto& x : outputs) {
|
||||||
os << " device " << get_type_string(x.dtype()) << "* "
|
os << " device " << get_type_string(x.dtype()) << "* "
|
||||||
<< namer.get_name(x) << " [[buffer(" << cnt++ << ")]]," << std::endl;
|
<< namer.get_name(x) << " [[buffer(" << cnt++ << ")]],\n";
|
||||||
}
|
}
|
||||||
// Add output strides and shape to extract the indices.
|
// Add output strides and shape to extract the indices.
|
||||||
if (!contiguous) {
|
if (!contiguous) {
|
||||||
os << " constant const size_t* output_strides [[buffer(" << cnt++
|
os << " constant const size_t* output_strides [[buffer(" << cnt++
|
||||||
<< ")]]," << std::endl
|
<< ")]],\n"
|
||||||
<< " constant const int* output_shape [[buffer(" << cnt++ << ")]],"
|
<< " constant const int* output_shape [[buffer(" << cnt++ << ")]],\n";
|
||||||
<< std::endl;
|
|
||||||
}
|
}
|
||||||
if (dynamic_dims) {
|
if (dynamic_dims) {
|
||||||
os << " constant const int& ndim [[buffer(" << cnt++ << ")]],"
|
os << " constant const int& ndim [[buffer(" << cnt++ << ")]],\n";
|
||||||
<< std::endl;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// The thread index in the whole grid
|
// The thread index in the whole grid
|
||||||
os << " uint3 pos [[thread_position_in_grid]]," << std::endl
|
os << " uint3 pos [[thread_position_in_grid]],\n"
|
||||||
<< " uint3 grid [[threads_per_grid]]) {" << std::endl;
|
<< " uint3 grid [[threads_per_grid]]) {\n";
|
||||||
|
|
||||||
if (use_big_index) {
|
if (use_big_index) {
|
||||||
// This is only used for contiguous kernels which don't have
|
// This is only used for contiguous kernels which don't have
|
||||||
// a third grid dimension
|
// a third grid dimension
|
||||||
os << " size_t index = pos.x + grid.x * size_t(pos.y);";
|
os << " size_t index = pos.x + grid.x * size_t(pos.y);\n";
|
||||||
|
} else if (work_per_thread > 1) {
|
||||||
|
os << " constexpr int N = " << std::to_string(work_per_thread) << ";\n"
|
||||||
|
<< " int xshape = output_shape["
|
||||||
|
<< (dynamic_dims ? "ndim - 1" : std::to_string(ndim - 1)) << "];\n"
|
||||||
|
<< " size_t index = N * pos.x + xshape * (pos.y + size_t(grid.y) * pos.z);\n";
|
||||||
} else {
|
} else {
|
||||||
os << " uint index = pos.x + grid.x * (pos.y + grid.y * pos.z);";
|
os << " size_t index = pos.x + grid.x * (pos.y + size_t(grid.y) * pos.z);\n";
|
||||||
}
|
|
||||||
os << std::endl;
|
|
||||||
|
|
||||||
// Extract the indices per axis to individual uints if we have arrays that
|
|
||||||
// are broadcasted or transposed
|
|
||||||
if (add_indices) {
|
|
||||||
if (!dynamic_dims) {
|
|
||||||
if (ndim == 1) {
|
|
||||||
os << " uint index_0 = pos.x;" << std::endl;
|
|
||||||
} else if (ndim == 2) {
|
|
||||||
os << " uint index_0 = pos.y;" << std::endl
|
|
||||||
<< " uint index_1 = pos.x;" << std::endl;
|
|
||||||
} else if (ndim == 3) {
|
|
||||||
os << " uint index_0 = pos.z;" << std::endl
|
|
||||||
<< " uint index_1 = pos.y;" << std::endl
|
|
||||||
<< " uint index_2 = pos.x;" << std::endl;
|
|
||||||
} else {
|
|
||||||
for (int i = 0; i < ndim - 2; i++) {
|
|
||||||
os << " uint index_" << i << " = (index / uint(output_strides[" << i
|
|
||||||
<< "])) % output_shape[" << i << "];" << std::endl;
|
|
||||||
}
|
|
||||||
os << " uint index_" << ndim - 2 << " = pos.y;" << std::endl
|
|
||||||
<< " uint index_" << ndim - 1 << " = pos.x;" << std::endl;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Read the inputs in tmps
|
// Read constant / contiguous inputs in tmps
|
||||||
int nc_in_count = 0;
|
std::vector<array> nc_inputs;
|
||||||
for (int i = 0; i < inputs.size(); ++i) {
|
for (int i = 0; i < inputs.size(); ++i) {
|
||||||
auto& x = inputs[i];
|
auto& x = inputs[i];
|
||||||
auto& xname = namer.get_name(x);
|
auto& xname = namer.get_name(x);
|
||||||
@ -130,56 +112,117 @@ inline void build_kernel(
|
|||||||
os << " auto tmp_" << xname << " = static_cast<"
|
os << " auto tmp_" << xname << " = static_cast<"
|
||||||
<< get_type_string(x.dtype()) << ">(";
|
<< get_type_string(x.dtype()) << ">(";
|
||||||
print_constant(os, x);
|
print_constant(os, x);
|
||||||
os << ");" << std::endl;
|
os << ");\n";
|
||||||
} else if (is_scalar(x)) {
|
} else if (is_scalar(x)) {
|
||||||
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
|
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
|
||||||
<< xname << "[0];" << std::endl;
|
<< xname << "[0];\n";
|
||||||
} else if (contiguous) {
|
} else if (contiguous) {
|
||||||
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
|
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
|
||||||
<< xname << "[index];" << std::endl;
|
<< xname << "[index];\n";
|
||||||
} else if (!dynamic_dims) {
|
|
||||||
int offset = nc_in_count * ndim;
|
|
||||||
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
|
|
||||||
<< xname << "[";
|
|
||||||
os << "index_0 * " << "in_strides[" << offset << "]";
|
|
||||||
for (int i = 1; i < ndim; i++) {
|
|
||||||
os << " + index_" << i << " * " << "in_strides[" << offset + i << "]";
|
|
||||||
}
|
|
||||||
os << "];" << std::endl;
|
|
||||||
nc_in_count++;
|
|
||||||
} else {
|
} else {
|
||||||
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
|
nc_inputs.push_back(x);
|
||||||
<< xname << "[elem_to_loc(index, output_shape, in_strides + "
|
|
||||||
<< nc_in_count * ndim << ", ndim)];" << std::endl;
|
|
||||||
nc_in_count++;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Initialize the indices for non-contiguous inputs
|
||||||
|
for (int i = 0; i < nc_inputs.size(); ++i) {
|
||||||
|
auto& xname = namer.get_name(nc_inputs[i]);
|
||||||
|
if (ndim == 1) {
|
||||||
|
int offset = i * ndim;
|
||||||
|
os << " size_t index_" << xname << " = elem_to_loc_1(pos.x, "
|
||||||
|
<< "in_strides[" << offset << "]);\n";
|
||||||
|
} else if (ndim == 2) {
|
||||||
|
int offset = i * ndim;
|
||||||
|
os << " size_t index_" << xname << " = elem_to_loc_2({pos.x, pos.y}, "
|
||||||
|
<< "in_strides + " << offset << ");\n";
|
||||||
|
} else if (ndim == 3) {
|
||||||
|
int offset = i * ndim;
|
||||||
|
os << " size_t index_" << xname << " = elem_to_loc_3(pos, "
|
||||||
|
<< "in_strides + " << offset << ");\n";
|
||||||
|
} else if (!dynamic_dims) {
|
||||||
|
int offset = i * ndim;
|
||||||
|
os << " size_t index_" << xname << " = N * pos.x * in_strides["
|
||||||
|
<< offset + ndim - 1 << "]"
|
||||||
|
<< " + pos.y * in_strides[" << offset + ndim - 2 << "];\n";
|
||||||
|
} else {
|
||||||
|
os << " size_t index_" << xname << " = N * pos.x * in_strides[ndim * "
|
||||||
|
<< i << " + ndim - 1]"
|
||||||
|
<< " + pos.y * in_strides[ndim * " << i << " + ndim - 2];\n";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!nc_inputs.empty() && (ndim > 3 || dynamic_dims)) {
|
||||||
|
os << " uint zpos = pos.z;\n";
|
||||||
|
if (dynamic_dims) {
|
||||||
|
os << " for (int d = ndim - 3; d >= 0; --d) {\n";
|
||||||
|
} else {
|
||||||
|
os << " for (int d = " << ndim - 3 << "; d >= 0; --d) {\n";
|
||||||
|
}
|
||||||
|
os << " uint l = zpos % output_shape[d];\n";
|
||||||
|
for (int i = 0; i < nc_inputs.size(); ++i) {
|
||||||
|
auto& xname = namer.get_name(nc_inputs[i]);
|
||||||
|
os << " index_" << xname << " += ";
|
||||||
|
if (dynamic_dims) {
|
||||||
|
os << "l * in_strides[" << i << " * ndim + d];\n";
|
||||||
|
} else {
|
||||||
|
os << "l * in_strides[" << i * ndim << " + d];\n";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
os << " zpos /= output_shape[d];\n }\n";
|
||||||
|
}
|
||||||
|
|
||||||
|
// Open per-thread loop
|
||||||
|
if (work_per_thread > 1) {
|
||||||
|
os << " for (int i = 0; i < N && (int(N * pos.x) + i) < xshape; ++i) {\n";
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read non-contiguous inputs into tmps
|
||||||
|
for (int i = 0; i < nc_inputs.size(); ++i) {
|
||||||
|
auto& x = nc_inputs[i];
|
||||||
|
auto& xname = namer.get_name(x);
|
||||||
|
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
|
||||||
|
<< xname << "[index_" << xname << "];\n";
|
||||||
|
}
|
||||||
|
|
||||||
// Actually write the computation
|
// Actually write the computation
|
||||||
for (auto& x : tape) {
|
for (auto& x : tape) {
|
||||||
os << " " << get_type_string(x.dtype()) << " tmp_" << namer.get_name(x)
|
os << " " << get_type_string(x.dtype()) << " tmp_" << namer.get_name(x)
|
||||||
<< " = ";
|
<< " = ";
|
||||||
if (is_static_cast(x.primitive())) {
|
if (is_static_cast(x.primitive())) {
|
||||||
os << "static_cast<" << get_type_string(x.dtype()) << ">(tmp_"
|
os << "static_cast<" << get_type_string(x.dtype()) << ">(tmp_"
|
||||||
<< namer.get_name(x.inputs()[0]) << ");" << std::endl;
|
<< namer.get_name(x.inputs()[0]) << ");\n";
|
||||||
} else {
|
} else {
|
||||||
x.primitive().print(os);
|
x.primitive().print(os);
|
||||||
os << "()(";
|
os << "()(";
|
||||||
for (int i = 0; i < x.inputs().size() - 1; i++) {
|
for (int i = 0; i < x.inputs().size() - 1; i++) {
|
||||||
os << "tmp_" << namer.get_name(x.inputs()[i]) << ", ";
|
os << "tmp_" << namer.get_name(x.inputs()[i]) << ", ";
|
||||||
}
|
}
|
||||||
os << "tmp_" << namer.get_name(x.inputs().back()) << ");" << std::endl;
|
os << "tmp_" << namer.get_name(x.inputs().back()) << ");\n";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Write the outputs from tmps
|
// Write the outputs from tmps
|
||||||
for (auto& x : outputs) {
|
for (auto& x : outputs) {
|
||||||
os << " " << namer.get_name(x) << "[index] = tmp_" << namer.get_name(x)
|
os << " " << namer.get_name(x) << "[index] = tmp_" << namer.get_name(x)
|
||||||
<< ";" << std::endl;
|
<< ";\n";
|
||||||
|
}
|
||||||
|
// Increment indices and close per thread loop
|
||||||
|
if (work_per_thread > 1) {
|
||||||
|
for (int i = 0; i < nc_inputs.size(); ++i) {
|
||||||
|
auto& x = nc_inputs[i];
|
||||||
|
auto& xname = namer.get_name(x);
|
||||||
|
if (!dynamic_dims) {
|
||||||
|
os << " index_" << xname << " += "
|
||||||
|
<< "in_strides[" << i * ndim + ndim - 1 << "];\n";
|
||||||
|
} else {
|
||||||
|
os << " index_" << xname << " += "
|
||||||
|
<< "in_strides[" << i << " * ndim + ndim - 1];\n";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
os << " index++;\n }\n";
|
||||||
}
|
}
|
||||||
|
|
||||||
// Finish the kernel
|
// Finish the kernel
|
||||||
os << "}" << std::endl;
|
os << "}\n";
|
||||||
|
|
||||||
if (cnt > 31) {
|
if (cnt > 31) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
@ -237,7 +280,9 @@ void Compiled::eval_gpu(
|
|||||||
constant_ids_,
|
constant_ids_,
|
||||||
/* contiguous = */ false,
|
/* contiguous = */ false,
|
||||||
/* ndim = */ i,
|
/* ndim = */ i,
|
||||||
/* dynamic_dims = */ false);
|
/* dynamic_dims = */ false,
|
||||||
|
/* use_big_index = */ false,
|
||||||
|
/* work_per_thread = */ i > 3 ? WORK_PER_THREAD : 1);
|
||||||
}
|
}
|
||||||
build_kernel(
|
build_kernel(
|
||||||
kernel,
|
kernel,
|
||||||
@ -248,7 +293,9 @@ void Compiled::eval_gpu(
|
|||||||
constant_ids_,
|
constant_ids_,
|
||||||
/* contiguous = */ false,
|
/* contiguous = */ false,
|
||||||
/* ndim = */ 0,
|
/* ndim = */ 0,
|
||||||
/* dynamic_dims = */ true);
|
/* dynamic_dims = */ true,
|
||||||
|
/* use_big_index = */ false,
|
||||||
|
/* work_per_thread = */ WORK_PER_THREAD);
|
||||||
return kernel.str();
|
return kernel.str();
|
||||||
});
|
});
|
||||||
|
|
||||||
@ -384,6 +431,8 @@ void Compiled::eval_gpu(
|
|||||||
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
|
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
|
||||||
size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;
|
size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;
|
||||||
size_t rest = outputs[0].size() / (dim0 * dim1);
|
size_t rest = outputs[0].size() / (dim0 * dim1);
|
||||||
|
int work_per_thread = ndim > 3 ? WORK_PER_THREAD : 1;
|
||||||
|
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
|
||||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||||
if (thread_group_size != 1024) {
|
if (thread_group_size != 1024) {
|
||||||
throw std::runtime_error("[Metal::binary] Must use 1024 sized block");
|
throw std::runtime_error("[Metal::binary] Must use 1024 sized block");
|
||||||
|
@ -98,7 +98,7 @@ void copy_gpu_inplace(
|
|||||||
if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) {
|
if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) {
|
||||||
if (shape.size() <= MAX_COPY_SPECIALIZED_DIMS) {
|
if (shape.size() <= MAX_COPY_SPECIALIZED_DIMS) {
|
||||||
kname << shape.size();
|
kname << shape.size();
|
||||||
} else if (shape[ndim - 1] >= 4) {
|
} else {
|
||||||
work_per_thread = 4;
|
work_per_thread = 4;
|
||||||
kname << "n4";
|
kname << "n4";
|
||||||
}
|
}
|
||||||
|
@ -50,8 +50,6 @@ MTL::ComputePipelineState* get_unary_kernel(
|
|||||||
"v_" + lib_name, "unary_v", get_type_string(out_type), op);
|
"v_" + lib_name, "unary_v", get_type_string(out_type), op);
|
||||||
kernel_source << get_template_definition(
|
kernel_source << get_template_definition(
|
||||||
"v2_" + lib_name, "unary_v2", get_type_string(out_type), op);
|
"v2_" + lib_name, "unary_v2", get_type_string(out_type), op);
|
||||||
kernel_source << get_template_definition(
|
|
||||||
"g_" + lib_name, "unary_g", get_type_string(out_type), op);
|
|
||||||
kernel_source << get_template_definition(
|
kernel_source << get_template_definition(
|
||||||
"gn4_" + lib_name, "unary_g", get_type_string(out_type), op, 4);
|
"gn4_" + lib_name, "unary_g", get_type_string(out_type), op, 4);
|
||||||
return kernel_source.str();
|
return kernel_source.str();
|
||||||
@ -65,7 +63,7 @@ void add_binary_kernels(
|
|||||||
Dtype out_type,
|
Dtype out_type,
|
||||||
const std::string op,
|
const std::string op,
|
||||||
std::ostringstream& kernel_source) {
|
std::ostringstream& kernel_source) {
|
||||||
const std::array<std::pair<std::string, std::string>, 11> kernel_types = {{
|
const std::array<std::pair<std::string, std::string>, 10> kernel_types = {{
|
||||||
{"ss", "binary_ss"},
|
{"ss", "binary_ss"},
|
||||||
{"vs", "binary_vs"},
|
{"vs", "binary_vs"},
|
||||||
{"sv", "binary_sv"},
|
{"sv", "binary_sv"},
|
||||||
@ -76,7 +74,6 @@ void add_binary_kernels(
|
|||||||
{"g1", "binary_g_nd1"},
|
{"g1", "binary_g_nd1"},
|
||||||
{"g2", "binary_g_nd2"},
|
{"g2", "binary_g_nd2"},
|
||||||
{"g3", "binary_g_nd3"},
|
{"g3", "binary_g_nd3"},
|
||||||
{"gn", "binary_g"},
|
|
||||||
}};
|
}};
|
||||||
for (auto& [name, func] : kernel_types) {
|
for (auto& [name, func] : kernel_types) {
|
||||||
std::string template_def;
|
std::string template_def;
|
||||||
@ -138,10 +135,9 @@ MTL::ComputePipelineState* get_ternary_kernel(
|
|||||||
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
||||||
auto lib = d.get_library(lib_name, [&]() {
|
auto lib = d.get_library(lib_name, [&]() {
|
||||||
std::ostringstream kernel_source;
|
std::ostringstream kernel_source;
|
||||||
const std::array<std::pair<std::string, std::string>, 6> kernel_types = {{
|
const std::array<std::pair<std::string, std::string>, 5> kernel_types = {{
|
||||||
{"v", "ternary_v"},
|
{"v", "ternary_v"},
|
||||||
{"v2", "ternary_v2"},
|
{"v2", "ternary_v2"},
|
||||||
{"g", "ternary_g"},
|
|
||||||
{"g1", "ternary_g_nd1"},
|
{"g1", "ternary_g_nd1"},
|
||||||
{"g2", "ternary_g_nd2"},
|
{"g2", "ternary_g_nd2"},
|
||||||
{"g3", "ternary_g_nd3"},
|
{"g3", "ternary_g_nd3"},
|
||||||
@ -170,29 +166,27 @@ MTL::ComputePipelineState* get_copy_kernel(
|
|||||||
std::ostringstream kernel_source;
|
std::ostringstream kernel_source;
|
||||||
auto in_type = get_type_string(in.dtype());
|
auto in_type = get_type_string(in.dtype());
|
||||||
auto out_type = get_type_string(out.dtype());
|
auto out_type = get_type_string(out.dtype());
|
||||||
kernel_source
|
kernel_source << metal::utils() << metal::copy()
|
||||||
<< metal::utils() << metal::copy()
|
<< get_template_definition(
|
||||||
<< get_template_definition("s_" + lib_name, "copy_s", in_type, out_type)
|
"s_" + lib_name, "copy_s", in_type, out_type)
|
||||||
<< get_template_definition("v_" + lib_name, "copy_v", in_type, out_type)
|
<< get_template_definition(
|
||||||
<< get_template_definition(
|
"v_" + lib_name, "copy_v", in_type, out_type)
|
||||||
"g1_" + lib_name, "copy_g_nd1", in_type, out_type)
|
<< get_template_definition(
|
||||||
<< get_template_definition(
|
"g1_" + lib_name, "copy_g_nd1", in_type, out_type)
|
||||||
"g2_" + lib_name, "copy_g_nd2", in_type, out_type)
|
<< get_template_definition(
|
||||||
<< get_template_definition(
|
"g2_" + lib_name, "copy_g_nd2", in_type, out_type)
|
||||||
"g3_" + lib_name, "copy_g_nd3", in_type, out_type)
|
<< get_template_definition(
|
||||||
<< get_template_definition("g_" + lib_name, "copy_g", in_type, out_type)
|
"g3_" + lib_name, "copy_g_nd3", in_type, out_type)
|
||||||
<< get_template_definition(
|
<< get_template_definition(
|
||||||
"gn4_" + lib_name, "copy_g", in_type, out_type, 4)
|
"gn4_" + lib_name, "copy_g", in_type, out_type, 4)
|
||||||
<< get_template_definition(
|
<< get_template_definition(
|
||||||
"gg1_" + lib_name, "copy_gg_nd1", in_type, out_type)
|
"gg1_" + lib_name, "copy_gg_nd1", in_type, out_type)
|
||||||
<< get_template_definition(
|
<< get_template_definition(
|
||||||
"gg2_" + lib_name, "copy_gg_nd2", in_type, out_type)
|
"gg2_" + lib_name, "copy_gg_nd2", in_type, out_type)
|
||||||
<< get_template_definition(
|
<< get_template_definition(
|
||||||
"gg3_" + lib_name, "copy_gg_nd3", in_type, out_type)
|
"gg3_" + lib_name, "copy_gg_nd3", in_type, out_type)
|
||||||
<< get_template_definition(
|
<< get_template_definition(
|
||||||
"gg_" + lib_name, "copy_gg", in_type, out_type)
|
"ggn4_" + lib_name, "copy_gg", in_type, out_type, 4);
|
||||||
<< get_template_definition(
|
|
||||||
"ggn4_" + lib_name, "copy_gg", in_type, out_type, 4);
|
|
||||||
return kernel_source.str();
|
return kernel_source.str();
|
||||||
});
|
});
|
||||||
return d.get_kernel(kernel_name, lib);
|
return d.get_kernel(kernel_name, lib);
|
||||||
|
@ -17,7 +17,6 @@
|
|||||||
instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \
|
instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \
|
||||||
instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \
|
instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \
|
||||||
instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \
|
instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \
|
||||||
instantiate_kernel("gn_" #op #tname, binary_g, itype, otype, op) \
|
|
||||||
instantiate_kernel("gn4_" #op #tname, binary_g, itype, otype, op, 4) \
|
instantiate_kernel("gn4_" #op #tname, binary_g, itype, otype, op, 4) \
|
||||||
instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op) \
|
instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op) \
|
||||||
instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op) \
|
instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op) \
|
||||||
|
@ -15,7 +15,6 @@
|
|||||||
instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \
|
instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \
|
||||||
instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \
|
instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \
|
||||||
instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \
|
instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \
|
||||||
instantiate_kernel("gn_" #op #tname, binary_g, itype, otype, op) \
|
|
||||||
instantiate_kernel("gn4_" #op #tname, binary_g, itype, otype, op, 4) \
|
instantiate_kernel("gn4_" #op #tname, binary_g, itype, otype, op, 4) \
|
||||||
instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op) \
|
instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op) \
|
||||||
instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op) \
|
instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op) \
|
||||||
|
@ -16,9 +16,7 @@
|
|||||||
instantiate_kernel("gg1_copy" #tname, copy_gg_nd1, itype, otype) \
|
instantiate_kernel("gg1_copy" #tname, copy_gg_nd1, itype, otype) \
|
||||||
instantiate_kernel("gg2_copy" #tname, copy_gg_nd2, itype, otype) \
|
instantiate_kernel("gg2_copy" #tname, copy_gg_nd2, itype, otype) \
|
||||||
instantiate_kernel("gg3_copy" #tname, copy_gg_nd3, itype, otype) \
|
instantiate_kernel("gg3_copy" #tname, copy_gg_nd3, itype, otype) \
|
||||||
instantiate_kernel("g_copy" #tname, copy_g, itype, otype) \
|
|
||||||
instantiate_kernel("gn4_copy" #tname, copy_g, itype, otype, 4) \
|
instantiate_kernel("gn4_copy" #tname, copy_g, itype, otype, 4) \
|
||||||
instantiate_kernel("gg_copy" #tname, copy_gg, itype, otype) \
|
|
||||||
instantiate_kernel("ggn4_copy" #tname, copy_gg, itype, otype, 4)
|
instantiate_kernel("ggn4_copy" #tname, copy_gg, itype, otype, 4)
|
||||||
|
|
||||||
#define instantiate_copy_itype(itname, itype) \
|
#define instantiate_copy_itype(itname, itype) \
|
||||||
|
@ -12,11 +12,10 @@
|
|||||||
#define instantiate_ternary_all(op, tname, type) \
|
#define instantiate_ternary_all(op, tname, type) \
|
||||||
instantiate_kernel("v_" #op #tname, ternary_v, type, op) \
|
instantiate_kernel("v_" #op #tname, ternary_v, type, op) \
|
||||||
instantiate_kernel("v2_" #op #tname, ternary_v2, type, op) \
|
instantiate_kernel("v2_" #op #tname, ternary_v2, type, op) \
|
||||||
instantiate_kernel("g_" #op #tname, ternary_g, type, op) \
|
|
||||||
instantiate_kernel("gn4_" #op #tname, ternary_g, type, op, 4) \
|
instantiate_kernel("gn4_" #op #tname, ternary_g, type, op, 4) \
|
||||||
instantiate_kernel("g1_" #op #tname, ternary_g_nd1, type, op) \
|
instantiate_kernel("g1_" #op #tname, ternary_g_nd1, type, op) \
|
||||||
instantiate_kernel("g2_" #op #tname, ternary_g_nd2, type, op) \
|
instantiate_kernel("g2_" #op #tname, ternary_g_nd2, type, op) \
|
||||||
instantiate_kernel("g3_" #op #tname, ternary_g_nd3, type, op) \
|
instantiate_kernel("g3_" #op #tname, ternary_g_nd3, type, op)
|
||||||
|
|
||||||
#define instantiate_ternary_types(op) \
|
#define instantiate_ternary_types(op) \
|
||||||
instantiate_ternary_all(op, bool_, bool) \
|
instantiate_ternary_all(op, bool_, bool) \
|
||||||
|
@ -8,8 +8,7 @@
|
|||||||
#define instantiate_unary_all(op, tname, type) \
|
#define instantiate_unary_all(op, tname, type) \
|
||||||
instantiate_kernel("v_" #op #tname, unary_v, type, op) \
|
instantiate_kernel("v_" #op #tname, unary_v, type, op) \
|
||||||
instantiate_kernel("v2_" #op #tname, unary_v2, type, op) \
|
instantiate_kernel("v2_" #op #tname, unary_v2, type, op) \
|
||||||
instantiate_kernel("gn4_" #op #tname, unary_g, type, op, 4) \
|
instantiate_kernel("gn4_" #op #tname, unary_g, type, op, 4)
|
||||||
instantiate_kernel("g_" #op #tname, unary_g, type, op)
|
|
||||||
|
|
||||||
#define instantiate_unary_float(op) \
|
#define instantiate_unary_float(op) \
|
||||||
instantiate_unary_all(op, float16, half) \
|
instantiate_unary_all(op, float16, half) \
|
||||||
|
@ -38,8 +38,7 @@ void ternary_op_gpu_inplace(
|
|||||||
|
|
||||||
bool use_2d = out.data_size() > UINT_MAX;
|
bool use_2d = out.data_size() > UINT_MAX;
|
||||||
auto ndim = shape.size();
|
auto ndim = shape.size();
|
||||||
int work_per_thread =
|
int work_per_thread = (topt == TernaryOpType::General) ? 4 : 1;
|
||||||
(topt == TernaryOpType::General && shape[ndim - 1] > 4) ? 4 : 1;
|
|
||||||
std::string kernel_name;
|
std::string kernel_name;
|
||||||
{
|
{
|
||||||
std::ostringstream kname;
|
std::ostringstream kname;
|
||||||
|
@ -35,7 +35,7 @@ void unary_op_gpu_inplace(
|
|||||||
};
|
};
|
||||||
auto [shape, strides] = maybe_collapse();
|
auto [shape, strides] = maybe_collapse();
|
||||||
int ndim = shape.size();
|
int ndim = shape.size();
|
||||||
int work_per_thread = (!contig && shape[ndim - 1] > 4) ? 4 : 1;
|
int work_per_thread = !contig ? 4 : 1;
|
||||||
size_t nthreads = contig ? in.data_size() : in.size();
|
size_t nthreads = contig ? in.data_size() : in.size();
|
||||||
bool use_2d = nthreads > UINT32_MAX;
|
bool use_2d = nthreads > UINT32_MAX;
|
||||||
std::string kernel_name;
|
std::string kernel_name;
|
||||||
|
@ -758,6 +758,20 @@ class TestCompile(mlx_tests.MLXTestCase):
|
|||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
y = mx.compile(fn)(x)
|
y = mx.compile(fn)(x)
|
||||||
|
|
||||||
|
def test_compile_dynamic_dims(self):
|
||||||
|
a = mx.random.uniform(shape=(2,) * 10)
|
||||||
|
b = mx.random.uniform(shape=(2,) * 10)
|
||||||
|
a = a.T
|
||||||
|
mx.eval(a, b)
|
||||||
|
|
||||||
|
def fn(a, b):
|
||||||
|
return mx.abs(a + b)
|
||||||
|
|
||||||
|
out = mx.compile(fn)(a, b)
|
||||||
|
expected = fn(a, b)
|
||||||
|
print((out - expected).abs().max())
|
||||||
|
self.assertTrue(mx.allclose(out, expected))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user