mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-19 23:51:14 +08:00
CPU compile (#691)
* build and load shared object for cpu compile * nits * cpu compile tests pass * cpu compile tests pass * fix preamble for g++ * donation * fix gpu buffer donation * reuse prebuilt libraries * faster contiguity conditoins * fix test * rid compiler warning * fast erf * Fix float16 for compile and add more types to cpu compile * Remove a forgotten comment * use cached libs * nits --------- Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
parent
c3965fc5ee
commit
dc937b8ed3
@ -33,7 +33,6 @@ DEFAULT(ArgSort)
|
||||
DEFAULT(AsStrided)
|
||||
DEFAULT(Broadcast)
|
||||
DEFAULT(Ceil)
|
||||
DEFAULT_MULTI(Compiled)
|
||||
DEFAULT(Concatenate)
|
||||
DEFAULT(Copy)
|
||||
DEFAULT_MULTI(CustomVJP)
|
||||
|
@ -1,59 +1,506 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <queue>
|
||||
#include <dlfcn.h>
|
||||
#include <filesystem>
|
||||
#include <list>
|
||||
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
#include "mlx/backend/common/compiled_preamble.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/graph_utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
// Build the real tape
|
||||
std::pair<std::queue<array>, std::vector<array>> trace_to_real(
|
||||
const std::vector<array>& trace_tape,
|
||||
const std::vector<array>& trace_inputs,
|
||||
const std::vector<array>& trace_outputs,
|
||||
const std::vector<array>& inputs) {
|
||||
std::unordered_map<uintptr_t, array> trace_to_real;
|
||||
for (int i = 0; i < inputs.size(); ++i) {
|
||||
trace_to_real.insert({trace_inputs[i].id(), inputs[i]});
|
||||
}
|
||||
std::queue<array> tape;
|
||||
for (auto& a : trace_tape) {
|
||||
// Find real inputs
|
||||
std::vector<array> real_inputs;
|
||||
for (auto& in : a.inputs()) {
|
||||
real_inputs.push_back(trace_to_real.at(in.id()));
|
||||
}
|
||||
tape.push(
|
||||
array(a.shape(), a.dtype(), a.primitive_ptr(), std::move(real_inputs)));
|
||||
trace_to_real.insert({a.id(), tape.back()});
|
||||
}
|
||||
|
||||
std::vector<array> outputs;
|
||||
for (auto& o : trace_outputs) {
|
||||
outputs.push_back(trace_to_real.at(o.id()));
|
||||
}
|
||||
return {tape, outputs};
|
||||
std::string get_temp_file(const std::string& name) {
|
||||
return std::filesystem::temp_directory_path().append(name);
|
||||
}
|
||||
|
||||
void Compiled::eval(
|
||||
std::string build_lib_name(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<array>& outputs,
|
||||
const std::vector<array>& tape,
|
||||
const std::unordered_set<uintptr_t>& constant_ids) {
|
||||
std::ostringstream os;
|
||||
std::ostringstream constant_hasher;
|
||||
|
||||
// The primitives describing the tape. For unary and binary primitives this
|
||||
// must be enough to describe the full computation.
|
||||
for (auto& a : tape) {
|
||||
a.primitive().print(os);
|
||||
}
|
||||
os << "_";
|
||||
|
||||
for (auto& x : inputs) {
|
||||
if (constant_ids.find(x.id()) != constant_ids.end()) {
|
||||
os << "C";
|
||||
print_constant(constant_hasher, x);
|
||||
} else {
|
||||
os << ((x.size() == 1) ? "S" : "V");
|
||||
}
|
||||
}
|
||||
os << "_";
|
||||
for (auto& x : inputs) {
|
||||
if (constant_ids.find(x.id()) != constant_ids.end()) {
|
||||
continue;
|
||||
}
|
||||
os << kindof(x.dtype()) << x.itemsize();
|
||||
}
|
||||
os << "_" << std::hash<std::string>{}(constant_hasher.str());
|
||||
|
||||
return os.str();
|
||||
}
|
||||
|
||||
void print_constant(std::ostream& os, const array& x) {
|
||||
switch (x.dtype()) {
|
||||
case float32:
|
||||
return print_float_constant<float>(os, x);
|
||||
case float16:
|
||||
return print_float_constant<float16_t>(os, x);
|
||||
case bfloat16:
|
||||
return print_float_constant<bfloat16_t>(os, x);
|
||||
case complex64:
|
||||
return print_complex_constant<complex64_t>(os, x);
|
||||
case int8:
|
||||
return print_int_constant<int8_t>(os, x);
|
||||
case int16:
|
||||
return print_int_constant<int16_t>(os, x);
|
||||
case int32:
|
||||
return print_int_constant<int32_t>(os, x);
|
||||
case int64:
|
||||
return print_int_constant<int64_t>(os, x);
|
||||
case uint8:
|
||||
return print_int_constant<uint8_t>(os, x);
|
||||
case uint16:
|
||||
return print_int_constant<uint16_t>(os, x);
|
||||
case uint32:
|
||||
return print_int_constant<uint32_t>(os, x);
|
||||
case uint64:
|
||||
return print_int_constant<uint64_t>(os, x);
|
||||
case bool_:
|
||||
os << std::boolalpha << x.item<bool>();
|
||||
return;
|
||||
default:
|
||||
throw std::runtime_error("Unsupported constant type");
|
||||
}
|
||||
}
|
||||
|
||||
std::string get_type_string(Dtype d) {
|
||||
switch (d) {
|
||||
case float32:
|
||||
return "float";
|
||||
case float16:
|
||||
return "float16_t";
|
||||
case bfloat16:
|
||||
return "bfloat16_t";
|
||||
case complex64:
|
||||
return "complex64_t";
|
||||
case bool_:
|
||||
return "bool";
|
||||
case int8:
|
||||
return "int8_t";
|
||||
case int16:
|
||||
return "int16_t";
|
||||
case int32:
|
||||
return "int32_t";
|
||||
case int64:
|
||||
return "int64_t";
|
||||
case uint8:
|
||||
return "uint8_t";
|
||||
case uint16:
|
||||
return "uint16_t";
|
||||
case uint32:
|
||||
return "uint32_t";
|
||||
case uint64:
|
||||
return "uint64_t";
|
||||
default: {
|
||||
std::ostringstream msg;
|
||||
msg << "Unsupported compilation type " << d;
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inline bool is_scalar(const array& x) {
|
||||
return x.size() == 1;
|
||||
};
|
||||
|
||||
// Return a pointer to a compiled function
|
||||
void* compile(
|
||||
const std::string& kernel_name,
|
||||
const std::string& source_code = "") {
|
||||
struct DLib {
|
||||
DLib(const std::string& libname) {
|
||||
lib = dlopen(libname.c_str(), RTLD_NOW);
|
||||
if (!lib) {
|
||||
std::ostringstream msg;
|
||||
msg << "Could not load C++ shared library " << dlerror();
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
}
|
||||
|
||||
~DLib() {
|
||||
dlclose(lib);
|
||||
}
|
||||
void* lib;
|
||||
};
|
||||
// Statics to cache compiled libraries and functions
|
||||
static std::list<DLib> libs;
|
||||
static std::unordered_map<std::string, void*> kernels;
|
||||
if (auto it = kernels.find(kernel_name); it != kernels.end()) {
|
||||
return it->second;
|
||||
}
|
||||
if (source_code.empty()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::ostringstream shared_lib_name;
|
||||
shared_lib_name << "lib" << kernel_name << ".so";
|
||||
auto shared_lib_path = get_temp_file(shared_lib_name.str());
|
||||
bool lib_exists = false;
|
||||
{
|
||||
std::ifstream f(shared_lib_path.c_str());
|
||||
lib_exists = f.good();
|
||||
}
|
||||
|
||||
if (!lib_exists) {
|
||||
// Open source file and write source code to it
|
||||
std::ostringstream source_file_name;
|
||||
source_file_name << kernel_name << ".cpp";
|
||||
auto source_file_path = get_temp_file(source_file_name.str());
|
||||
|
||||
std::ofstream source_file(source_file_path);
|
||||
source_file << source_code;
|
||||
source_file.close();
|
||||
|
||||
std::ostringstream build_command;
|
||||
build_command << "g++ -std=c++17 -O2 -Wall -fPIC -shared "
|
||||
<< source_file_path << " -o " << shared_lib_path;
|
||||
std::string build_command_str = build_command.str();
|
||||
system(build_command_str.c_str());
|
||||
}
|
||||
|
||||
// load library
|
||||
libs.emplace_back(shared_lib_path);
|
||||
|
||||
// Load function
|
||||
void* fun = dlsym(libs.back().lib, kernel_name.c_str());
|
||||
if (!fun) {
|
||||
std::ostringstream msg;
|
||||
msg << "[Compile::eval_cpu] Failed to load compiled function "
|
||||
<< kernel_name << std::endl
|
||||
<< dlerror();
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
kernels.insert({kernel_name, fun});
|
||||
return fun;
|
||||
}
|
||||
|
||||
inline void build_kernel(
|
||||
std::ostream& os,
|
||||
const std::string& kernel_name,
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<array>& outputs,
|
||||
const std::vector<array>& tape,
|
||||
const std::unordered_set<uintptr_t>& constant_ids,
|
||||
bool contiguous,
|
||||
int ndim) {
|
||||
// All outputs should have the exact same shape and will be row contiguous
|
||||
auto output_shape = outputs[0].shape();
|
||||
auto output_strides = outputs[0].strides();
|
||||
|
||||
// Constants are scalars that are captured by value and cannot change
|
||||
auto is_constant = [&constant_ids](const array& x) {
|
||||
return constant_ids.find(x.id()) != constant_ids.end();
|
||||
};
|
||||
|
||||
NodeNamer namer;
|
||||
|
||||
// Start the kernel
|
||||
os << "void " << kernel_name << "(void** args) {" << std::endl;
|
||||
|
||||
// Add the input arguments
|
||||
int cnt = 0;
|
||||
for (auto& x : inputs) {
|
||||
auto& xname = namer.get_name(x);
|
||||
|
||||
// Skip constants from the input list
|
||||
if (is_constant(x)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto tstr = get_type_string(x.dtype());
|
||||
os << " " << tstr << "* " << xname << " = (" << tstr << "*)args[" << cnt++
|
||||
<< "];" << std::endl;
|
||||
// Scalars and contiguous need no strides
|
||||
if (!is_scalar(x) && !contiguous) {
|
||||
os << " const size_t* " << xname << "_strides = (size_t*)args[" << cnt++
|
||||
<< "];" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
// Add the output arguments
|
||||
for (auto& x : outputs) {
|
||||
auto tstr = get_type_string(x.dtype());
|
||||
os << " " << tstr << "* " << namer.get_name(x) << " = (" << tstr
|
||||
<< "*)args[" << cnt++ << "];" << std::endl;
|
||||
}
|
||||
// Add output strides and shape to extract the indices.
|
||||
if (!contiguous) {
|
||||
os << " const int* shape = (int*)args[" << cnt++ << "];" << std::endl;
|
||||
} else {
|
||||
os << " const size_t size = (size_t)args[" << cnt++ << "];" << std::endl;
|
||||
}
|
||||
|
||||
if (contiguous) {
|
||||
os << " for (size_t i = 0; i < size; ++i) {" << std::endl;
|
||||
} else {
|
||||
for (int d = 0; d < ndim; ++d) {
|
||||
os << " for (int i" << d << " = 0; i" << d << " < shape[" << d
|
||||
<< "]; ++i" << d << ") {" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
// Read the inputs in tmps
|
||||
for (auto& x : inputs) {
|
||||
auto& xname = namer.get_name(x);
|
||||
|
||||
if (is_constant(x)) {
|
||||
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = ";
|
||||
print_constant(os, x);
|
||||
os << ";" << std::endl;
|
||||
} else if (is_scalar(x)) {
|
||||
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
|
||||
<< xname << "[0];" << std::endl;
|
||||
} else if (contiguous) {
|
||||
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
|
||||
<< xname << "[i];" << std::endl;
|
||||
} else {
|
||||
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = *"
|
||||
<< xname << ";" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
// Actually write the computation
|
||||
for (auto& x : tape) {
|
||||
os << " " << get_type_string(x.dtype()) << " tmp_" << namer.get_name(x)
|
||||
<< " = ";
|
||||
if (is_static_cast(x.primitive())) {
|
||||
os << "static_cast<" << get_type_string(x.dtype()) << ">(tmp_"
|
||||
<< namer.get_name(x.inputs()[0]) << ");" << std::endl;
|
||||
} else {
|
||||
x.primitive().print(os);
|
||||
os << "()(";
|
||||
for (int i = 0; i < x.inputs().size() - 1; i++) {
|
||||
os << "tmp_" << namer.get_name(x.inputs()[i]) << ", ";
|
||||
}
|
||||
os << "tmp_" << namer.get_name(x.inputs().back()) << ");" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
// Write the outputs from tmps
|
||||
for (auto& x : outputs) {
|
||||
if (contiguous) {
|
||||
os << " " << namer.get_name(x) << "[i] = tmp_" << namer.get_name(x)
|
||||
<< ";" << std::endl;
|
||||
} else {
|
||||
os << " *" << namer.get_name(x) << "++ = tmp_" << namer.get_name(x)
|
||||
<< ";" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
// Close loops
|
||||
if (contiguous) {
|
||||
os << " }" << std::endl;
|
||||
} else {
|
||||
for (int d = ndim - 1; d >= 0; --d) {
|
||||
// Update pointers
|
||||
for (auto& x : inputs) {
|
||||
if (is_constant(x) || is_scalar(x)) {
|
||||
continue;
|
||||
}
|
||||
auto& xname = namer.get_name(x);
|
||||
os << " " << xname << " += " << xname << "_strides[" << d << "];"
|
||||
<< std::endl;
|
||||
if (d < ndim - 1) {
|
||||
os << " " << xname << " -= " << xname << "_strides[" << d + 1 << "]"
|
||||
<< " * shape[" << d + 1 << "];" << std::endl;
|
||||
}
|
||||
}
|
||||
os << " }" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
// Finish the kernel
|
||||
os << "}" << std::endl;
|
||||
}
|
||||
|
||||
void Compiled::eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
// Make the a real tape from the tracers
|
||||
auto [tape, real_outputs] = trace_to_real(tape_, inputs_, outputs_, inputs);
|
||||
|
||||
// Run the tape
|
||||
while (!tape.empty()) {
|
||||
auto a = std::move(tape.front());
|
||||
tape.pop();
|
||||
auto outputs = a.outputs();
|
||||
a.primitive().eval_cpu(a.inputs(), outputs);
|
||||
a.detach();
|
||||
if (kernel_lib_.empty()) {
|
||||
kernel_lib_ = build_lib_name(inputs_, outputs_, tape_, constant_ids_);
|
||||
}
|
||||
|
||||
// Copy results into outputs
|
||||
for (int o = 0; o < real_outputs.size(); ++o) {
|
||||
outputs[o].copy_shared_buffer(real_outputs[o]);
|
||||
// Figure out which kernel we are using
|
||||
auto& shape = outputs[0].shape();
|
||||
bool contiguous = true;
|
||||
{
|
||||
bool all_contig = true;
|
||||
bool all_row_contig = true;
|
||||
bool all_col_contig = true;
|
||||
int non_scalar_inputs = 0;
|
||||
for (auto& x : inputs) {
|
||||
if (x.size() == 1) {
|
||||
continue;
|
||||
}
|
||||
non_scalar_inputs++;
|
||||
bool shape_eq = x.shape() == shape;
|
||||
all_contig &= (x.flags().contiguous && shape_eq);
|
||||
all_row_contig &= (x.flags().row_contiguous && shape_eq);
|
||||
all_col_contig &= (x.flags().col_contiguous && shape_eq);
|
||||
}
|
||||
if (non_scalar_inputs > 1 && !all_row_contig && !all_col_contig) {
|
||||
contiguous = false;
|
||||
} else if (non_scalar_inputs == 1 && !all_contig) {
|
||||
contiguous = false;
|
||||
}
|
||||
}
|
||||
|
||||
// Handle all broadcasting and collect function input arguments
|
||||
std::vector<void*> args;
|
||||
std::vector<std::vector<size_t>> strides;
|
||||
for (int i = 0; i < inputs.size(); i++) {
|
||||
// Skip constants.
|
||||
if (constant_ids_.find(inputs_[i].id()) != constant_ids_.end()) {
|
||||
continue;
|
||||
}
|
||||
auto& x = inputs[i];
|
||||
args.push_back((void*)x.data<void>());
|
||||
|
||||
if (contiguous || x.size() <= 1) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Broadcast the input to the output shape.
|
||||
std::vector<size_t> xstrides;
|
||||
int j = 0;
|
||||
for (; j < shape.size() - x.ndim(); j++) {
|
||||
if (shape[j] == 1) {
|
||||
xstrides.push_back(outputs[0].strides()[j]);
|
||||
} else {
|
||||
xstrides.push_back(0);
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < x.ndim(); i++, j++) {
|
||||
if (x.shape(i) == 1) {
|
||||
if (shape[j] == 1) {
|
||||
xstrides.push_back(outputs[0].strides()[j]);
|
||||
} else {
|
||||
xstrides.push_back(0);
|
||||
}
|
||||
} else {
|
||||
xstrides.push_back(x.strides()[i]);
|
||||
}
|
||||
}
|
||||
strides.push_back(std::move(xstrides));
|
||||
args.push_back(strides.back().data());
|
||||
}
|
||||
|
||||
// Get the kernel name from the lib
|
||||
int ndim = shape.size();
|
||||
bool dynamic = ndim >= 8;
|
||||
auto kernel_name = kernel_lib_ + (contiguous ? "_contiguous" : "_strided_");
|
||||
if (!contiguous) {
|
||||
kernel_name += std::to_string(shape.size());
|
||||
}
|
||||
|
||||
// Get the function
|
||||
auto fn_ptr = compile(kernel_name);
|
||||
|
||||
// If it doesn't exist, compile it
|
||||
if (fn_ptr == nullptr) {
|
||||
std::ostringstream kernel;
|
||||
kernel << preamble << std::endl;
|
||||
kernel << "extern \"C\" {" << std::endl;
|
||||
build_kernel(
|
||||
kernel,
|
||||
kernel_name,
|
||||
inputs_,
|
||||
outputs_,
|
||||
tape_,
|
||||
constant_ids_,
|
||||
contiguous,
|
||||
ndim);
|
||||
// Close extern "C"
|
||||
kernel << "}" << std::endl;
|
||||
|
||||
// Compile and get function pointer
|
||||
fn_ptr = compile(kernel_name, kernel.str());
|
||||
}
|
||||
|
||||
// Allocate space for the outputs possibly with input donation
|
||||
if (contiguous) {
|
||||
int o = 0;
|
||||
std::vector<size_t> strides;
|
||||
size_t data_size;
|
||||
array::Flags flags;
|
||||
for (int i = 0; i < inputs.size() && o < outputs.size(); ++i) {
|
||||
auto& in = inputs[i];
|
||||
// Conditions for donation
|
||||
// - Contiguous
|
||||
// - Donatable
|
||||
// - Correct size
|
||||
// - Not a constant
|
||||
if (in.flags().contiguous && in.size() > 1 && in.is_donatable() &&
|
||||
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
|
||||
outputs[o++].copy_shared_buffer(in);
|
||||
}
|
||||
// Get representative input flags to properly set non-donated outputs
|
||||
if (strides.empty() && in.size() == outputs[0].size()) {
|
||||
strides = in.strides();
|
||||
flags = in.flags();
|
||||
data_size = in.data_size();
|
||||
}
|
||||
}
|
||||
for (; o < outputs.size(); ++o) {
|
||||
outputs[o].set_data(
|
||||
allocator::malloc_or_wait(data_size * outputs[o].itemsize()),
|
||||
data_size,
|
||||
strides,
|
||||
flags);
|
||||
}
|
||||
} else {
|
||||
int o = 0;
|
||||
for (int i = 0; i < inputs.size() && o < outputs.size(); ++i) {
|
||||
auto& in = inputs[i];
|
||||
// Conditions for donation
|
||||
// - Row contiguous
|
||||
// - Donatable
|
||||
// - Correct size
|
||||
// - Not a constant
|
||||
if (in.flags().row_contiguous && in.nbytes() == outputs[o].nbytes() &&
|
||||
in.is_donatable() &&
|
||||
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
|
||||
outputs[o++].copy_shared_buffer(in);
|
||||
}
|
||||
}
|
||||
for (; o < outputs.size(); ++o) {
|
||||
outputs[o].set_data(allocator::malloc_or_wait(outputs[o].nbytes()));
|
||||
}
|
||||
}
|
||||
|
||||
for (auto& x : outputs) {
|
||||
args.push_back(x.data<void>());
|
||||
}
|
||||
if (!contiguous) {
|
||||
args.push_back((void*)outputs[0].shape().data());
|
||||
} else {
|
||||
args.push_back((void*)outputs[0].data_size());
|
||||
}
|
||||
auto fun = (void (*)(void**))fn_ptr;
|
||||
fun(args.data());
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
52
mlx/backend/common/compiled.h
Normal file
52
mlx/backend/common/compiled.h
Normal file
@ -0,0 +1,52 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
#pragma once
|
||||
|
||||
#include <iomanip>
|
||||
#include <sstream>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
inline bool is_static_cast(const Primitive& p) {
|
||||
return (
|
||||
typeid(p) == typeid(Broadcast) || typeid(p) == typeid(Copy) ||
|
||||
typeid(p) == typeid(StopGradient) || typeid(p) == typeid(AsType));
|
||||
}
|
||||
|
||||
std::string build_lib_name(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<array>& outputs,
|
||||
const std::vector<array>& tape,
|
||||
const std::unordered_set<uintptr_t>& constant_ids);
|
||||
|
||||
std::string get_type_string(Dtype d);
|
||||
|
||||
template <typename T>
|
||||
void print_float_constant(std::ostream& os, const array& x) {
|
||||
auto old_precision = os.precision();
|
||||
os << std::setprecision(std::numeric_limits<float>::digits10 + 1)
|
||||
<< x.item<T>() << std::setprecision(old_precision);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void print_int_constant(std::ostream& os, const array& x) {
|
||||
os << x.item<T>();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void print_complex_constant(std::ostream& os, const array& x) {
|
||||
auto old_precision = os.precision();
|
||||
T constant = x.item<T>();
|
||||
|
||||
os << get_type_string(x.dtype()) << "("
|
||||
<< std::setprecision(std::numeric_limits<float>::digits10 + 1)
|
||||
<< constant.real() << ", " << constant.imag() << ")"
|
||||
<< std::setprecision(old_precision);
|
||||
}
|
||||
|
||||
void print_constant(std::ostream& os, const array& x);
|
||||
|
||||
} // namespace mlx::core
|
1121
mlx/backend/common/compiled_preamble.h
Normal file
1121
mlx/backend/common/compiled_preamble.h
Normal file
File diff suppressed because it is too large
Load Diff
@ -43,7 +43,6 @@ DEFAULT(AsStrided)
|
||||
DEFAULT(Broadcast)
|
||||
DEFAULT_MULTI(DivMod)
|
||||
DEFAULT(Ceil)
|
||||
DEFAULT_MULTI(Compiled)
|
||||
DEFAULT(Concatenate)
|
||||
DEFAULT(Convolution)
|
||||
DEFAULT(Copy)
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
#include "mlx/backend/metal/compiled_preamble.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
@ -11,125 +12,6 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
inline bool is_static_cast(const Primitive& p) {
|
||||
return (
|
||||
typeid(p) == typeid(Broadcast) || typeid(p) == typeid(Copy) ||
|
||||
typeid(p) == typeid(StopGradient) || typeid(p) == typeid(AsType));
|
||||
}
|
||||
|
||||
inline auto get_type_string(Dtype d) {
|
||||
switch (d) {
|
||||
case float32:
|
||||
return "float";
|
||||
case float16:
|
||||
return "half";
|
||||
case bfloat16:
|
||||
return "bfloat16_t";
|
||||
case bool_:
|
||||
return "bool";
|
||||
case int8:
|
||||
return "int8_t";
|
||||
case int16:
|
||||
return "int16_t";
|
||||
case int32:
|
||||
return "int32_t";
|
||||
case int64:
|
||||
return "int64_t";
|
||||
case uint8:
|
||||
return "uint8_t";
|
||||
case uint16:
|
||||
return "uint16_t";
|
||||
case uint32:
|
||||
return "uint32_t";
|
||||
case uint64:
|
||||
return "uint64_t";
|
||||
default: {
|
||||
std::ostringstream msg;
|
||||
msg << "Unsupported compilation type " << d;
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void print_float_constant(std::ostream& os, const array& x) {
|
||||
auto old_precision = os.precision();
|
||||
os << std::setprecision(std::numeric_limits<float>::digits10 + 1)
|
||||
<< x.item<T>() << std::setprecision(old_precision);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void print_int_constant(std::ostream& os, const array& x) {
|
||||
os << x.item<T>();
|
||||
}
|
||||
|
||||
void print_constant(std::ostream& os, const array& x) {
|
||||
switch (x.dtype()) {
|
||||
case float32:
|
||||
return print_float_constant<float>(os, x);
|
||||
case float16:
|
||||
return print_float_constant<float16_t>(os, x);
|
||||
case bfloat16:
|
||||
return print_float_constant<bfloat16_t>(os, x);
|
||||
case int8:
|
||||
return print_int_constant<int8_t>(os, x);
|
||||
case int16:
|
||||
return print_int_constant<int16_t>(os, x);
|
||||
case int32:
|
||||
return print_int_constant<int32_t>(os, x);
|
||||
case int64:
|
||||
return print_int_constant<int64_t>(os, x);
|
||||
case uint8:
|
||||
return print_int_constant<uint8_t>(os, x);
|
||||
case uint16:
|
||||
return print_int_constant<uint16_t>(os, x);
|
||||
case uint32:
|
||||
return print_int_constant<uint32_t>(os, x);
|
||||
case uint64:
|
||||
return print_int_constant<uint64_t>(os, x);
|
||||
case bool_:
|
||||
os << std::boolalpha << x.item<bool>();
|
||||
return;
|
||||
default:
|
||||
throw std::runtime_error("Unsupported constant type");
|
||||
}
|
||||
}
|
||||
|
||||
inline std::string build_lib_name(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<array>& outputs,
|
||||
const std::vector<array>& tape,
|
||||
const std::unordered_set<uintptr_t>& constant_ids) {
|
||||
std::ostringstream os;
|
||||
std::ostringstream constant_hasher;
|
||||
|
||||
// The primitives describing the tape. For unary and binary primitives this
|
||||
// must be enough to describe the full computation.
|
||||
for (auto& a : tape) {
|
||||
a.primitive().print(os);
|
||||
}
|
||||
os << "_";
|
||||
|
||||
for (auto& x : inputs) {
|
||||
if (constant_ids.find(x.id()) != constant_ids.end()) {
|
||||
os << "C";
|
||||
print_constant(constant_hasher, x);
|
||||
} else {
|
||||
os << ((x.size() == 1) ? "S" : "V");
|
||||
}
|
||||
}
|
||||
os << "_";
|
||||
for (auto& x : inputs) {
|
||||
if (constant_ids.find(x.id()) != constant_ids.end()) {
|
||||
continue;
|
||||
}
|
||||
os << kindof(x.dtype()) << x.itemsize();
|
||||
}
|
||||
os << "_" << std::hash<std::string>{}(constant_hasher.str());
|
||||
|
||||
return os.str();
|
||||
}
|
||||
|
||||
inline void build_kernel(
|
||||
std::ostream& os,
|
||||
const std::string& kernel_name,
|
||||
@ -286,7 +168,7 @@ inline void build_kernel(
|
||||
|
||||
if (cnt > 31) {
|
||||
std::ostringstream msg;
|
||||
msg << "[compile] Too many inputs/outputs fused in the Metal Compile "
|
||||
msg << "[compile] Too many inputs/outputs fused in the Metal Compiled "
|
||||
<< "primitive which exhausted the available argument buffers for "
|
||||
<< "the kernel. Please file an issue with the function that results "
|
||||
<< "in this error. The name of the kernel is '" << kernel_name << "'";
|
||||
@ -348,11 +230,6 @@ void Compiled::eval_gpu(
|
||||
lib = d.get_library(kernel_lib_, kernel_source_);
|
||||
}
|
||||
|
||||
// Allocate space for the outputs
|
||||
for (auto& out : outputs) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
}
|
||||
|
||||
// Figure out which kernel we are using
|
||||
auto& output_shape = outputs[0].shape();
|
||||
bool contiguous = true;
|
||||
@ -443,6 +320,27 @@ void Compiled::eval_gpu(
|
||||
}
|
||||
}
|
||||
|
||||
// Allocate space for the outputs possibly with input donation
|
||||
{
|
||||
int o = 0;
|
||||
for (int i = 0; i < inputs.size() && o < outputs.size(); ++i) {
|
||||
auto& in = inputs[i];
|
||||
// Conditions for donation
|
||||
// - Row contiguous
|
||||
// - Donatable
|
||||
// - Correct size
|
||||
// - Not a constant
|
||||
if (in.flags().row_contiguous && in.nbytes() == outputs[o].nbytes() &&
|
||||
in.is_donatable() &&
|
||||
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
|
||||
outputs[o++].move_shared_buffer(in);
|
||||
}
|
||||
}
|
||||
for (; o < outputs.size(); ++o) {
|
||||
outputs[o].set_data(allocator::malloc_or_wait(outputs[o].nbytes()));
|
||||
}
|
||||
}
|
||||
|
||||
// Put the outputs in
|
||||
for (auto& x : outputs) {
|
||||
set_array_buffer(compute_encoder, x, cnt++);
|
||||
|
@ -2,3 +2,5 @@
|
||||
|
||||
#include "mlx/backend/metal/kernels/binary.h"
|
||||
#include "mlx/backend/metal/kernels/unary.h"
|
||||
|
||||
typedef half float16_t;
|
||||
|
@ -319,6 +319,9 @@ void compile_simplify(
|
||||
case 1:
|
||||
v = *a.data<uint8_t>();
|
||||
break;
|
||||
case 2:
|
||||
v = *a.data<uint16_t>();
|
||||
break;
|
||||
case 4:
|
||||
v = *a.data<uint32_t>();
|
||||
break;
|
||||
|
@ -3,8 +3,8 @@
|
||||
namespace mlx::core::fast {
|
||||
|
||||
// Custom primitive accepts a fallback function which it uses for
|
||||
// transformations. Transformations are virtual so that derived classes may to
|
||||
// override the default behavior
|
||||
// transformations. Transformations are virtual so that derived classes may
|
||||
// override the default behavior.
|
||||
class Custom : public Primitive {
|
||||
public:
|
||||
explicit Custom(
|
||||
|
@ -496,8 +496,6 @@ class Compiled : public Primitive {
|
||||
|
||||
std::string kernel_lib_;
|
||||
std::string kernel_source_;
|
||||
|
||||
void eval(const std::vector<array>& inputs, std::vector<array>& out);
|
||||
};
|
||||
|
||||
class Concatenate : public UnaryPrimitive {
|
||||
|
@ -60,25 +60,30 @@ inline complex64_t operator-(const complex64_t& v) {
|
||||
// clang-format off
|
||||
#define complex_binop_helper(_op_, _operator_, itype) \
|
||||
inline complex64_t _operator_(itype x, const complex64_t& y) { \
|
||||
return x _op_ static_cast<std::complex<float>>(y); \
|
||||
return static_cast<complex64_t>(x) _op_ y; \
|
||||
} \
|
||||
inline complex64_t _operator_(const complex64_t& x, itype y) { \
|
||||
return static_cast<std::complex<float>>(x) _op_ y; \
|
||||
return x _op_ static_cast<complex64_t>(y); \
|
||||
}
|
||||
|
||||
#define complex_binop(_op_, _operator_) \
|
||||
inline complex64_t _operator_(const complex64_t& x, const complex64_t& y) { \
|
||||
return static_cast<std::complex<float>>(x) \
|
||||
_op_ static_cast<std::complex<float>>(y); \
|
||||
} \
|
||||
complex_binop_helper(_op_, _operator_, bool) \
|
||||
complex_binop_helper(_op_, _operator_, uint32_t) \
|
||||
complex_binop_helper(_op_, _operator_, uint64_t) \
|
||||
complex_binop_helper(_op_, _operator_, int32_t) \
|
||||
complex_binop_helper(_op_, _operator_, int64_t) \
|
||||
complex_binop_helper(_op_, _operator_, float16_t) \
|
||||
complex_binop_helper(_op_, _operator_, bfloat16_t) \
|
||||
complex_binop_helper(_op_, _operator_, const std::complex<float>&) \
|
||||
#define complex_binop(_op_, _operator_) \
|
||||
inline complex64_t _operator_(const std::complex<float>& x, const complex64_t& y) { \
|
||||
return x _op_ static_cast<std::complex<float>>(y); \
|
||||
} \
|
||||
inline complex64_t _operator_(const complex64_t& x, const std::complex<float>& y) { \
|
||||
return static_cast<std::complex<float>>(x) _op_ y; \
|
||||
} \
|
||||
inline complex64_t _operator_(const complex64_t& x, const complex64_t& y) { \
|
||||
return static_cast<std::complex<float>>(x) \
|
||||
_op_ static_cast<std::complex<float>>(y); \
|
||||
} \
|
||||
complex_binop_helper(_op_, _operator_, bool) \
|
||||
complex_binop_helper(_op_, _operator_, uint32_t) \
|
||||
complex_binop_helper(_op_, _operator_, uint64_t) \
|
||||
complex_binop_helper(_op_, _operator_, int32_t) \
|
||||
complex_binop_helper(_op_, _operator_, int64_t) \
|
||||
complex_binop_helper(_op_, _operator_, float16_t) \
|
||||
complex_binop_helper(_op_, _operator_, bfloat16_t) \
|
||||
complex_binop_helper(_op_, _operator_, float)
|
||||
// clang-format on
|
||||
|
||||
|
@ -77,7 +77,7 @@ std::ostream& operator<<(std::ostream& os, array a);
|
||||
std::ostream& operator<<(std::ostream& os, const std::vector<int>& v);
|
||||
std::ostream& operator<<(std::ostream& os, const std::vector<size_t>& v);
|
||||
inline std::ostream& operator<<(std::ostream& os, const complex64_t& v) {
|
||||
return os << v.real() << (v.imag() > 0 ? "+" : "") << v.imag() << "j";
|
||||
return os << v.real() << (v.imag() >= 0 ? "+" : "") << v.imag() << "j";
|
||||
}
|
||||
inline std::ostream& operator<<(std::ostream& os, const float16_t& v) {
|
||||
return os << static_cast<float>(v);
|
||||
|
@ -44,8 +44,8 @@ TEST_CASE("test compile with grad") {
|
||||
auto y = array(1.0f);
|
||||
auto grads_expected = grad_fun({x, y});
|
||||
auto grads_compile = compile(grad_fun)({x, y});
|
||||
CHECK_EQ(grads_compile[0].item<float>(), grads_expected[0].item<float>());
|
||||
CHECK_EQ(grads_compile[1].item<float>(), grads_expected[1].item<float>());
|
||||
CHECK(allclose(grads_compile[0], grads_expected[0]).item<bool>());
|
||||
CHECK(allclose(grads_compile[1], grads_expected[1]).item<bool>());
|
||||
}
|
||||
|
||||
TEST_CASE("test compile inputs with primitive") {
|
||||
@ -272,7 +272,7 @@ TEST_CASE("test compile unary fused") {
|
||||
CHECK_EQ(out.inputs()[0].id(), x.id());
|
||||
|
||||
auto expected_out = unary_fused_1({array(2.0)})[0];
|
||||
CHECK_EQ(out.item<float>(), expected_out.item<float>());
|
||||
CHECK(allclose(out, expected_out).item<bool>());
|
||||
}
|
||||
|
||||
{
|
||||
|
Loading…
Reference in New Issue
Block a user