diff --git a/mlx/array.h b/mlx/array.h index 2b849a7ae..5eefcf727 100644 --- a/mlx/array.h +++ b/mlx/array.h @@ -121,6 +121,9 @@ class array { template T item(); + template + T item() const; + struct ArrayIterator { using iterator_category = std::random_access_iterator_tag; using difference_type = size_t; @@ -454,6 +457,18 @@ T array::item() { return *data(); } +template +T array::item() const { + if (size() != 1) { + throw std::invalid_argument("item can only be called on arrays of size 1."); + } + if (!is_evaled()) { + throw std::invalid_argument( + "item() const can only be called on evaled arrays"); + } + return *data(); +} + template void array::init(It src) { set_data(allocator::malloc(size() * size_of(dtype()))); diff --git a/mlx/backend/metal/CMakeLists.txt b/mlx/backend/metal/CMakeLists.txt index fd1a47f01..93a25434f 100644 --- a/mlx/backend/metal/CMakeLists.txt +++ b/mlx/backend/metal/CMakeLists.txt @@ -1,3 +1,23 @@ +add_custom_command( + OUTPUT compiled_preamble.cpp + COMMAND /bin/bash + ${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.sh + ${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp + ${CMAKE_C_COMPILER} + ${CMAKE_SOURCE_DIR} + DEPENDS make_compiled_preamble.sh + kernels/compiled_preamble.h + kernels/unary.h + kernels/binary.h +) + +add_custom_target( + compiled_preamble + DEPENDS compiled_preamble.cpp +) + +add_dependencies(mlx compiled_preamble) + target_sources( mlx PRIVATE @@ -16,6 +36,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp ${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp ${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp + ${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp ) if (NOT MLX_METAL_PATH) diff --git a/mlx/backend/metal/compiled.cpp b/mlx/backend/metal/compiled.cpp index 4ce9b0c85..1f27a2493 100644 --- a/mlx/backend/metal/compiled.cpp +++ b/mlx/backend/metal/compiled.cpp @@ -1,44 +1,484 @@ // Copyright © 2023-2024 Apple Inc. +#include + +#include "mlx/backend/metal/compiled_preamble.h" #include "mlx/backend/metal/device.h" +#include "mlx/backend/metal/utils.h" +#include "mlx/graph_utils.h" #include "mlx/primitives.h" +#include "mlx/utils.h" namespace mlx::core { +inline bool is_static_cast(const Primitive& p) { + return ( + typeid(p) == typeid(Broadcast) || typeid(p) == typeid(Copy) || + typeid(p) == typeid(StopGradient) || typeid(p) == typeid(AsType)); +} + +inline auto get_type_string(Dtype d) { + switch (d) { + case float32: + return "float"; + case float16: + return "half"; + case bfloat16: + return "bfloat16_t"; + case bool_: + return "bool"; + case int8: + return "int8_t"; + case int16: + return "int16_t"; + case int32: + return "int32_t"; + case int64: + return "int64_t"; + case uint8: + return "uint8_t"; + case uint16: + return "uint16_t"; + case uint32: + return "uint32_t"; + case uint64: + return "uint64_t"; + default: { + std::ostringstream msg; + msg << "Unsupported compilation type " << d; + throw std::runtime_error(msg.str()); + } + } +} + +template +void print_float_constant(std::ostream& os, const array& x) { + auto old_precision = os.precision(); + os << std::setprecision(std::numeric_limits::digits10 + 1) + << x.item() << std::setprecision(old_precision); +} + +template +void print_int_constant(std::ostream& os, const array& x) { + os << x.item(); +} + +void print_constant(std::ostream& os, const array& x) { + switch (x.dtype()) { + case float32: + return print_float_constant(os, x); + case float16: + return print_float_constant(os, x); + case bfloat16: + return print_float_constant(os, x); + case int8: + return print_int_constant(os, x); + case int16: + return print_int_constant(os, x); + case int32: + return print_int_constant(os, x); + case int64: + return print_int_constant(os, x); + case uint8: + return print_int_constant(os, x); + case uint16: + return print_int_constant(os, x); + case uint32: + return print_int_constant(os, x); + case uint64: + return print_int_constant(os, x); + case bool_: + os << std::boolalpha << x.item(); + return; + default: + throw std::runtime_error("Unsupported constant type"); + } +} + +inline std::string build_lib_name( + const std::vector& inputs, + const std::vector& outputs, + const std::vector& tape, + const std::unordered_set& constant_ids) { + std::ostringstream os; + std::ostringstream constant_hasher; + + // The primitives describing the tape. For unary and binary primitives this + // must be enough to describe the full computation. + for (auto& a : tape) { + a.primitive().print(os); + } + os << "_"; + + for (auto& x : inputs) { + if (constant_ids.find(x.id()) != constant_ids.end()) { + os << "C"; + print_constant(constant_hasher, x); + } else { + os << ((x.size() == 1) ? "S" : "V"); + } + } + os << "_"; + for (auto& x : inputs) { + if (constant_ids.find(x.id()) != constant_ids.end()) { + continue; + } + os << kindof(x.dtype()) << x.itemsize(); + } + os << "_" << std::hash{}(constant_hasher.str()); + + return os.str(); +} + +inline void build_kernel( + std::ostream& os, + const std::string& kernel_name, + const std::vector& inputs, + const std::vector& outputs, + const std::vector& tape, + const std::unordered_set& constant_ids, + bool contiguous, + int ndim, + bool dynamic_dims) { + // All outputs should have the exact same shape and will be row contiguous + auto output_shape = outputs[0].shape(); + auto output_strides = outputs[0].strides(); + + // Constants are scalars that are captured by value and cannot change + auto is_constant = [&constant_ids](const array& x) { + return constant_ids.find(x.id()) != constant_ids.end(); + }; + + // For scalar we shouldn't do the indexing things, just read at 0 + auto is_scalar = [](const array& x) { return x.size() == 1; }; + + NodeNamer namer; + bool add_indices = false; + int cnt = 0; + + // Start the kernel + os << "[[host_name(\"" << kernel_name << "\")]]" << std::endl + << "[[kernel]] void " << kernel_name << "(" << std::endl; + + // Add the input arguments + for (auto& x : inputs) { + auto& xname = namer.get_name(x); + + // Skip constants from the input list + if (is_constant(x)) { + continue; + } + + // Scalars and contiguous need no strides + if (is_scalar(x) || contiguous) { + os << " device const " << get_type_string(x.dtype()) << "* " << xname + << " [[buffer(" << cnt++ << ")]]," << std::endl; + } else { + add_indices = true; + os << " device const " << get_type_string(x.dtype()) << "* " << xname + << " [[buffer(" << cnt++ << ")]]," << std::endl + << " constant const size_t* " << xname << "_strides [[buffer(" + << cnt++ << ")]]," << std::endl; + } + } + + // Add the output arguments + for (auto& x : outputs) { + os << " device " << get_type_string(x.dtype()) << "* " + << namer.get_name(x) << " [[buffer(" << cnt++ << ")]]," << std::endl; + } + // Add output strides and shape to extract the indices. + if (!contiguous) { + os << " constant const size_t* output_strides [[buffer(" << cnt++ + << ")]]," << std::endl + << " constant const int* output_shape [[buffer(" << cnt++ << ")]]," + << std::endl; + } + if (dynamic_dims) { + os << " constant const int& ndim [[buffer(" << cnt++ << ")]]," + << std::endl; + } + + // The thread index in the whole grid + os << " uint3 pos [[thread_position_in_grid]]," << std::endl + << " uint3 grid [[threads_per_grid]]) {" << std::endl + << " uint index = pos.x + grid.x * (pos.y + grid.y * pos.z);" + << std::endl; + + // Extract the indices per axis to individual uints if we have arrays that + // are broadcasted or transposed + if (add_indices) { + if (!dynamic_dims) { + if (ndim == 1) { + os << " uint index_0 = pos.x;" << std::endl; + } else if (ndim == 2) { + os << " uint index_0 = pos.y;" << std::endl + << " uint index_1 = pos.x;" << std::endl; + } else if (ndim == 3) { + os << " uint index_0 = pos.z;" << std::endl + << " uint index_1 = pos.y;" << std::endl + << " uint index_2 = pos.x;" << std::endl; + } else { + for (int i = 0; i < ndim - 2; i++) { + os << " uint index_" << i << " = (index / uint(output_strides[" << i + << "])) % output_shape[" << i << "];" << std::endl; + } + os << " uint index_" << ndim - 2 << " = pos.y;" << std::endl + << " uint index_" << ndim - 1 << " = pos.x;" << std::endl; + } + } + } + + // Read the inputs in tmps + for (auto& x : inputs) { + auto& xname = namer.get_name(x); + + if (is_constant(x)) { + os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "; + print_constant(os, x); + os << ";" << std::endl; + } else if (is_scalar(x)) { + os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = " + << xname << "[0];" << std::endl; + } else if (contiguous) { + os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = " + << xname << "[index];" << std::endl; + } else if (!dynamic_dims) { + os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = " + << xname << "["; + os << "index_0 * " << xname << "_strides[0]"; + for (int i = 1; i < ndim; i++) { + os << " + index_" << i << " * " << xname << "_strides[" << i << "]"; + } + os << "];" << std::endl; + } else { + os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = " + << xname << "[elem_to_loc(index, output_shape, " << xname + << "_strides, ndim)];" << std::endl; + } + } + + // Actually write the computation + for (auto& x : tape) { + os << " " << get_type_string(x.dtype()) << " tmp_" << namer.get_name(x) + << " = "; + if (is_static_cast(x.primitive())) { + os << "static_cast<" << get_type_string(x.dtype()) << ">(tmp_" + << namer.get_name(x.inputs()[0]) << ");" << std::endl; + } else { + x.primitive().print(os); + os << "()("; + for (int i = 0; i < x.inputs().size() - 1; i++) { + os << "tmp_" << namer.get_name(x.inputs()[i]) << ", "; + } + os << "tmp_" << namer.get_name(x.inputs().back()) << ");" << std::endl; + } + } + + // Write the outputs from tmps + for (auto& x : outputs) { + os << " " << namer.get_name(x) << "[index] = tmp_" << namer.get_name(x) + << ";" << std::endl; + } + + // Finish the kernel + os << "}" << std::endl; + + if (cnt > 31) { + std::ostringstream msg; + msg << "[compile] Too many inputs/outputs fused in the Metal Compile " + << "primitive which exhausted the available argument buffers for " + << "the kernel. Please file an issue with the function that results " + << "in this error. The name of the kernel is '" << kernel_name << "'"; + throw std::runtime_error(msg.str()); + } +} + void Compiled::eval_gpu( const std::vector& inputs, std::vector& outputs) { - // Just a fall-back to the original tape for now - std::unordered_map trace_to_real; - for (int i = 0; i < inputs.size(); ++i) { - trace_to_real.insert({inputs_[i].id(), inputs[i]}); - } - for (int i = 0; i < outputs.size(); ++i) { - trace_to_real.insert({outputs_[i].id(), outputs[i]}); + // Make the name for the kernel library + if (kernel_lib_.empty()) { + kernel_lib_ = build_lib_name(inputs_, outputs_, tape_, constant_ids_); } - for (auto& a : tape_) { - std::vector p_inputs; - for (auto& in : a.inputs()) { - p_inputs.push_back(trace_to_real.at(in.id())); - } - // If a is an output get it from the map, otherwise create it - // NB this is safe as long as no multi-output sub primitves are allowed - // in Compiled - std::vector p_outputs; - if (auto it = trace_to_real.find(a.id()); it != trace_to_real.end()) { - p_outputs.push_back(it->second); - } else { - p_outputs.push_back(array(a.shape(), a.dtype(), a.primitive_ptr(), {})); - trace_to_real.insert({a.id(), p_outputs[0]}); - } - a.primitive().eval_gpu(p_inputs, p_outputs); - } + // Get the kernel if someone else built it already auto& s = stream(); auto& d = metal::device(s.device); - auto command_buffer = d.get_command_buffer(s.index); - command_buffer->addCompletedHandler( - [trace_to_real](MTL::CommandBuffer*) mutable {}); + auto lib = d.get_library(kernel_lib_); + + // If not we have to build it ourselves + if (lib == nullptr) { + std::ostringstream kernel; + kernel << metal::get_kernel_preamble() << std::endl; + build_kernel( + kernel, + kernel_lib_ + "_contiguous", + inputs_, + outputs_, + tape_, + constant_ids_, + /* contiguous = */ true, + /* ndim = */ 0, + /* dynamic_dims = */ false); + for (int i = 1; i < 8; i++) { + build_kernel( + kernel, + kernel_lib_ + "_strided_" + std::to_string(i), + inputs_, + outputs_, + tape_, + constant_ids_, + /* contiguous = */ false, + /* ndim = */ i, + /* dynamic_dims = */ false); + } + build_kernel( + kernel, + kernel_lib_ + "_strided_dynamic", + inputs_, + outputs_, + tape_, + constant_ids_, + /* contiguous = */ false, + /* ndim = */ 0, + /* dynamic_dims = */ true); + + kernel_source_ = kernel.str(); + lib = d.get_library(kernel_lib_, kernel_source_); + } + + // Allocate space for the outputs + for (auto& out : outputs) { + out.set_data(allocator::malloc_or_wait(out.nbytes())); + } + + // Figure out which kernel we are using + auto& output_shape = outputs[0].shape(); + bool contiguous = true; + for (auto& x : inputs) { + if ((!x.flags().row_contiguous || x.shape() != output_shape) && + x.size() > 1) { + contiguous = false; + break; + } + } + + // Collapse contiguous dims to route to a faster kernel if possible. Also + // handle all broadcasting. + std::vector> initial_strides; + initial_strides.push_back(outputs[0].strides()); + std::vector shape; + std::vector> strides; + if (!contiguous) { + for (int i = 0; i < inputs.size(); i++) { + // Skip constants. + if (constant_ids_.find(inputs_[i].id()) != constant_ids_.end()) { + continue; + } + auto& x = inputs[i]; + + // Skip scalar inputs. + if (x.size() <= 1) { + continue; + } + + // Broadcast the inputs to the output shape. + std::vector xstrides; + int j = 0; + for (; j < output_shape.size() - x.ndim(); j++) { + if (output_shape[j] == 1) { + xstrides.push_back(outputs[0].strides()[j]); + } else { + xstrides.push_back(0); + } + } + for (int i = 0; i < x.ndim(); i++, j++) { + if (x.shape(i) == 1) { + if (output_shape[j] == 1) { + xstrides.push_back(outputs[0].strides()[j]); + } else { + xstrides.push_back(0); + } + } else { + xstrides.push_back(x.strides()[i]); + } + } + initial_strides.push_back(std::move(xstrides)); + } + std::tie(shape, strides) = + collapse_contiguous_dims(output_shape, initial_strides); + } + + // Get the kernel from the lib + int ndim = shape.size(); + bool dynamic = ndim >= 8; + auto kernel_name = kernel_lib_ + (contiguous ? "_contiguous" : "_strided_"); + if (!contiguous) { + if (dynamic) { + kernel_name += "dynamic"; + } else { + kernel_name += std::to_string(shape.size()); + } + } + auto kernel = d.get_kernel(kernel_name, lib); + auto compute_encoder = d.get_command_encoder(s.index); + compute_encoder->setComputePipelineState(kernel); + + // Put the inputs in + int cnt = 0; + int stride_idx = 1; // idx 0 is the output strides + for (int i = 0; i < inputs.size(); i++) { + if (constant_ids_.find(inputs_[i].id()) != constant_ids_.end()) { + continue; + } + auto& x = inputs[i]; + set_array_buffer(compute_encoder, x, cnt++); + if (!contiguous && x.size() > 1) { + compute_encoder->setBytes( + strides[stride_idx].data(), + strides[stride_idx].size() * sizeof(size_t), + cnt++); + stride_idx++; + } + } + + // Put the outputs in + for (auto& x : outputs) { + set_array_buffer(compute_encoder, x, cnt++); + } + + // Put the output shape and strides in + if (!contiguous) { + compute_encoder->setBytes( + strides[0].data(), strides[0].size() * sizeof(size_t), cnt++); + compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), cnt++); + } + + // Put the number of dims in if it is dynamic + if (dynamic) { + compute_encoder->setBytes(&ndim, sizeof(int), cnt++); + } + + // Launch the kernel + if (contiguous) { + size_t nthreads = outputs[0].size(); + MTL::Size grid_dims(nthreads, 1, 1); + MTL::Size group_dims( + std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1); + compute_encoder->dispatchThreads(grid_dims, group_dims); + } else { + size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1; + size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1; + size_t rest = outputs[0].size() / (dim0 * dim1); + NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); + if (thread_group_size != 1024) { + throw std::runtime_error("[Metal::binary] Must use 1024 sized block"); + } + auto group_dims = get_block_dims(dim0, dim1, rest); + MTL::Size grid_dims = MTL::Size(dim0, dim1, rest); + compute_encoder->dispatchThreads(grid_dims, group_dims); + } } } // namespace mlx::core diff --git a/mlx/backend/metal/compiled_preamble.h b/mlx/backend/metal/compiled_preamble.h new file mode 100644 index 000000000..9122d3d54 --- /dev/null +++ b/mlx/backend/metal/compiled_preamble.h @@ -0,0 +1,9 @@ +// Copyright © 2023-24 Apple Inc. + +#pragma once + +namespace mlx::core::metal { + +const char* get_kernel_preamble(); + +} diff --git a/mlx/backend/metal/device.cpp b/mlx/backend/metal/device.cpp index 7c61e68ae..e50441d48 100644 --- a/mlx/backend/metal/device.cpp +++ b/mlx/backend/metal/device.cpp @@ -414,6 +414,11 @@ MTL::ComputePipelineState* Device::get_kernel_( return kernel; } +MTL::Library* Device::get_library(const std::string& name) { + auto it = library_map_.find(name); + return (it != library_map_.end()) ? it->second : nullptr; +} + MTL::Library* Device::get_library( const std::string& name, const std::string& source, diff --git a/mlx/backend/metal/device.h b/mlx/backend/metal/device.h index 6acfe9332..8312084ce 100644 --- a/mlx/backend/metal/device.h +++ b/mlx/backend/metal/device.h @@ -62,6 +62,8 @@ class Device { const std::function& lib_path_func = get_colocated_mtllib_path); + MTL::Library* get_library(const std::string& name); + MTL::Library* get_library( const std::string& name, const std::string& source_string, diff --git a/mlx/backend/metal/kernels/binary.h b/mlx/backend/metal/kernels/binary.h new file mode 100644 index 000000000..8adb84c58 --- /dev/null +++ b/mlx/backend/metal/kernels/binary.h @@ -0,0 +1,221 @@ +// Copyright © 2023-2024 Apple Inc. + +#pragma once + +#include +#include + +#include "mlx/backend/metal/kernels/bf16.h" +#include "mlx/backend/metal/kernels/utils.h" + +struct Add { + template + T operator()(T x, T y) { + return x + y; + } +}; + +struct Divide { + template + T operator()(T x, T y) { + return x / y; + } +}; + +struct Remainder { + template + T operator()(T x, T y) { + return x % y; + } + template <> + float operator()(float x, float y) { + return fmod(x, y); + } + template <> + half operator()(half x, half y) { + return fmod(x, y); + } + template <> + bfloat16_t operator()(bfloat16_t x, bfloat16_t y) { + return fmod(x, y); + } +}; + +struct Equal { + template + bool operator()(T x, T y) { + return x == y; + } +}; + +struct NaNEqual { + template + bool operator()(T x, T y) { + return x == y || (metal::isnan(x) && metal::isnan(y)); + } + template <> + bool operator()(complex64_t x, complex64_t y) { + return x == y || + (metal::isnan(x.real) && metal::isnan(y.real) && metal::isnan(x.imag) && + metal::isnan(y.imag)) || + (x.real == y.real && metal::isnan(x.imag) && metal::isnan(y.imag)) || + (metal::isnan(x.real) && metal::isnan(y.real) && x.imag == y.imag); + } +}; + +struct Greater { + template + bool operator()(T x, T y) { + return x > y; + } +}; + +struct GreaterEqual { + template + bool operator()(T x, T y) { + return x >= y; + } +}; + +struct Less { + template + bool operator()(T x, T y) { + return x < y; + } +}; + +struct LessEqual { + template + bool operator()(T x, T y) { + return x <= y; + } +}; + +struct LogAddExp { + template + T operator()(T x, T y) { + if (metal::isnan(x) || metal::isnan(y)) { + return metal::numeric_limits::quiet_NaN(); + } + constexpr T inf = metal::numeric_limits::infinity(); + T maxval = metal::max(x, y); + T minval = metal::min(x, y); + return (minval == -inf || maxval == inf) + ? maxval + : (maxval + log1p(metal::exp(minval - maxval))); + }; +}; + +struct Maximum { + template + metal::enable_if_t, T> operator()(T x, T y) { + return metal::max(x, y); + } + + template + metal::enable_if_t, T> operator()(T x, T y) { + if (metal::isnan(x)) { + return x; + } + return x > y ? x : y; + } + + template <> + complex64_t operator()(complex64_t x, complex64_t y) { + if (metal::isnan(x.real) || metal::isnan(x.imag)) { + return x; + } + return x > y ? x : y; + } +}; + +struct Minimum { + template + metal::enable_if_t, T> operator()(T x, T y) { + return metal::min(x, y); + } + + template + metal::enable_if_t, T> operator()(T x, T y) { + if (metal::isnan(x)) { + return x; + } + return x < y ? x : y; + } + + template <> + complex64_t operator()(complex64_t x, complex64_t y) { + if (metal::isnan(x.real) || metal::isnan(x.imag)) { + return x; + } + return x < y ? x : y; + } +}; + +struct Multiply { + template + T operator()(T x, T y) { + return x * y; + } +}; + +struct NotEqual { + template + bool operator()(T x, T y) { + return x != y; + } + template <> + bool operator()(complex64_t x, complex64_t y) { + return x.real != y.real || x.imag != y.imag; + } +}; + +struct Power { + template + metal::enable_if_t, T> operator()(T base, T exp) { + return metal::pow(base, exp); + } + + template + metal::enable_if_t, T> operator()(T base, T exp) { + T res = 1; + while (exp) { + if (exp & 1) { + res *= base; + } + exp >>= 1; + base *= base; + } + return res; + } + + template <> + complex64_t operator()(complex64_t x, complex64_t y) { + auto x_theta = metal::atan(x.imag / x.real); + auto x_ln_r = 0.5 * metal::log(x.real * x.real + x.imag * x.imag); + auto mag = metal::exp(y.real * x_ln_r - y.imag * x_theta); + auto phase = y.imag * x_ln_r + y.real * x_theta; + return {mag * metal::cos(phase), mag * metal::sin(phase)}; + } +}; + +struct Subtract { + template + T operator()(T x, T y) { + return x - y; + } +}; + +struct LogicalAnd { + template + T operator()(T x, T y) { + return x && y; + }; +}; + +struct LogicalOr { + template + T operator()(T x, T y) { + return x || y; + }; +}; diff --git a/mlx/backend/metal/kernels/binary.metal b/mlx/backend/metal/kernels/binary.metal index 1b84c70a5..4d449ab69 100644 --- a/mlx/backend/metal/kernels/binary.metal +++ b/mlx/backend/metal/kernels/binary.metal @@ -1,176 +1,6 @@ -// Copyright © 2023 Apple Inc. +// Copyright © 2023-2024 Apple Inc. -#include -#include - -#include "mlx/backend/metal/kernels/utils.h" -#include "mlx/backend/metal/kernels/bf16.h" - -struct Add { - template T operator()(T x, T y) { return x + y; } -}; - -struct Divide { - template T operator()(T x, T y) { return x / y; } -}; - -struct Remainder { - template T operator()(T x, T y) { return x % y; } - template <> float operator()(float x, float y) { return fmod(x, y); } - template <> half operator()(half x, half y) { return fmod(x, y); } - template <> bfloat16_t operator()(bfloat16_t x, bfloat16_t y) { return fmod(x, y); } -}; - -struct Equal { - template bool operator()(T x, T y) { return x == y; } -}; - -struct NaNEqual { - template bool operator()(T x, T y) { - return x == y || (metal::isnan(x) && metal::isnan(y)); - } - template <> - bool operator()(complex64_t x, complex64_t y) { - return x == y || - (metal::isnan(x.real) && metal::isnan(y.real) - && metal::isnan(x.imag) && metal::isnan(y.imag)) || - (x.real == y.real && metal::isnan(x.imag) && metal::isnan(y.imag)) || - (metal::isnan(x.real) && metal::isnan(y.real) && x.imag == y.imag); - } -}; - -struct Greater { - template bool operator()(T x, T y) { return x > y; } -}; - -struct GreaterEqual { - template bool operator()(T x, T y) { return x >= y; } -}; - -struct Less { - template bool operator()(T x, T y) { return x < y; } -}; - -struct LessEqual { - template bool operator()(T x, T y) { return x <= y; } -}; - -struct LogAddExp { - template - T operator()(T x, T y) { - if (metal::isnan(x) || metal::isnan(y)) { - return metal::numeric_limits::quiet_NaN(); - } - constexpr T inf = metal::numeric_limits::infinity(); - T maxval = metal::max(x, y); - T minval = metal::min(x, y); - return (minval == -inf || maxval == inf) ? maxval : - (maxval + log1p(metal::exp(minval - maxval))); - }; -}; - -struct Maximum { - template - metal::enable_if_t, T> operator()(T x, T y) { - return metal::max(x, y); - } - - template - metal::enable_if_t, T> operator()(T x, T y) { - if (metal::isnan(x)) { - return x; - } - return x > y ? x : y; - } - - template <> - complex64_t operator()(complex64_t x, complex64_t y) { - if (metal::isnan(x.real) || metal::isnan(x.imag)) { - return x; - } - return x > y ? x : y; - } -}; - -struct Minimum { - template - metal::enable_if_t, T> operator()(T x, T y) { - return metal::min(x, y); - } - - template - metal::enable_if_t, T> operator()(T x, T y) { - if (metal::isnan(x)) { - return x; - } - return x < y ? x : y; - } - - template <> - complex64_t operator()(complex64_t x, complex64_t y) { - if (metal::isnan(x.real) || metal::isnan(x.imag)) { - return x; - } - return x < y ? x : y; - } -}; - -struct Multiply { - template T operator()(T x, T y) { return x * y; } -}; - -struct NotEqual { - template bool operator()(T x, T y) { return x != y; } - template <> - bool operator()(complex64_t x, complex64_t y) { - return x.real != y.real || x.imag != y.imag; - } -}; - -struct Power { - - template - metal::enable_if_t, T> operator()(T base, T exp) { - return metal::pow(base, exp); - } - - template - metal::enable_if_t, T> operator()(T base, T exp) { - T res = 1; - while (exp) { - if (exp & 1) { - res *= base; - } - exp >>= 1; - base *= base; - } - return res; - } - - template <> - complex64_t operator()(complex64_t x, complex64_t y) { - auto x_theta = metal::atan(x.imag / x.real); - auto x_ln_r = 0.5 * metal::log(x.real * x.real + x.imag * x.imag); - auto mag = metal::exp(y.real * x_ln_r - y.imag * x_theta); - auto phase = y.imag * x_ln_r + y.real * x_theta; - return {mag * metal::cos(phase), mag * metal::sin(phase)}; - } -}; - - -struct Subtract { - template T operator()(T x, T y) { return x - y; } -}; - -struct LogicalAnd { - template - T operator()(T x, T y) { return x && y; }; -}; - -struct LogicalOr { - template - T operator()(T x, T y) { return x || y; }; -}; +#include "mlx/backend/metal/kernels/binary.h" template [[kernel]] void binary_op_s2s( diff --git a/mlx/backend/metal/kernels/compiled_preamble.h b/mlx/backend/metal/kernels/compiled_preamble.h new file mode 100644 index 000000000..82a9e9c5c --- /dev/null +++ b/mlx/backend/metal/kernels/compiled_preamble.h @@ -0,0 +1,4 @@ +// Copyright © 2023-2024 Apple Inc. + +#include "mlx/backend/metal/kernels/binary.h" +#include "mlx/backend/metal/kernels/unary.h" diff --git a/mlx/backend/metal/kernels/unary.h b/mlx/backend/metal/kernels/unary.h new file mode 100644 index 000000000..6d086b775 --- /dev/null +++ b/mlx/backend/metal/kernels/unary.h @@ -0,0 +1,376 @@ +// Copyright © 2023-2024 Apple Inc. + +#pragma once + +#include +#include + +#include "mlx/backend/metal/kernels/bf16.h" +#include "mlx/backend/metal/kernels/erf.h" +#include "mlx/backend/metal/kernels/utils.h" + +struct Abs { + template + T operator()(T x) { + return metal::abs(x); + }; + template <> + uint8_t operator()(uint8_t x) { + return x; + }; + template <> + uint16_t operator()(uint16_t x) { + return x; + }; + template <> + uint32_t operator()(uint32_t x) { + return x; + }; + template <> + uint64_t operator()(uint64_t x) { + return x; + }; + template <> + bool operator()(bool x) { + return x; + }; + template <> + complex64_t operator()(complex64_t x) { + return {metal::precise::sqrt(x.real * x.real + x.imag * x.imag), 0}; + }; +}; + +struct ArcCos { + template + T operator()(T x) { + return metal::precise::acos(x); + }; +}; + +struct ArcCosh { + template + T operator()(T x) { + return metal::precise::acosh(x); + }; +}; + +struct ArcSin { + template + T operator()(T x) { + return metal::precise::asin(x); + }; +}; + +struct ArcSinh { + template + T operator()(T x) { + return metal::precise::asinh(x); + }; +}; + +struct ArcTan { + template + T operator()(T x) { + return metal::precise::atan(x); + }; +}; + +struct ArcTanh { + template + T operator()(T x) { + return metal::precise::atanh(x); + }; +}; + +struct Ceil { + template + T operator()(T x) { + return metal::ceil(x); + }; + template <> + int8_t operator()(int8_t x) { + return x; + }; + template <> + int16_t operator()(int16_t x) { + return x; + }; + template <> + int32_t operator()(int32_t x) { + return x; + }; + template <> + int64_t operator()(int64_t x) { + return x; + }; + template <> + uint8_t operator()(uint8_t x) { + return x; + }; + template <> + uint16_t operator()(uint16_t x) { + return x; + }; + template <> + uint32_t operator()(uint32_t x) { + return x; + }; + template <> + uint64_t operator()(uint64_t x) { + return x; + }; + template <> + bool operator()(bool x) { + return x; + }; +}; + +struct Cos { + template + T operator()(T x) { + return metal::precise::cos(x); + }; + + template <> + complex64_t operator()(complex64_t x) { + return { + metal::precise::cos(x.real) * metal::precise::cosh(x.imag), + -metal::precise::sin(x.real) * metal::precise::sinh(x.imag)}; + }; +}; + +struct Cosh { + template + T operator()(T x) { + return metal::precise::cosh(x); + }; + + template <> + complex64_t operator()(complex64_t x) { + return { + metal::precise::cosh(x.real) * metal::precise::cos(x.imag), + metal::precise::sinh(x.real) * metal::precise::sin(x.imag)}; + }; +}; + +struct Erf { + template + T operator()(T x) { + return static_cast(erf(static_cast(x))); + }; +}; + +struct ErfInv { + template + T operator()(T x) { + return static_cast(erfinv(static_cast(x))); + }; +}; + +struct Exp { + template + T operator()(T x) { + return metal::precise::exp(x); + }; + template <> + complex64_t operator()(complex64_t x) { + auto m = metal::precise::exp(x.real); + return {m * metal::precise::cos(x.imag), m * metal::precise::sin(x.imag)}; + } +}; + +struct Floor { + template + T operator()(T x) { + return metal::floor(x); + }; + template <> + int8_t operator()(int8_t x) { + return x; + }; + template <> + int16_t operator()(int16_t x) { + return x; + }; + template <> + int32_t operator()(int32_t x) { + return x; + }; + template <> + int64_t operator()(int64_t x) { + return x; + }; + template <> + uint8_t operator()(uint8_t x) { + return x; + }; + template <> + uint16_t operator()(uint16_t x) { + return x; + }; + template <> + uint32_t operator()(uint32_t x) { + return x; + }; + template <> + uint64_t operator()(uint64_t x) { + return x; + }; + template <> + bool operator()(bool x) { + return x; + }; +}; + +struct Log { + template + T operator()(T x) { + return metal::precise::log(x); + }; +}; + +struct Log2 { + template + T operator()(T x) { + return metal::precise::log2(x); + }; +}; + +struct Log10 { + template + T operator()(T x) { + return metal::precise::log10(x); + }; +}; + +struct Log1p { + template + T operator()(T x) { + return log1p(x); + }; +}; + +struct LogicalNot { + template + T operator()(T x) { + return !x; + }; +}; + +struct Negative { + template + T operator()(T x) { + return -x; + }; +}; + +struct Round { + template + T operator()(T x) { + return metal::rint(x); + }; + template <> + complex64_t operator()(complex64_t x) { + return {metal::rint(x.real), metal::rint(x.imag)}; + }; +}; + +struct Sigmoid { + template + T operator()(T x) { + auto y = 1 / (1 + metal::exp(-metal::abs(x))); + return (x < 0) ? 1 - y : y; + } +}; + +struct Sign { + template + T operator()(T x) { + return (x > T(0)) - (x < T(0)); + }; + template <> + uint32_t operator()(uint32_t x) { + return x != 0; + }; +}; + +struct Sin { + template + T operator()(T x) { + return metal::precise::sin(x); + }; + + template <> + complex64_t operator()(complex64_t x) { + return { + metal::precise::sin(x.real) * metal::precise::cosh(x.imag), + metal::precise::cos(x.real) * metal::precise::sinh(x.imag)}; + }; +}; + +struct Sinh { + template + T operator()(T x) { + return metal::precise::sinh(x); + }; + + template <> + complex64_t operator()(complex64_t x) { + return { + metal::precise::sinh(x.real) * metal::precise::cos(x.imag), + metal::precise::cosh(x.real) * metal::precise::sin(x.imag)}; + }; +}; + +struct Square { + template + T operator()(T x) { + return x * x; + }; +}; + +struct Sqrt { + template + T operator()(T x) { + return metal::precise::sqrt(x); + }; +}; + +struct Rsqrt { + template + T operator()(T x) { + return metal::precise::rsqrt(x); + }; +}; + +struct Tan { + template + T operator()(T x) { + return metal::precise::tan(x); + }; + + template <> + complex64_t operator()(complex64_t x) { + float tan_a = metal::precise::tan(x.real); + float tanh_b = metal::precise::tanh(x.imag); + float t1 = tan_a * tanh_b; + float denom = 1. + t1 * t1; + return {(tan_a - tanh_b * t1) / denom, (tanh_b + tan_a * t1) / denom}; + }; +}; + +struct Tanh { + template + T operator()(T x) { + return metal::precise::tanh(x); + }; + + template <> + complex64_t operator()(complex64_t x) { + float tanh_a = metal::precise::tanh(x.real); + float tan_b = metal::precise::tan(x.imag); + float t1 = tanh_a * tan_b; + float denom = 1. + t1 * t1; + return {(tanh_a + tan_b * t1) / denom, (tan_b - tanh_a * t1) / denom}; + }; +}; diff --git a/mlx/backend/metal/kernels/unary.metal b/mlx/backend/metal/kernels/unary.metal index 681d7707f..154db0520 100644 --- a/mlx/backend/metal/kernels/unary.metal +++ b/mlx/backend/metal/kernels/unary.metal @@ -1,223 +1,6 @@ -// Copyright © 2023 Apple Inc. +// Copyright © 2023-2024 Apple Inc. -#include -#include - -#include "mlx/backend/metal/kernels/utils.h" -#include "mlx/backend/metal/kernels/erf.h" -#include "mlx/backend/metal/kernels/bf16.h" - -struct Abs { - template T operator()(T x) { return metal::abs(x); }; - template <> uint8_t operator()(uint8_t x) { return x; }; - template <> uint16_t operator()(uint16_t x) { return x; }; - template <> uint32_t operator()(uint32_t x) { return x; }; - template <> uint64_t operator()(uint64_t x) { return x; }; - template <> bool operator()(bool x) { return x; }; - template <> complex64_t operator()(complex64_t x) { - return {metal::precise::sqrt(x.real * x.real + x.imag * x.imag), 0}; - }; -}; - -struct ArcCos { - template T operator()(T x) { return metal::precise::acos(x); }; -}; - -struct ArcCosh { - template T operator()(T x) { return metal::precise::acosh(x); }; -}; - -struct ArcSin { - template T operator()(T x) { return metal::precise::asin(x); }; -}; - -struct ArcSinh { - template T operator()(T x) { return metal::precise::asinh(x); }; -}; - -struct ArcTan { - template T operator()(T x) { return metal::precise::atan(x); }; -}; - -struct ArcTanh { - template T operator()(T x) { return metal::precise::atanh(x); }; -}; - -struct Ceil { - template T operator()(T x) { return metal::ceil(x); }; - template <> int8_t operator()(int8_t x) { return x; }; - template <> int16_t operator()(int16_t x) { return x; }; - template <> int32_t operator()(int32_t x) { return x; }; - template <> int64_t operator()(int64_t x) { return x; }; - template <> uint8_t operator()(uint8_t x) { return x; }; - template <> uint16_t operator()(uint16_t x) { return x; }; - template <> uint32_t operator()(uint32_t x) { return x; }; - template <> uint64_t operator()(uint64_t x) { return x; }; - template <> bool operator()(bool x) { return x; }; -}; - -struct Cos { - template T operator()(T x) { return metal::precise::cos(x); }; - - template <> - complex64_t operator()(complex64_t x) { - return { - metal::precise::cos(x.real) * metal::precise::cosh(x.imag), - -metal::precise::sin(x.real) * metal::precise::sinh(x.imag) - }; - }; -}; - -struct Cosh { - template T operator()(T x) { return metal::precise::cosh(x); }; - - template <> - complex64_t operator()(complex64_t x) { - return { - metal::precise::cosh(x.real) * metal::precise::cos(x.imag), - metal::precise::sinh(x.real) * metal::precise::sin(x.imag) - }; - }; -}; - -struct Erf { - template T operator()(T x) { return static_cast(erf(static_cast(x))); }; -}; - -struct ErfInv { - template T operator()(T x) { return static_cast(erfinv(static_cast(x))); }; -}; - -struct Exp { - template T operator()(T x) { return metal::precise::exp(x); }; - template <> complex64_t operator()(complex64_t x) { - auto m = metal::precise::exp(x.real); - return {m * metal::precise::cos(x.imag), m * metal::precise::sin(x.imag)}; - } -}; - -struct Floor { - template T operator()(T x) { return metal::floor(x); }; - template <> int8_t operator()(int8_t x) { return x; }; - template <> int16_t operator()(int16_t x) { return x; }; - template <> int32_t operator()(int32_t x) { return x; }; - template <> int64_t operator()(int64_t x) { return x; }; - template <> uint8_t operator()(uint8_t x) { return x; }; - template <> uint16_t operator()(uint16_t x) { return x; }; - template <> uint32_t operator()(uint32_t x) { return x; }; - template <> uint64_t operator()(uint64_t x) { return x; }; - template <> bool operator()(bool x) { return x; }; -}; - -struct Log { - template T operator()(T x) { return metal::precise::log(x); }; -}; - -struct Log2 { - template T operator()(T x) { return metal::precise::log2(x); }; -}; - -struct Log10 { - template T operator()(T x) { return metal::precise::log10(x); }; -}; - -struct Log1p { - template T operator()(T x) { return log1p(x); }; -}; - -struct LogicalNot { - template T operator()(T x) { return !x; }; -}; - -struct Negative { - template T operator()(T x) { return -x; }; -}; - -struct Round { - template T operator()(T x) { return metal::rint(x); }; - template <> complex64_t operator()(complex64_t x) { return {metal::rint(x.real), metal::rint(x.imag)}; }; -}; - -struct Sigmoid { - template - T operator()(T x) { - auto y = 1 / (1 + metal::exp(-metal::abs(x))); - return (x < 0) ? 1 - y : y; - } -}; - -struct Sign { - template T operator()(T x) { return (x > T(0)) - (x < T(0)); }; - template <> uint32_t operator()(uint32_t x) { return x != 0; }; -}; - -struct Sin { - template T operator()(T x) { return metal::precise::sin(x); }; - - template <> - complex64_t operator()(complex64_t x) { - return { - metal::precise::sin(x.real) * metal::precise::cosh(x.imag), - metal::precise::cos(x.real) * metal::precise::sinh(x.imag) - }; - }; -}; - -struct Sinh { - template T operator()(T x) { return metal::precise::sinh(x); }; - - template <> - complex64_t operator()(complex64_t x) { - return { - metal::precise::sinh(x.real) * metal::precise::cos(x.imag), - metal::precise::cosh(x.real) * metal::precise::sin(x.imag) - }; - }; -}; - -struct Square { - template T operator()(T x) { return x * x; }; -}; - -struct Sqrt { - template T operator()(T x) { return metal::precise::sqrt(x); }; -}; - -struct Rsqrt { - template T operator()(T x) { return metal::precise::rsqrt(x); }; -}; - -struct Tan { - template T operator()(T x) { return metal::precise::tan(x); }; - - template <> - complex64_t operator()(complex64_t x) { - float tan_a = metal::precise::tan(x.real); - float tanh_b = metal::precise::tanh(x.imag); - float t1 = tan_a * tanh_b; - float denom = 1. + t1 * t1; - return { - (tan_a - tanh_b * t1) / denom, - (tanh_b + tan_a * t1) / denom - }; - }; -}; - -struct Tanh { - template T operator()(T x) { return metal::precise::tanh(x); }; - - template <> - complex64_t operator()(complex64_t x) { - float tanh_a = metal::precise::tanh(x.real); - float tan_b = metal::precise::tan(x.imag); - float t1 = tanh_a * tan_b; - float denom = 1. + t1 * t1; - return { - (tanh_a + tan_b * t1) / denom, - (tan_b - tanh_a * t1) / denom - }; - }; -}; +#include "mlx/backend/metal/kernels/unary.h" template [[kernel]] void unary_op_v( diff --git a/mlx/backend/metal/kernels/utils.h b/mlx/backend/metal/kernels/utils.h index 634a9d6df..f9d507cf2 100644 --- a/mlx/backend/metal/kernels/utils.h +++ b/mlx/backend/metal/kernels/utils.h @@ -12,10 +12,10 @@ template struct Limits { - static const constant U max; - static const constant U min; - static const constant U finite_max; - static const constant U finite_min; + static const constant U max = metal::numeric_limits::max(); + static const constant U min = metal::numeric_limits::min(); + static const constant U finite_max = metal::numeric_limits::max(); + static const constant U finite_min = metal::numeric_limits::min(); }; #define instantiate_default_limit(type) \ @@ -273,4 +273,4 @@ inline int64_t simd_shuffle_down(int64_t data, uint16_t delta) { inline bool simd_shuffle_down(bool data, uint16_t delta) { return simd_shuffle_down(static_cast(data), delta); -} \ No newline at end of file +} diff --git a/mlx/backend/metal/make_compiled_preamble.sh b/mlx/backend/metal/make_compiled_preamble.sh new file mode 100644 index 000000000..1271f567d --- /dev/null +++ b/mlx/backend/metal/make_compiled_preamble.sh @@ -0,0 +1,28 @@ +#!/bin/bash +# +# This script generates a C++ function that provides the Metal unary and binary +# ops at runtime for use with kernel generation. +# +# Copyright © 2023-24 Apple Inc. + + +OUTPUT_FILE=$1 +CC=$2 +SRCDIR=$3 + +CONTENT=$($CC -I $SRCDIR -E $SRCDIR/mlx/backend/metal/kernels/compiled_preamble.h 2>/dev/null) + +cat << EOF > $OUTPUT_FILE +// Copyright © 2023-24 Apple Inc. + +namespace mlx::core::metal { + +const char* get_kernel_preamble() { + return R"preamble( +$CONTENT +)preamble"; + +} + +} // namespace mlx::core::metal +EOF diff --git a/mlx/backend/metal/utils.h b/mlx/backend/metal/utils.h index 378850802..f7c672c9f 100644 --- a/mlx/backend/metal/utils.h +++ b/mlx/backend/metal/utils.h @@ -117,16 +117,18 @@ MTL::Size get_block_dims(int dim0, int dim1, int dim2) { // When multiple arrays are passed they should all have the same shape. The // collapsed axes are also the same so one shape is returned. std::tuple, std::vector>> -collapse_contiguous_dims(const std::vector& xs) { +collapse_contiguous_dims( + const std::vector& shape, + const std::vector> strides) { // Make a vector that has axes separated with -1. Collapse all axes between // -1. std::vector to_collapse; - if (xs[0].ndim() > 0) { + if (shape.size() > 0) { to_collapse.push_back(0); - for (int i = 1; i < xs[0].ndim(); i++) { + for (int i = 1; i < shape.size(); i++) { bool contiguous = true; - for (auto& x : xs) { - if (x.strides()[i] * x.shape()[i] != x.strides()[i - 1]) { + for (const std::vector& st : strides) { + if (st[i] * shape[i] != st[i - 1]) { contiguous = false; } if (!contiguous) { @@ -142,21 +144,31 @@ collapse_contiguous_dims(const std::vector& xs) { } std::vector out_shape; - std::vector> out_strides(xs.size()); + std::vector> out_strides(strides.size()); for (int i = 0; i < to_collapse.size(); i++) { - int current_shape = xs[0].shape()[to_collapse[i]]; + int current_shape = shape[to_collapse[i]]; while (to_collapse[++i] != -1) { - current_shape *= xs[0].shape()[to_collapse[i]]; + current_shape *= shape[to_collapse[i]]; } out_shape.push_back(current_shape); - for (int j = 0; j < xs.size(); j++) { - out_strides[j].push_back(xs[j].strides()[to_collapse[i - 1]]); + for (int j = 0; j < strides.size(); j++) { + const std::vector& st = strides[j]; + out_strides[j].push_back(st[to_collapse[i - 1]]); } } return std::make_tuple(out_shape, out_strides); } +std::tuple, std::vector>> +collapse_contiguous_dims(const std::vector& xs) { + std::vector> strides; + for (auto& x : xs) { + strides.emplace_back(x.strides()); + } + return collapse_contiguous_dims(xs[0].shape(), strides); +} + template std::tuple, std::vector>> collapse_contiguous_dims(Arrays... xs) { diff --git a/mlx/compile.cpp b/mlx/compile.cpp index fa9e0a987..c8ee3b0da 100644 --- a/mlx/compile.cpp +++ b/mlx/compile.cpp @@ -13,7 +13,7 @@ namespace mlx::core { -constexpr int max_compile_depth = 6; +constexpr int max_compile_depth = 10; bool is_unary(const Primitive& p) { return ( diff --git a/mlx/graph_utils.cpp b/mlx/graph_utils.cpp index 7c1c17740..ba031c441 100644 --- a/mlx/graph_utils.cpp +++ b/mlx/graph_utils.cpp @@ -12,27 +12,24 @@ namespace mlx::core { -struct NodeNamer { - std::unordered_map names; - - std::string get_name(const array& x) { - auto it = names.find(x.id()); - if (it == names.end()) { - // Get the next name in the sequence - // [A, B, ..., Z, AA, AB, ...] - std::vector letters; - auto var_num = names.size() + 1; - while (var_num > 0) { - letters.push_back('A' + (var_num - 1) % 26); - var_num = (var_num - 1) / 26; - } - std::string name(letters.rbegin(), letters.rend()); - names.insert({x.id(), name}); - return name; +const std::string& NodeNamer::get_name(const array& x) { + auto it = names.find(x.id()); + if (it == names.end()) { + // Get the next name in the sequence + // [A, B, ..., Z, AA, AB, ...] + std::vector letters; + auto var_num = names.size() + 1; + while (var_num > 0) { + letters.push_back('A' + (var_num - 1) % 26); + var_num = (var_num - 1) / 26; } - return it->second; + std::string name(letters.rbegin(), letters.rend()); + names.insert({x.id(), name}); + + return get_name(x); } -}; + return it->second; +} void depth_first_traversal( std::function callback, diff --git a/mlx/graph_utils.h b/mlx/graph_utils.h index 3bd373bec..5e024704e 100644 --- a/mlx/graph_utils.h +++ b/mlx/graph_utils.h @@ -6,6 +6,12 @@ namespace mlx::core { +struct NodeNamer { + std::unordered_map names; + + const std::string& get_name(const array& x); +}; + void print_graph(std::ostream& os, const std::vector& outputs); template diff --git a/mlx/primitives.h b/mlx/primitives.h index 5bdee12cf..b06a35780 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -473,23 +473,30 @@ class Compiled : public Primitive { void eval_cpu(const std::vector& inputs, std::vector& outputs) override; - void eval_gpu(const std::vector& inputs, std::vector& outputs) override; DEFINE_VMAP() DEFINE_GRADS() - void print(std::ostream& os) override; - bool is_equivalent(const Primitive& other) const override; + std::string metal_lib_name() const { + return kernel_lib_; + } + std::string metal_lib_source() const { + return kernel_source_; + } + private: const std::vector inputs_; const std::vector outputs_; const std::vector tape_; const std::unordered_set constant_ids_; + std::string kernel_lib_; + std::string kernel_source_; + void eval(const std::vector& inputs, std::vector& out); }; @@ -709,9 +716,16 @@ class Equal : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Equal) DEFINE_DEFAULT_IS_EQUIVALENT() + void print(std::ostream& os) override { + if (equal_nan_) { + os << "NanEqual"; + } else { + os << "Equal"; + } + } + private: void eval(const std::vector& inputs, array& out); bool equal_nan_; @@ -945,9 +959,22 @@ class Log : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Log) DEFINE_DEFAULT_IS_EQUIVALENT() + void print(std::ostream& os) override { + switch (base_) { + case e: + os << "Log"; + break; + case two: + os << "Log2"; + break; + case ten: + os << "Log10"; + break; + } + } + private: Base base_; void eval(const std::vector& inputs, array& out); @@ -1594,9 +1621,16 @@ class Sqrt : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Sqrt) bool is_equivalent(const Primitive& other) const override; + void print(std::ostream& os) override { + if (recip_) { + os << "Rsqrt"; + } else { + os << "Sqrt"; + } + } + private: void eval(const std::vector& inputs, array& out); bool recip_; diff --git a/tests/compile_tests.cpp b/tests/compile_tests.cpp index 935a881a8..be460e3b6 100644 --- a/tests/compile_tests.cpp +++ b/tests/compile_tests.cpp @@ -623,3 +623,63 @@ TEST_CASE("test transform compiled function") { CHECK(!outs[0].inputs()[0].has_primitive()); CHECK(!outs[0].inputs()[1].has_primitive()); } + +TEST_CASE("test metal fusion kernel reuse") { + if (default_device() != Device::gpu) { + return; + } + + auto cfun = compile(gelu_1); + auto x = array({2.0f, -2.0f}); + auto y = cfun({x})[0]; + auto p = std::dynamic_pointer_cast(y.primitive_ptr()); + eval(y); + + std::string lib_name = p->metal_lib_name(); + std::string lib_source = p->metal_lib_source(); + CHECK(!lib_name.empty()); + CHECK(!lib_source.empty()); + + x = astype(reshape(arange(10), {2, 5}), float32); + auto z = cfun({x})[0]; + auto pz = std::dynamic_pointer_cast(z.primitive_ptr()); + eval(z); + + std::string lib_name_z = pz->metal_lib_name(); + std::string lib_source_z = pz->metal_lib_source(); + CHECK(!lib_name_z.empty()); + CHECK(lib_source_z.empty()); + + CHECK_EQ(lib_name, lib_name_z); +} + +auto add3(const std::vector& xs) { + return std::vector{xs[0] + xs[0] + xs[0]}; +} + +TEST_CASE("test metal fusion types") { + if (default_device() != Device::gpu) { + return; + } + + auto cfun = compile(add3); + auto x = array({2.0f, -2.0f}); + auto y = cfun({x})[0]; + auto p = std::dynamic_pointer_cast(y.primitive_ptr()); + eval(y); + + std::string lib_name = p->metal_lib_name(); + std::string lib_source = p->metal_lib_source(); + CHECK(!lib_name.empty()); + CHECK(!lib_source.empty()); + + x = array({2, -2}, int32); + auto z = cfun({x})[0]; + auto pz = std::dynamic_pointer_cast(z.primitive_ptr()); + eval(z); + + std::string lib_name_z = pz->metal_lib_name(); + std::string lib_source_z = pz->metal_lib_source(); + CHECK(!lib_name_z.empty()); + CHECK(!lib_source_z.empty()); +}