mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-21 20:46:46 +08:00
parent
50d8bed468
commit
92d7cb71f8
@ -103,8 +103,8 @@ void* compile(
|
|||||||
source_file.close();
|
source_file.close();
|
||||||
|
|
||||||
std::ostringstream build_command;
|
std::ostringstream build_command;
|
||||||
build_command << "g++ -std=c++17 -O3 -Wall -fPIC -shared "
|
build_command << "g++ -std=c++17 -O3 -Wall -fPIC -shared '"
|
||||||
<< source_file_path << " -o " << shared_lib_path;
|
<< source_file_path << "' -o '" << shared_lib_path << "'";
|
||||||
std::string build_command_str = build_command.str();
|
std::string build_command_str = build_command.str();
|
||||||
auto return_code = system(build_command_str.c_str());
|
auto return_code = system(build_command_str.c_str());
|
||||||
if (return_code) {
|
if (return_code) {
|
||||||
|
@ -93,10 +93,10 @@ inline void build_kernel(
|
|||||||
// a third grid dimension
|
// a third grid dimension
|
||||||
os << " size_t index = pos.x + grid.x * size_t(pos.y);\n";
|
os << " size_t index = pos.x + grid.x * size_t(pos.y);\n";
|
||||||
} else if (work_per_thread > 1) {
|
} else if (work_per_thread > 1) {
|
||||||
os << " constexpr int N = " << std::to_string(work_per_thread) << ";\n"
|
os << " constexpr int N_ = " << std::to_string(work_per_thread) << ";\n"
|
||||||
<< " int xshape = output_shape["
|
<< " int xshape = output_shape["
|
||||||
<< (dynamic_dims ? "ndim - 1" : std::to_string(ndim - 1)) << "];\n"
|
<< (dynamic_dims ? "ndim - 1" : std::to_string(ndim - 1)) << "];\n"
|
||||||
<< " size_t index = N * pos.x + xshape * (pos.y + size_t(grid.y) * pos.z);\n";
|
<< " size_t index = N_ * pos.x + xshape * (pos.y + size_t(grid.y) * pos.z);\n";
|
||||||
} else {
|
} else {
|
||||||
os << " size_t index = pos.x + grid.x * (pos.y + size_t(grid.y) * pos.z);\n";
|
os << " size_t index = pos.x + grid.x * (pos.y + size_t(grid.y) * pos.z);\n";
|
||||||
}
|
}
|
||||||
@ -141,11 +141,11 @@ inline void build_kernel(
|
|||||||
<< "in_strides + " << offset << ");\n";
|
<< "in_strides + " << offset << ");\n";
|
||||||
} else if (!dynamic_dims) {
|
} else if (!dynamic_dims) {
|
||||||
int offset = i * ndim;
|
int offset = i * ndim;
|
||||||
os << " size_t index_" << xname << " = N * pos.x * in_strides["
|
os << " size_t index_" << xname << " = N_ * pos.x * in_strides["
|
||||||
<< offset + ndim - 1 << "]"
|
<< offset + ndim - 1 << "]"
|
||||||
<< " + pos.y * in_strides[" << offset + ndim - 2 << "];\n";
|
<< " + pos.y * in_strides[" << offset + ndim - 2 << "];\n";
|
||||||
} else {
|
} else {
|
||||||
os << " size_t index_" << xname << " = N * pos.x * in_strides[ndim * "
|
os << " size_t index_" << xname << " = N_ * pos.x * in_strides[ndim * "
|
||||||
<< i << " + ndim - 1]"
|
<< i << " + ndim - 1]"
|
||||||
<< " + pos.y * in_strides[ndim * " << i << " + ndim - 2];\n";
|
<< " + pos.y * in_strides[ndim * " << i << " + ndim - 2];\n";
|
||||||
}
|
}
|
||||||
@ -172,7 +172,7 @@ inline void build_kernel(
|
|||||||
|
|
||||||
// Open per-thread loop
|
// Open per-thread loop
|
||||||
if (work_per_thread > 1) {
|
if (work_per_thread > 1) {
|
||||||
os << " for (int i = 0; i < N && (int(N * pos.x) + i) < xshape; ++i) {\n";
|
os << " for (int i = 0; i < N_ && (int(N_ * pos.x) + i) < xshape; ++i) {\n";
|
||||||
}
|
}
|
||||||
|
|
||||||
// Read non-contiguous inputs into tmps
|
// Read non-contiguous inputs into tmps
|
||||||
@ -434,10 +434,15 @@ void Compiled::eval_gpu(
|
|||||||
int work_per_thread = ndim > 3 ? WORK_PER_THREAD : 1;
|
int work_per_thread = ndim > 3 ? WORK_PER_THREAD : 1;
|
||||||
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
|
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
|
||||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||||
if (thread_group_size != 1024) {
|
int pow2;
|
||||||
throw std::runtime_error("[Metal::binary] Must use 1024 sized block");
|
if (thread_group_size == 1024) {
|
||||||
|
pow2 = 10;
|
||||||
|
} else if (thread_group_size > 512) {
|
||||||
|
pow2 = 9;
|
||||||
|
} else {
|
||||||
|
throw std::runtime_error("[Metal::compiled] Must use > 512 sized block");
|
||||||
}
|
}
|
||||||
auto group_dims = get_block_dims(dim0, dim1, rest);
|
auto group_dims = get_block_dims(dim0, dim1, rest, pow2);
|
||||||
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
|
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
|
||||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||||
}
|
}
|
||||||
|
@ -52,7 +52,7 @@ std::string type_to_name(const array& a) {
|
|||||||
return tname;
|
return tname;
|
||||||
}
|
}
|
||||||
|
|
||||||
MTL::Size get_block_dims(int dim0, int dim1, int dim2) {
|
MTL::Size get_block_dims(int dim0, int dim1, int dim2, int pow2 /* = 10 */) {
|
||||||
int pows[3] = {0, 0, 0};
|
int pows[3] = {0, 0, 0};
|
||||||
int sum = 0;
|
int sum = 0;
|
||||||
while (true) {
|
while (true) {
|
||||||
@ -76,7 +76,7 @@ MTL::Size get_block_dims(int dim0, int dim1, int dim2) {
|
|||||||
pows[2]++;
|
pows[2]++;
|
||||||
sum++;
|
sum++;
|
||||||
}
|
}
|
||||||
if (sum == presum || sum == 10) {
|
if (sum == presum || sum == pow2) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -30,8 +30,8 @@ std::string type_to_name(const array& a);
|
|||||||
// Compute the thread block dimensions which fit the given
|
// Compute the thread block dimensions which fit the given
|
||||||
// input dimensions.
|
// input dimensions.
|
||||||
// - The thread block dimensions will be powers of two
|
// - The thread block dimensions will be powers of two
|
||||||
// - The thread block size will be less than 1024
|
// - The thread block size will be less than 2^pow2
|
||||||
MTL::Size get_block_dims(int dim0, int dim1, int dim2);
|
MTL::Size get_block_dims(int dim0, int dim1, int dim2, int pow2 = 10);
|
||||||
|
|
||||||
// Computes a 2D grid where each element is < UINT_MAX
|
// Computes a 2D grid where each element is < UINT_MAX
|
||||||
// Assumes:
|
// Assumes:
|
||||||
|
@ -772,6 +772,23 @@ class TestCompile(mlx_tests.MLXTestCase):
|
|||||||
print((out - expected).abs().max())
|
print((out - expected).abs().max())
|
||||||
self.assertTrue(mx.allclose(out, expected))
|
self.assertTrue(mx.allclose(out, expected))
|
||||||
|
|
||||||
|
def test_compile_many_inputs(self):
|
||||||
|
inputs = [mx.ones((2, 2, 2, 2)) for _ in range(20)]
|
||||||
|
inputs[0] = inputs[0].T
|
||||||
|
|
||||||
|
@mx.compile
|
||||||
|
def fun(*inputs):
|
||||||
|
x = inputs[0]
|
||||||
|
for y in inputs[1:10]:
|
||||||
|
x = x + y
|
||||||
|
a = inputs[10]
|
||||||
|
for b in inputs[11:]:
|
||||||
|
a = a + b
|
||||||
|
return x + a
|
||||||
|
|
||||||
|
out = fun(*inputs)
|
||||||
|
self.assertTrue(mx.allclose(out, mx.full((2, 2), 20)))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user