mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-23 18:11:17 +08:00
Faster metal compiled kernels + some fixes (#1486)
* bump mac tests to use py39 * work per thread for compiled kernels * fixe for large arrays * fix
This commit is contained in:
parent
0eef4febfd
commit
881615b072
@ -84,8 +84,7 @@ void binary_op_gpu_inplace(
|
||||
|
||||
bool use_2d = out.data_size() > UINT32_MAX;
|
||||
auto ndim = shape.size();
|
||||
int work_per_thread =
|
||||
(bopt == BinaryOpType::General && shape[ndim - 1] > 4) ? 4 : 1;
|
||||
int work_per_thread = (bopt == BinaryOpType::General) ? 4 : 1;
|
||||
std::string kernel_name =
|
||||
get_kernel_name(bopt, op, a, use_2d, shape.size(), work_per_thread);
|
||||
auto& d = metal::device(s.device);
|
||||
|
@ -13,6 +13,8 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
constexpr int WORK_PER_THREAD = 4;
|
||||
|
||||
inline void build_kernel(
|
||||
std::ostream& os,
|
||||
const std::string& kernel_name,
|
||||
@ -23,7 +25,8 @@ inline void build_kernel(
|
||||
bool contiguous,
|
||||
int ndim,
|
||||
bool dynamic_dims,
|
||||
bool use_big_index = false) {
|
||||
bool use_big_index = false,
|
||||
int work_per_thread = 1) {
|
||||
// All outputs should have the exact same shape and will be row contiguous
|
||||
auto output_shape = outputs[0].shape();
|
||||
auto output_strides = outputs[0].strides();
|
||||
@ -38,8 +41,8 @@ inline void build_kernel(
|
||||
int cnt = 0;
|
||||
|
||||
// Start the kernel
|
||||
os << "[[host_name(\"" << kernel_name << "\")]]" << std::endl
|
||||
<< "[[kernel]] void " << kernel_name << "(" << std::endl;
|
||||
os << "[[host_name(\"" << kernel_name << "\")]]\n"
|
||||
<< "[[kernel]] void " << kernel_name << "(\n";
|
||||
|
||||
// Add the input arguments
|
||||
for (auto& x : inputs) {
|
||||
@ -53,11 +56,11 @@ inline void build_kernel(
|
||||
// Scalars and contiguous need no strides
|
||||
if (is_scalar(x) || contiguous) {
|
||||
os << " device const " << get_type_string(x.dtype()) << "* " << xname
|
||||
<< " [[buffer(" << cnt++ << ")]]," << std::endl;
|
||||
<< " [[buffer(" << cnt++ << ")]],\n";
|
||||
} else {
|
||||
add_indices = true;
|
||||
os << " device const " << get_type_string(x.dtype()) << "* " << xname
|
||||
<< " [[buffer(" << cnt++ << ")]]," << std::endl;
|
||||
<< " [[buffer(" << cnt++ << ")]],\n";
|
||||
}
|
||||
}
|
||||
|
||||
@ -69,58 +72,37 @@ inline void build_kernel(
|
||||
// Add the output arguments
|
||||
for (auto& x : outputs) {
|
||||
os << " device " << get_type_string(x.dtype()) << "* "
|
||||
<< namer.get_name(x) << " [[buffer(" << cnt++ << ")]]," << std::endl;
|
||||
<< namer.get_name(x) << " [[buffer(" << cnt++ << ")]],\n";
|
||||
}
|
||||
// Add output strides and shape to extract the indices.
|
||||
if (!contiguous) {
|
||||
os << " constant const size_t* output_strides [[buffer(" << cnt++
|
||||
<< ")]]," << std::endl
|
||||
<< " constant const int* output_shape [[buffer(" << cnt++ << ")]],"
|
||||
<< std::endl;
|
||||
<< ")]],\n"
|
||||
<< " constant const int* output_shape [[buffer(" << cnt++ << ")]],\n";
|
||||
}
|
||||
if (dynamic_dims) {
|
||||
os << " constant const int& ndim [[buffer(" << cnt++ << ")]],"
|
||||
<< std::endl;
|
||||
os << " constant const int& ndim [[buffer(" << cnt++ << ")]],\n";
|
||||
}
|
||||
|
||||
// The thread index in the whole grid
|
||||
os << " uint3 pos [[thread_position_in_grid]]," << std::endl
|
||||
<< " uint3 grid [[threads_per_grid]]) {" << std::endl;
|
||||
os << " uint3 pos [[thread_position_in_grid]],\n"
|
||||
<< " uint3 grid [[threads_per_grid]]) {\n";
|
||||
|
||||
if (use_big_index) {
|
||||
// This is only used for contiguous kernels which don't have
|
||||
// a third grid dimension
|
||||
os << " size_t index = pos.x + grid.x * size_t(pos.y);";
|
||||
os << " size_t index = pos.x + grid.x * size_t(pos.y);\n";
|
||||
} else if (work_per_thread > 1) {
|
||||
os << " constexpr int N = " << std::to_string(work_per_thread) << ";\n"
|
||||
<< " int xshape = output_shape["
|
||||
<< (dynamic_dims ? "ndim - 1" : std::to_string(ndim - 1)) << "];\n"
|
||||
<< " size_t index = N * pos.x + xshape * (pos.y + size_t(grid.y) * pos.z);\n";
|
||||
} else {
|
||||
os << " uint index = pos.x + grid.x * (pos.y + grid.y * pos.z);";
|
||||
}
|
||||
os << std::endl;
|
||||
|
||||
// Extract the indices per axis to individual uints if we have arrays that
|
||||
// are broadcasted or transposed
|
||||
if (add_indices) {
|
||||
if (!dynamic_dims) {
|
||||
if (ndim == 1) {
|
||||
os << " uint index_0 = pos.x;" << std::endl;
|
||||
} else if (ndim == 2) {
|
||||
os << " uint index_0 = pos.y;" << std::endl
|
||||
<< " uint index_1 = pos.x;" << std::endl;
|
||||
} else if (ndim == 3) {
|
||||
os << " uint index_0 = pos.z;" << std::endl
|
||||
<< " uint index_1 = pos.y;" << std::endl
|
||||
<< " uint index_2 = pos.x;" << std::endl;
|
||||
} else {
|
||||
for (int i = 0; i < ndim - 2; i++) {
|
||||
os << " uint index_" << i << " = (index / uint(output_strides[" << i
|
||||
<< "])) % output_shape[" << i << "];" << std::endl;
|
||||
}
|
||||
os << " uint index_" << ndim - 2 << " = pos.y;" << std::endl
|
||||
<< " uint index_" << ndim - 1 << " = pos.x;" << std::endl;
|
||||
}
|
||||
}
|
||||
os << " size_t index = pos.x + grid.x * (pos.y + size_t(grid.y) * pos.z);\n";
|
||||
}
|
||||
|
||||
// Read the inputs in tmps
|
||||
int nc_in_count = 0;
|
||||
// Read constant / contiguous inputs in tmps
|
||||
std::vector<array> nc_inputs;
|
||||
for (int i = 0; i < inputs.size(); ++i) {
|
||||
auto& x = inputs[i];
|
||||
auto& xname = namer.get_name(x);
|
||||
@ -130,56 +112,117 @@ inline void build_kernel(
|
||||
os << " auto tmp_" << xname << " = static_cast<"
|
||||
<< get_type_string(x.dtype()) << ">(";
|
||||
print_constant(os, x);
|
||||
os << ");" << std::endl;
|
||||
os << ");\n";
|
||||
} else if (is_scalar(x)) {
|
||||
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
|
||||
<< xname << "[0];" << std::endl;
|
||||
<< xname << "[0];\n";
|
||||
} else if (contiguous) {
|
||||
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
|
||||
<< xname << "[index];" << std::endl;
|
||||
} else if (!dynamic_dims) {
|
||||
int offset = nc_in_count * ndim;
|
||||
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
|
||||
<< xname << "[";
|
||||
os << "index_0 * " << "in_strides[" << offset << "]";
|
||||
for (int i = 1; i < ndim; i++) {
|
||||
os << " + index_" << i << " * " << "in_strides[" << offset + i << "]";
|
||||
}
|
||||
os << "];" << std::endl;
|
||||
nc_in_count++;
|
||||
<< xname << "[index];\n";
|
||||
} else {
|
||||
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
|
||||
<< xname << "[elem_to_loc(index, output_shape, in_strides + "
|
||||
<< nc_in_count * ndim << ", ndim)];" << std::endl;
|
||||
nc_in_count++;
|
||||
nc_inputs.push_back(x);
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize the indices for non-contiguous inputs
|
||||
for (int i = 0; i < nc_inputs.size(); ++i) {
|
||||
auto& xname = namer.get_name(nc_inputs[i]);
|
||||
if (ndim == 1) {
|
||||
int offset = i * ndim;
|
||||
os << " size_t index_" << xname << " = elem_to_loc_1(pos.x, "
|
||||
<< "in_strides[" << offset << "]);\n";
|
||||
} else if (ndim == 2) {
|
||||
int offset = i * ndim;
|
||||
os << " size_t index_" << xname << " = elem_to_loc_2({pos.x, pos.y}, "
|
||||
<< "in_strides + " << offset << ");\n";
|
||||
} else if (ndim == 3) {
|
||||
int offset = i * ndim;
|
||||
os << " size_t index_" << xname << " = elem_to_loc_3(pos, "
|
||||
<< "in_strides + " << offset << ");\n";
|
||||
} else if (!dynamic_dims) {
|
||||
int offset = i * ndim;
|
||||
os << " size_t index_" << xname << " = N * pos.x * in_strides["
|
||||
<< offset + ndim - 1 << "]"
|
||||
<< " + pos.y * in_strides[" << offset + ndim - 2 << "];\n";
|
||||
} else {
|
||||
os << " size_t index_" << xname << " = N * pos.x * in_strides[ndim * "
|
||||
<< i << " + ndim - 1]"
|
||||
<< " + pos.y * in_strides[ndim * " << i << " + ndim - 2];\n";
|
||||
}
|
||||
}
|
||||
if (!nc_inputs.empty() && (ndim > 3 || dynamic_dims)) {
|
||||
os << " uint zpos = pos.z;\n";
|
||||
if (dynamic_dims) {
|
||||
os << " for (int d = ndim - 3; d >= 0; --d) {\n";
|
||||
} else {
|
||||
os << " for (int d = " << ndim - 3 << "; d >= 0; --d) {\n";
|
||||
}
|
||||
os << " uint l = zpos % output_shape[d];\n";
|
||||
for (int i = 0; i < nc_inputs.size(); ++i) {
|
||||
auto& xname = namer.get_name(nc_inputs[i]);
|
||||
os << " index_" << xname << " += ";
|
||||
if (dynamic_dims) {
|
||||
os << "l * in_strides[" << i << " * ndim + d];\n";
|
||||
} else {
|
||||
os << "l * in_strides[" << i * ndim << " + d];\n";
|
||||
}
|
||||
}
|
||||
os << " zpos /= output_shape[d];\n }\n";
|
||||
}
|
||||
|
||||
// Open per-thread loop
|
||||
if (work_per_thread > 1) {
|
||||
os << " for (int i = 0; i < N && (int(N * pos.x) + i) < xshape; ++i) {\n";
|
||||
}
|
||||
|
||||
// Read non-contiguous inputs into tmps
|
||||
for (int i = 0; i < nc_inputs.size(); ++i) {
|
||||
auto& x = nc_inputs[i];
|
||||
auto& xname = namer.get_name(x);
|
||||
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
|
||||
<< xname << "[index_" << xname << "];\n";
|
||||
}
|
||||
|
||||
// Actually write the computation
|
||||
for (auto& x : tape) {
|
||||
os << " " << get_type_string(x.dtype()) << " tmp_" << namer.get_name(x)
|
||||
<< " = ";
|
||||
if (is_static_cast(x.primitive())) {
|
||||
os << "static_cast<" << get_type_string(x.dtype()) << ">(tmp_"
|
||||
<< namer.get_name(x.inputs()[0]) << ");" << std::endl;
|
||||
<< namer.get_name(x.inputs()[0]) << ");\n";
|
||||
} else {
|
||||
x.primitive().print(os);
|
||||
os << "()(";
|
||||
for (int i = 0; i < x.inputs().size() - 1; i++) {
|
||||
os << "tmp_" << namer.get_name(x.inputs()[i]) << ", ";
|
||||
}
|
||||
os << "tmp_" << namer.get_name(x.inputs().back()) << ");" << std::endl;
|
||||
os << "tmp_" << namer.get_name(x.inputs().back()) << ");\n";
|
||||
}
|
||||
}
|
||||
|
||||
// Write the outputs from tmps
|
||||
for (auto& x : outputs) {
|
||||
os << " " << namer.get_name(x) << "[index] = tmp_" << namer.get_name(x)
|
||||
<< ";" << std::endl;
|
||||
<< ";\n";
|
||||
}
|
||||
// Increment indices and close per thread loop
|
||||
if (work_per_thread > 1) {
|
||||
for (int i = 0; i < nc_inputs.size(); ++i) {
|
||||
auto& x = nc_inputs[i];
|
||||
auto& xname = namer.get_name(x);
|
||||
if (!dynamic_dims) {
|
||||
os << " index_" << xname << " += "
|
||||
<< "in_strides[" << i * ndim + ndim - 1 << "];\n";
|
||||
} else {
|
||||
os << " index_" << xname << " += "
|
||||
<< "in_strides[" << i << " * ndim + ndim - 1];\n";
|
||||
}
|
||||
}
|
||||
os << " index++;\n }\n";
|
||||
}
|
||||
|
||||
// Finish the kernel
|
||||
os << "}" << std::endl;
|
||||
os << "}\n";
|
||||
|
||||
if (cnt > 31) {
|
||||
std::ostringstream msg;
|
||||
@ -237,7 +280,9 @@ void Compiled::eval_gpu(
|
||||
constant_ids_,
|
||||
/* contiguous = */ false,
|
||||
/* ndim = */ i,
|
||||
/* dynamic_dims = */ false);
|
||||
/* dynamic_dims = */ false,
|
||||
/* use_big_index = */ false,
|
||||
/* work_per_thread = */ i > 3 ? WORK_PER_THREAD : 1);
|
||||
}
|
||||
build_kernel(
|
||||
kernel,
|
||||
@ -248,7 +293,9 @@ void Compiled::eval_gpu(
|
||||
constant_ids_,
|
||||
/* contiguous = */ false,
|
||||
/* ndim = */ 0,
|
||||
/* dynamic_dims = */ true);
|
||||
/* dynamic_dims = */ true,
|
||||
/* use_big_index = */ false,
|
||||
/* work_per_thread = */ WORK_PER_THREAD);
|
||||
return kernel.str();
|
||||
});
|
||||
|
||||
@ -384,6 +431,8 @@ void Compiled::eval_gpu(
|
||||
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
|
||||
size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;
|
||||
size_t rest = outputs[0].size() / (dim0 * dim1);
|
||||
int work_per_thread = ndim > 3 ? WORK_PER_THREAD : 1;
|
||||
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (thread_group_size != 1024) {
|
||||
throw std::runtime_error("[Metal::binary] Must use 1024 sized block");
|
||||
|
@ -98,7 +98,7 @@ void copy_gpu_inplace(
|
||||
if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) {
|
||||
if (shape.size() <= MAX_COPY_SPECIALIZED_DIMS) {
|
||||
kname << shape.size();
|
||||
} else if (shape[ndim - 1] >= 4) {
|
||||
} else {
|
||||
work_per_thread = 4;
|
||||
kname << "n4";
|
||||
}
|
||||
|
@ -50,8 +50,6 @@ MTL::ComputePipelineState* get_unary_kernel(
|
||||
"v_" + lib_name, "unary_v", get_type_string(out_type), op);
|
||||
kernel_source << get_template_definition(
|
||||
"v2_" + lib_name, "unary_v2", get_type_string(out_type), op);
|
||||
kernel_source << get_template_definition(
|
||||
"g_" + lib_name, "unary_g", get_type_string(out_type), op);
|
||||
kernel_source << get_template_definition(
|
||||
"gn4_" + lib_name, "unary_g", get_type_string(out_type), op, 4);
|
||||
return kernel_source.str();
|
||||
@ -65,7 +63,7 @@ void add_binary_kernels(
|
||||
Dtype out_type,
|
||||
const std::string op,
|
||||
std::ostringstream& kernel_source) {
|
||||
const std::array<std::pair<std::string, std::string>, 11> kernel_types = {{
|
||||
const std::array<std::pair<std::string, std::string>, 10> kernel_types = {{
|
||||
{"ss", "binary_ss"},
|
||||
{"vs", "binary_vs"},
|
||||
{"sv", "binary_sv"},
|
||||
@ -76,7 +74,6 @@ void add_binary_kernels(
|
||||
{"g1", "binary_g_nd1"},
|
||||
{"g2", "binary_g_nd2"},
|
||||
{"g3", "binary_g_nd3"},
|
||||
{"gn", "binary_g"},
|
||||
}};
|
||||
for (auto& [name, func] : kernel_types) {
|
||||
std::string template_def;
|
||||
@ -138,10 +135,9 @@ MTL::ComputePipelineState* get_ternary_kernel(
|
||||
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
||||
auto lib = d.get_library(lib_name, [&]() {
|
||||
std::ostringstream kernel_source;
|
||||
const std::array<std::pair<std::string, std::string>, 6> kernel_types = {{
|
||||
const std::array<std::pair<std::string, std::string>, 5> kernel_types = {{
|
||||
{"v", "ternary_v"},
|
||||
{"v2", "ternary_v2"},
|
||||
{"g", "ternary_g"},
|
||||
{"g1", "ternary_g_nd1"},
|
||||
{"g2", "ternary_g_nd2"},
|
||||
{"g3", "ternary_g_nd3"},
|
||||
@ -170,29 +166,27 @@ MTL::ComputePipelineState* get_copy_kernel(
|
||||
std::ostringstream kernel_source;
|
||||
auto in_type = get_type_string(in.dtype());
|
||||
auto out_type = get_type_string(out.dtype());
|
||||
kernel_source
|
||||
<< metal::utils() << metal::copy()
|
||||
<< get_template_definition("s_" + lib_name, "copy_s", in_type, out_type)
|
||||
<< get_template_definition("v_" + lib_name, "copy_v", in_type, out_type)
|
||||
<< get_template_definition(
|
||||
"g1_" + lib_name, "copy_g_nd1", in_type, out_type)
|
||||
<< get_template_definition(
|
||||
"g2_" + lib_name, "copy_g_nd2", in_type, out_type)
|
||||
<< get_template_definition(
|
||||
"g3_" + lib_name, "copy_g_nd3", in_type, out_type)
|
||||
<< get_template_definition("g_" + lib_name, "copy_g", in_type, out_type)
|
||||
<< get_template_definition(
|
||||
"gn4_" + lib_name, "copy_g", in_type, out_type, 4)
|
||||
<< get_template_definition(
|
||||
"gg1_" + lib_name, "copy_gg_nd1", in_type, out_type)
|
||||
<< get_template_definition(
|
||||
"gg2_" + lib_name, "copy_gg_nd2", in_type, out_type)
|
||||
<< get_template_definition(
|
||||
"gg3_" + lib_name, "copy_gg_nd3", in_type, out_type)
|
||||
<< get_template_definition(
|
||||
"gg_" + lib_name, "copy_gg", in_type, out_type)
|
||||
<< get_template_definition(
|
||||
"ggn4_" + lib_name, "copy_gg", in_type, out_type, 4);
|
||||
kernel_source << metal::utils() << metal::copy()
|
||||
<< get_template_definition(
|
||||
"s_" + lib_name, "copy_s", in_type, out_type)
|
||||
<< get_template_definition(
|
||||
"v_" + lib_name, "copy_v", in_type, out_type)
|
||||
<< get_template_definition(
|
||||
"g1_" + lib_name, "copy_g_nd1", in_type, out_type)
|
||||
<< get_template_definition(
|
||||
"g2_" + lib_name, "copy_g_nd2", in_type, out_type)
|
||||
<< get_template_definition(
|
||||
"g3_" + lib_name, "copy_g_nd3", in_type, out_type)
|
||||
<< get_template_definition(
|
||||
"gn4_" + lib_name, "copy_g", in_type, out_type, 4)
|
||||
<< get_template_definition(
|
||||
"gg1_" + lib_name, "copy_gg_nd1", in_type, out_type)
|
||||
<< get_template_definition(
|
||||
"gg2_" + lib_name, "copy_gg_nd2", in_type, out_type)
|
||||
<< get_template_definition(
|
||||
"gg3_" + lib_name, "copy_gg_nd3", in_type, out_type)
|
||||
<< get_template_definition(
|
||||
"ggn4_" + lib_name, "copy_gg", in_type, out_type, 4);
|
||||
return kernel_source.str();
|
||||
});
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
|
@ -17,7 +17,6 @@
|
||||
instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \
|
||||
instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \
|
||||
instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \
|
||||
instantiate_kernel("gn_" #op #tname, binary_g, itype, otype, op) \
|
||||
instantiate_kernel("gn4_" #op #tname, binary_g, itype, otype, op, 4) \
|
||||
instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op) \
|
||||
instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op) \
|
||||
|
@ -15,7 +15,6 @@
|
||||
instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \
|
||||
instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \
|
||||
instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \
|
||||
instantiate_kernel("gn_" #op #tname, binary_g, itype, otype, op) \
|
||||
instantiate_kernel("gn4_" #op #tname, binary_g, itype, otype, op, 4) \
|
||||
instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op) \
|
||||
instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op) \
|
||||
|
@ -16,9 +16,7 @@
|
||||
instantiate_kernel("gg1_copy" #tname, copy_gg_nd1, itype, otype) \
|
||||
instantiate_kernel("gg2_copy" #tname, copy_gg_nd2, itype, otype) \
|
||||
instantiate_kernel("gg3_copy" #tname, copy_gg_nd3, itype, otype) \
|
||||
instantiate_kernel("g_copy" #tname, copy_g, itype, otype) \
|
||||
instantiate_kernel("gn4_copy" #tname, copy_g, itype, otype, 4) \
|
||||
instantiate_kernel("gg_copy" #tname, copy_gg, itype, otype) \
|
||||
instantiate_kernel("ggn4_copy" #tname, copy_gg, itype, otype, 4)
|
||||
|
||||
#define instantiate_copy_itype(itname, itype) \
|
||||
|
@ -12,11 +12,10 @@
|
||||
#define instantiate_ternary_all(op, tname, type) \
|
||||
instantiate_kernel("v_" #op #tname, ternary_v, type, op) \
|
||||
instantiate_kernel("v2_" #op #tname, ternary_v2, type, op) \
|
||||
instantiate_kernel("g_" #op #tname, ternary_g, type, op) \
|
||||
instantiate_kernel("gn4_" #op #tname, ternary_g, type, op, 4) \
|
||||
instantiate_kernel("g1_" #op #tname, ternary_g_nd1, type, op) \
|
||||
instantiate_kernel("g2_" #op #tname, ternary_g_nd2, type, op) \
|
||||
instantiate_kernel("g3_" #op #tname, ternary_g_nd3, type, op) \
|
||||
instantiate_kernel("g3_" #op #tname, ternary_g_nd3, type, op)
|
||||
|
||||
#define instantiate_ternary_types(op) \
|
||||
instantiate_ternary_all(op, bool_, bool) \
|
||||
|
@ -8,8 +8,7 @@
|
||||
#define instantiate_unary_all(op, tname, type) \
|
||||
instantiate_kernel("v_" #op #tname, unary_v, type, op) \
|
||||
instantiate_kernel("v2_" #op #tname, unary_v2, type, op) \
|
||||
instantiate_kernel("gn4_" #op #tname, unary_g, type, op, 4) \
|
||||
instantiate_kernel("g_" #op #tname, unary_g, type, op)
|
||||
instantiate_kernel("gn4_" #op #tname, unary_g, type, op, 4)
|
||||
|
||||
#define instantiate_unary_float(op) \
|
||||
instantiate_unary_all(op, float16, half) \
|
||||
|
@ -38,8 +38,7 @@ void ternary_op_gpu_inplace(
|
||||
|
||||
bool use_2d = out.data_size() > UINT_MAX;
|
||||
auto ndim = shape.size();
|
||||
int work_per_thread =
|
||||
(topt == TernaryOpType::General && shape[ndim - 1] > 4) ? 4 : 1;
|
||||
int work_per_thread = (topt == TernaryOpType::General) ? 4 : 1;
|
||||
std::string kernel_name;
|
||||
{
|
||||
std::ostringstream kname;
|
||||
|
@ -35,7 +35,7 @@ void unary_op_gpu_inplace(
|
||||
};
|
||||
auto [shape, strides] = maybe_collapse();
|
||||
int ndim = shape.size();
|
||||
int work_per_thread = (!contig && shape[ndim - 1] > 4) ? 4 : 1;
|
||||
int work_per_thread = !contig ? 4 : 1;
|
||||
size_t nthreads = contig ? in.data_size() : in.size();
|
||||
bool use_2d = nthreads > UINT32_MAX;
|
||||
std::string kernel_name;
|
||||
|
@ -758,6 +758,20 @@ class TestCompile(mlx_tests.MLXTestCase):
|
||||
with self.assertRaises(ValueError):
|
||||
y = mx.compile(fn)(x)
|
||||
|
||||
def test_compile_dynamic_dims(self):
|
||||
a = mx.random.uniform(shape=(2,) * 10)
|
||||
b = mx.random.uniform(shape=(2,) * 10)
|
||||
a = a.T
|
||||
mx.eval(a, b)
|
||||
|
||||
def fn(a, b):
|
||||
return mx.abs(a + b)
|
||||
|
||||
out = mx.compile(fn)(a, b)
|
||||
expected = fn(a, b)
|
||||
print((out - expected).abs().max())
|
||||
self.assertTrue(mx.allclose(out, expected))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user