mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-20 10:27:41 +08:00
parent
50d8bed468
commit
92d7cb71f8
@ -103,8 +103,8 @@ void* compile(
|
||||
source_file.close();
|
||||
|
||||
std::ostringstream build_command;
|
||||
build_command << "g++ -std=c++17 -O3 -Wall -fPIC -shared "
|
||||
<< source_file_path << " -o " << shared_lib_path;
|
||||
build_command << "g++ -std=c++17 -O3 -Wall -fPIC -shared '"
|
||||
<< source_file_path << "' -o '" << shared_lib_path << "'";
|
||||
std::string build_command_str = build_command.str();
|
||||
auto return_code = system(build_command_str.c_str());
|
||||
if (return_code) {
|
||||
|
@ -93,10 +93,10 @@ inline void build_kernel(
|
||||
// a third grid dimension
|
||||
os << " size_t index = pos.x + grid.x * size_t(pos.y);\n";
|
||||
} else if (work_per_thread > 1) {
|
||||
os << " constexpr int N = " << std::to_string(work_per_thread) << ";\n"
|
||||
os << " constexpr int N_ = " << std::to_string(work_per_thread) << ";\n"
|
||||
<< " int xshape = output_shape["
|
||||
<< (dynamic_dims ? "ndim - 1" : std::to_string(ndim - 1)) << "];\n"
|
||||
<< " size_t index = N * pos.x + xshape * (pos.y + size_t(grid.y) * pos.z);\n";
|
||||
<< " size_t index = N_ * pos.x + xshape * (pos.y + size_t(grid.y) * pos.z);\n";
|
||||
} else {
|
||||
os << " size_t index = pos.x + grid.x * (pos.y + size_t(grid.y) * pos.z);\n";
|
||||
}
|
||||
@ -141,11 +141,11 @@ inline void build_kernel(
|
||||
<< "in_strides + " << offset << ");\n";
|
||||
} else if (!dynamic_dims) {
|
||||
int offset = i * ndim;
|
||||
os << " size_t index_" << xname << " = N * pos.x * in_strides["
|
||||
os << " size_t index_" << xname << " = N_ * pos.x * in_strides["
|
||||
<< offset + ndim - 1 << "]"
|
||||
<< " + pos.y * in_strides[" << offset + ndim - 2 << "];\n";
|
||||
} else {
|
||||
os << " size_t index_" << xname << " = N * pos.x * in_strides[ndim * "
|
||||
os << " size_t index_" << xname << " = N_ * pos.x * in_strides[ndim * "
|
||||
<< i << " + ndim - 1]"
|
||||
<< " + pos.y * in_strides[ndim * " << i << " + ndim - 2];\n";
|
||||
}
|
||||
@ -172,7 +172,7 @@ inline void build_kernel(
|
||||
|
||||
// Open per-thread loop
|
||||
if (work_per_thread > 1) {
|
||||
os << " for (int i = 0; i < N && (int(N * pos.x) + i) < xshape; ++i) {\n";
|
||||
os << " for (int i = 0; i < N_ && (int(N_ * pos.x) + i) < xshape; ++i) {\n";
|
||||
}
|
||||
|
||||
// Read non-contiguous inputs into tmps
|
||||
@ -434,10 +434,15 @@ void Compiled::eval_gpu(
|
||||
int work_per_thread = ndim > 3 ? WORK_PER_THREAD : 1;
|
||||
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (thread_group_size != 1024) {
|
||||
throw std::runtime_error("[Metal::binary] Must use 1024 sized block");
|
||||
int pow2;
|
||||
if (thread_group_size == 1024) {
|
||||
pow2 = 10;
|
||||
} else if (thread_group_size > 512) {
|
||||
pow2 = 9;
|
||||
} else {
|
||||
throw std::runtime_error("[Metal::compiled] Must use > 512 sized block");
|
||||
}
|
||||
auto group_dims = get_block_dims(dim0, dim1, rest);
|
||||
auto group_dims = get_block_dims(dim0, dim1, rest, pow2);
|
||||
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
|
@ -52,7 +52,7 @@ std::string type_to_name(const array& a) {
|
||||
return tname;
|
||||
}
|
||||
|
||||
MTL::Size get_block_dims(int dim0, int dim1, int dim2) {
|
||||
MTL::Size get_block_dims(int dim0, int dim1, int dim2, int pow2 /* = 10 */) {
|
||||
int pows[3] = {0, 0, 0};
|
||||
int sum = 0;
|
||||
while (true) {
|
||||
@ -76,7 +76,7 @@ MTL::Size get_block_dims(int dim0, int dim1, int dim2) {
|
||||
pows[2]++;
|
||||
sum++;
|
||||
}
|
||||
if (sum == presum || sum == 10) {
|
||||
if (sum == presum || sum == pow2) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
@ -30,8 +30,8 @@ std::string type_to_name(const array& a);
|
||||
// Compute the thread block dimensions which fit the given
|
||||
// input dimensions.
|
||||
// - The thread block dimensions will be powers of two
|
||||
// - The thread block size will be less than 1024
|
||||
MTL::Size get_block_dims(int dim0, int dim1, int dim2);
|
||||
// - The thread block size will be less than 2^pow2
|
||||
MTL::Size get_block_dims(int dim0, int dim1, int dim2, int pow2 = 10);
|
||||
|
||||
// Computes a 2D grid where each element is < UINT_MAX
|
||||
// Assumes:
|
||||
|
@ -772,6 +772,23 @@ class TestCompile(mlx_tests.MLXTestCase):
|
||||
print((out - expected).abs().max())
|
||||
self.assertTrue(mx.allclose(out, expected))
|
||||
|
||||
def test_compile_many_inputs(self):
|
||||
inputs = [mx.ones((2, 2, 2, 2)) for _ in range(20)]
|
||||
inputs[0] = inputs[0].T
|
||||
|
||||
@mx.compile
|
||||
def fun(*inputs):
|
||||
x = inputs[0]
|
||||
for y in inputs[1:10]:
|
||||
x = x + y
|
||||
a = inputs[10]
|
||||
for b in inputs[11:]:
|
||||
a = a + b
|
||||
return x + a
|
||||
|
||||
out = fun(*inputs)
|
||||
self.assertTrue(mx.allclose(out, mx.full((2, 2), 20)))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user