diff --git a/mlx/backend/accelerate/primitives.cpp b/mlx/backend/accelerate/primitives.cpp index 499cc0ce4..4cccd35ae 100644 --- a/mlx/backend/accelerate/primitives.cpp +++ b/mlx/backend/accelerate/primitives.cpp @@ -33,7 +33,6 @@ DEFAULT(ArgSort) DEFAULT(AsStrided) DEFAULT(Broadcast) DEFAULT(Ceil) -DEFAULT_MULTI(Compiled) DEFAULT(Concatenate) DEFAULT(Copy) DEFAULT_MULTI(CustomVJP) diff --git a/mlx/backend/common/compiled.cpp b/mlx/backend/common/compiled.cpp index 149556530..8bf1f43e4 100644 --- a/mlx/backend/common/compiled.cpp +++ b/mlx/backend/common/compiled.cpp @@ -1,59 +1,506 @@ // Copyright © 2023-2024 Apple Inc. -#include +#include +#include +#include +#include "mlx/backend/common/compiled.h" +#include "mlx/backend/common/compiled_preamble.h" +#include "mlx/backend/common/utils.h" +#include "mlx/graph_utils.h" #include "mlx/primitives.h" +#include "mlx/utils.h" namespace mlx::core { -// Build the real tape -std::pair, std::vector> trace_to_real( - const std::vector& trace_tape, - const std::vector& trace_inputs, - const std::vector& trace_outputs, - const std::vector& inputs) { - std::unordered_map trace_to_real; - for (int i = 0; i < inputs.size(); ++i) { - trace_to_real.insert({trace_inputs[i].id(), inputs[i]}); - } - std::queue tape; - for (auto& a : trace_tape) { - // Find real inputs - std::vector real_inputs; - for (auto& in : a.inputs()) { - real_inputs.push_back(trace_to_real.at(in.id())); - } - tape.push( - array(a.shape(), a.dtype(), a.primitive_ptr(), std::move(real_inputs))); - trace_to_real.insert({a.id(), tape.back()}); - } - - std::vector outputs; - for (auto& o : trace_outputs) { - outputs.push_back(trace_to_real.at(o.id())); - } - return {tape, outputs}; +std::string get_temp_file(const std::string& name) { + return std::filesystem::temp_directory_path().append(name); } -void Compiled::eval( +std::string build_lib_name( + const std::vector& inputs, + const std::vector& outputs, + const std::vector& tape, + const std::unordered_set& constant_ids) { + std::ostringstream os; + std::ostringstream constant_hasher; + + // The primitives describing the tape. For unary and binary primitives this + // must be enough to describe the full computation. + for (auto& a : tape) { + a.primitive().print(os); + } + os << "_"; + + for (auto& x : inputs) { + if (constant_ids.find(x.id()) != constant_ids.end()) { + os << "C"; + print_constant(constant_hasher, x); + } else { + os << ((x.size() == 1) ? "S" : "V"); + } + } + os << "_"; + for (auto& x : inputs) { + if (constant_ids.find(x.id()) != constant_ids.end()) { + continue; + } + os << kindof(x.dtype()) << x.itemsize(); + } + os << "_" << std::hash{}(constant_hasher.str()); + + return os.str(); +} + +void print_constant(std::ostream& os, const array& x) { + switch (x.dtype()) { + case float32: + return print_float_constant(os, x); + case float16: + return print_float_constant(os, x); + case bfloat16: + return print_float_constant(os, x); + case complex64: + return print_complex_constant(os, x); + case int8: + return print_int_constant(os, x); + case int16: + return print_int_constant(os, x); + case int32: + return print_int_constant(os, x); + case int64: + return print_int_constant(os, x); + case uint8: + return print_int_constant(os, x); + case uint16: + return print_int_constant(os, x); + case uint32: + return print_int_constant(os, x); + case uint64: + return print_int_constant(os, x); + case bool_: + os << std::boolalpha << x.item(); + return; + default: + throw std::runtime_error("Unsupported constant type"); + } +} + +std::string get_type_string(Dtype d) { + switch (d) { + case float32: + return "float"; + case float16: + return "float16_t"; + case bfloat16: + return "bfloat16_t"; + case complex64: + return "complex64_t"; + case bool_: + return "bool"; + case int8: + return "int8_t"; + case int16: + return "int16_t"; + case int32: + return "int32_t"; + case int64: + return "int64_t"; + case uint8: + return "uint8_t"; + case uint16: + return "uint16_t"; + case uint32: + return "uint32_t"; + case uint64: + return "uint64_t"; + default: { + std::ostringstream msg; + msg << "Unsupported compilation type " << d; + throw std::runtime_error(msg.str()); + } + } +} + +inline bool is_scalar(const array& x) { + return x.size() == 1; +}; + +// Return a pointer to a compiled function +void* compile( + const std::string& kernel_name, + const std::string& source_code = "") { + struct DLib { + DLib(const std::string& libname) { + lib = dlopen(libname.c_str(), RTLD_NOW); + if (!lib) { + std::ostringstream msg; + msg << "Could not load C++ shared library " << dlerror(); + throw std::runtime_error(msg.str()); + } + } + + ~DLib() { + dlclose(lib); + } + void* lib; + }; + // Statics to cache compiled libraries and functions + static std::list libs; + static std::unordered_map kernels; + if (auto it = kernels.find(kernel_name); it != kernels.end()) { + return it->second; + } + if (source_code.empty()) { + return nullptr; + } + + std::ostringstream shared_lib_name; + shared_lib_name << "lib" << kernel_name << ".so"; + auto shared_lib_path = get_temp_file(shared_lib_name.str()); + bool lib_exists = false; + { + std::ifstream f(shared_lib_path.c_str()); + lib_exists = f.good(); + } + + if (!lib_exists) { + // Open source file and write source code to it + std::ostringstream source_file_name; + source_file_name << kernel_name << ".cpp"; + auto source_file_path = get_temp_file(source_file_name.str()); + + std::ofstream source_file(source_file_path); + source_file << source_code; + source_file.close(); + + std::ostringstream build_command; + build_command << "g++ -std=c++17 -O2 -Wall -fPIC -shared " + << source_file_path << " -o " << shared_lib_path; + std::string build_command_str = build_command.str(); + system(build_command_str.c_str()); + } + + // load library + libs.emplace_back(shared_lib_path); + + // Load function + void* fun = dlsym(libs.back().lib, kernel_name.c_str()); + if (!fun) { + std::ostringstream msg; + msg << "[Compile::eval_cpu] Failed to load compiled function " + << kernel_name << std::endl + << dlerror(); + throw std::runtime_error(msg.str()); + } + kernels.insert({kernel_name, fun}); + return fun; +} + +inline void build_kernel( + std::ostream& os, + const std::string& kernel_name, + const std::vector& inputs, + const std::vector& outputs, + const std::vector& tape, + const std::unordered_set& constant_ids, + bool contiguous, + int ndim) { + // All outputs should have the exact same shape and will be row contiguous + auto output_shape = outputs[0].shape(); + auto output_strides = outputs[0].strides(); + + // Constants are scalars that are captured by value and cannot change + auto is_constant = [&constant_ids](const array& x) { + return constant_ids.find(x.id()) != constant_ids.end(); + }; + + NodeNamer namer; + + // Start the kernel + os << "void " << kernel_name << "(void** args) {" << std::endl; + + // Add the input arguments + int cnt = 0; + for (auto& x : inputs) { + auto& xname = namer.get_name(x); + + // Skip constants from the input list + if (is_constant(x)) { + continue; + } + + auto tstr = get_type_string(x.dtype()); + os << " " << tstr << "* " << xname << " = (" << tstr << "*)args[" << cnt++ + << "];" << std::endl; + // Scalars and contiguous need no strides + if (!is_scalar(x) && !contiguous) { + os << " const size_t* " << xname << "_strides = (size_t*)args[" << cnt++ + << "];" << std::endl; + } + } + + // Add the output arguments + for (auto& x : outputs) { + auto tstr = get_type_string(x.dtype()); + os << " " << tstr << "* " << namer.get_name(x) << " = (" << tstr + << "*)args[" << cnt++ << "];" << std::endl; + } + // Add output strides and shape to extract the indices. + if (!contiguous) { + os << " const int* shape = (int*)args[" << cnt++ << "];" << std::endl; + } else { + os << " const size_t size = (size_t)args[" << cnt++ << "];" << std::endl; + } + + if (contiguous) { + os << " for (size_t i = 0; i < size; ++i) {" << std::endl; + } else { + for (int d = 0; d < ndim; ++d) { + os << " for (int i" << d << " = 0; i" << d << " < shape[" << d + << "]; ++i" << d << ") {" << std::endl; + } + } + + // Read the inputs in tmps + for (auto& x : inputs) { + auto& xname = namer.get_name(x); + + if (is_constant(x)) { + os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "; + print_constant(os, x); + os << ";" << std::endl; + } else if (is_scalar(x)) { + os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = " + << xname << "[0];" << std::endl; + } else if (contiguous) { + os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = " + << xname << "[i];" << std::endl; + } else { + os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = *" + << xname << ";" << std::endl; + } + } + + // Actually write the computation + for (auto& x : tape) { + os << " " << get_type_string(x.dtype()) << " tmp_" << namer.get_name(x) + << " = "; + if (is_static_cast(x.primitive())) { + os << "static_cast<" << get_type_string(x.dtype()) << ">(tmp_" + << namer.get_name(x.inputs()[0]) << ");" << std::endl; + } else { + x.primitive().print(os); + os << "()("; + for (int i = 0; i < x.inputs().size() - 1; i++) { + os << "tmp_" << namer.get_name(x.inputs()[i]) << ", "; + } + os << "tmp_" << namer.get_name(x.inputs().back()) << ");" << std::endl; + } + } + + // Write the outputs from tmps + for (auto& x : outputs) { + if (contiguous) { + os << " " << namer.get_name(x) << "[i] = tmp_" << namer.get_name(x) + << ";" << std::endl; + } else { + os << " *" << namer.get_name(x) << "++ = tmp_" << namer.get_name(x) + << ";" << std::endl; + } + } + + // Close loops + if (contiguous) { + os << " }" << std::endl; + } else { + for (int d = ndim - 1; d >= 0; --d) { + // Update pointers + for (auto& x : inputs) { + if (is_constant(x) || is_scalar(x)) { + continue; + } + auto& xname = namer.get_name(x); + os << " " << xname << " += " << xname << "_strides[" << d << "];" + << std::endl; + if (d < ndim - 1) { + os << " " << xname << " -= " << xname << "_strides[" << d + 1 << "]" + << " * shape[" << d + 1 << "];" << std::endl; + } + } + os << " }" << std::endl; + } + } + + // Finish the kernel + os << "}" << std::endl; +} + +void Compiled::eval_cpu( const std::vector& inputs, std::vector& outputs) { - // Make the a real tape from the tracers - auto [tape, real_outputs] = trace_to_real(tape_, inputs_, outputs_, inputs); - - // Run the tape - while (!tape.empty()) { - auto a = std::move(tape.front()); - tape.pop(); - auto outputs = a.outputs(); - a.primitive().eval_cpu(a.inputs(), outputs); - a.detach(); + if (kernel_lib_.empty()) { + kernel_lib_ = build_lib_name(inputs_, outputs_, tape_, constant_ids_); } - // Copy results into outputs - for (int o = 0; o < real_outputs.size(); ++o) { - outputs[o].copy_shared_buffer(real_outputs[o]); + // Figure out which kernel we are using + auto& shape = outputs[0].shape(); + bool contiguous = true; + { + bool all_contig = true; + bool all_row_contig = true; + bool all_col_contig = true; + int non_scalar_inputs = 0; + for (auto& x : inputs) { + if (x.size() == 1) { + continue; + } + non_scalar_inputs++; + bool shape_eq = x.shape() == shape; + all_contig &= (x.flags().contiguous && shape_eq); + all_row_contig &= (x.flags().row_contiguous && shape_eq); + all_col_contig &= (x.flags().col_contiguous && shape_eq); + } + if (non_scalar_inputs > 1 && !all_row_contig && !all_col_contig) { + contiguous = false; + } else if (non_scalar_inputs == 1 && !all_contig) { + contiguous = false; + } } + + // Handle all broadcasting and collect function input arguments + std::vector args; + std::vector> strides; + for (int i = 0; i < inputs.size(); i++) { + // Skip constants. + if (constant_ids_.find(inputs_[i].id()) != constant_ids_.end()) { + continue; + } + auto& x = inputs[i]; + args.push_back((void*)x.data()); + + if (contiguous || x.size() <= 1) { + continue; + } + + // Broadcast the input to the output shape. + std::vector xstrides; + int j = 0; + for (; j < shape.size() - x.ndim(); j++) { + if (shape[j] == 1) { + xstrides.push_back(outputs[0].strides()[j]); + } else { + xstrides.push_back(0); + } + } + for (int i = 0; i < x.ndim(); i++, j++) { + if (x.shape(i) == 1) { + if (shape[j] == 1) { + xstrides.push_back(outputs[0].strides()[j]); + } else { + xstrides.push_back(0); + } + } else { + xstrides.push_back(x.strides()[i]); + } + } + strides.push_back(std::move(xstrides)); + args.push_back(strides.back().data()); + } + + // Get the kernel name from the lib + int ndim = shape.size(); + bool dynamic = ndim >= 8; + auto kernel_name = kernel_lib_ + (contiguous ? "_contiguous" : "_strided_"); + if (!contiguous) { + kernel_name += std::to_string(shape.size()); + } + + // Get the function + auto fn_ptr = compile(kernel_name); + + // If it doesn't exist, compile it + if (fn_ptr == nullptr) { + std::ostringstream kernel; + kernel << preamble << std::endl; + kernel << "extern \"C\" {" << std::endl; + build_kernel( + kernel, + kernel_name, + inputs_, + outputs_, + tape_, + constant_ids_, + contiguous, + ndim); + // Close extern "C" + kernel << "}" << std::endl; + + // Compile and get function pointer + fn_ptr = compile(kernel_name, kernel.str()); + } + + // Allocate space for the outputs possibly with input donation + if (contiguous) { + int o = 0; + std::vector strides; + size_t data_size; + array::Flags flags; + for (int i = 0; i < inputs.size() && o < outputs.size(); ++i) { + auto& in = inputs[i]; + // Conditions for donation + // - Contiguous + // - Donatable + // - Correct size + // - Not a constant + if (in.flags().contiguous && in.size() > 1 && in.is_donatable() && + constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) { + outputs[o++].copy_shared_buffer(in); + } + // Get representative input flags to properly set non-donated outputs + if (strides.empty() && in.size() == outputs[0].size()) { + strides = in.strides(); + flags = in.flags(); + data_size = in.data_size(); + } + } + for (; o < outputs.size(); ++o) { + outputs[o].set_data( + allocator::malloc_or_wait(data_size * outputs[o].itemsize()), + data_size, + strides, + flags); + } + } else { + int o = 0; + for (int i = 0; i < inputs.size() && o < outputs.size(); ++i) { + auto& in = inputs[i]; + // Conditions for donation + // - Row contiguous + // - Donatable + // - Correct size + // - Not a constant + if (in.flags().row_contiguous && in.nbytes() == outputs[o].nbytes() && + in.is_donatable() && + constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) { + outputs[o++].copy_shared_buffer(in); + } + } + for (; o < outputs.size(); ++o) { + outputs[o].set_data(allocator::malloc_or_wait(outputs[o].nbytes())); + } + } + + for (auto& x : outputs) { + args.push_back(x.data()); + } + if (!contiguous) { + args.push_back((void*)outputs[0].shape().data()); + } else { + args.push_back((void*)outputs[0].data_size()); + } + auto fun = (void (*)(void**))fn_ptr; + fun(args.data()); } } // namespace mlx::core diff --git a/mlx/backend/common/compiled.h b/mlx/backend/common/compiled.h new file mode 100644 index 000000000..adbd5399c --- /dev/null +++ b/mlx/backend/common/compiled.h @@ -0,0 +1,52 @@ +// Copyright © 2023-2024 Apple Inc. +#pragma once + +#include +#include +#include + +#include "mlx/array.h" +#include "mlx/primitives.h" + +namespace mlx::core { + +inline bool is_static_cast(const Primitive& p) { + return ( + typeid(p) == typeid(Broadcast) || typeid(p) == typeid(Copy) || + typeid(p) == typeid(StopGradient) || typeid(p) == typeid(AsType)); +} + +std::string build_lib_name( + const std::vector& inputs, + const std::vector& outputs, + const std::vector& tape, + const std::unordered_set& constant_ids); + +std::string get_type_string(Dtype d); + +template +void print_float_constant(std::ostream& os, const array& x) { + auto old_precision = os.precision(); + os << std::setprecision(std::numeric_limits::digits10 + 1) + << x.item() << std::setprecision(old_precision); +} + +template +void print_int_constant(std::ostream& os, const array& x) { + os << x.item(); +} + +template +void print_complex_constant(std::ostream& os, const array& x) { + auto old_precision = os.precision(); + T constant = x.item(); + + os << get_type_string(x.dtype()) << "(" + << std::setprecision(std::numeric_limits::digits10 + 1) + << constant.real() << ", " << constant.imag() << ")" + << std::setprecision(old_precision); +} + +void print_constant(std::ostream& os, const array& x); + +} // namespace mlx::core diff --git a/mlx/backend/common/compiled_preamble.h b/mlx/backend/common/compiled_preamble.h new file mode 100644 index 000000000..8ccaa8bd7 --- /dev/null +++ b/mlx/backend/common/compiled_preamble.h @@ -0,0 +1,1121 @@ +// Copyright © 2023-2024 Apple Inc. + +const std::string preamble = R"( +#include +#include +#include + +#ifdef __ARM_FEATURE_FP16_SCALAR_ARITHMETIC + +#include +typedef __fp16 float16_t; + +#else + +#define ADD_HALF_BINOPS +#include +#include +#include +#include + +#define __MLX_HALF_NAN__ 0x7D00 + + +namespace { +union float_bits_fp16 { + float f; + uint32_t u; +}; +} // namespace + +struct _MLX_Float16 { + uint16_t bits_; + + // Default constructor + _MLX_Float16() = default; + + // Default copy constructor + _MLX_Float16(_MLX_Float16 const&) = default; + + // Appease std::vector for being special + _MLX_Float16& operator=(std::vector::reference x) { + bits_ = x; + return *this; + } + + _MLX_Float16& operator=(const float& x) { + return (*this = _MLX_Float16(x)); + } + + // From float32 + _MLX_Float16(const float& x) : bits_(0) { + // Conversion following + // https://github.com/Maratyszcza/FP16/blob/master/include/fp16/fp16.h + + // Union + float_bits_fp16 in; + + // Take fp32 bits + in.f = x; + + // Find and take sign bit + uint32_t x_sign_32 = in.u & uint32_t(0x80000000); + uint16_t x_sign_16 = (x_sign_32 >> 16); + + if (std::isnan(x)) { + bits_ = x_sign_16 | uint16_t(__MLX_HALF_NAN__); + } else { + // Union + float_bits_fp16 inf_scale, zero_scale, magic_bits; + + // Find exponent bits and take the max supported by half + uint32_t x_expo_32 = in.u & uint32_t(0x7f800000); + uint32_t max_expo_32 = uint32_t(0x38800000); + x_expo_32 = x_expo_32 < max_expo_32 ? max_expo_32 : x_expo_32; + x_expo_32 += uint32_t(15) << 23; + + // Handle scaling to inf as needed + inf_scale.u = uint32_t(0x77800000); + zero_scale.u = uint32_t(0x08800000); + + // Combine with magic and let addition do rounding + magic_bits.u = x_expo_32; + magic_bits.f += (std::abs(x) * inf_scale.f) * zero_scale.f; + + // Take the lower 5 bits of the exponent + uint32_t x_expo_16 = ((magic_bits.u >> 13) & uint32_t(0x7c00)); + + // Collect the lower 12 bits which have the mantissa + uint32_t x_mant_16 = magic_bits.u & uint32_t(0x0fff); + + // Combine sign, exp and mantissa + bits_ = (x_sign_16 | uint16_t(x_expo_16 + x_mant_16)); + } + } + + // To float32 + operator float() const { + // Conversion following + // https://github.com/Maratyszcza/FP16/blob/master/include/fp16/fp16.h + + // Union + float_bits_fp16 out; + + uint32_t x_sign_32 = (bits_ << 16) & uint32_t(0x80000000); + uint32_t base = (bits_ << 16); + uint32_t two_base = base + base; + + uint32_t denorm_max = 1u << 27; + if (two_base < denorm_max) { + out.u = uint32_t(126) << 23; // magic mask + out.u |= (two_base >> 17); // Bits from fp16 + out.f -= 0.5f; // magic bias + } else { + out.u = uint32_t(0xE0) << 23; // exponent offset + out.u += (two_base >> 4); // Bits from fp16 + float out_unscaled = out.f; // Store value + out.u = uint32_t(0x7800000); // exponent scale + out.f *= out_unscaled; + } + + // Add sign + out.u |= x_sign_32; + + return out.f; + } +}; + +#define half_binop_base(__op__, __operator__, otype, atype, btype, ctype) \ + inline otype __operator__(atype lhs, btype rhs) { \ + return static_cast(lhs) __op__ static_cast(rhs); \ + } + +#define half_binop_helper(__op__, __operator__, otype, itype, ctype) \ + inline otype __operator__(_MLX_Float16 lhs, itype rhs) { \ + return static_cast(lhs) __op__ static_cast(rhs); \ + } \ + inline otype __operator__(itype lhs, _MLX_Float16 rhs) { \ + return static_cast(lhs) __op__ static_cast(rhs); \ + } + +// Operators +#define half_binop(__op__, __operator__) \ + half_binop_base( \ + __op__, __operator__, _MLX_Float16, _MLX_Float16, _MLX_Float16, float); \ + half_binop_helper(__op__, __operator__, float, float, float); \ + half_binop_helper(__op__, __operator__, double, double, double); \ + half_binop_helper(__op__, __operator__, _MLX_Float16, bool, float); \ + half_binop_helper(__op__, __operator__, _MLX_Float16, int32_t, float); \ + half_binop_helper(__op__, __operator__, _MLX_Float16, uint32_t, float); \ + half_binop_helper(__op__, __operator__, _MLX_Float16, int64_t, float); \ + half_binop_helper(__op__, __operator__, _MLX_Float16, uint64_t, float); + +half_binop(+, operator+); +half_binop(-, operator-); +half_binop(*, operator*); +half_binop(/, operator/); + +#undef half_binop + +// Comparison ops +#define half_compop(__op__, __operator__) \ + half_binop_base( \ + __op__, __operator__, bool, _MLX_Float16, _MLX_Float16, float); \ + half_binop_helper(__op__, __operator__, bool, float, float); \ + half_binop_helper(__op__, __operator__, bool, double, double); \ + half_binop_helper(__op__, __operator__, bool, int32_t, float); \ + half_binop_helper(__op__, __operator__, bool, uint32_t, float); \ + half_binop_helper(__op__, __operator__, bool, int64_t, float); \ + half_binop_helper(__op__, __operator__, bool, uint64_t, float); + +half_compop(>, operator>); +half_compop(<, operator<); +half_compop(>=, operator>=); +half_compop(<=, operator<=); +half_compop(==, operator==); +half_compop(!=, operator!=); + +#undef half_compop + +// Negative +inline _MLX_Float16 operator-(_MLX_Float16 lhs) { + return -static_cast(lhs); +} + +// Inplace ops +#define half_inplace_op(__op__, __operator__) \ + inline _MLX_Float16& __operator__(_MLX_Float16& lhs, const float& rhs) { \ + lhs = lhs __op__ rhs; \ + return lhs; \ + } \ + inline float& __operator__(float& lhs, _MLX_Float16 rhs) { \ + lhs = lhs __op__ rhs; \ + return lhs; \ + } + +half_inplace_op(+, operator+=); +half_inplace_op(-, operator-=); +half_inplace_op(*, operator*=); +half_inplace_op(/, operator/=); + +#undef half_inplace_op + +// Bitwise ops + +#define half_bitop(__op__, __operator__) \ + inline _MLX_Float16 __operator__(_MLX_Float16 lhs, _MLX_Float16 rhs) { \ + _MLX_Float16 out; \ + out.bits_ = lhs.bits_ __op__ rhs.bits_; \ + return out; \ + } \ + inline _MLX_Float16 __operator__(_MLX_Float16 lhs, uint16_t rhs) { \ + _MLX_Float16 out; \ + out.bits_ = lhs.bits_ __op__ rhs; \ + return out; \ + } \ + inline _MLX_Float16 __operator__(uint16_t lhs, _MLX_Float16 rhs) { \ + _MLX_Float16 out; \ + out.bits_ = lhs __op__ rhs.bits_; \ + return out; \ + } + +half_bitop(|, operator|); +half_bitop(&, operator&); +half_bitop(^, operator^); + +#undef half_bitop + +#define half_inplace_bitop(__op__, __operator__) \ + inline _MLX_Float16& __operator__(_MLX_Float16& lhs, _MLX_Float16 rhs) { \ + lhs.bits_ = lhs.bits_ __op__ rhs.bits_; \ + return lhs; \ + } \ + inline _MLX_Float16& __operator__(_MLX_Float16& lhs, uint16_t rhs) { \ + lhs.bits_ = lhs.bits_ __op__ rhs; \ + return lhs; \ + } + +half_inplace_bitop(|, operator|=); +half_inplace_bitop(&, operator&=); +half_inplace_bitop(^, operator^=); + +#undef half_inplace_bitop + +typedef struct _MLX_Float16 float16_t; + +#endif // __ARM_FEATURE_FP16_SCALAR_ARITHMETIC +#ifdef __ARM_FEATURE_BF16 + +#include +typedef __bf16 bfloat16_t; + +#else + +#define ADD_HALF_BINOPS +#include +#include +#include +#include + +#define __MLX_BFLOAT_NAN__ 0x7FC0 + + +namespace { +union float_bits_bf16 { + float f; + uint32_t u; +}; +} // namespace + +struct _MLX_BFloat16 { + uint16_t bits_; + + // Default constructor + _MLX_BFloat16() = default; + + // Default copy constructor + _MLX_BFloat16(_MLX_BFloat16 const&) = default; + + // Appease std::vector for being special + _MLX_BFloat16& operator=(std::vector::reference x) { + bits_ = x; + return *this; + } + + _MLX_BFloat16& operator=(const float& x) { + return (*this = _MLX_BFloat16(x)); + } + + // From float32 + _MLX_BFloat16(const float& x) { + if (std::isnan(x)) { + bits_ = __MLX_BFLOAT_NAN__; + } else { + // Union + float_bits_bf16 in; + + // Take bits + in.f = x; + + // Round to nearest even + in.u += (in.u >> 16 & 1) + uint32_t(0x7FFF); + + // Take upper 16 bits + bits_ = in.u >> 16; + } + } + + // To float32 + operator float() const { + // Union + float_bits_bf16 out; + + // Upper 16 bits are the data and lower 16 bits are 0s + out.u = ((uint32_t)bits_) << 16; + + return out.f; + } +}; + +#define bfloat_binop_base(__op__, __operator__, otype, atype, btype, ctype) \ + inline otype __operator__(atype lhs, btype rhs) { \ + return static_cast(lhs) __op__ static_cast(rhs); \ + } + +#define bfloat_binop_helper(__op__, __operator__, otype, itype, ctype) \ + inline otype __operator__(_MLX_BFloat16 lhs, itype rhs) { \ + return static_cast(lhs) __op__ static_cast(rhs); \ + } \ + inline otype __operator__(itype lhs, _MLX_BFloat16 rhs) { \ + return static_cast(lhs) __op__ static_cast(rhs); \ + } + +// Operators +#define bfloat_binop(_op_, _operator_) \ + bfloat_binop_base( \ + _op_, _operator_, _MLX_BFloat16, _MLX_BFloat16, _MLX_BFloat16, float); \ + bfloat_binop_helper(_op_, _operator_, float, float, float); \ + bfloat_binop_helper(_op_, _operator_, double, double, double); \ + bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, bool, float); \ + bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int32_t, float); \ + bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint32_t, float); \ + bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int64_t, float); \ + bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint64_t, float); + +bfloat_binop(+, operator+); +bfloat_binop(-, operator-); +bfloat_binop(*, operator*); +bfloat_binop(/, operator/); + +#undef bfloat_binop + +// Comparison ops +#define bfloat_compop(__op__, __operator__) \ + bfloat_binop_base( \ + __op__, __operator__, bool, _MLX_BFloat16, _MLX_BFloat16, float); \ + bfloat_binop_helper(__op__, __operator__, bool, float, float); \ + bfloat_binop_helper(__op__, __operator__, bool, double, double); \ + bfloat_binop_helper(__op__, __operator__, bool, int32_t, float); \ + bfloat_binop_helper(__op__, __operator__, bool, uint32_t, float); \ + bfloat_binop_helper(__op__, __operator__, bool, int64_t, float); \ + bfloat_binop_helper(__op__, __operator__, bool, uint64_t, float); + +bfloat_compop(>, operator>); +bfloat_compop(<, operator<); +bfloat_compop(>=, operator>=); +bfloat_compop(<=, operator<=); +bfloat_compop(==, operator==); +bfloat_compop(!=, operator!=); + +#undef bfloat_compop + +// Negative +inline _MLX_BFloat16 operator-(_MLX_BFloat16 lhs) { + return -static_cast(lhs); +} + +// Inplace ops +#define bfloat_inplace_op(__op__, __operator__) \ + inline _MLX_BFloat16& __operator__(_MLX_BFloat16& lhs, const float& rhs) { \ + lhs = lhs __op__ rhs; \ + return lhs; \ + } \ + inline float& __operator__(float& lhs, _MLX_BFloat16 rhs) { \ + lhs = lhs __op__ rhs; \ + return lhs; \ + } + +bfloat_inplace_op(+, operator+=); +bfloat_inplace_op(-, operator-=); +bfloat_inplace_op(*, operator*=); +bfloat_inplace_op(/, operator/=); + +#undef bfloat_inplace_op + +// Bitwise ops + +#define bfloat_bitop(__op__, __operator__) \ + inline _MLX_BFloat16 __operator__(_MLX_BFloat16 lhs, _MLX_BFloat16 rhs) { \ + _MLX_BFloat16 out; \ + out.bits_ = lhs.bits_ __op__ rhs.bits_; \ + return out; \ + } \ + inline _MLX_BFloat16 __operator__(_MLX_BFloat16 lhs, uint16_t rhs) { \ + _MLX_BFloat16 out; \ + out.bits_ = lhs.bits_ __op__ rhs; \ + return out; \ + } \ + inline _MLX_BFloat16 __operator__(uint16_t lhs, _MLX_BFloat16 rhs) { \ + _MLX_BFloat16 out; \ + out.bits_ = lhs __op__ rhs.bits_; \ + return out; \ + } + +bfloat_bitop(|, operator|); +bfloat_bitop(&, operator&); +bfloat_bitop(^, operator^); + +#undef bfloat_bitop + +#define bfloat_inplace_bitop(__op__, __operator__) \ + inline _MLX_BFloat16& __operator__(_MLX_BFloat16& lhs, _MLX_BFloat16 rhs) { \ + lhs.bits_ = lhs.bits_ __op__ rhs.bits_; \ + return lhs; \ + } \ + inline _MLX_BFloat16& __operator__(_MLX_BFloat16& lhs, uint16_t rhs) { \ + lhs.bits_ = lhs.bits_ __op__ rhs; \ + return lhs; \ + } + +bfloat_inplace_bitop(|, operator|=); +bfloat_inplace_bitop(&, operator&=); +bfloat_inplace_bitop(^, operator^=); + +#undef bfloat_inplace_bitop + +typedef struct _MLX_BFloat16 bfloat16_t; + +#endif // __ARM_FEATURE_BF16 + +#ifdef ADD_HALF_BINOPS + +// clang-format off +#define fp16_bf16_binop_helper(__op__, __operator__) \ + inline float __operator__(float16_t lhs, bfloat16_t rhs) { \ + return static_cast(lhs) __op__ static_cast(rhs); \ + } \ + inline float __operator__(bfloat16_t lhs, float16_t rhs) { \ + return static_cast(lhs) __op__ static_cast(rhs); \ + } + +fp16_bf16_binop_helper(+, operator+) +fp16_bf16_binop_helper(-, operator-) +fp16_bf16_binop_helper(*, operator*) +fp16_bf16_binop_helper(/, operator/) +// clang-format on + +#endif + + +struct complex64_t; + +template +static constexpr bool can_convert_to_complex64 = + !std::is_same_v && std::is_convertible_v; + +struct complex64_t : public std::complex { + complex64_t(float v, float u) : std::complex(v, u){}; + complex64_t(std::complex v) : std::complex(v){}; + + template < + typename T, + typename = typename std::enable_if>::type> + complex64_t(T x) : std::complex(x){}; + + operator float() const { + return real(); + }; +}; + +inline bool operator>=(const complex64_t& a, const complex64_t& b) { + return (a.real() > b.real()) || + (a.real() == b.real() && a.imag() >= b.imag()); +} + +inline bool operator>(const complex64_t& a, const complex64_t& b) { + return (a.real() > b.real()) || (a.real() == b.real() && a.imag() > b.imag()); +} + +inline complex64_t operator%(complex64_t a, complex64_t b) { + auto real = a.real() - (b.real() * static_cast(a.real() / b.real())); + auto imag = a.imag() - (b.imag() * static_cast(a.imag() / b.imag())); + if (real != 0 && (real < 0 != b.real() < 0)) + real += b.real(); + if (imag != 0 && (imag < 0 != b.imag() < 0)) + imag += b.imag(); + return {real, imag}; +} + +inline bool operator<=(const complex64_t& a, const complex64_t& b) { + return operator>=(b, a); +} + +inline bool operator<(const complex64_t& a, const complex64_t& b) { + return operator>(b, a); +} + +inline complex64_t operator-(const complex64_t& v) { + return -static_cast>(v); +} + +// clang-format off +#define complex_binop_helper(_op_, _operator_, itype) \ + inline complex64_t _operator_(itype x, const complex64_t& y) { \ + return static_cast(x) _op_ y; \ + } \ + inline complex64_t _operator_(const complex64_t& x, itype y) { \ + return x _op_ static_cast(y); \ + } + +#define complex_binop(_op_, _operator_) \ + inline complex64_t _operator_(const std::complex& x, const complex64_t& y) { \ + return x _op_ static_cast>(y); \ + } \ + inline complex64_t _operator_(const complex64_t& x, const std::complex& y) { \ + return static_cast>(x) _op_ y; \ + } \ + inline complex64_t _operator_(const complex64_t& x, const complex64_t& y) { \ + return static_cast>(x) \ + _op_ static_cast>(y); \ + } \ + complex_binop_helper(_op_, _operator_, bool) \ + complex_binop_helper(_op_, _operator_, uint32_t) \ + complex_binop_helper(_op_, _operator_, uint64_t) \ + complex_binop_helper(_op_, _operator_, int32_t) \ + complex_binop_helper(_op_, _operator_, int64_t) \ + complex_binop_helper(_op_, _operator_, float16_t) \ + complex_binop_helper(_op_, _operator_, bfloat16_t) \ + complex_binop_helper(_op_, _operator_, float) +// clang-format on + +complex_binop(+, operator+) + +typedef union { + int i; + float f; +} IntOrFloat; + +inline float fast_exp(float x) { + x *= 1.442695; // multiply with log_2(e) + float ipart, fpart; + IntOrFloat epart; + x = std::max(-80.f, std::min(x, 80.f)); + ipart = std::floor(x + 0.5); + fpart = x - ipart; + + x = 1.535336188319500e-4f; + x = x * fpart + 1.339887440266574e-3f; + x = x * fpart + 9.618437357674640e-3f; + x = x * fpart + 5.550332471162809e-2f; + x = x * fpart + 2.402264791363012e-1f; + x = x * fpart + 6.931472028550421e-1f; + x = x * fpart + 1.000000000000000f; + + // generate 2**ipart in the floating point representation using integer + // bitshifting + epart.i = (int(ipart) + 127) << 23; + + return epart.f * x; +} + +float fast_erf(float a) { + float r, s, t, u; + t = std::abs(a); + s = a * a; + if (t > 0.927734375f) { + // maximum error 0.99527 ulp + r = std::fma( + -1.72853470e-5f, t, 3.83197126e-4f); // -0x1.220000p-16,0x1.91cfb2p-12 + u = std::fma( + -3.88396438e-3f, t, 2.42546219e-2f); // -0x1.fd1438p-9, 0x1.8d6342p-6 + r = std::fma(r, s, u); + r = std::fma(r, t, -1.06777877e-1f); // -0x1.b55cb8p-4 + r = std::fma(r, t, -6.34846687e-1f); // -0x1.450aa0p-1 + r = std::fma(r, t, -1.28717512e-1f); // -0x1.079d0cp-3 + r = std::fma(r, t, -t); + // TODO, replace with expm1 when implemented + r = 1.0f - std::exp(r); + r = std::copysign(r, a); + } else { + // maximum error 0.98929 ulp + r = -5.96761703e-4f; // -0x1.38e000p-11 + r = std::fma(r, s, 4.99119423e-3f); // 0x1.471a58p-8 + r = std::fma(r, s, -2.67681349e-2f); // -0x1.b691b2p-6 + r = std::fma(r, s, 1.12819925e-1f); // 0x1.ce1c44p-4 + r = std::fma(r, s, -3.76125336e-1f); // -0x1.812700p-2 + r = std::fma(r, s, 1.28379166e-1f); // 0x1.06eba8p-3 + r = std::fma(r, a, a); + } + return r; +} + +float fast_erfinv(float a) { + auto t = std::fma(a, 0.0f - a, 1.0f); + t = std::log(t); + float p; + if (std::abs(t) > 6.125f) { // maximum ulp error = 2.35793 + p = 3.03697567e-10f; // 0x1.4deb44p-32 + p = std::fma(p, t, 2.93243101e-8f); // 0x1.f7c9aep-26 + p = std::fma(p, t, 1.22150334e-6f); // 0x1.47e512p-20 + p = std::fma(p, t, 2.84108955e-5f); // 0x1.dca7dep-16 + p = std::fma(p, t, 3.93552968e-4f); // 0x1.9cab92p-12 + p = std::fma(p, t, 3.02698812e-3f); // 0x1.8cc0dep-9 + p = std::fma(p, t, 4.83185798e-3f); // 0x1.3ca920p-8 + p = std::fma(p, t, -2.64646143e-1f); // -0x1.0eff66p-2 + p = std::fma(p, t, 8.40016484e-1f); // 0x1.ae16a4p-1 + } else { // maximum ulp error = 2.35002 + p = 5.43877832e-9f; // 0x1.75c000p-28 + p = std::fma(p, t, 1.43285448e-7f); // 0x1.33b402p-23 + p = std::fma(p, t, 1.22774793e-6f); // 0x1.499232p-20 + p = std::fma(p, t, 1.12963626e-7f); // 0x1.e52cd2p-24 + p = std::fma(p, t, -5.61530760e-5f); // -0x1.d70bd0p-15 + p = std::fma(p, t, -1.47697632e-4f); // -0x1.35be90p-13 + p = std::fma(p, t, 2.31468678e-3f); // 0x1.2f6400p-9 + p = std::fma(p, t, 1.15392581e-2f); // 0x1.7a1e50p-7 + p = std::fma(p, t, -2.32015476e-1f); // -0x1.db2aeep-3 + p = std::fma(p, t, 8.86226892e-1f); // 0x1.c5bf88p-1 + } + return a * p; +} + +struct Abs { + template + T operator()(T x) { + return std::abs(x); + }; + uint8_t operator()(uint8_t x) { + return x; + }; + uint16_t operator()(uint16_t x) { + return x; + }; + uint32_t operator()(uint32_t x) { + return x; + }; + uint64_t operator()(uint64_t x) { + return x; + }; + bool operator()(bool x) { + return x; + }; +}; + +struct ArcCos { + template + T operator()(T x) { + return std::acos(x); + }; +}; + +struct ArcCosh { + template + T operator()(T x) { + return std::acosh(x); + }; +}; + +struct ArcSin { + template + T operator()(T x) { + return std::asin(x); + }; +}; + +struct ArcSinh { + template + T operator()(T x) { + return std::asinh(x); + }; +}; + +struct ArcTan { + template + T operator()(T x) { + return std::atan(x); + }; +}; + +struct ArcTanh { + template + T operator()(T x) { + return std::atanh(x); + }; +}; + +struct Ceil { + template + T operator()(T x) { + return std::ceil(x); + }; + int8_t operator()(int8_t x) { + return x; + }; + int16_t operator()(int16_t x) { + return x; + }; + int32_t operator()(int32_t x) { + return x; + }; + int64_t operator()(int64_t x) { + return x; + }; + uint8_t operator()(uint8_t x) { + return x; + }; + uint16_t operator()(uint16_t x) { + return x; + }; + uint32_t operator()(uint32_t x) { + return x; + }; + uint64_t operator()(uint64_t x) { + return x; + }; + bool operator()(bool x) { + return x; + }; +}; + +struct Cos { + template + T operator()(T x) { + return std::cos(x); + }; +}; + +struct Cosh { + template + T operator()(T x) { + return std::cosh(x); + }; +}; + +struct Erf { + template + T operator()(T x) { + return static_cast(fast_erf(static_cast(x))); + }; +}; + +struct ErfInv { + template + T operator()(T x) { + return static_cast(fast_erfinv(static_cast(x))); + }; +}; + +struct Exp { + template + T operator()(T x) { + return fast_exp(x); + }; +}; + +struct Floor { + template + T operator()(T x) { + return std::floor(x); + }; + int8_t operator()(int8_t x) { + return x; + }; + int16_t operator()(int16_t x) { + return x; + }; + int32_t operator()(int32_t x) { + return x; + }; + int64_t operator()(int64_t x) { + return x; + }; + uint8_t operator()(uint8_t x) { + return x; + }; + uint16_t operator()(uint16_t x) { + return x; + }; + uint32_t operator()(uint32_t x) { + return x; + }; + uint64_t operator()(uint64_t x) { + return x; + }; + bool operator()(bool x) { + return x; + }; +}; + +struct Log { + template + T operator()(T x) { + return std::log(x); + }; +}; + +struct Log2 { + template + T operator()(T x) { + return std::log2(x); + }; +}; + +struct Log10 { + template + T operator()(T x) { + return std::log10(x); + }; +}; + +struct Log1p { + template + T operator()(T x) { + return log1p(x); + }; +}; + +struct LogicalNot { + template + T operator()(T x) { + return !x; + }; +}; + +struct Negative { + template + T operator()(T x) { + return -x; + }; +}; + +struct Round { + template + T operator()(T x) { + return std::rint(x); + } + + std::complex operator()(std::complex x) { + return {std::rint(x.real()), std::rint(x.imag())}; + } +}; + +struct Sigmoid { + template + T operator()(T x) { + auto one = static_cast(1.0); + return one / (one + fast_exp(-x)); + } +}; + +struct Sign { + template + T operator()(T x) { + return (x > T(0)) - (x < T(0)); + } + uint8_t operator()(uint8_t x) { + return x != 0; + } + uint16_t operator()(uint16_t x) { + return x != 0; + } + uint32_t operator()(uint32_t x) { + return x != 0; + } + uint64_t operator()(uint64_t x) { + return x != 0; + } +}; + +struct Sin { + template + T operator()(T x) { + return std::sin(x); + }; +}; + +struct Sinh { + template + T operator()(T x) { + return std::sinh(x); + }; +}; + +struct Square { + template + T operator()(T x) { + return x * x; + }; +}; + +struct Sqrt { + template + T operator()(T x) { + return std::sqrt(x); + }; +}; + +struct Rsqrt { + template + T operator()(T x) { + return static_cast(1.0) / std::sqrt(x); + }; +}; + +struct Tan { + template + T operator()(T x) { + return std::tan(x); + }; +}; + +struct Tanh { + template + T operator()(T x) { + return std::tanh(x); + }; +}; + +struct Add { + template + T operator()(T x, T y) { + return x + y; + } +}; + +struct Divide { + template + T operator()(T x, T y) { + return x / y; + } +}; + +struct Remainder { + template + std::enable_if_t & !std::is_signed_v, T> operator()( + T numerator, + T denominator) { + return numerator % denominator; + } + + template + std::enable_if_t & std::is_signed_v, T> operator()( + T numerator, + T denominator) { + auto r = numerator % denominator; + if (r != 0 && (r < 0 != denominator < 0)) + r += denominator; + return r; + } + + template + std::enable_if_t, T> operator()( + T numerator, + T denominator) { + auto r = std::fmod(numerator, denominator); + if (r != 0 && (r < 0 != denominator < 0)) { + r += denominator; + } + return r; + } + + std::complex operator()( + std::complex a, std::complex b) { + auto real = a.real() - (b.real() * static_cast(a.real() / b.real())); + auto imag = a.imag() - (b.imag() * static_cast(a.imag() / b.imag())); + if (real != 0 && ((real < 0) != (b.real() < 0))) + real += b.real(); + if (imag != 0 && ((imag < 0) != (b.imag() < 0))) + imag += b.imag(); + return {real, imag}; + } +}; + +struct Equal { + template + bool operator()(T x, T y) { + return x == y; + } +}; + +struct NaNEqual { + template + bool operator()(T x, T y) { + return x == y || (std::isnan(x) && std::isnan(y)); + } +}; + +struct Greater { + template + bool operator()(T x, T y) { + return x > y; + } +}; + +struct GreaterEqual { + template + bool operator()(T x, T y) { + return x >= y; + } +}; + +struct Less { + template + bool operator()(T x, T y) { + return x < y; + } +}; + +struct LessEqual { + template + bool operator()(T x, T y) { + return x <= y; + } +}; + +struct LogAddExp { + template + T operator()(T x, T y) { + constexpr float inf = std::numeric_limits::infinity(); + auto maxval = (x > y) ? x : y; + auto minval = (x > y) ? y : x; + return (minval == -inf || maxval == inf) + ? maxval + : static_cast( + maxval + std::log1p(fast_exp(minval - maxval))); + }; +}; + +struct Maximum { + template + std::enable_if_t, T> operator()(T x, T y) { + return (x > y) ? x : y; + } + + template + std::enable_if_t, T> operator()(T x, T y) { + if (std::isnan(x)) { + return x; + } + return (x > y) ? x : y; + } +}; + +struct Minimum { + template + std::enable_if_t, T> operator()(T x, T y) { + return x < y ? x : y; + } + + template + std::enable_if_t, T> operator()(T x, T y) { + if (std::isnan(x)) { + return x; + } + return x < y ? x : y; + } +}; + +struct Multiply { + template + T operator()(T x, T y) { + return x * y; + } +}; + +struct NotEqual { + template + bool operator()(T x, T y) { + return x != y; + } +}; + +struct Power { + template + std::enable_if_t, T> operator()(T base, T exp) { + return std::pow(base, exp); + } + + template + std::enable_if_t, T> operator()(T base, T exp) { + T res = 1; + while (exp) { + if (exp & 1) { + res *= base; + } + exp >>= 1; + base *= base; + } + return res; + } +}; + +struct Subtract { + template + T operator()(T x, T y) { + return x - y; + } +}; + +struct LogicalAnd { + template + T operator()(T x, T y) { + return x && y; + }; +}; + +struct LogicalOr { + template + T operator()(T x, T y) { + return x || y; + }; +}; +)"; diff --git a/mlx/backend/common/default_primitives.cpp b/mlx/backend/common/default_primitives.cpp index 6befc8eb9..c65028d95 100644 --- a/mlx/backend/common/default_primitives.cpp +++ b/mlx/backend/common/default_primitives.cpp @@ -43,7 +43,6 @@ DEFAULT(AsStrided) DEFAULT(Broadcast) DEFAULT_MULTI(DivMod) DEFAULT(Ceil) -DEFAULT_MULTI(Compiled) DEFAULT(Concatenate) DEFAULT(Convolution) DEFAULT(Copy) diff --git a/mlx/backend/metal/compiled.cpp b/mlx/backend/metal/compiled.cpp index 1f27a2493..681d635ba 100644 --- a/mlx/backend/metal/compiled.cpp +++ b/mlx/backend/metal/compiled.cpp @@ -2,6 +2,7 @@ #include +#include "mlx/backend/common/compiled.h" #include "mlx/backend/metal/compiled_preamble.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/utils.h" @@ -11,125 +12,6 @@ namespace mlx::core { -inline bool is_static_cast(const Primitive& p) { - return ( - typeid(p) == typeid(Broadcast) || typeid(p) == typeid(Copy) || - typeid(p) == typeid(StopGradient) || typeid(p) == typeid(AsType)); -} - -inline auto get_type_string(Dtype d) { - switch (d) { - case float32: - return "float"; - case float16: - return "half"; - case bfloat16: - return "bfloat16_t"; - case bool_: - return "bool"; - case int8: - return "int8_t"; - case int16: - return "int16_t"; - case int32: - return "int32_t"; - case int64: - return "int64_t"; - case uint8: - return "uint8_t"; - case uint16: - return "uint16_t"; - case uint32: - return "uint32_t"; - case uint64: - return "uint64_t"; - default: { - std::ostringstream msg; - msg << "Unsupported compilation type " << d; - throw std::runtime_error(msg.str()); - } - } -} - -template -void print_float_constant(std::ostream& os, const array& x) { - auto old_precision = os.precision(); - os << std::setprecision(std::numeric_limits::digits10 + 1) - << x.item() << std::setprecision(old_precision); -} - -template -void print_int_constant(std::ostream& os, const array& x) { - os << x.item(); -} - -void print_constant(std::ostream& os, const array& x) { - switch (x.dtype()) { - case float32: - return print_float_constant(os, x); - case float16: - return print_float_constant(os, x); - case bfloat16: - return print_float_constant(os, x); - case int8: - return print_int_constant(os, x); - case int16: - return print_int_constant(os, x); - case int32: - return print_int_constant(os, x); - case int64: - return print_int_constant(os, x); - case uint8: - return print_int_constant(os, x); - case uint16: - return print_int_constant(os, x); - case uint32: - return print_int_constant(os, x); - case uint64: - return print_int_constant(os, x); - case bool_: - os << std::boolalpha << x.item(); - return; - default: - throw std::runtime_error("Unsupported constant type"); - } -} - -inline std::string build_lib_name( - const std::vector& inputs, - const std::vector& outputs, - const std::vector& tape, - const std::unordered_set& constant_ids) { - std::ostringstream os; - std::ostringstream constant_hasher; - - // The primitives describing the tape. For unary and binary primitives this - // must be enough to describe the full computation. - for (auto& a : tape) { - a.primitive().print(os); - } - os << "_"; - - for (auto& x : inputs) { - if (constant_ids.find(x.id()) != constant_ids.end()) { - os << "C"; - print_constant(constant_hasher, x); - } else { - os << ((x.size() == 1) ? "S" : "V"); - } - } - os << "_"; - for (auto& x : inputs) { - if (constant_ids.find(x.id()) != constant_ids.end()) { - continue; - } - os << kindof(x.dtype()) << x.itemsize(); - } - os << "_" << std::hash{}(constant_hasher.str()); - - return os.str(); -} - inline void build_kernel( std::ostream& os, const std::string& kernel_name, @@ -286,7 +168,7 @@ inline void build_kernel( if (cnt > 31) { std::ostringstream msg; - msg << "[compile] Too many inputs/outputs fused in the Metal Compile " + msg << "[compile] Too many inputs/outputs fused in the Metal Compiled " << "primitive which exhausted the available argument buffers for " << "the kernel. Please file an issue with the function that results " << "in this error. The name of the kernel is '" << kernel_name << "'"; @@ -348,11 +230,6 @@ void Compiled::eval_gpu( lib = d.get_library(kernel_lib_, kernel_source_); } - // Allocate space for the outputs - for (auto& out : outputs) { - out.set_data(allocator::malloc_or_wait(out.nbytes())); - } - // Figure out which kernel we are using auto& output_shape = outputs[0].shape(); bool contiguous = true; @@ -443,6 +320,27 @@ void Compiled::eval_gpu( } } + // Allocate space for the outputs possibly with input donation + { + int o = 0; + for (int i = 0; i < inputs.size() && o < outputs.size(); ++i) { + auto& in = inputs[i]; + // Conditions for donation + // - Row contiguous + // - Donatable + // - Correct size + // - Not a constant + if (in.flags().row_contiguous && in.nbytes() == outputs[o].nbytes() && + in.is_donatable() && + constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) { + outputs[o++].move_shared_buffer(in); + } + } + for (; o < outputs.size(); ++o) { + outputs[o].set_data(allocator::malloc_or_wait(outputs[o].nbytes())); + } + } + // Put the outputs in for (auto& x : outputs) { set_array_buffer(compute_encoder, x, cnt++); diff --git a/mlx/backend/metal/kernels/compiled_preamble.h b/mlx/backend/metal/kernels/compiled_preamble.h index 82a9e9c5c..d5bf33696 100644 --- a/mlx/backend/metal/kernels/compiled_preamble.h +++ b/mlx/backend/metal/kernels/compiled_preamble.h @@ -2,3 +2,5 @@ #include "mlx/backend/metal/kernels/binary.h" #include "mlx/backend/metal/kernels/unary.h" + +typedef half float16_t; diff --git a/mlx/compile.cpp b/mlx/compile.cpp index e69c442f2..a648d191f 100644 --- a/mlx/compile.cpp +++ b/mlx/compile.cpp @@ -319,6 +319,9 @@ void compile_simplify( case 1: v = *a.data(); break; + case 2: + v = *a.data(); + break; case 4: v = *a.data(); break; diff --git a/mlx/fast_primitives.h b/mlx/fast_primitives.h index acb5d0046..2b854960b 100644 --- a/mlx/fast_primitives.h +++ b/mlx/fast_primitives.h @@ -3,8 +3,8 @@ namespace mlx::core::fast { // Custom primitive accepts a fallback function which it uses for -// transformations. Transformations are virtual so that derived classes may to -// override the default behavior +// transformations. Transformations are virtual so that derived classes may +// override the default behavior. class Custom : public Primitive { public: explicit Custom( diff --git a/mlx/primitives.h b/mlx/primitives.h index b06a35780..9d0a9181c 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -496,8 +496,6 @@ class Compiled : public Primitive { std::string kernel_lib_; std::string kernel_source_; - - void eval(const std::vector& inputs, std::vector& out); }; class Concatenate : public UnaryPrimitive { diff --git a/mlx/types/complex.h b/mlx/types/complex.h index 19ab1b542..46f4310f9 100644 --- a/mlx/types/complex.h +++ b/mlx/types/complex.h @@ -60,25 +60,30 @@ inline complex64_t operator-(const complex64_t& v) { // clang-format off #define complex_binop_helper(_op_, _operator_, itype) \ inline complex64_t _operator_(itype x, const complex64_t& y) { \ - return x _op_ static_cast>(y); \ + return static_cast(x) _op_ y; \ } \ inline complex64_t _operator_(const complex64_t& x, itype y) { \ - return static_cast>(x) _op_ y; \ + return x _op_ static_cast(y); \ } -#define complex_binop(_op_, _operator_) \ - inline complex64_t _operator_(const complex64_t& x, const complex64_t& y) { \ - return static_cast>(x) \ - _op_ static_cast>(y); \ - } \ - complex_binop_helper(_op_, _operator_, bool) \ - complex_binop_helper(_op_, _operator_, uint32_t) \ - complex_binop_helper(_op_, _operator_, uint64_t) \ - complex_binop_helper(_op_, _operator_, int32_t) \ - complex_binop_helper(_op_, _operator_, int64_t) \ - complex_binop_helper(_op_, _operator_, float16_t) \ - complex_binop_helper(_op_, _operator_, bfloat16_t) \ - complex_binop_helper(_op_, _operator_, const std::complex&) \ +#define complex_binop(_op_, _operator_) \ + inline complex64_t _operator_(const std::complex& x, const complex64_t& y) { \ + return x _op_ static_cast>(y); \ + } \ + inline complex64_t _operator_(const complex64_t& x, const std::complex& y) { \ + return static_cast>(x) _op_ y; \ + } \ + inline complex64_t _operator_(const complex64_t& x, const complex64_t& y) { \ + return static_cast>(x) \ + _op_ static_cast>(y); \ + } \ + complex_binop_helper(_op_, _operator_, bool) \ + complex_binop_helper(_op_, _operator_, uint32_t) \ + complex_binop_helper(_op_, _operator_, uint64_t) \ + complex_binop_helper(_op_, _operator_, int32_t) \ + complex_binop_helper(_op_, _operator_, int64_t) \ + complex_binop_helper(_op_, _operator_, float16_t) \ + complex_binop_helper(_op_, _operator_, bfloat16_t) \ complex_binop_helper(_op_, _operator_, float) // clang-format on diff --git a/mlx/utils.h b/mlx/utils.h index 88f47e3e1..ebcca3a1e 100644 --- a/mlx/utils.h +++ b/mlx/utils.h @@ -77,7 +77,7 @@ std::ostream& operator<<(std::ostream& os, array a); std::ostream& operator<<(std::ostream& os, const std::vector& v); std::ostream& operator<<(std::ostream& os, const std::vector& v); inline std::ostream& operator<<(std::ostream& os, const complex64_t& v) { - return os << v.real() << (v.imag() > 0 ? "+" : "") << v.imag() << "j"; + return os << v.real() << (v.imag() >= 0 ? "+" : "") << v.imag() << "j"; } inline std::ostream& operator<<(std::ostream& os, const float16_t& v) { return os << static_cast(v); diff --git a/tests/compile_tests.cpp b/tests/compile_tests.cpp index be460e3b6..8ad67a1ed 100644 --- a/tests/compile_tests.cpp +++ b/tests/compile_tests.cpp @@ -44,8 +44,8 @@ TEST_CASE("test compile with grad") { auto y = array(1.0f); auto grads_expected = grad_fun({x, y}); auto grads_compile = compile(grad_fun)({x, y}); - CHECK_EQ(grads_compile[0].item(), grads_expected[0].item()); - CHECK_EQ(grads_compile[1].item(), grads_expected[1].item()); + CHECK(allclose(grads_compile[0], grads_expected[0]).item()); + CHECK(allclose(grads_compile[1], grads_expected[1]).item()); } TEST_CASE("test compile inputs with primitive") { @@ -272,7 +272,7 @@ TEST_CASE("test compile unary fused") { CHECK_EQ(out.inputs()[0].id(), x.id()); auto expected_out = unary_fused_1({array(2.0)})[0]; - CHECK_EQ(out.item(), expected_out.item()); + CHECK(allclose(out, expected_out).item()); } {