Fix compile (#1501)

* fix compile

* fix space
This commit is contained in:
Awni Hannun 2024-10-18 11:06:40 -07:00 committed by GitHub
parent 50d8bed468
commit 92d7cb71f8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 36 additions and 14 deletions

View File

@ -103,8 +103,8 @@ void* compile(
source_file.close(); source_file.close();
std::ostringstream build_command; std::ostringstream build_command;
build_command << "g++ -std=c++17 -O3 -Wall -fPIC -shared " build_command << "g++ -std=c++17 -O3 -Wall -fPIC -shared '"
<< source_file_path << " -o " << shared_lib_path; << source_file_path << "' -o '" << shared_lib_path << "'";
std::string build_command_str = build_command.str(); std::string build_command_str = build_command.str();
auto return_code = system(build_command_str.c_str()); auto return_code = system(build_command_str.c_str());
if (return_code) { if (return_code) {

View File

@ -93,10 +93,10 @@ inline void build_kernel(
// a third grid dimension // a third grid dimension
os << " size_t index = pos.x + grid.x * size_t(pos.y);\n"; os << " size_t index = pos.x + grid.x * size_t(pos.y);\n";
} else if (work_per_thread > 1) { } else if (work_per_thread > 1) {
os << " constexpr int N = " << std::to_string(work_per_thread) << ";\n" os << " constexpr int N_ = " << std::to_string(work_per_thread) << ";\n"
<< " int xshape = output_shape[" << " int xshape = output_shape["
<< (dynamic_dims ? "ndim - 1" : std::to_string(ndim - 1)) << "];\n" << (dynamic_dims ? "ndim - 1" : std::to_string(ndim - 1)) << "];\n"
<< " size_t index = N * pos.x + xshape * (pos.y + size_t(grid.y) * pos.z);\n"; << " size_t index = N_ * pos.x + xshape * (pos.y + size_t(grid.y) * pos.z);\n";
} else { } else {
os << " size_t index = pos.x + grid.x * (pos.y + size_t(grid.y) * pos.z);\n"; os << " size_t index = pos.x + grid.x * (pos.y + size_t(grid.y) * pos.z);\n";
} }
@ -141,11 +141,11 @@ inline void build_kernel(
<< "in_strides + " << offset << ");\n"; << "in_strides + " << offset << ");\n";
} else if (!dynamic_dims) { } else if (!dynamic_dims) {
int offset = i * ndim; int offset = i * ndim;
os << " size_t index_" << xname << " = N * pos.x * in_strides[" os << " size_t index_" << xname << " = N_ * pos.x * in_strides["
<< offset + ndim - 1 << "]" << offset + ndim - 1 << "]"
<< " + pos.y * in_strides[" << offset + ndim - 2 << "];\n"; << " + pos.y * in_strides[" << offset + ndim - 2 << "];\n";
} else { } else {
os << " size_t index_" << xname << " = N * pos.x * in_strides[ndim * " os << " size_t index_" << xname << " = N_ * pos.x * in_strides[ndim * "
<< i << " + ndim - 1]" << i << " + ndim - 1]"
<< " + pos.y * in_strides[ndim * " << i << " + ndim - 2];\n"; << " + pos.y * in_strides[ndim * " << i << " + ndim - 2];\n";
} }
@ -172,7 +172,7 @@ inline void build_kernel(
// Open per-thread loop // Open per-thread loop
if (work_per_thread > 1) { if (work_per_thread > 1) {
os << " for (int i = 0; i < N && (int(N * pos.x) + i) < xshape; ++i) {\n"; os << " for (int i = 0; i < N_ && (int(N_ * pos.x) + i) < xshape; ++i) {\n";
} }
// Read non-contiguous inputs into tmps // Read non-contiguous inputs into tmps
@ -434,10 +434,15 @@ void Compiled::eval_gpu(
int work_per_thread = ndim > 3 ? WORK_PER_THREAD : 1; int work_per_thread = ndim > 3 ? WORK_PER_THREAD : 1;
dim0 = (dim0 + work_per_thread - 1) / work_per_thread; dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size != 1024) { int pow2;
throw std::runtime_error("[Metal::binary] Must use 1024 sized block"); if (thread_group_size == 1024) {
pow2 = 10;
} else if (thread_group_size > 512) {
pow2 = 9;
} else {
throw std::runtime_error("[Metal::compiled] Must use > 512 sized block");
} }
auto group_dims = get_block_dims(dim0, dim1, rest); auto group_dims = get_block_dims(dim0, dim1, rest, pow2);
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest); MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
compute_encoder.dispatchThreads(grid_dims, group_dims); compute_encoder.dispatchThreads(grid_dims, group_dims);
} }

View File

@ -52,7 +52,7 @@ std::string type_to_name(const array& a) {
return tname; return tname;
} }
MTL::Size get_block_dims(int dim0, int dim1, int dim2) { MTL::Size get_block_dims(int dim0, int dim1, int dim2, int pow2 /* = 10 */) {
int pows[3] = {0, 0, 0}; int pows[3] = {0, 0, 0};
int sum = 0; int sum = 0;
while (true) { while (true) {
@ -76,7 +76,7 @@ MTL::Size get_block_dims(int dim0, int dim1, int dim2) {
pows[2]++; pows[2]++;
sum++; sum++;
} }
if (sum == presum || sum == 10) { if (sum == presum || sum == pow2) {
break; break;
} }
} }

View File

@ -30,8 +30,8 @@ std::string type_to_name(const array& a);
// Compute the thread block dimensions which fit the given // Compute the thread block dimensions which fit the given
// input dimensions. // input dimensions.
// - The thread block dimensions will be powers of two // - The thread block dimensions will be powers of two
// - The thread block size will be less than 1024 // - The thread block size will be less than 2^pow2
MTL::Size get_block_dims(int dim0, int dim1, int dim2); MTL::Size get_block_dims(int dim0, int dim1, int dim2, int pow2 = 10);
// Computes a 2D grid where each element is < UINT_MAX // Computes a 2D grid where each element is < UINT_MAX
// Assumes: // Assumes:

View File

@ -772,6 +772,23 @@ class TestCompile(mlx_tests.MLXTestCase):
print((out - expected).abs().max()) print((out - expected).abs().max())
self.assertTrue(mx.allclose(out, expected)) self.assertTrue(mx.allclose(out, expected))
def test_compile_many_inputs(self):
inputs = [mx.ones((2, 2, 2, 2)) for _ in range(20)]
inputs[0] = inputs[0].T
@mx.compile
def fun(*inputs):
x = inputs[0]
for y in inputs[1:10]:
x = x + y
a = inputs[10]
for b in inputs[11:]:
a = a + b
return x + a
out = fun(*inputs)
self.assertTrue(mx.allclose(out, mx.full((2, 2), 20)))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()