mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Fix cpu compile (#934)
* fix one cpu bug, test for another * format hooks * simplify contiguity check for cpu compile * fix * add back donation * comment
This commit is contained in:
parent
639e06e1f3
commit
2427fa171e
@ -126,4 +126,102 @@ std::string build_lib_name(
|
|||||||
return os.str();
|
return os.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool compiled_check_contiguity(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
const std::vector<int>& shape) {
|
||||||
|
bool contiguous = true;
|
||||||
|
bool all_contig = true;
|
||||||
|
bool all_row_contig = true;
|
||||||
|
bool all_col_contig = true;
|
||||||
|
int non_scalar_inputs = 0;
|
||||||
|
for (const auto& x : inputs) {
|
||||||
|
if (is_scalar(x)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
non_scalar_inputs++;
|
||||||
|
bool shape_eq = x.shape() == shape;
|
||||||
|
all_contig &= (x.flags().contiguous && shape_eq);
|
||||||
|
all_row_contig &= (x.flags().row_contiguous && shape_eq);
|
||||||
|
all_col_contig &= (x.flags().col_contiguous && shape_eq);
|
||||||
|
}
|
||||||
|
if (non_scalar_inputs > 1 && !all_row_contig && !all_col_contig) {
|
||||||
|
contiguous = false;
|
||||||
|
} else if (non_scalar_inputs == 1 && !all_contig) {
|
||||||
|
contiguous = false;
|
||||||
|
} else if (non_scalar_inputs == 0 && !shape.empty()) {
|
||||||
|
contiguous = false;
|
||||||
|
}
|
||||||
|
return contiguous;
|
||||||
|
}
|
||||||
|
|
||||||
|
void compiled_allocate_outputs(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
std::vector<array>& outputs,
|
||||||
|
const std::vector<array>& inputs_,
|
||||||
|
const std::unordered_set<uintptr_t>& constant_ids_,
|
||||||
|
bool contiguous,
|
||||||
|
bool move_buffers /* = false */) {
|
||||||
|
if (contiguous) {
|
||||||
|
int o = 0;
|
||||||
|
std::vector<size_t> strides;
|
||||||
|
size_t data_size;
|
||||||
|
array::Flags flags;
|
||||||
|
for (int i = 0; i < inputs.size() && o < outputs.size(); ++i) {
|
||||||
|
auto& in = inputs[i];
|
||||||
|
// Conditions for donation
|
||||||
|
// - Correct size
|
||||||
|
// - Not a scalar
|
||||||
|
// - Donatable
|
||||||
|
// - Not a constant
|
||||||
|
if (in.itemsize() == outputs[o].itemsize() && !is_scalar(in) &&
|
||||||
|
in.is_donatable() &&
|
||||||
|
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
|
||||||
|
if (move_buffers) {
|
||||||
|
outputs[o++].move_shared_buffer(in);
|
||||||
|
} else {
|
||||||
|
outputs[o++].copy_shared_buffer(in);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Get representative input flags to properly set non-donated outputs
|
||||||
|
if (strides.empty() && in.size() == outputs[0].size()) {
|
||||||
|
strides = in.strides();
|
||||||
|
flags = in.flags();
|
||||||
|
data_size = in.data_size();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (; o < outputs.size(); ++o) {
|
||||||
|
outputs[o].set_data(
|
||||||
|
allocator::malloc_or_wait(data_size * outputs[o].itemsize()),
|
||||||
|
data_size,
|
||||||
|
strides,
|
||||||
|
flags);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
int o = 0;
|
||||||
|
for (int i = 0; i < inputs.size() && o < outputs.size(); ++i) {
|
||||||
|
auto& in = inputs[i];
|
||||||
|
// Conditions for donation
|
||||||
|
// - Row contiguous
|
||||||
|
// - Donatable
|
||||||
|
// - Correct size
|
||||||
|
// - Not a constant
|
||||||
|
if (in.flags().row_contiguous && in.nbytes() == outputs[o].nbytes() &&
|
||||||
|
in.is_donatable() &&
|
||||||
|
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
|
||||||
|
if (move_buffers) {
|
||||||
|
outputs[o].move_shared_buffer(
|
||||||
|
in, outputs[o].strides(), in.flags(), in.data_size());
|
||||||
|
} else {
|
||||||
|
outputs[o].copy_shared_buffer(
|
||||||
|
in, outputs[o].strides(), in.flags(), in.data_size());
|
||||||
|
}
|
||||||
|
o++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (; o < outputs.size(); ++o) {
|
||||||
|
outputs[o].set_data(allocator::malloc_or_wait(outputs[o].nbytes()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -53,4 +53,18 @@ inline bool is_scalar(const array& x) {
|
|||||||
return x.ndim() == 0;
|
return x.ndim() == 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check if we can use a contiguous operation given inputs and the output shape
|
||||||
|
bool compiled_check_contiguity(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
const std::vector<int>& shape);
|
||||||
|
|
||||||
|
// Allocate space for the outputs possibly with input donation
|
||||||
|
void compiled_allocate_outputs(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
std::vector<array>& outputs,
|
||||||
|
const std::vector<array>& inputs_,
|
||||||
|
const std::unordered_set<uintptr_t>& constant_ids_,
|
||||||
|
bool contiguous,
|
||||||
|
bool move_buffers = false);
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -52,8 +52,25 @@ void* compile(
|
|||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::string kernel_file_name;
|
||||||
|
|
||||||
|
// Deal with long kernel names. Maximum length for files on macOS is 255
|
||||||
|
// characters. Clip file name with a little extra room and append a 16
|
||||||
|
// character hash.
|
||||||
|
constexpr int max_file_name_length = 245;
|
||||||
|
if (kernel_name.size() > max_file_name_length) {
|
||||||
|
std::ostringstream file_name;
|
||||||
|
file_name
|
||||||
|
<< std::string_view(kernel_name).substr(0, max_file_name_length - 16);
|
||||||
|
auto file_id = std::hash<std::string>{}(kernel_name);
|
||||||
|
file_name << "_" << std::hex << std::setw(16) << file_id << std::dec;
|
||||||
|
kernel_file_name = file_name.str();
|
||||||
|
} else {
|
||||||
|
kernel_file_name = kernel_name;
|
||||||
|
}
|
||||||
|
|
||||||
std::ostringstream shared_lib_name;
|
std::ostringstream shared_lib_name;
|
||||||
shared_lib_name << "lib" << kernel_name << ".so";
|
shared_lib_name << "lib" << kernel_file_name << ".so";
|
||||||
auto shared_lib_path = get_temp_file(shared_lib_name.str());
|
auto shared_lib_path = get_temp_file(shared_lib_name.str());
|
||||||
bool lib_exists = false;
|
bool lib_exists = false;
|
||||||
{
|
{
|
||||||
@ -64,7 +81,7 @@ void* compile(
|
|||||||
if (!lib_exists) {
|
if (!lib_exists) {
|
||||||
// Open source file and write source code to it
|
// Open source file and write source code to it
|
||||||
std::ostringstream source_file_name;
|
std::ostringstream source_file_name;
|
||||||
source_file_name << kernel_name << ".cpp";
|
source_file_name << kernel_file_name << ".cpp";
|
||||||
auto source_file_path = get_temp_file(source_file_name.str());
|
auto source_file_path = get_temp_file(source_file_name.str());
|
||||||
|
|
||||||
std::ofstream source_file(source_file_path);
|
std::ofstream source_file(source_file_path);
|
||||||
@ -248,28 +265,7 @@ void Compiled::eval_cpu(
|
|||||||
|
|
||||||
// Figure out which kernel we are using
|
// Figure out which kernel we are using
|
||||||
auto& shape = outputs[0].shape();
|
auto& shape = outputs[0].shape();
|
||||||
bool contiguous = true;
|
bool contiguous = compiled_check_contiguity(inputs, shape);
|
||||||
{
|
|
||||||
bool all_contig = true;
|
|
||||||
bool all_row_contig = true;
|
|
||||||
bool all_col_contig = true;
|
|
||||||
int non_scalar_inputs = 0;
|
|
||||||
for (auto& x : inputs) {
|
|
||||||
if (is_scalar(x)) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
non_scalar_inputs++;
|
|
||||||
bool shape_eq = x.shape() == shape;
|
|
||||||
all_contig &= (x.flags().contiguous && shape_eq);
|
|
||||||
all_row_contig &= (x.flags().row_contiguous && shape_eq);
|
|
||||||
all_col_contig &= (x.flags().col_contiguous && shape_eq);
|
|
||||||
}
|
|
||||||
if (non_scalar_inputs > 1 && !all_row_contig && !all_col_contig) {
|
|
||||||
contiguous = false;
|
|
||||||
} else if (non_scalar_inputs == 1 && !all_contig) {
|
|
||||||
contiguous = false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle all broadcasting and collect function input arguments
|
// Handle all broadcasting and collect function input arguments
|
||||||
std::vector<void*> args;
|
std::vector<void*> args;
|
||||||
@ -342,58 +338,8 @@ void Compiled::eval_cpu(
|
|||||||
fn_ptr = compile(kernel_name, kernel.str());
|
fn_ptr = compile(kernel_name, kernel.str());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Allocate space for the outputs possibly with input donation
|
compiled_allocate_outputs(
|
||||||
if (contiguous) {
|
inputs, outputs, inputs_, constant_ids_, contiguous, false);
|
||||||
int o = 0;
|
|
||||||
std::vector<size_t> strides;
|
|
||||||
size_t data_size;
|
|
||||||
array::Flags flags;
|
|
||||||
for (int i = 0; i < inputs.size() && o < outputs.size(); ++i) {
|
|
||||||
auto& in = inputs[i];
|
|
||||||
// Conditions for donation
|
|
||||||
// - Contiguous
|
|
||||||
// - Donatable
|
|
||||||
// - Correct size
|
|
||||||
// - Not a constant
|
|
||||||
if (in.flags().contiguous && !is_scalar(in) && in.is_donatable() &&
|
|
||||||
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
|
|
||||||
outputs[o++].copy_shared_buffer(in);
|
|
||||||
}
|
|
||||||
// Get representative input flags to properly set non-donated outputs
|
|
||||||
if (strides.empty() && in.size() == outputs[0].size()) {
|
|
||||||
strides = in.strides();
|
|
||||||
flags = in.flags();
|
|
||||||
data_size = in.data_size();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for (; o < outputs.size(); ++o) {
|
|
||||||
outputs[o].set_data(
|
|
||||||
allocator::malloc_or_wait(data_size * outputs[o].itemsize()),
|
|
||||||
data_size,
|
|
||||||
strides,
|
|
||||||
flags);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
int o = 0;
|
|
||||||
for (int i = 0; i < inputs.size() && o < outputs.size(); ++i) {
|
|
||||||
auto& in = inputs[i];
|
|
||||||
// Conditions for donation
|
|
||||||
// - Row contiguous
|
|
||||||
// - Donatable
|
|
||||||
// - Correct size
|
|
||||||
// - Not a constant
|
|
||||||
if (in.flags().row_contiguous && in.nbytes() == outputs[o].nbytes() &&
|
|
||||||
in.is_donatable() &&
|
|
||||||
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
|
|
||||||
outputs[o].copy_shared_buffer(
|
|
||||||
in, outputs[o].strides(), in.flags(), in.data_size());
|
|
||||||
o++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for (; o < outputs.size(); ++o) {
|
|
||||||
outputs[o].set_data(allocator::malloc_or_wait(outputs[o].nbytes()));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for (auto& x : outputs) {
|
for (auto& x : outputs) {
|
||||||
args.push_back(x.data<void>());
|
args.push_back(x.data<void>());
|
||||||
|
@ -229,14 +229,7 @@ void Compiled::eval_gpu(
|
|||||||
|
|
||||||
// Figure out which kernel we are using
|
// Figure out which kernel we are using
|
||||||
auto& output_shape = outputs[0].shape();
|
auto& output_shape = outputs[0].shape();
|
||||||
bool contiguous = true;
|
bool contiguous = compiled_check_contiguity(inputs, output_shape);
|
||||||
for (auto& x : inputs) {
|
|
||||||
if ((!x.flags().row_contiguous || x.shape() != output_shape) &&
|
|
||||||
!is_scalar(x)) {
|
|
||||||
contiguous = false;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Collapse contiguous dims to route to a faster kernel if possible. Also
|
// Collapse contiguous dims to route to a faster kernel if possible. Also
|
||||||
// handle all broadcasting.
|
// handle all broadcasting.
|
||||||
@ -317,28 +310,8 @@ void Compiled::eval_gpu(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Allocate space for the outputs possibly with input donation
|
compiled_allocate_outputs(
|
||||||
{
|
inputs, outputs, inputs_, constant_ids_, contiguous, true);
|
||||||
int o = 0;
|
|
||||||
for (int i = 0; i < inputs.size() && o < outputs.size(); ++i) {
|
|
||||||
auto& in = inputs[i];
|
|
||||||
// Conditions for donation
|
|
||||||
// - Row contiguous
|
|
||||||
// - Donatable
|
|
||||||
// - Correct size
|
|
||||||
// - Not a constant
|
|
||||||
if (in.flags().row_contiguous && in.nbytes() == outputs[o].nbytes() &&
|
|
||||||
in.is_donatable() &&
|
|
||||||
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
|
|
||||||
outputs[o].move_shared_buffer(
|
|
||||||
in, outputs[o].strides(), in.flags(), in.data_size());
|
|
||||||
o++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for (; o < outputs.size(); ++o) {
|
|
||||||
outputs[o].set_data(allocator::malloc_or_wait(outputs[o].nbytes()));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Put the outputs in
|
// Put the outputs in
|
||||||
for (auto& x : outputs) {
|
for (auto& x : outputs) {
|
||||||
|
@ -671,6 +671,26 @@ class TestCompile(mlx_tests.MLXTestCase):
|
|||||||
out = cmean(x)
|
out = cmean(x)
|
||||||
self.assertTrue(mx.allclose(out, mean(x)))
|
self.assertTrue(mx.allclose(out, mean(x)))
|
||||||
|
|
||||||
|
def test_compile_broadcast_only(self):
|
||||||
|
def fn(a):
|
||||||
|
a = mx.broadcast_to(a, (1,))
|
||||||
|
return a + a
|
||||||
|
|
||||||
|
out = mx.compile(fn)(mx.array(2.0))
|
||||||
|
# Make sure repr can be called
|
||||||
|
self.assertTrue(repr(out) is not None)
|
||||||
|
self.assertTrue(mx.array_equal(out, mx.array([4.0])))
|
||||||
|
|
||||||
|
def test_compile_with_long_name(self):
|
||||||
|
def fn(a, b):
|
||||||
|
for _ in range(10):
|
||||||
|
a = a - 1.0
|
||||||
|
b = b - 1.0
|
||||||
|
return a + b
|
||||||
|
|
||||||
|
out = mx.compile(fn)(mx.array(10.0), mx.array(20.0))
|
||||||
|
self.assertEqual(out.item(), 10.0)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user