Fix cpu compile (#934)

* fix one cpu bug, test for another

* format hooks

* simplify contiguity check for cpu compile

* fix

* add back donation

* comment
This commit is contained in:
Awni Hannun 2024-04-01 17:37:12 -07:00 committed by GitHub
parent 639e06e1f3
commit 2427fa171e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 157 additions and 106 deletions

View File

@ -126,4 +126,102 @@ std::string build_lib_name(
return os.str(); return os.str();
} }
bool compiled_check_contiguity(
const std::vector<array>& inputs,
const std::vector<int>& shape) {
bool contiguous = true;
bool all_contig = true;
bool all_row_contig = true;
bool all_col_contig = true;
int non_scalar_inputs = 0;
for (const auto& x : inputs) {
if (is_scalar(x)) {
continue;
}
non_scalar_inputs++;
bool shape_eq = x.shape() == shape;
all_contig &= (x.flags().contiguous && shape_eq);
all_row_contig &= (x.flags().row_contiguous && shape_eq);
all_col_contig &= (x.flags().col_contiguous && shape_eq);
}
if (non_scalar_inputs > 1 && !all_row_contig && !all_col_contig) {
contiguous = false;
} else if (non_scalar_inputs == 1 && !all_contig) {
contiguous = false;
} else if (non_scalar_inputs == 0 && !shape.empty()) {
contiguous = false;
}
return contiguous;
}
void compiled_allocate_outputs(
const std::vector<array>& inputs,
std::vector<array>& outputs,
const std::vector<array>& inputs_,
const std::unordered_set<uintptr_t>& constant_ids_,
bool contiguous,
bool move_buffers /* = false */) {
if (contiguous) {
int o = 0;
std::vector<size_t> strides;
size_t data_size;
array::Flags flags;
for (int i = 0; i < inputs.size() && o < outputs.size(); ++i) {
auto& in = inputs[i];
// Conditions for donation
// - Correct size
// - Not a scalar
// - Donatable
// - Not a constant
if (in.itemsize() == outputs[o].itemsize() && !is_scalar(in) &&
in.is_donatable() &&
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
if (move_buffers) {
outputs[o++].move_shared_buffer(in);
} else {
outputs[o++].copy_shared_buffer(in);
}
}
// Get representative input flags to properly set non-donated outputs
if (strides.empty() && in.size() == outputs[0].size()) {
strides = in.strides();
flags = in.flags();
data_size = in.data_size();
}
}
for (; o < outputs.size(); ++o) {
outputs[o].set_data(
allocator::malloc_or_wait(data_size * outputs[o].itemsize()),
data_size,
strides,
flags);
}
} else {
int o = 0;
for (int i = 0; i < inputs.size() && o < outputs.size(); ++i) {
auto& in = inputs[i];
// Conditions for donation
// - Row contiguous
// - Donatable
// - Correct size
// - Not a constant
if (in.flags().row_contiguous && in.nbytes() == outputs[o].nbytes() &&
in.is_donatable() &&
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
if (move_buffers) {
outputs[o].move_shared_buffer(
in, outputs[o].strides(), in.flags(), in.data_size());
} else {
outputs[o].copy_shared_buffer(
in, outputs[o].strides(), in.flags(), in.data_size());
}
o++;
}
}
for (; o < outputs.size(); ++o) {
outputs[o].set_data(allocator::malloc_or_wait(outputs[o].nbytes()));
}
}
}
} // namespace mlx::core } // namespace mlx::core

View File

@ -53,4 +53,18 @@ inline bool is_scalar(const array& x) {
return x.ndim() == 0; return x.ndim() == 0;
} }
// Check if we can use a contiguous operation given inputs and the output shape
bool compiled_check_contiguity(
const std::vector<array>& inputs,
const std::vector<int>& shape);
// Allocate space for the outputs possibly with input donation
void compiled_allocate_outputs(
const std::vector<array>& inputs,
std::vector<array>& outputs,
const std::vector<array>& inputs_,
const std::unordered_set<uintptr_t>& constant_ids_,
bool contiguous,
bool move_buffers = false);
} // namespace mlx::core } // namespace mlx::core

View File

@ -52,8 +52,25 @@ void* compile(
return nullptr; return nullptr;
} }
std::string kernel_file_name;
// Deal with long kernel names. Maximum length for files on macOS is 255
// characters. Clip file name with a little extra room and append a 16
// character hash.
constexpr int max_file_name_length = 245;
if (kernel_name.size() > max_file_name_length) {
std::ostringstream file_name;
file_name
<< std::string_view(kernel_name).substr(0, max_file_name_length - 16);
auto file_id = std::hash<std::string>{}(kernel_name);
file_name << "_" << std::hex << std::setw(16) << file_id << std::dec;
kernel_file_name = file_name.str();
} else {
kernel_file_name = kernel_name;
}
std::ostringstream shared_lib_name; std::ostringstream shared_lib_name;
shared_lib_name << "lib" << kernel_name << ".so"; shared_lib_name << "lib" << kernel_file_name << ".so";
auto shared_lib_path = get_temp_file(shared_lib_name.str()); auto shared_lib_path = get_temp_file(shared_lib_name.str());
bool lib_exists = false; bool lib_exists = false;
{ {
@ -64,7 +81,7 @@ void* compile(
if (!lib_exists) { if (!lib_exists) {
// Open source file and write source code to it // Open source file and write source code to it
std::ostringstream source_file_name; std::ostringstream source_file_name;
source_file_name << kernel_name << ".cpp"; source_file_name << kernel_file_name << ".cpp";
auto source_file_path = get_temp_file(source_file_name.str()); auto source_file_path = get_temp_file(source_file_name.str());
std::ofstream source_file(source_file_path); std::ofstream source_file(source_file_path);
@ -248,28 +265,7 @@ void Compiled::eval_cpu(
// Figure out which kernel we are using // Figure out which kernel we are using
auto& shape = outputs[0].shape(); auto& shape = outputs[0].shape();
bool contiguous = true; bool contiguous = compiled_check_contiguity(inputs, shape);
{
bool all_contig = true;
bool all_row_contig = true;
bool all_col_contig = true;
int non_scalar_inputs = 0;
for (auto& x : inputs) {
if (is_scalar(x)) {
continue;
}
non_scalar_inputs++;
bool shape_eq = x.shape() == shape;
all_contig &= (x.flags().contiguous && shape_eq);
all_row_contig &= (x.flags().row_contiguous && shape_eq);
all_col_contig &= (x.flags().col_contiguous && shape_eq);
}
if (non_scalar_inputs > 1 && !all_row_contig && !all_col_contig) {
contiguous = false;
} else if (non_scalar_inputs == 1 && !all_contig) {
contiguous = false;
}
}
// Handle all broadcasting and collect function input arguments // Handle all broadcasting and collect function input arguments
std::vector<void*> args; std::vector<void*> args;
@ -342,58 +338,8 @@ void Compiled::eval_cpu(
fn_ptr = compile(kernel_name, kernel.str()); fn_ptr = compile(kernel_name, kernel.str());
} }
// Allocate space for the outputs possibly with input donation compiled_allocate_outputs(
if (contiguous) { inputs, outputs, inputs_, constant_ids_, contiguous, false);
int o = 0;
std::vector<size_t> strides;
size_t data_size;
array::Flags flags;
for (int i = 0; i < inputs.size() && o < outputs.size(); ++i) {
auto& in = inputs[i];
// Conditions for donation
// - Contiguous
// - Donatable
// - Correct size
// - Not a constant
if (in.flags().contiguous && !is_scalar(in) && in.is_donatable() &&
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
outputs[o++].copy_shared_buffer(in);
}
// Get representative input flags to properly set non-donated outputs
if (strides.empty() && in.size() == outputs[0].size()) {
strides = in.strides();
flags = in.flags();
data_size = in.data_size();
}
}
for (; o < outputs.size(); ++o) {
outputs[o].set_data(
allocator::malloc_or_wait(data_size * outputs[o].itemsize()),
data_size,
strides,
flags);
}
} else {
int o = 0;
for (int i = 0; i < inputs.size() && o < outputs.size(); ++i) {
auto& in = inputs[i];
// Conditions for donation
// - Row contiguous
// - Donatable
// - Correct size
// - Not a constant
if (in.flags().row_contiguous && in.nbytes() == outputs[o].nbytes() &&
in.is_donatable() &&
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
outputs[o].copy_shared_buffer(
in, outputs[o].strides(), in.flags(), in.data_size());
o++;
}
}
for (; o < outputs.size(); ++o) {
outputs[o].set_data(allocator::malloc_or_wait(outputs[o].nbytes()));
}
}
for (auto& x : outputs) { for (auto& x : outputs) {
args.push_back(x.data<void>()); args.push_back(x.data<void>());

View File

@ -229,14 +229,7 @@ void Compiled::eval_gpu(
// Figure out which kernel we are using // Figure out which kernel we are using
auto& output_shape = outputs[0].shape(); auto& output_shape = outputs[0].shape();
bool contiguous = true; bool contiguous = compiled_check_contiguity(inputs, output_shape);
for (auto& x : inputs) {
if ((!x.flags().row_contiguous || x.shape() != output_shape) &&
!is_scalar(x)) {
contiguous = false;
break;
}
}
// Collapse contiguous dims to route to a faster kernel if possible. Also // Collapse contiguous dims to route to a faster kernel if possible. Also
// handle all broadcasting. // handle all broadcasting.
@ -317,28 +310,8 @@ void Compiled::eval_gpu(
} }
} }
// Allocate space for the outputs possibly with input donation compiled_allocate_outputs(
{ inputs, outputs, inputs_, constant_ids_, contiguous, true);
int o = 0;
for (int i = 0; i < inputs.size() && o < outputs.size(); ++i) {
auto& in = inputs[i];
// Conditions for donation
// - Row contiguous
// - Donatable
// - Correct size
// - Not a constant
if (in.flags().row_contiguous && in.nbytes() == outputs[o].nbytes() &&
in.is_donatable() &&
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
outputs[o].move_shared_buffer(
in, outputs[o].strides(), in.flags(), in.data_size());
o++;
}
}
for (; o < outputs.size(); ++o) {
outputs[o].set_data(allocator::malloc_or_wait(outputs[o].nbytes()));
}
}
// Put the outputs in // Put the outputs in
for (auto& x : outputs) { for (auto& x : outputs) {

View File

@ -671,6 +671,26 @@ class TestCompile(mlx_tests.MLXTestCase):
out = cmean(x) out = cmean(x)
self.assertTrue(mx.allclose(out, mean(x))) self.assertTrue(mx.allclose(out, mean(x)))
def test_compile_broadcast_only(self):
def fn(a):
a = mx.broadcast_to(a, (1,))
return a + a
out = mx.compile(fn)(mx.array(2.0))
# Make sure repr can be called
self.assertTrue(repr(out) is not None)
self.assertTrue(mx.array_equal(out, mx.array([4.0])))
def test_compile_with_long_name(self):
def fn(a, b):
for _ in range(10):
a = a - 1.0
b = b - 1.0
return a + b
out = mx.compile(fn)(mx.array(10.0), mx.array(20.0))
self.assertEqual(out.item(), 10.0)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()