diff --git a/CMakeLists.txt b/CMakeLists.txt index 1cd5892db..50d12cdfc 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -232,6 +232,14 @@ if(MPI_FOUND) endif() endif() +message(STATUS "Downloading json") +FetchContent_Declare( + json + URL https://github.com/nlohmann/json/releases/download/v3.11.3/json.tar.xz) +FetchContent_MakeAvailable(json) +target_include_directories( + mlx PRIVATE $) + add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/mlx) target_include_directories( diff --git a/docs/src/dev/extensions.rst b/docs/src/dev/extensions.rst index c08614d03..4272371cb 100644 --- a/docs/src/dev/extensions.rst +++ b/docs/src/dev/extensions.rst @@ -22,12 +22,12 @@ You can do that in MLX directly: This function performs that operation while leaving the implementation and function transformations to MLX. -However you may need to customize the underlying implementation, perhaps to -make it faster or for custom differentiation. In this tutorial we will go -through adding custom extensions. It will cover: +However, you may want to customize the underlying implementation, perhaps to +make it faster. In this tutorial we will go through adding custom extensions. +It will cover: * The structure of the MLX library. -* Implementing a CPU operation that redirects to Accelerate_ when appropriate. +* Implementing a CPU operation. * Implementing a GPU operation using metal. * Adding the ``vjp`` and ``jvp`` function transformation. * Building a custom extension and binding it to python. @@ -45,7 +45,7 @@ Operations Operations are the front-end functions that operate on arrays. They are defined in the C++ API (:ref:`cpp_ops`), and the Python API (:ref:`ops`) binds them. -We would like an operation, :meth:`axpby` that takes in two arrays ``x`` and +We would like an operation :meth:`axpby` that takes in two arrays, ``x`` and ``y``, and two scalars, ``alpha`` and ``beta``. This is how to define it in C++: @@ -55,7 +55,7 @@ C++: * Scale and sum two vectors element-wise * z = alpha * x + beta * y * - * Follow numpy style broadcasting between x and y + * Use NumPy-style broadcasting between x and y * Inputs are upcasted to floats if needed **/ array axpby( @@ -66,7 +66,7 @@ C++: StreamOrDevice s = {} // Stream on which to schedule the operation ); -The simplest way to this operation is in terms of existing operations: +The simplest way to implement this is with existing operations: .. code-block:: C++ @@ -153,9 +153,6 @@ more concrete: private: float alpha_; float beta_; - - /** Fall back implementation for evaluation on CPU */ - void eval(const std::vector& inputs, array& out); }; The :class:`Axpby` class derives from the base :class:`Primitive` class. The @@ -188,7 +185,7 @@ Let's reimplement our operation now in terms of our :class:`Axpby` primitive. auto promoted_dtype = promote_types(x.dtype(), y.dtype()); // Upcast to float32 for non-floating point inputs x and y - auto out_dtype = is_floating_point(promoted_dtype) + auto out_dtype = issubdtype(promoted_dtype, float32) ? promoted_dtype : promote_types(promoted_dtype, float32); @@ -234,49 +231,59 @@ the execution of the computation graph, and calls :meth:`Axpby::eval_cpu` or Implementing the CPU Back-end ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -Let's start by implementing a naive and generic version of -:meth:`Axpby::eval_cpu`. We declared this as a private member function of -:class:`Axpby` earlier called :meth:`Axpby::eval`. +Let's start by implementing :meth:`Axpby::eval_cpu`. -Our naive method will go over each element of the output array, find the +The method will go over each element of the output array, find the corresponding input elements of ``x`` and ``y`` and perform the operation point-wise. This is captured in the templated function :meth:`axpby_impl`. .. code-block:: C++ - template - void axpby_impl( - const array& x, - const array& y, - array& out, - float alpha_, - float beta_) { - // We only allocate memory when we are ready to fill the output - // malloc_or_wait synchronously allocates available memory - // There may be a wait executed here if the allocation is requested - // under memory-pressured conditions - out.set_data(allocator::malloc_or_wait(out.nbytes())); + template + void axpby_impl( + const mx::array& x, + const mx::array& y, + mx::array& out, + float alpha_, + float beta_, + mx::Stream stream) { + // Allocate the output with `malloc_or_wait` which synchronously allocates + // memory, potentially waiting if the system is under memory pressure + out.set_data(mx::allocator::malloc_or_wait(out.nbytes())); - // Collect input and output data pointers - const T* x_ptr = x.data(); - const T* y_ptr = y.data(); - T* out_ptr = out.data(); + // Get the CPU command encoder and register input and output arrays + auto& encoder = mx::cpu::get_command_encoder(stream); + encoder.set_input_array(x); + encoder.set_input_array(y); + encoder.set_output_array(out); - // Cast alpha and beta to the relevant types - T alpha = static_cast(alpha_); - T beta = static_cast(beta_); + // Launch the CPU kernel + encoder.dispatch([x_ptr = x.data(), + y_ptr = y.data(), + out_ptr = out.data(), + size = out.size(), + shape = out.shape(), + x_strides = x.strides(), + y_strides = y.strides(), + alpha_, + beta_]() { - // Do the element-wise operation for each output - for (size_t out_idx = 0; out_idx < out.size(); out_idx++) { - // Map linear indices to offsets in x and y - auto x_offset = elem_to_loc(out_idx, x.shape(), x.strides()); - auto y_offset = elem_to_loc(out_idx, y.shape(), y.strides()); + // Cast alpha and beta to the relevant types + T alpha = static_cast(alpha_); + T beta = static_cast(beta_); - // We allocate the output to be contiguous and regularly strided - // (defaults to row major) and hence it doesn't need additional mapping - out_ptr[out_idx] = alpha * x_ptr[x_offset] + beta * y_ptr[y_offset]; - } - } + // Do the element-wise operation for each output + for (size_t out_idx = 0; out_idx < size; out_idx++) { + // Map linear indices to offsets in x and y + auto x_offset = mx::elem_to_loc(out_idx, shape, x_strides); + auto y_offset = mx::elem_to_loc(out_idx, shape, y_strides); + + // We allocate the output to be contiguous and regularly strided + // (defaults to row major) and hence it doesn't need additional mapping + out_ptr[out_idx] = alpha * x_ptr[x_offset] + beta * y_ptr[y_offset]; + } + }); + } Our implementation should work for all incoming floating point arrays. Accordingly, we add dispatches for ``float32``, ``float16``, ``bfloat16`` and @@ -284,112 +291,32 @@ Accordingly, we add dispatches for ``float32``, ``float16``, ``bfloat16`` and .. code-block:: C++ - /** Fall back implementation for evaluation on CPU */ - void Axpby::eval( - const std::vector& inputs, - const std::vector& outputs) { - auto& x = inputs[0]; - auto& y = inputs[1]; - auto& out = outputs[0]; - - // Dispatch to the correct dtype - if (out.dtype() == float32) { - return axpby_impl(x, y, out, alpha_, beta_); - } else if (out.dtype() == float16) { - return axpby_impl(x, y, out, alpha_, beta_); - } else if (out.dtype() == bfloat16) { - return axpby_impl(x, y, out, alpha_, beta_); - } else if (out.dtype() == complex64) { - return axpby_impl(x, y, out, alpha_, beta_); - } else { - throw std::runtime_error( - "[Axpby] Only supports floating point types."); - } - } - -This is good as a fallback implementation. We can use the ``axpby`` routine -provided by the Accelerate_ framework for a faster implementation in certain -cases: - -#. Accelerate does not provide implementations of ``axpby`` for half precision - floats. We can only use it for ``float32`` types. -#. Accelerate assumes the inputs ``x`` and ``y`` are contiguous and all - elements have fixed strides between them. We only direct to Accelerate - if both ``x`` and ``y`` are row contiguous or column contiguous. -#. Accelerate performs the routine ``Y = (alpha * X) + (beta * Y)`` in-place. - MLX expects to write the output to a new array. We must copy the elements - of ``y`` into the output and use that as an input to ``axpby``. - -Let's write an implementation that uses Accelerate in the right conditions. -It allocates data for the output, copies ``y`` into it, and then calls the -:func:`catlas_saxpby` from accelerate. - -.. code-block:: C++ - - template - void axpby_impl_accelerate( - const array& x, - const array& y, - array& out, - float alpha_, - float beta_) { - // Accelerate library provides catlas_saxpby which does - // Y = (alpha * X) + (beta * Y) in place - // To use it, we first copy the data in y over to the output array - out.set_data(allocator::malloc_or_wait(out.nbytes())); - - // We then copy over the elements using the contiguous vector specialization - copy_inplace(y, out, CopyType::Vector); - - // Get x and y pointers for catlas_saxpby - const T* x_ptr = x.data(); - T* y_ptr = out.data(); - - T alpha = static_cast(alpha_); - T beta = static_cast(beta_); - - // Call the inplace accelerate operator - catlas_saxpby( - /* N = */ out.size(), - /* ALPHA = */ alpha, - /* X = */ x_ptr, - /* INCX = */ 1, - /* BETA = */ beta, - /* Y = */ y_ptr, - /* INCY = */ 1); - } - -For inputs that do not fit the criteria for accelerate, we fall back to -:meth:`Axpby::eval`. With this in mind, let's finish our -:meth:`Axpby::eval_cpu`. - -.. code-block:: C++ - - /** Evaluate primitive on CPU using accelerate specializations */ void Axpby::eval_cpu( - const std::vector& inputs, - const std::vector& outputs) { - assert(inputs.size() == 2); - auto& x = inputs[0]; - auto& y = inputs[1]; - auto& out = outputs[0]; + const std::vector& inputs, + std::vector& outputs) { + auto& x = inputs[0]; + auto& y = inputs[1]; + auto& out = outputs[0]; - // Accelerate specialization for contiguous single precision float arrays - if (out.dtype() == float32 && - ((x.flags().row_contiguous && y.flags().row_contiguous) || - (x.flags().col_contiguous && y.flags().col_contiguous))) { - axpby_impl_accelerate(x, y, out, alpha_, beta_); - return; - } - - // Fall back to common back-end if specializations are not available - eval(inputs, outputs); + // Dispatch to the correct dtype + if (out.dtype() == mx::float32) { + return axpby_impl(x, y, out, alpha_, beta_, stream()); + } else if (out.dtype() == mx::float16) { + return axpby_impl(x, y, out, alpha_, beta_, stream()); + } else if (out.dtype() == mx::bfloat16) { + return axpby_impl(x, y, out, alpha_, beta_, stream()); + } else if (out.dtype() == mx::complex64) { + return axpby_impl(x, y, out, alpha_, beta_, stream()); + } else { + throw std::runtime_error( + "Axpby is only supported for floating point types."); + } } Just this much is enough to run the operation :meth:`axpby` on a CPU stream! If you do not plan on running the operation on the GPU or using transforms on computation graphs that contain :class:`Axpby`, you can stop implementing the -primitive here and enjoy the speed-ups you get from the Accelerate library. +primitive here. Implementing the GPU Back-end ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -824,7 +751,7 @@ Results ^^^^^^^ Let's run a quick benchmark and see how our new ``axpby`` operation compares -with the naive :meth:`simple_axpby` we first defined on the CPU. +with the naive :meth:`simple_axpby` we first defined. .. code-block:: python @@ -832,13 +759,11 @@ with the naive :meth:`simple_axpby` we first defined on the CPU. from mlx_sample_extensions import axpby import time - mx.set_default_device(mx.cpu) - def simple_axpby(x: mx.array, y: mx.array, alpha: float, beta: float) -> mx.array: return alpha * x + beta * y - M = 256 - N = 512 + M = 4096 + N = 4096 x = mx.random.normal((M, N)) y = mx.random.normal((M, N)) @@ -849,24 +774,24 @@ with the naive :meth:`simple_axpby` we first defined on the CPU. def bench(f): # Warm up - for i in range(100): + for i in range(5): z = f(x, y, alpha, beta) mx.eval(z) # Timed run s = time.time() - for i in range(5000): + for i in range(100): z = f(x, y, alpha, beta) mx.eval(z) e = time.time() - return e - s + return 1000 * (e - s) / 100 simple_time = bench(simple_axpby) custom_time = bench(axpby) - print(f"Simple axpby: {simple_time:.3f} s | Custom axpby: {custom_time:.3f} s") + print(f"Simple axpby: {simple_time:.3f} ms | Custom axpby: {custom_time:.3f} ms") -The results are ``Simple axpby: 0.114 s | Custom axpby: 0.109 s``. We see +The results are ``Simple axpby: 1.559 ms | Custom axpby: 0.774 ms``. We see modest improvements right away! This operation is now good to be used to build other operations, in diff --git a/examples/extensions/CMakeLists.txt b/examples/extensions/CMakeLists.txt index db2ba9b59..0f7018720 100644 --- a/examples/extensions/CMakeLists.txt +++ b/examples/extensions/CMakeLists.txt @@ -10,7 +10,6 @@ set(CMAKE_POSITION_INDEPENDENT_CODE ON) option(BUILD_SHARED_LIBS "Build extensions as a shared library" ON) # ----------------------------- Dependencies ----------------------------- -find_package(MLX CONFIG REQUIRED) find_package( Python 3.8 COMPONENTS Interpreter Development.Module @@ -21,6 +20,12 @@ execute_process( OUTPUT_VARIABLE nanobind_ROOT) find_package(nanobind CONFIG REQUIRED) +execute_process( + COMMAND "${Python_EXECUTABLE}" -m mlx --cmake-dir + OUTPUT_STRIP_TRAILING_WHITESPACE + OUTPUT_VARIABLE MLX_ROOT) +find_package(MLX CONFIG REQUIRED) + # ----------------------------- Extensions ----------------------------- # Add library diff --git a/examples/extensions/axpby/axpby.cpp b/examples/extensions/axpby/axpby.cpp index aa312bb3a..d3663e71e 100644 --- a/examples/extensions/axpby/axpby.cpp +++ b/examples/extensions/axpby/axpby.cpp @@ -1,20 +1,14 @@ -// Copyright © 2023-2024 Apple Inc. +// Copyright © 2023-2025 Apple Inc. -#include #include #include -#include "mlx/backend/common/copy.h" #include "mlx/backend/common/utils.h" -#include "mlx/backend/cpu/copy.h" +#include "mlx/backend/cpu/encoder.h" #include "mlx/utils.h" #include "axpby/axpby.h" -#ifdef ACCELERATE_NEW_LAPACK -#include -#endif - #ifdef _METAL_ #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/utils.h" @@ -76,136 +70,67 @@ void axpby_impl( const mx::array& y, mx::array& out, float alpha_, - float beta_) { - // We only allocate memory when we are ready to fill the output - // malloc_or_wait synchronously allocates available memory - // There may be a wait executed here if the allocation is requested - // under memory-pressured conditions + float beta_, + mx::Stream stream) { + // Allocate the output with `malloc_or_wait` which synchronously allocates + // memory, potentially waiting if the system is under memory pressure out.set_data(mx::allocator::malloc_or_wait(out.nbytes())); - // Collect input and output data pointers - const T* x_ptr = x.data(); - const T* y_ptr = y.data(); - T* out_ptr = out.data(); + // Get the CPU command encoder and register input and output arrays + auto& encoder = mx::cpu::get_command_encoder(stream); + encoder.set_input_array(x); + encoder.set_input_array(y); + encoder.set_output_array(out); - // Cast alpha and beta to the relevant types - T alpha = static_cast(alpha_); - T beta = static_cast(beta_); + // Launch the CPU kernel + encoder.dispatch([x_ptr = x.data(), + y_ptr = y.data(), + out_ptr = out.data(), + size = out.size(), + shape = out.shape(), + x_strides = x.strides(), + y_strides = y.strides(), + alpha_, + beta_]() { + // Cast alpha and beta to the relevant types + T alpha = static_cast(alpha_); + T beta = static_cast(beta_); - // Do the element-wise operation for each output - for (size_t out_idx = 0; out_idx < out.size(); out_idx++) { - // Map linear indices to offsets in x and y - auto x_offset = mx::elem_to_loc(out_idx, x.shape(), x.strides()); - auto y_offset = mx::elem_to_loc(out_idx, y.shape(), y.strides()); + // Do the element-wise operation for each output + for (size_t out_idx = 0; out_idx < size; out_idx++) { + // Map linear indices to offsets in x and y + auto x_offset = mx::elem_to_loc(out_idx, shape, x_strides); + auto y_offset = mx::elem_to_loc(out_idx, shape, y_strides); - // We allocate the output to be contiguous and regularly strided - // (defaults to row major) and hence it doesn't need additional mapping - out_ptr[out_idx] = alpha * x_ptr[x_offset] + beta * y_ptr[y_offset]; - } + // We allocate the output to be contiguous and regularly strided + // (defaults to row major) and hence it doesn't need additional mapping + out_ptr[out_idx] = alpha * x_ptr[x_offset] + beta * y_ptr[y_offset]; + } + }); } -/** Fall back implementation for evaluation on CPU */ -void Axpby::eval( +void Axpby::eval_cpu( const std::vector& inputs, std::vector& outputs) { - // Check the inputs (registered in the op while constructing the out array) - assert(inputs.size() == 2); auto& x = inputs[0]; auto& y = inputs[1]; auto& out = outputs[0]; // Dispatch to the correct dtype if (out.dtype() == mx::float32) { - return axpby_impl(x, y, out, alpha_, beta_); + return axpby_impl(x, y, out, alpha_, beta_, stream()); } else if (out.dtype() == mx::float16) { - return axpby_impl(x, y, out, alpha_, beta_); + return axpby_impl(x, y, out, alpha_, beta_, stream()); } else if (out.dtype() == mx::bfloat16) { - return axpby_impl(x, y, out, alpha_, beta_); + return axpby_impl(x, y, out, alpha_, beta_, stream()); } else if (out.dtype() == mx::complex64) { - return axpby_impl(x, y, out, alpha_, beta_); + return axpby_impl(x, y, out, alpha_, beta_, stream()); } else { throw std::runtime_error( "Axpby is only supported for floating point types."); } } -/////////////////////////////////////////////////////////////////////////////// -// Primitive Accelerate Backend Implementation -/////////////////////////////////////////////////////////////////////////////// - -#ifdef ACCELERATE_NEW_LAPACK - -template -void axpby_impl_accelerate( - const mx::array& x, - const mx::array& y, - mx::array& out, - float alpha_, - float beta_) { - // Accelerate library provides catlas_saxpby which does - // Y = (alpha * X) + (beta * Y) in place - // To use it, we first copy the data in y over to the output array - - // This specialization requires both x and y be contiguous in the same mode - // i.e: corresponding linear indices in both point to corresponding elements - // The data in the output array is allocated to match the strides in y - // such that x, y, and out are contiguous in the same mode and - // no transposition is needed - out.set_data(mx::allocator::malloc_or_wait(out.nbytes())); - - // We then copy over the elements using the contiguous vector specialization - copy_inplace(y, out, mx::CopyType::Vector); - - // Get x and y pointers for catlas_saxpby - const T* x_ptr = x.data(); - T* y_ptr = out.data(); - - T alpha = static_cast(alpha_); - T beta = static_cast(beta_); - - // Call the inplace accelerate operator - catlas_saxpby( - /* N = */ out.size(), - /* ALPHA = */ alpha, - /* X = */ x_ptr, - /* INCX = */ 1, - /* BETA = */ beta, - /* Y = */ y_ptr, - /* INCY = */ 1); -} - -/** Evaluate primitive on CPU using accelerate specializations */ -void Axpby::eval_cpu( - const std::vector& inputs, - std::vector& outputs) { - assert(inputs.size() == 2); - auto& x = inputs[0]; - auto& y = inputs[1]; - auto& out = outputs[0]; - - // Accelerate specialization for contiguous single precision float arrays - if (out.dtype() == mx::float32 && - ((x.flags().row_contiguous && y.flags().row_contiguous) || - (x.flags().col_contiguous && y.flags().col_contiguous))) { - axpby_impl_accelerate(x, y, out, alpha_, beta_); - return; - } - - // Fall back to common backend if specializations are not available - eval(inputs, outputs); -} - -#else // Accelerate not available - -/** Evaluate primitive on CPU falling back to common backend */ -void Axpby::eval_cpu( - const std::vector& inputs, - std::vector& outputs) { - eval(inputs, outputs); -} - -#endif - /////////////////////////////////////////////////////////////////////////////// // Primitive Metal Backend Implementation /////////////////////////////////////////////////////////////////////////////// @@ -217,7 +142,6 @@ void Axpby::eval_gpu( const std::vector& inputs, std::vector& outputs) { // Prepare inputs - assert(inputs.size() == 2); auto& x = inputs[0]; auto& y = inputs[1]; auto& out = outputs[0]; diff --git a/examples/extensions/axpby/axpby.h b/examples/extensions/axpby/axpby.h index 76421b493..26f80961c 100644 --- a/examples/extensions/axpby/axpby.h +++ b/examples/extensions/axpby/axpby.h @@ -1,4 +1,4 @@ -// Copyright © 2023 Apple Inc. +// Copyright © 2023-2025 Apple Inc. #pragma once @@ -85,11 +85,6 @@ class Axpby : public mx::Primitive { private: float alpha_; float beta_; - - /** Fall back implementation for evaluation on CPU */ - void eval( - const std::vector& inputs, - std::vector& outputs); }; } // namespace my_ext diff --git a/examples/extensions/axpby/axpby.metal b/examples/extensions/axpby/axpby.metal index bec5a5d53..ce8c54520 100644 --- a/examples/extensions/axpby/axpby.metal +++ b/examples/extensions/axpby/axpby.metal @@ -1,4 +1,4 @@ -// Copyright © 2023 Apple Inc. +// Copyright © 2023-2025 Apple Inc. #include diff --git a/mlx/array.cpp b/mlx/array.cpp index 9db264317..406a491e6 100644 --- a/mlx/array.cpp +++ b/mlx/array.cpp @@ -76,35 +76,27 @@ array::array(allocator::Buffer data, Shape shape, Dtype dtype, Deleter deleter) set_data(data, deleter); } -array::array( - allocator::Buffer data, - Shape shape, - Dtype dtype, - Strides strides, - size_t data_size, - Flags flags, - Deleter deleter) - : array_desc_(std::make_shared(std::move(shape), dtype)) { - set_data(data, data_size, std::move(strides), flags, deleter); -} - void array::detach() { + array_desc_->primitive = nullptr; + for (auto& s : array_desc_->siblings) { + s.array_desc_->primitive = nullptr; + } for (auto& s : array_desc_->siblings) { s.array_desc_->inputs.clear(); s.array_desc_->siblings.clear(); s.array_desc_->position = 0; - s.array_desc_->primitive = nullptr; } array_desc_->inputs.clear(); array_desc_->siblings.clear(); array_desc_->position = 0; - array_desc_->primitive = nullptr; } bool array::is_available() const { if (status() == Status::available) { return true; - } else if (status() == Status::evaluated && event().is_signaled()) { + } else if ( + status() == Status::evaluated && + (!event().valid() || event().is_signaled())) { set_status(Status::available); return true; } @@ -113,7 +105,10 @@ bool array::is_available() const { void array::wait() { if (!is_available()) { - event().wait(); + if (event().valid()) { + event().wait(); + detach_event(); + } set_status(Status::available); } } @@ -174,34 +169,13 @@ void array::copy_shared_buffer(const array& other) { copy_shared_buffer(other, other.strides(), other.flags(), other.data_size()); } -void array::move_shared_buffer( - array other, - const Strides& strides, - Flags flags, - size_t data_size, - size_t offset /* = 0 */) { - array_desc_->data = std::move(other.array_desc_->data); - array_desc_->strides = strides; - array_desc_->flags = flags; - array_desc_->data_size = data_size; - auto char_offset = sizeof(char) * itemsize() * offset; - auto data_ptr = other.array_desc_->data_ptr; - other.array_desc_->data_ptr = nullptr; - array_desc_->data_ptr = - static_cast(static_cast(data_ptr) + char_offset); -} - -void array::move_shared_buffer(array other) { - move_shared_buffer(other, other.strides(), other.flags(), other.data_size()); -} - array::~array() { if (array_desc_ == nullptr) { return; } - // Ignore arrays that might be detached during eval - if (status() == array::Status::scheduled) { + // Detached/detaching + if (array_desc_->primitive == nullptr) { return; } diff --git a/mlx/array.h b/mlx/array.h index e5446088b..5e980d532 100644 --- a/mlx/array.h +++ b/mlx/array.h @@ -243,18 +243,6 @@ class array { bool col_contiguous : 1; }; - /** Build an array from all the info held by the array description. Including - * the buffer, strides, flags. - */ - explicit array( - allocator::Buffer data, - Shape shape, - Dtype dtype, - Strides strides, - size_t data_size, - Flags flags, - Deleter deleter = allocator::free); - /** The array's primitive. */ Primitive& primitive() const { return *(array_desc_->primitive); @@ -365,11 +353,6 @@ class array { // For example, the status of `x` in `auto x = a + b`. unscheduled, - // The ouptut of a computation which has been scheduled but `eval_*` has - // not yet been called on the array's primitive. A possible - // status of `x` in `auto x = a + b; eval(x);` - scheduled, - // The array's `eval_*` function has been run, but the computation is not // necessarily complete. The array will have memory allocated and if it is // not a tracer then it will be detached from the graph. @@ -406,6 +389,10 @@ class array { array_desc_->event = std::move(e); } + void detach_event() const { + array_desc_->event = Event{}; + } + // Mark the array as a tracer array (true) or not. void set_tracer(bool is_tracer) { array_desc_->is_tracer = is_tracer; @@ -431,15 +418,6 @@ class array { void copy_shared_buffer(const array& other); - void move_shared_buffer( - array other, - const Strides& strides, - Flags flags, - size_t data_size, - size_t offset = 0); - - void move_shared_buffer(array other); - void overwrite_descriptor(const array& other) { array_desc_ = other.array_desc_; } diff --git a/mlx/backend/common/binary.h b/mlx/backend/common/binary.h index 9a8c10951..0c56b71a0 100644 --- a/mlx/backend/common/binary.h +++ b/mlx/backend/common/binary.h @@ -38,8 +38,7 @@ inline void set_binary_op_output_data( const array& a, const array& b, array& out, - BinaryOpType bopt, - bool donate_with_move = false) { + BinaryOpType bopt) { bool b_donatable = is_donatable(b, out); bool a_donatable = is_donatable(a, out); switch (bopt) { @@ -49,11 +48,7 @@ inline void set_binary_op_output_data( break; case BinaryOpType::ScalarVector: if (b_donatable) { - if (donate_with_move) { - out.move_shared_buffer(b); - } else { - out.copy_shared_buffer(b); - } + out.copy_shared_buffer(b); } else { out.set_data( allocator::malloc_or_wait(b.data_size() * out.itemsize()), @@ -64,11 +59,7 @@ inline void set_binary_op_output_data( break; case BinaryOpType::VectorScalar: if (a_donatable) { - if (donate_with_move) { - out.move_shared_buffer(a); - } else { - out.copy_shared_buffer(a); - } + out.copy_shared_buffer(a); } else { out.set_data( allocator::malloc_or_wait(a.data_size() * out.itemsize()), @@ -79,17 +70,9 @@ inline void set_binary_op_output_data( break; case BinaryOpType::VectorVector: if (a_donatable) { - if (donate_with_move) { - out.move_shared_buffer(a); - } else { - out.copy_shared_buffer(a); - } + out.copy_shared_buffer(a); } else if (b_donatable) { - if (donate_with_move) { - out.move_shared_buffer(b); - } else { - out.copy_shared_buffer(b); - } + out.copy_shared_buffer(b); } else { out.set_data( allocator::malloc_or_wait(a.data_size() * out.itemsize()), @@ -100,18 +83,10 @@ inline void set_binary_op_output_data( break; case BinaryOpType::General: if (a_donatable && a.flags().row_contiguous && a.size() == out.size()) { - if (donate_with_move) { - out.move_shared_buffer(a); - } else { - out.copy_shared_buffer(a); - } + out.copy_shared_buffer(a); } else if ( b_donatable && b.flags().row_contiguous && b.size() == out.size()) { - if (donate_with_move) { - out.move_shared_buffer(b); - } else { - out.copy_shared_buffer(b); - } + out.copy_shared_buffer(b); } else { out.set_data(allocator::malloc_or_wait(out.nbytes())); } diff --git a/mlx/backend/common/common.cpp b/mlx/backend/common/common.cpp index a63de9aa1..3eb98d09a 100644 --- a/mlx/backend/common/common.cpp +++ b/mlx/backend/common/common.cpp @@ -39,7 +39,7 @@ void AsStrided::eval(const std::vector& inputs, array& out) { // rely on data_size anyway. size_t data_size = out.size(); - return move_or_copy(in, out, strides_, flags, data_size, offset_); + return out.copy_shared_buffer(in, strides_, flags, data_size, offset_); } void broadcast(const array& in, array& out) { @@ -56,7 +56,7 @@ void broadcast(const array& in, array& out) { if (out.size() > in.size()) { flags.row_contiguous = flags.col_contiguous = false; } - move_or_copy(in, out, strides, flags, in.data_size()); + out.copy_shared_buffer(in, strides, flags, in.data_size()); } void Broadcast::eval(const std::vector& inputs, array& out) { @@ -69,7 +69,7 @@ void BroadcastAxes::eval(const std::vector& inputs, array& out) { void Copy::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); - move_or_copy(inputs[0], out); + out.copy_shared_buffer(inputs[0]); } void CustomTransforms::eval( @@ -78,7 +78,7 @@ void CustomTransforms::eval( assert(inputs.size() > outputs.size()); for (int i = 0, j = inputs.size() - outputs.size(); i < outputs.size(); i++, j++) { - move_or_copy(inputs[j], outputs[i]); + outputs[i].copy_shared_buffer(inputs[j]); } } @@ -87,7 +87,7 @@ void Depends::eval( std::vector& outputs) { assert(inputs.size() > outputs.size()); for (int i = 0; i < outputs.size(); i++) { - move_or_copy(inputs[i], outputs[i]); + outputs[i].copy_shared_buffer(inputs[i]); } } @@ -98,7 +98,7 @@ void ExpandDims::eval(const std::vector& inputs, array& out) { for (auto ax : axes_) { strides.insert(strides.begin() + ax, 1); } - move_or_copy(in, out, strides, in.flags(), in.data_size()); + out.copy_shared_buffer(in, strides, in.flags(), in.data_size()); } void NumberOfElements::eval(const std::vector& inputs, array& out) { @@ -210,7 +210,7 @@ void shared_buffer_reshape( auto max_dim = std::max_element(out.shape().begin(), out.shape().end()); flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim; } - move_or_copy(in, out, out_strides, flags, in.data_size()); + out.copy_shared_buffer(in, out_strides, flags, in.data_size()); } void Split::eval( @@ -276,12 +276,12 @@ void Squeeze::eval(const std::vector& inputs, array& out) { strides.push_back(in.strides(i)); } } - move_or_copy(in, out, strides, in.flags(), in.data_size()); + out.copy_shared_buffer(in, strides, in.flags(), in.data_size()); } void StopGradient::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); - move_or_copy(inputs[0], out); + out.copy_shared_buffer(inputs[0]); } void Transpose::eval(const std::vector& inputs, array& out) { @@ -315,7 +315,7 @@ void Transpose::eval(const std::vector& inputs, array& out) { b_stride *= out.shape(ri); } } - move_or_copy(in, out, out_strides, flags, in.data_size()); + out.copy_shared_buffer(in, out_strides, flags, in.data_size()); } } // namespace mlx::core diff --git a/mlx/backend/common/compiled.cpp b/mlx/backend/common/compiled.cpp index 75d679d86..dfa9e700a 100644 --- a/mlx/backend/common/compiled.cpp +++ b/mlx/backend/common/compiled.cpp @@ -161,8 +161,7 @@ void compiled_allocate_outputs( std::vector& outputs, const std::vector& inputs_, const std::unordered_set& constant_ids_, - bool contiguous, - bool move_buffers /* = false */) { + bool contiguous) { if (contiguous) { int o = 0; Strides strides; @@ -178,11 +177,7 @@ void compiled_allocate_outputs( if (in.itemsize() == outputs[o].itemsize() && !is_scalar(in) && in.is_donatable() && constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) { - if (move_buffers) { - outputs[o++].move_shared_buffer(in); - } else { - outputs[o++].copy_shared_buffer(in); - } + outputs[o++].copy_shared_buffer(in); } // Get representative input flags to properly set non-donated outputs if (strides.empty() && in.size() == outputs[0].size()) { @@ -210,13 +205,8 @@ void compiled_allocate_outputs( if (in.flags().row_contiguous && in.size() == outputs[o].size() && in.itemsize() == outputs[o].itemsize() && in.is_donatable() && constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) { - if (move_buffers) { - outputs[o].move_shared_buffer( - in, outputs[o].strides(), in.flags(), in.data_size()); - } else { - outputs[o].copy_shared_buffer( - in, outputs[o].strides(), in.flags(), in.data_size()); - } + outputs[o].copy_shared_buffer( + in, outputs[o].strides(), in.flags(), in.data_size()); o++; } } diff --git a/mlx/backend/common/compiled.h b/mlx/backend/common/compiled.h index c79d852a0..f4d28d6ab 100644 --- a/mlx/backend/common/compiled.h +++ b/mlx/backend/common/compiled.h @@ -62,7 +62,6 @@ void compiled_allocate_outputs( std::vector& outputs, const std::vector& inputs_, const std::unordered_set& constant_ids_, - bool contiguous, - bool move_buffers = false); + bool contiguous); } // namespace mlx::core diff --git a/mlx/backend/common/copy.h b/mlx/backend/common/copy.h index b05967638..5d8f7a58e 100644 --- a/mlx/backend/common/copy.h +++ b/mlx/backend/common/copy.h @@ -22,4 +22,25 @@ enum class CopyType { GeneralGeneral }; +inline bool set_copy_output_data(const array& in, array& out, CopyType ctype) { + if (ctype == CopyType::Vector) { + // If the input is donateable, we are doing a vector copy and the types + // have the same size, then the input buffer can hold the output. + if (in.is_donatable() && in.itemsize() == out.itemsize()) { + out.copy_shared_buffer(in); + return true; + } else { + out.set_data( + allocator::malloc_or_wait(in.data_size() * out.itemsize()), + in.data_size(), + in.strides(), + in.flags()); + return false; + } + } else { + out.set_data(allocator::malloc_or_wait(out.nbytes())); + return false; + } +} + } // namespace mlx::core diff --git a/mlx/backend/common/load.cpp b/mlx/backend/common/load.cpp index 1fc6c9b31..3f194f1e2 100644 --- a/mlx/backend/common/load.cpp +++ b/mlx/backend/common/load.cpp @@ -3,7 +3,8 @@ #include #include -#include "mlx/backend/common/load.h" +#include "mlx/primitives.h" +#include "mlx/scheduler.h" namespace { @@ -26,26 +27,31 @@ void swap_endianness(uint8_t* data_bytes, size_t N) { namespace mlx::core { -void load( - array& out, - size_t offset, - const std::shared_ptr& reader, - bool swap_endianness_) { - reader->read(out.data(), out.nbytes(), offset); - - if (swap_endianness_) { - switch (out.itemsize()) { - case 2: - swap_endianness<2>(out.data(), out.data_size()); - break; - case 4: - swap_endianness<4>(out.data(), out.data_size()); - break; - case 8: - swap_endianness<8>(out.data(), out.data_size()); - break; +void Load::eval_cpu(const std::vector& inputs, array& out) { + out.set_data(allocator::malloc_or_wait(out.nbytes())); + auto read_task = [out_ptr = out.data(), + size = out.size(), + itemsize = out.itemsize(), + offset = offset_, + reader = reader_, + swap_endianness_ = swap_endianness_]() mutable { + reader->read(out_ptr, size * itemsize, offset); + if (swap_endianness_) { + switch (itemsize) { + case 2: + swap_endianness<2>(reinterpret_cast(out_ptr), size); + break; + case 4: + swap_endianness<4>(reinterpret_cast(out_ptr), size); + break; + case 8: + swap_endianness<8>(reinterpret_cast(out_ptr), size); + break; + } } - } + }; + auto fut = io::thread_pool().enqueue(std::move(read_task)).share(); + scheduler::enqueue(stream(), [fut = std::move(fut)]() { fut.wait(); }); } } // namespace mlx::core diff --git a/mlx/backend/common/load.h b/mlx/backend/common/load.h deleted file mode 100644 index 5806cc693..000000000 --- a/mlx/backend/common/load.h +++ /dev/null @@ -1,14 +0,0 @@ -// Copyright © 2024 Apple Inc. - -#include "mlx/array.h" -#include "mlx/io/load.h" - -namespace mlx::core { - -void load( - array& out, - size_t offset, - const std::shared_ptr& reader, - bool swap_endianess); - -} // namespace mlx::core diff --git a/mlx/backend/common/slicing.cpp b/mlx/backend/common/slicing.cpp index 93b9d480e..6f5736d63 100644 --- a/mlx/backend/common/slicing.cpp +++ b/mlx/backend/common/slicing.cpp @@ -36,7 +36,7 @@ void shared_buffer_slice( flags.col_contiguous = is_col_contiguous; flags.contiguous = (no_bsx_size == data_size); - move_or_copy(in, out, out_strides, flags, data_size, data_offset); + out.copy_shared_buffer(in, out_strides, flags, data_size, data_offset); } void slice( diff --git a/mlx/backend/common/ternary.h b/mlx/backend/common/ternary.h index fad7fa95d..ad6df2fc4 100644 --- a/mlx/backend/common/ternary.h +++ b/mlx/backend/common/ternary.h @@ -36,15 +36,10 @@ inline void set_ternary_op_output_data( const array& b, const array& c, array& out, - TernaryOpType topt, - bool donate_with_move = false) { - auto maybe_donate = [&out, donate_with_move](const array& x) { + TernaryOpType topt) { + auto maybe_donate = [&out](const array& x) { if (is_donatable(x, out)) { - if (donate_with_move) { - out.move_shared_buffer(x); - } else { - out.copy_shared_buffer(x); - } + out.copy_shared_buffer(x); return true; } return false; diff --git a/mlx/backend/common/utils.cpp b/mlx/backend/common/utils.cpp index 1179ee9be..35bba9c63 100644 --- a/mlx/backend/common/utils.cpp +++ b/mlx/backend/common/utils.cpp @@ -4,28 +4,6 @@ namespace mlx::core { -void move_or_copy(const array& in, array& out) { - if (in.is_donatable()) { - out.move_shared_buffer(in); - } else { - out.copy_shared_buffer(in); - } -} - -void move_or_copy( - const array& in, - array& out, - const Strides& strides, - array::Flags flags, - size_t data_size, - size_t offset /* = 0 */) { - if (in.is_donatable()) { - out.move_shared_buffer(in, strides, flags, data_size, offset); - } else { - out.copy_shared_buffer(in, strides, flags, data_size, offset); - } -} - std::tuple> collapse_contiguous_dims( const Shape& shape, const std::vector& strides, diff --git a/mlx/backend/common/utils.h b/mlx/backend/common/utils.h index 5f3376fdf..20a65d7b1 100644 --- a/mlx/backend/common/utils.h +++ b/mlx/backend/common/utils.h @@ -159,15 +159,6 @@ inline bool is_donatable(const array& in, const array& out) { in.buffer_size() <= out.nbytes() + donation_extra; } -void move_or_copy(const array& in, array& out); -void move_or_copy( - const array& in, - array& out, - const Strides& strides, - array::Flags flags, - size_t data_size, - size_t offset = 0); - std::pair prepare_reshape(const array& in, const array& out); void shared_buffer_reshape( diff --git a/mlx/backend/cpu/CMakeLists.txt b/mlx/backend/cpu/CMakeLists.txt index 17a6e053e..e36e0567a 100644 --- a/mlx/backend/cpu/CMakeLists.txt +++ b/mlx/backend/cpu/CMakeLists.txt @@ -44,7 +44,9 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp ${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp ${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp ${CMAKE_CURRENT_SOURCE_DIR}/eigh.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/encoder.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp ${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp ${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp @@ -65,6 +67,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/inverse.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cholesky.cpp ${CMAKE_CURRENT_SOURCE_DIR}/unary.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp ${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp) if(MLX_BUILD_ACCELERATE) diff --git a/mlx/backend/cpu/arange.h b/mlx/backend/cpu/arange.h index fadeafde3..9e9b03bd7 100644 --- a/mlx/backend/cpu/arange.h +++ b/mlx/backend/cpu/arange.h @@ -2,76 +2,27 @@ #pragma once -#include "mlx/allocator.h" #include "mlx/array.h" +#include "mlx/backend/cpu/encoder.h" namespace mlx::core { namespace { template -void arange(T start, T next, array& out, size_t size) { +void arange(T start, T next, array& out, size_t size, Stream stream) { auto ptr = out.data(); auto step_size = next - start; - for (int i = 0; i < size; ++i) { - ptr[i] = start; - start += step_size; - } + auto& encoder = cpu::get_command_encoder(stream); + encoder.set_output_array(out); + encoder.dispatch([ptr, start, step_size, size]() mutable { + for (int i = 0; i < size; ++i) { + ptr[i] = start; + start += step_size; + } + }); } } // namespace -void arange( - const std::vector& inputs, - array& out, - double start, - double step) { - assert(inputs.size() == 0); - out.set_data(allocator::malloc_or_wait(out.nbytes())); - switch (out.dtype()) { - case bool_: - throw std::runtime_error("Bool type unsupported for arange."); - break; - case uint8: - arange(start, start + step, out, out.size()); - break; - case uint16: - arange(start, start + step, out, out.size()); - break; - case uint32: - arange(start, start + step, out, out.size()); - break; - case uint64: - arange(start, start + step, out, out.size()); - break; - case int8: - arange(start, start + step, out, out.size()); - break; - case int16: - arange(start, start + step, out, out.size()); - break; - case int32: - arange(start, start + step, out, out.size()); - break; - case int64: - arange(start, start + step, out, out.size()); - break; - case float16: - arange(start, start + step, out, out.size()); - break; - case float32: - arange(start, start + step, out, out.size()); - break; - case float64: - arange(start, start + step, out, out.size()); - break; - case bfloat16: - arange(start, start + step, out, out.size()); - break; - case complex64: - arange(start, start + step, out, out.size()); - break; - } -} - } // namespace mlx::core diff --git a/mlx/backend/cpu/arg_reduce.cpp b/mlx/backend/cpu/arg_reduce.cpp index 1eccf1012..96e7f9ee9 100644 --- a/mlx/backend/cpu/arg_reduce.cpp +++ b/mlx/backend/cpu/arg_reduce.cpp @@ -3,6 +3,7 @@ #include #include "mlx/backend/common/utils.h" +#include "mlx/backend/cpu/encoder.h" #include "mlx/primitives.h" namespace mlx::core { @@ -10,23 +11,43 @@ namespace mlx::core { namespace { template -void arg_reduce(const array& in, array& out, const OpT& op, int axis) { +void arg_reduce( + const array& in, + array& out, + const OpT& op, + int axis, + Stream stream) { auto axis_size = in.shape()[axis]; auto axis_stride = in.strides()[axis]; Strides strides = in.strides(); Shape shape = in.shape(); strides.erase(strides.begin() + axis); shape.erase(shape.begin() + axis); - for (uint32_t i = 0; i < out.size(); ++i) { - auto loc = elem_to_loc(i, shape, strides); - auto in_ptr = in.data() + loc; - uint32_t ind_v = 0; - InT v = (*in_ptr); - for (uint32_t j = 0; j < axis_size; ++j, in_ptr += axis_stride) { - op(j, (*in_ptr), &ind_v, &v); + auto in_ptr = in.data(); + auto out_ptr = out.data(); + + auto& encoder = cpu::get_command_encoder(stream); + encoder.set_input_array(in); + encoder.set_output_array(out); + encoder.dispatch([in_ptr, + out_ptr, + axis_size, + axis_stride, + op = std::move(op), + shape = std::move(shape), + strides = std::move(strides), + size = out.size()]() { + for (uint32_t i = 0; i < size; ++i) { + auto loc = elem_to_loc(i, shape, strides); + auto local_in_ptr = in_ptr + loc; + uint32_t ind_v = 0; + InT v = (*local_in_ptr); + for (uint32_t j = 0; j < axis_size; ++j, local_in_ptr += axis_stride) { + op(j, (*local_in_ptr), &ind_v, &v); + } + out_ptr[i] = ind_v; } - out.data()[i] = ind_v; - } + }); } template @@ -34,7 +55,8 @@ void arg_reduce_dispatch( const array& in, array& out, ArgReduce::ReduceType rtype, - int axis) { + int axis, + Stream stream) { switch (rtype) { case ArgReduce::ArgMin: { auto op = [](auto ind_x, auto x, auto ind_y, auto y) { @@ -43,7 +65,7 @@ void arg_reduce_dispatch( (*ind_y) = ind_x; } }; - arg_reduce(in, out, op, axis); + arg_reduce(in, out, op, axis, stream); break; } case ArgReduce::ArgMax: { @@ -53,7 +75,7 @@ void arg_reduce_dispatch( (*ind_y) = ind_x; } }; - arg_reduce(in, out, op, axis); + arg_reduce(in, out, op, axis, stream); break; } } @@ -68,46 +90,46 @@ void ArgReduce::eval_cpu(const std::vector& inputs, array& out) { switch (in.dtype()) { case bool_: - arg_reduce_dispatch(in, out, reduce_type_, axis_); + arg_reduce_dispatch(in, out, reduce_type_, axis_, stream()); break; case uint8: - arg_reduce_dispatch(in, out, reduce_type_, axis_); + arg_reduce_dispatch(in, out, reduce_type_, axis_, stream()); break; case uint16: - arg_reduce_dispatch(in, out, reduce_type_, axis_); + arg_reduce_dispatch(in, out, reduce_type_, axis_, stream()); break; case uint32: - arg_reduce_dispatch(in, out, reduce_type_, axis_); + arg_reduce_dispatch(in, out, reduce_type_, axis_, stream()); break; case uint64: - arg_reduce_dispatch(in, out, reduce_type_, axis_); + arg_reduce_dispatch(in, out, reduce_type_, axis_, stream()); break; case int8: - arg_reduce_dispatch(in, out, reduce_type_, axis_); + arg_reduce_dispatch(in, out, reduce_type_, axis_, stream()); break; case int16: - arg_reduce_dispatch(in, out, reduce_type_, axis_); + arg_reduce_dispatch(in, out, reduce_type_, axis_, stream()); break; case int32: - arg_reduce_dispatch(in, out, reduce_type_, axis_); + arg_reduce_dispatch(in, out, reduce_type_, axis_, stream()); break; case int64: - arg_reduce_dispatch(in, out, reduce_type_, axis_); + arg_reduce_dispatch(in, out, reduce_type_, axis_, stream()); break; case float16: - arg_reduce_dispatch(in, out, reduce_type_, axis_); + arg_reduce_dispatch(in, out, reduce_type_, axis_, stream()); break; case float32: - arg_reduce_dispatch(in, out, reduce_type_, axis_); + arg_reduce_dispatch(in, out, reduce_type_, axis_, stream()); break; case bfloat16: - arg_reduce_dispatch(in, out, reduce_type_, axis_); + arg_reduce_dispatch(in, out, reduce_type_, axis_, stream()); break; case float64: - arg_reduce_dispatch(in, out, reduce_type_, axis_); + arg_reduce_dispatch(in, out, reduce_type_, axis_, stream()); break; case complex64: - arg_reduce_dispatch(in, out, reduce_type_, axis_); + arg_reduce_dispatch(in, out, reduce_type_, axis_, stream()); break; } } diff --git a/mlx/backend/cpu/binary.cpp b/mlx/backend/cpu/binary.cpp index acb988c4b..55c4e69f4 100644 --- a/mlx/backend/cpu/binary.cpp +++ b/mlx/backend/cpu/binary.cpp @@ -16,49 +16,49 @@ namespace mlx::core { namespace { template -void comparison_op(const array& a, const array& b, array& out, Op op) { +void comparison_op(const array& a, const array& b, array& out) { switch (a.dtype()) { case bool_: - binary_op(a, b, out, op); + binary_op(a, b, out); break; case uint8: - binary_op(a, b, out, op); + binary_op(a, b, out); break; case uint16: - binary_op(a, b, out, op); + binary_op(a, b, out); break; case uint32: - binary_op(a, b, out, op); + binary_op(a, b, out); break; case uint64: - binary_op(a, b, out, op); + binary_op(a, b, out); break; case int8: - binary_op(a, b, out, op); + binary_op(a, b, out); break; case int16: - binary_op(a, b, out, op); + binary_op(a, b, out); break; case int32: - binary_op(a, b, out, op); + binary_op(a, b, out); break; case int64: - binary_op(a, b, out, op); + binary_op(a, b, out); break; case float16: - binary_op(a, b, out, op); + binary_op(a, b, out); break; case float32: - binary_op(a, b, out, op); + binary_op(a, b, out); break; case float64: - binary_op(a, b, out, op); + binary_op(a, b, out); break; case bfloat16: - binary_op(a, b, out, op); + binary_op(a, b, out); break; case complex64: - binary_op(a, b, out, op); + binary_op(a, b, out); break; } } @@ -151,47 +151,47 @@ void Equal::eval_cpu(const std::vector& inputs, array& out) { if (equal_nan_) { switch (a.dtype()) { case float16: - binary_op(a, b, out, detail::NaNEqual()); + binary_op(a, b, out); break; case float32: - binary_op(a, b, out, detail::NaNEqual()); + binary_op(a, b, out); break; case float64: - binary_op(a, b, out, detail::NaNEqual()); + binary_op(a, b, out); break; case bfloat16: - binary_op(a, b, out, detail::NaNEqual()); + binary_op(a, b, out); break; case complex64: - binary_op(a, b, out, detail::NaNEqual()); + binary_op(a, b, out); break; default: throw std::runtime_error( "[NanEqual::eval_cpu] Only for floating point types."); } } else { - comparison_op(a, b, out, detail::Equal()); + comparison_op(a, b, out); } } void Greater::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 2); - comparison_op(inputs[0], inputs[1], out, detail::Greater()); + comparison_op(inputs[0], inputs[1], out); } void GreaterEqual::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 2); - comparison_op(inputs[0], inputs[1], out, detail::GreaterEqual()); + comparison_op(inputs[0], inputs[1], out); } void Less::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 2); - comparison_op(inputs[0], inputs[1], out, detail::Less()); + comparison_op(inputs[0], inputs[1], out); } void LessEqual::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 2); - comparison_op(inputs[0], inputs[1], out, detail::LessEqual()); + comparison_op(inputs[0], inputs[1], out); } void LogAddExp::eval_cpu(const std::vector& inputs, array& out) { @@ -200,16 +200,16 @@ void LogAddExp::eval_cpu(const std::vector& inputs, array& out) { auto& b = inputs[1]; switch (out.dtype()) { case float16: - binary_op(a, b, out, detail::LogAddExp()); + binary_op(a, b, out); break; case float32: - binary_op(a, b, out, detail::LogAddExp()); + binary_op(a, b, out); break; case float64: - binary_op(a, b, out, detail::LogAddExp()); + binary_op(a, b, out); break; case bfloat16: - binary_op(a, b, out, detail::LogAddExp()); + binary_op(a, b, out); break; default: throw std::runtime_error( @@ -254,7 +254,7 @@ void Multiply::eval_cpu(const std::vector& inputs, array& out) { void NotEqual::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 2); - comparison_op(inputs[0], inputs[1], out, detail::NotEqual()); + comparison_op(inputs[0], inputs[1], out); } void Power::eval_cpu(const std::vector& inputs, array& out) { diff --git a/mlx/backend/cpu/binary.h b/mlx/backend/cpu/binary.h index c85e6ca1a..623f1910a 100644 --- a/mlx/backend/cpu/binary.h +++ b/mlx/backend/cpu/binary.h @@ -7,6 +7,8 @@ #include "mlx/array.h" #include "mlx/backend/common/binary.h" #include "mlx/backend/common/utils.h" +#include "mlx/backend/cpu/encoder.h" +#include "mlx/primitives.h" #include "mlx/backend/cpu/simd/simd.h" @@ -14,22 +16,18 @@ namespace mlx::core { template struct VectorScalar { - Op op; - - VectorScalar(Op op_) : op(op_) {} - template void operator()(const T* a, const T* b, U* dst, int size) { T scalar = *b; constexpr int N = simd::max_size; while (size >= N) { - simd::store(dst, op(simd::load(a), simd::Simd(scalar))); + simd::store(dst, Op{}(simd::load(a), simd::Simd(scalar))); dst += N; a += N; size -= N; } while (size-- > 0) { - *dst = op(*a, scalar); + *dst = Op{}(*a, scalar); dst++; a++; } @@ -38,22 +36,18 @@ struct VectorScalar { template struct ScalarVector { - Op op; - - ScalarVector(Op op_) : op(op_) {} - template void operator()(const T* a, const T* b, U* dst, int size) { T scalar = *a; constexpr int N = simd::max_size; while (size >= N) { - simd::store(dst, op(simd::Simd(scalar), simd::load(b))); + simd::store(dst, Op{}(simd::Simd(scalar), simd::load(b))); dst += N; b += N; size -= N; } while (size-- > 0) { - *dst = op(scalar, *b); + *dst = Op{}(scalar, *b); dst++; b++; } @@ -62,22 +56,18 @@ struct ScalarVector { template struct VectorVector { - Op op; - - VectorVector(Op op_) : op(op_) {} - template void operator()(const T* a, const T* b, U* dst, int size) { constexpr int N = simd::max_size; while (size >= N) { - simd::store(dst, op(simd::load(a), simd::load(b))); + simd::store(dst, Op{}(simd::load(a), simd::load(b))); dst += N; a += N; b += N; size -= N; } while (size-- > 0) { - *dst = op(*a, *b); + *dst = Op{}(*a, *b); dst++; a++; b++; @@ -90,7 +80,6 @@ void binary_op_dims( const T* a, const T* b, U* out, - Op op, const Shape& shape, const Strides& a_strides, const Strides& b_strides, @@ -104,12 +93,12 @@ void binary_op_dims( for (int i = 0; i < N; i++) { if constexpr (D > 1) { binary_op_dims( - a, b, out, op, shape, a_strides, b_strides, out_strides, axis + 1); + a, b, out, shape, a_strides, b_strides, out_strides, axis + 1); } else { if constexpr (Strided) { - op(a, b, out, stride_out); + Op{}(a, b, out, stride_out); } else { - *out = op(*a, *b); + *out = Op{}(*a, *b); } } out += stride_out; @@ -120,66 +109,38 @@ void binary_op_dims( template void binary_op_dispatch_dims( - const array& a, - const array& b, - array& out, - Op op, + const T* a, + const T* b, + U* out, int dim, + int size, const Shape& shape, const Strides& a_strides, const Strides& b_strides, const Strides& out_strides) { - const T* a_ptr = a.data(); - const T* b_ptr = b.data(); - U* out_ptr = out.data(); switch (dim) { case 1: binary_op_dims( - a_ptr, - b_ptr, - out_ptr, - op, - shape, - a_strides, - b_strides, - out_strides, - 0); + a, b, out, shape, a_strides, b_strides, out_strides, 0); return; case 2: binary_op_dims( - a_ptr, - b_ptr, - out_ptr, - op, - shape, - a_strides, - b_strides, - out_strides, - 0); + a, b, out, shape, a_strides, b_strides, out_strides, 0); return; case 3: binary_op_dims( - a_ptr, - b_ptr, - out_ptr, - op, - shape, - a_strides, - b_strides, - out_strides, - 0); + a, b, out, shape, a_strides, b_strides, out_strides, 0); return; } ContiguousIterator a_it(shape, a_strides, dim - 3); ContiguousIterator b_it(shape, b_strides, dim - 3); auto stride = out_strides[dim - 4]; - for (int64_t elem = 0; elem < a.size(); elem += stride) { + for (int64_t elem = 0; elem < size; elem += stride) { binary_op_dims( - a_ptr + a_it.loc, - b_ptr + b_it.loc, - out_ptr + elem, - op, + a + a_it.loc, + b + b_it.loc, + out + elem, shape, a_strides, b_strides, @@ -191,181 +152,216 @@ void binary_op_dispatch_dims( } template -void binary_op(const array& a, const array& b, array& out, Op op) { +void binary_op(const array& a, const array& b, array& out) { auto bopt = get_binary_op_type(a, b); set_binary_op_output_data(a, b, out, bopt); // The full computation is scalar scalar so call the base op once - if (bopt == BinaryOpType::ScalarScalar) { - *(out.data()) = op(*a.data(), *b.data()); - return; - } + auto a_ptr = a.data(); + auto b_ptr = b.data(); - // The full computation is scalar vector so delegate to the op - if (bopt == BinaryOpType::ScalarVector) { - ScalarVector{op}(a.data(), b.data(), out.data(), b.data_size()); - return; - } - - // The full computation is vector scalar so delegate to the op - if (bopt == BinaryOpType::VectorScalar) { - VectorScalar{op}(a.data(), b.data(), out.data(), a.data_size()); - return; - } - - // The full computation is vector vector so delegate to the op - if (bopt == BinaryOpType::VectorVector) { - VectorVector{op}(a.data(), b.data(), out.data(), out.size()); - return; - } - - // General computation so let's try to optimize - auto [new_shape, new_strides] = collapse_contiguous_dims( - a.shape(), {a.strides(), b.strides(), out.strides()}); - const auto& a_strides = new_strides[0]; - const auto& b_strides = new_strides[1]; - const auto& strides = new_strides[2]; - - // Get the left-most dim such that the array is row contiguous after - auto leftmost_rc_dim = [&strides](const auto& arr_strides) { - int d = arr_strides.size() - 1; - for (; d >= 0 && arr_strides[d] == strides[d]; d--) { + auto out_ptr = out.data(); + auto& encoder = cpu::get_command_encoder(out.primitive().stream()); + encoder.set_input_array(a); + encoder.set_input_array(b); + encoder.set_output_array(out); + encoder.dispatch([bopt, + a_ptr, + b_ptr, + out_ptr, + a_data_size = a.data_size(), + b_data_size = b.data_size(), + size = a.size(), + shape = a.shape(), + a_strides = a.strides(), + b_strides = b.strides(), + strides = out.strides()]() mutable { + if (bopt == BinaryOpType::ScalarScalar) { + *out_ptr = Op{}(*a_ptr, *b_ptr); + return; } - return d + 1; - }; - auto a_rc_dim = leftmost_rc_dim(a_strides); - auto b_rc_dim = leftmost_rc_dim(b_strides); - // Get the left-most dim such that the array is a broadcasted "scalar" after - auto leftmost_s_dim = [](const auto& arr_strides) { - int d = arr_strides.size() - 1; - for (; d >= 0 && arr_strides[d] == 0; d--) { + // The full computation is scalar vector so delegate to the op + if (bopt == BinaryOpType::ScalarVector) { + ScalarVector{}(a_ptr, b_ptr, out_ptr, b_data_size); + return; } - return d + 1; - }; - auto a_s_dim = leftmost_s_dim(a_strides); - auto b_s_dim = leftmost_s_dim(b_strides); - auto ndim = new_shape.size(); + // The full computation is vector scalar so delegate to the op + if (bopt == BinaryOpType::VectorScalar) { + VectorScalar{}(a_ptr, b_ptr, out_ptr, a_data_size); + return; + } - // Case 1: LxM and FxM where L and F are broadcastable and M is row contiguous - int dim = ndim; - if (int d = std::max(a_rc_dim, b_rc_dim); d < ndim) { - bopt = BinaryOpType::VectorVector; - dim = d; - // Case 2: LxM and Fx1 where L and F are broadcastable and M is row + // The full computation is vector vector so delegate to the op + if (bopt == BinaryOpType::VectorVector) { + VectorVector{}(a_ptr, b_ptr, out_ptr, size); + return; + } + + // General computation so let's try to optimize + auto [new_shape, new_strides] = collapse_contiguous_dims( + shape, + {std::move(a_strides), std::move(b_strides), std::move(strides)}); + a_strides = new_strides[0]; + b_strides = new_strides[1]; + strides = new_strides[2]; + + // Get the left-most dim such that the array is row contiguous after + auto leftmost_rc_dim = [&strides](const auto& arr_strides) { + int d = arr_strides.size() - 1; + for (; d >= 0 && arr_strides[d] == strides[d]; d--) { + } + return d + 1; + }; + auto a_rc_dim = leftmost_rc_dim(a_strides); + auto b_rc_dim = leftmost_rc_dim(b_strides); + + // Get the left-most dim such that the array is a broadcasted "scalar" after + auto leftmost_s_dim = [](const auto& arr_strides) { + int d = arr_strides.size() - 1; + for (; d >= 0 && arr_strides[d] == 0; d--) { + } + return d + 1; + }; + auto a_s_dim = leftmost_s_dim(a_strides); + auto b_s_dim = leftmost_s_dim(b_strides); + + auto ndim = new_shape.size(); + + // Case 1: LxM and FxM where L and F are broadcastable and M is row // contiguous - } else if (int d = std::max(a_rc_dim, b_s_dim); d < ndim) { - bopt = BinaryOpType::VectorScalar; - dim = d; - // Case 3: Lx1 and FxM where L and F are broadcastable and M is row - // contiguous - } else if (int d = std::max(a_s_dim, b_rc_dim); d < ndim) { - bopt = BinaryOpType::ScalarVector; - dim = d; - } + int dim = ndim; + if (int d = std::max(a_rc_dim, b_rc_dim); d < ndim) { + bopt = BinaryOpType::VectorVector; + dim = d; + // Case 2: LxM and Fx1 where L and F are broadcastable and M is row + // contiguous + } else if (int d = std::max(a_rc_dim, b_s_dim); d < ndim) { + bopt = BinaryOpType::VectorScalar; + dim = d; + // Case 3: Lx1 and FxM where L and F are broadcastable and M is row + // contiguous + } else if (int d = std::max(a_s_dim, b_rc_dim); d < ndim) { + bopt = BinaryOpType::ScalarVector; + dim = d; + } - // Can be sure dim > 0 since otherwise we would have used one of the fully - // contiguous methods above. Except for the case that the flags do not - // correspond to the underlying contiguity. - if (dim == 0 || strides[dim - 1] < 16) { - bopt = BinaryOpType::General; - dim = ndim; - } + // Can be sure dim > 0 since otherwise we would have used one of the fully + // contiguous methods above. Except for the case that the flags do not + // correspond to the underlying contiguity. + if (dim == 0 || strides[dim - 1] < 16) { + bopt = BinaryOpType::General; + dim = ndim; + } - switch (bopt) { - case BinaryOpType::VectorVector: - binary_op_dispatch_dims( - a, - b, - out, - VectorVector{op}, - dim, - new_shape, - a_strides, - b_strides, - strides); - break; - case BinaryOpType::VectorScalar: - binary_op_dispatch_dims( - a, - b, - out, - VectorScalar{op}, - dim, - new_shape, - a_strides, - b_strides, - strides); - break; - case BinaryOpType::ScalarVector: - binary_op_dispatch_dims( - a, - b, - out, - ScalarVector{op}, - dim, - new_shape, - a_strides, - b_strides, - strides); - break; - default: - binary_op_dispatch_dims( - a, b, out, op, dim, new_shape, a_strides, b_strides, strides); - break; - } + switch (bopt) { + case BinaryOpType::VectorVector: + binary_op_dispatch_dims>( + a_ptr, + b_ptr, + out_ptr, + dim, + size, + new_shape, + a_strides, + b_strides, + strides); + break; + case BinaryOpType::VectorScalar: + binary_op_dispatch_dims>( + a_ptr, + b_ptr, + out_ptr, + dim, + size, + new_shape, + a_strides, + b_strides, + strides); + break; + case BinaryOpType::ScalarVector: + binary_op_dispatch_dims>( + a_ptr, + b_ptr, + out_ptr, + dim, + size, + new_shape, + a_strides, + b_strides, + strides); + break; + default: + binary_op_dispatch_dims( + a_ptr, + b_ptr, + out_ptr, + dim, + size, + new_shape, + a_strides, + b_strides, + strides); + break; + } + }); +} + +template +void binary_op(const array& a, const array& b, array& out) { + binary_op(a, b, out); } template void binary_op(const array& a, const array& b, array& out, Op op) { - binary_op(a, b, out, op); + binary_op(a, b, out); } template void binary(const array& a, const array& b, array& out, Op op) { switch (out.dtype()) { case bool_: - binary_op(a, b, out, op); + binary_op(a, b, out); break; case uint8: - binary_op(a, b, out, op); + binary_op(a, b, out); break; case uint16: - binary_op(a, b, out, op); + binary_op(a, b, out); break; case uint32: - binary_op(a, b, out, op); + binary_op(a, b, out); break; case uint64: - binary_op(a, b, out, op); + binary_op(a, b, out); break; case int8: - binary_op(a, b, out, op); + binary_op(a, b, out); break; case int16: - binary_op(a, b, out, op); + binary_op(a, b, out); break; case int32: - binary_op(a, b, out, op); + binary_op(a, b, out); break; case int64: - binary_op(a, b, out, op); + binary_op(a, b, out); break; case float16: - binary_op(a, b, out, op); + binary_op(a, b, out); break; case float32: - binary_op(a, b, out, op); + binary_op(a, b, out); break; case float64: - binary_op(a, b, out, op); + binary_op(a, b, out); break; case bfloat16: - binary_op(a, b, out, op); + binary_op(a, b, out); break; case complex64: - binary_op(a, b, out, op); + binary_op(a, b, out); break; } } diff --git a/mlx/backend/cpu/binary_two.h b/mlx/backend/cpu/binary_two.h index 4b6bf25cd..a89f5aa11 100644 --- a/mlx/backend/cpu/binary_two.h +++ b/mlx/backend/cpu/binary_two.h @@ -4,6 +4,8 @@ #include "mlx/backend/common/utils.h" #include "mlx/backend/cpu/binary.h" +#include "mlx/backend/cpu/encoder.h" +#include "mlx/primitives.h" namespace mlx::core { @@ -55,65 +57,81 @@ void binary_op_dispatch_dims( const array& b, array& out_a, array& out_b, + Stream stream, Op op) { + auto& encoder = cpu::get_command_encoder(stream); + encoder.set_input_array(a); + encoder.set_input_array(b); + encoder.set_output_array(out_a); + encoder.set_output_array(out_b); + auto [shape, strides] = collapse_contiguous_dims( a.shape(), {a.strides(), b.strides(), out_a.strides()}); - const auto& a_strides = strides[0]; - const auto& b_strides = strides[1]; - const auto& out_strides = strides[2]; const T* a_ptr = a.data(); const T* b_ptr = b.data(); U* out_a_ptr = out_a.data(); U* out_b_ptr = out_b.data(); - int ndim = shape.size(); - switch (ndim) { - case 1: - binary_op_dims( - a_ptr, - b_ptr, - out_a_ptr, - out_b_ptr, - op, - shape, - a_strides, - b_strides, - out_strides, - 0); - return; - case 2: - binary_op_dims( - a_ptr, - b_ptr, - out_a_ptr, - out_b_ptr, - op, - shape, - a_strides, - b_strides, - out_strides, - 0); - return; - } + encoder.dispatch([a_ptr, + b_ptr, + out_a_ptr, + out_b_ptr, + size = a.size(), + shape = std::move(shape), + strides = std::move(strides), + op = std::move(op)]() { + const auto& a_strides = strides[0]; + const auto& b_strides = strides[1]; + const auto& out_strides = strides[2]; + int ndim = shape.size(); + switch (ndim) { + case 1: + binary_op_dims( + a_ptr, + b_ptr, + out_a_ptr, + out_b_ptr, + op, + shape, + a_strides, + b_strides, + out_strides, + 0); + return; + case 2: + binary_op_dims( + a_ptr, + b_ptr, + out_a_ptr, + out_b_ptr, + op, + shape, + a_strides, + b_strides, + out_strides, + 0); + return; + } - ContiguousIterator a_it(shape, a_strides, ndim - 2); - ContiguousIterator b_it(shape, b_strides, ndim - 2); - auto stride = out_strides[ndim - 3]; - for (size_t elem = 0; elem < a.size(); elem += stride) { - binary_op_dims( - a_ptr + a_it.loc, - b_ptr + b_it.loc, - out_a_ptr + elem, - out_b_ptr + elem, - op, - shape, - a_strides, - b_strides, - out_strides, - ndim - 2); - a_it.step(); - b_it.step(); - } + ContiguousIterator a_it(shape, a_strides, ndim - 2); + ContiguousIterator b_it(shape, b_strides, ndim - 2); + auto stride = out_strides[ndim - 3]; + for (size_t elem = 0; elem < size; elem += stride) { + binary_op_dims( + a_ptr + a_it.loc, + b_ptr + b_it.loc, + out_a_ptr + elem, + out_b_ptr + elem, + op, + shape, + a_strides, + b_strides, + out_strides, + ndim - 2); + a_it.step(); + b_it.step(); + } + }); } template @@ -128,40 +146,71 @@ void binary_op( set_binary_op_output_data(a, b, out_a, bopt); set_binary_op_output_data(a, b, out_b, bopt); + auto stream = out_a.primitive().stream(); // The full computation is scalar scalar so call the base op once if (bopt == BinaryOpType::General) { - binary_op_dispatch_dims(a, b, out_a, out_b, op); + binary_op_dispatch_dims(a, b, out_a, out_b, stream, op); return; } + auto& encoder = cpu::get_command_encoder(stream); + encoder.set_input_array(a); + encoder.set_input_array(b); + encoder.set_output_array(out_a); + encoder.set_output_array(out_b); + auto a_ptr = a.data(); auto b_ptr = b.data(); auto out_a_ptr = out_a.data(); auto out_b_ptr = out_b.data(); if (bopt == BinaryOpType::ScalarScalar) { - std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr); + encoder.dispatch( + [a_ptr, b_ptr, out_a_ptr, out_b_ptr, op = std::move(op)]() mutable { + std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr); + }); } else if (bopt == BinaryOpType::ScalarVector) { - for (size_t i = 0; i < b.size(); ++i) { - std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr); - out_a_ptr++; - out_b_ptr++; - b_ptr++; - } + encoder.dispatch([a_ptr, + b_ptr, + out_a_ptr, + out_b_ptr, + size = b.size(), + op = std::move(op)]() mutable { + for (size_t i = 0; i < size; ++i) { + std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr); + out_a_ptr++; + out_b_ptr++; + b_ptr++; + } + }); } else if (bopt == BinaryOpType::VectorScalar) { - for (size_t i = 0; i < a.size(); ++i) { - std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr); - out_a_ptr++; - out_b_ptr++; - a_ptr++; - } + encoder.dispatch([a_ptr, + b_ptr, + out_a_ptr, + out_b_ptr, + size = a.size(), + op = std::move(op)]() mutable { + for (size_t i = 0; i < size; ++i) { + std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr); + out_a_ptr++; + out_b_ptr++; + a_ptr++; + } + }); } else { // VectorVector - for (size_t i = 0; i < a.size(); ++i) { - std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr); - out_a_ptr++; - out_b_ptr++; - a_ptr++; - b_ptr++; - } + encoder.dispatch([a_ptr, + b_ptr, + out_a_ptr, + out_b_ptr, + size = a.size(), + op = std::move(op)]() mutable { + for (size_t i = 0; i < size; ++i) { + std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr); + out_a_ptr++; + out_b_ptr++; + a_ptr++; + b_ptr++; + } + }); } } diff --git a/mlx/backend/cpu/cholesky.cpp b/mlx/backend/cpu/cholesky.cpp index 52a39c7c1..f6b28a85a 100644 --- a/mlx/backend/cpu/cholesky.cpp +++ b/mlx/backend/cpu/cholesky.cpp @@ -2,6 +2,7 @@ #include "mlx/allocator.h" #include "mlx/backend/cpu/copy.h" +#include "mlx/backend/cpu/encoder.h" #include "mlx/backend/cpu/lapack.h" #include "mlx/linalg.h" #include "mlx/primitives.h" @@ -9,7 +10,7 @@ namespace mlx::core { template -void cholesky_impl(const array& a, array& factor, bool upper) { +void cholesky_impl(const array& a, array& factor, bool upper, Stream stream) { // Lapack uses the column-major convention. We take advantage of the fact that // the matrix should be symmetric: // (A)ᵀ = A @@ -17,60 +18,63 @@ void cholesky_impl(const array& a, array& factor, bool upper) { // triangular matrix, so uplo is the opposite of what we would expect from // upper - char uplo = (upper) ? 'L' : 'U'; - // The decomposition is computed in place, so just copy the input to the // output. copy( a, factor, - a.flags().row_contiguous ? CopyType::Vector : CopyType::General); + a.flags().row_contiguous ? CopyType::Vector : CopyType::General, + stream); - const int N = a.shape(-1); - const size_t num_matrices = a.size() / (N * N); + auto& encoder = cpu::get_command_encoder(stream); + encoder.set_output_array(factor); + encoder.dispatch([matrix = factor.data(), + upper, + N = a.shape(-1), + size = a.size()]() mutable { + char uplo = (upper) ? 'L' : 'U'; + size_t num_matrices = size / (N * N); + for (int i = 0; i < num_matrices; i++) { + // Compute Cholesky factorization. + int info; + potrf( + /* uplo = */ &uplo, + /* n = */ &N, + /* a = */ matrix, + /* lda = */ &N, + /* info = */ &info); - T* matrix = factor.data(); - - for (int i = 0; i < num_matrices; i++) { - // Compute Cholesky factorization. - int info; - potrf( - /* uplo = */ &uplo, - /* n = */ &N, - /* a = */ matrix, - /* lda = */ &N, - /* info = */ &info); - - // TODO: We do nothing when the matrix is not positive semi-definite - // because throwing an error would result in a crash. If we figure out how - // to catch errors from the implementation we should throw. - if (info < 0) { - std::stringstream msg; - msg << "[cholesky] Cholesky decomposition failed with error code " - << info; - throw std::runtime_error(msg.str()); - } - - // Zero out the upper/lower triangle while advancing the pointer to the - // next matrix at the same time. - for (int row = 0; row < N; row++) { - if (upper) { - std::fill(matrix, matrix + row, 0); - } else { - std::fill(matrix + row + 1, matrix + N, 0); + // TODO: We do nothing when the matrix is not positive semi-definite + // because throwing an error would result in a crash. If we figure out how + // to catch errors from the implementation we should throw. + if (info < 0) { + std::stringstream msg; + msg << "[Cholesky::eval_cpu] Cholesky decomposition failed with error code " + << info; + throw std::runtime_error(msg.str()); + } + + // Zero out the upper/lower triangle while advancing the pointer to the + // next matrix at the same time. + for (int row = 0; row < N; row++) { + if (upper) { + std::fill(matrix, matrix + row, 0); + } else { + std::fill(matrix + row + 1, matrix + N, 0); + } + matrix += N; } - matrix += N; } - } + }); } void Cholesky::eval_cpu(const std::vector& inputs, array& output) { switch (inputs[0].dtype()) { case float32: - cholesky_impl(inputs[0], output, upper_); + cholesky_impl(inputs[0], output, upper_, stream()); break; case float64: - cholesky_impl(inputs[0], output, upper_); + cholesky_impl(inputs[0], output, upper_, stream()); break; default: throw std::runtime_error( diff --git a/mlx/backend/cpu/compiled.cpp b/mlx/backend/cpu/compiled.cpp index 905db82c8..9da9c14e8 100644 --- a/mlx/backend/cpu/compiled.cpp +++ b/mlx/backend/cpu/compiled.cpp @@ -11,6 +11,7 @@ #include "mlx/backend/common/compiled.h" #include "mlx/backend/cpu/compiled_preamble.h" +#include "mlx/backend/cpu/encoder.h" #include "mlx/backend/cpu/jit_compiler.h" #include "mlx/device.h" #include "mlx/graph_utils.h" @@ -288,6 +289,7 @@ void Compiled::eval_cpu( // Figure out which kernel we are using auto& shape = outputs[0].shape(); auto contiguous = compiled_check_contiguity(inputs, shape); + auto& encoder = cpu::get_command_encoder(stream()); // Handle all broadcasting and collect function input arguments std::vector args; @@ -298,6 +300,7 @@ void Compiled::eval_cpu( continue; } auto& x = inputs[i]; + encoder.set_input_array(x); args.push_back((void*)x.data()); if (contiguous || is_scalar(x)) { @@ -356,18 +359,25 @@ void Compiled::eval_cpu( }); compiled_allocate_outputs( - inputs, outputs, inputs_, constant_ids_, contiguous, false); + inputs, outputs, inputs_, constant_ids_, contiguous); for (auto& x : outputs) { args.push_back(x.data()); + encoder.set_output_array(x); } + Shape out_shape; if (!contiguous) { - args.push_back((void*)outputs[0].shape().data()); + out_shape = outputs[0].shape(); + args.push_back((void*)out_shape.data()); } else { args.push_back((void*)outputs[0].data_size()); } auto fun = (void (*)(void**))fn_ptr; - fun(args.data()); + encoder.dispatch( + [fun, + args = std::move(args), + strides = std::move(strides), + out_shape = std::move(out_shape)]() mutable { fun(args.data()); }); } } // namespace mlx::core diff --git a/mlx/backend/cpu/conv.cpp b/mlx/backend/cpu/conv.cpp index 418a8cf25..8259e0597 100644 --- a/mlx/backend/cpu/conv.cpp +++ b/mlx/backend/cpu/conv.cpp @@ -4,6 +4,7 @@ #include #include "mlx/backend/cpu/copy.h" +#include "mlx/backend/cpu/encoder.h" #include "mlx/backend/cpu/lapack.h" #include "mlx/primitives.h" #include "mlx/utils.h" @@ -25,68 +26,82 @@ void slow_conv_1D( const std::vector& wt_strides, const std::vector& wt_dilation, const std::vector& in_dilation, - bool flip) { - const T* start_wt_ptr = wt.data(); + bool flip, + Stream stream) { + auto& encoder = cpu::get_command_encoder(stream); + encoder.set_input_array(in); + encoder.set_input_array(wt); + encoder.set_output_array(out); - const T* in_ptr = in.data(); - T* out_ptr = out.data(); + encoder.dispatch([start_wt_ptr = wt.data(), + in_ptr = in.data(), + out_ptr = out.data(), - const int N = in.shape(0); // Batch size, should be the same as out.shape(0) - const int iH = 1 + in_dilation[0] * (in.shape(1) - 1); // Input spatial dim - const int C = in.shape(2); // Input channels - const int oH = out.shape(1); // Output spatial dim - const int O = wt.shape(0); // Out channels - const int wH = wt.shape(1); // Weight spatial dim + N = in.shape( + 0), // Batch size, should be the same as out.shape(0) + iH = 1 + + in_dilation[0] * (in.shape(1) - 1), // Input spatial dim + oH = out.shape(1), // Output spatial dim + wH = wt.shape(1), // Weight spatial dim + groups = in.shape(2) / wt.shape(2), + O = wt.shape(0), // Out channels + C_per_group = wt.shape(2), - const int groups = C / wt.shape(2); - const int C_per_group = wt.shape(2); - const int O_per_group = O / groups; + in_stride_N = in.strides()[0], + in_stride_H = in.strides()[1], + in_stride_C = in.strides()[2], - const size_t in_stride_N = in.strides()[0]; - const size_t in_stride_H = in.strides()[1]; - const size_t in_stride_C = in.strides()[2]; + wt_stride_O = wt.strides()[0], + wt_stride_H = wt.strides()[1], + wt_stride_C = wt.strides()[2], - const size_t wt_stride_O = wt.strides()[0]; - const size_t wt_stride_H = wt.strides()[1]; - const size_t wt_stride_C = wt.strides()[2]; + out_stride_N = out.strides()[0], + out_stride_H = out.strides()[1], + out_stride_O = out.strides()[2], - const size_t out_stride_N = out.strides()[0]; - const size_t out_stride_H = out.strides()[1]; - const size_t out_stride_O = out.strides()[2]; + flip, + padding = padding[0], + wt_stride = wt_strides[0], + wt_dilation = wt_dilation[0], + in_dilation = in_dilation[0]]() mutable { + auto O_per_group = O / groups; - for (int n = 0; n < N; ++n) { - for (int oh = 0; oh < oH; ++oh) { - for (int g = 0; g < groups; ++g) { - for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) { - const T* filter_wt_ptr = start_wt_ptr + o * wt_stride_O; - float r = 0.; + for (int n = 0; n < N; ++n) { + for (int oh = 0; oh < oH; ++oh) { + for (int g = 0; g < groups; ++g) { + for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) { + const T* filter_wt_ptr = start_wt_ptr + o * wt_stride_O; + float r = 0.; - for (int wh = 0; wh < wH; ++wh) { - const T* wt_ptr = filter_wt_ptr + wh * wt_stride_H; + for (int wh = 0; wh < wH; ++wh) { + const T* wt_ptr = filter_wt_ptr + wh * wt_stride_H; - int wh_flip = flip ? (wH - wh - 1) : wh; - int ih = oh * wt_strides[0] - padding[0] + wh_flip * wt_dilation[0]; + int wh_flip = flip ? (wH - wh - 1) : wh; + int ih = oh * wt_stride - padding + wh_flip * wt_dilation; - auto ih_div = std::div(ih, in_dilation[0]); + auto ih_div = std::div(ih, in_dilation); - if (ih >= 0 && ih < iH && ih_div.rem == 0) { - for (int c = g * C_per_group; c < (g + 1) * C_per_group; ++c) { - r += static_cast( - in_ptr[ih_div.quot * in_stride_H + c * in_stride_C]) * - static_cast(wt_ptr[(c % C_per_group) * wt_stride_C]); - } // c + if (ih >= 0 && ih < iH && ih_div.rem == 0) { + for (int c = g * C_per_group; c < (g + 1) * C_per_group; ++c) { + r += + static_cast( + in_ptr[ih_div.quot * in_stride_H + c * in_stride_C]) * + static_cast( + wt_ptr[(c % C_per_group) * wt_stride_C]); + } // c - } // ih check - } // wh + } // ih check + } // wh - out_ptr[oh * out_stride_H + o * out_stride_O] = static_cast(r); - } // o - } // g - } // oh + out_ptr[oh * out_stride_H + o * out_stride_O] = static_cast(r); + } // o + } // g + } // oh - in_ptr += in_stride_N; - out_ptr += out_stride_N; - } // n + in_ptr += in_stride_N; + out_ptr += out_stride_N; + } // n + }); } template @@ -98,219 +113,237 @@ void slow_conv_2D( const std::vector& wt_strides, const std::vector& wt_dilation, const std::vector& in_dilation, - bool flip) { - const T* st_wt_ptr = wt.data(); - const T* st_in_ptr = in.data(); - T* st_out_ptr = out.data(); + bool flip, + Stream stream) { + auto& encoder = cpu::get_command_encoder(stream); + encoder.set_input_array(in); + encoder.set_input_array(wt); + encoder.set_output_array(out); - const int N = in.shape(0); // Batch size, should be the same as out.shape(0) - const int iH = 1 + in_dilation[0] * (in.shape(1) - 1); // Input spatial dim - const int iW = 1 + in_dilation[1] * (in.shape(2) - 1); // Input spatial dim - const int C = in.shape(3); // In channels - const int oH = out.shape(1); // Output spatial dim - const int oW = out.shape(2); // Output spatial dim - const int O = wt.shape(0); // Out channels - const int wH = wt.shape(1); // Weight spatial dim - const int wW = wt.shape(2); // Weight spatial dim + encoder.dispatch([st_wt_ptr = wt.data(), + st_in_ptr = in.data(), + st_out_ptr = out.data(), - const int groups = C / wt.shape(3); - const int C_per_group = wt.shape(3); - const int O_per_group = O / groups; + N = in.shape( + 0), // Batch size, should be the same as out.shape(0) + iH = 1 + + in_dilation[0] * (in.shape(1) - 1), // Input spatial dim + iW = 1 + + in_dilation[1] * (in.shape(2) - 1), // Input spatial dim + C = in.shape(3), // In channels + oH = out.shape(1), // Output spatial dim + oW = out.shape(2), // Output spatial dim + O = wt.shape(0), // Out channels + wH = wt.shape(1), // Weight spatial dim + wW = wt.shape(2), // Weight spatial dim - const size_t in_stride_N = in.strides()[0]; - const size_t in_stride_H = in.strides()[1]; - const size_t in_stride_W = in.strides()[2]; - const size_t in_stride_C = in.strides()[3]; + groups = in.shape(3) / wt.shape(3), + C_per_group = wt.shape(3), - const size_t wt_stride_O = wt.strides()[0]; - const size_t wt_stride_H = wt.strides()[1]; - const size_t wt_stride_W = wt.strides()[2]; - const size_t wt_stride_C = wt.strides()[3]; + in_stride_N = in.strides()[0], + in_stride_H = in.strides()[1], + in_stride_W = in.strides()[2], + in_stride_C = in.strides()[3], - const size_t out_stride_N = out.strides()[0]; - const size_t out_stride_H = out.strides()[1]; - const size_t out_stride_W = out.strides()[2]; - const size_t out_stride_O = out.strides()[3]; + wt_stride_O = wt.strides()[0], + wt_stride_H = wt.strides()[1], + wt_stride_W = wt.strides()[2], + wt_stride_C = wt.strides()[3], - bool is_idil_one = in_dilation[0] == 1 && in_dilation[1] == 1; + out_stride_N = out.strides()[0], + out_stride_H = out.strides()[1], + out_stride_W = out.strides()[2], + out_stride_O = out.strides()[3], - auto pt_conv_no_checks = - [&](const T* in_ptr, const T* wt_ptr, T* out_ptr, int oh, int ow) { - out_ptr += oh * out_stride_H + ow * out_stride_W; - int ih_base = oh * wt_strides[0] - padding[0]; - int iw_base = ow * wt_strides[1] - padding[1]; + padding, + wt_strides, + wt_dilation, + in_dilation, + flip]() mutable { + bool is_idil_one = in_dilation[0] == 1 && in_dilation[1] == 1; - for (int g = 0; g < groups; ++g) { - for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) { - float r = 0.; + const int O_per_group = O / groups; + auto pt_conv_no_checks = [&](const T* in_ptr, + const T* wt_ptr, + T* out_ptr, + int oh, + int ow) { + out_ptr += oh * out_stride_H + ow * out_stride_W; + int ih_base = oh * wt_strides[0] - padding[0]; + int iw_base = ow * wt_strides[1] - padding[1]; - for (int wh = 0; wh < wH; ++wh) { - for (int ww = 0; ww < wW; ++ww) { - int wh_flip = flip ? wH - wh - 1 : wh; - int ww_flip = flip ? wW - ww - 1 : ww; - int ih = ih_base + wh_flip * wt_dilation[0]; - int iw = iw_base + ww_flip * wt_dilation[1]; + for (int g = 0; g < groups; ++g) { + for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) { + float r = 0.; - const T* wt_ptr_pt = - wt_ptr + wh * wt_stride_H + ww * wt_stride_W; - const T* in_ptr_pt = - in_ptr + ih * in_stride_H + iw * in_stride_W; + for (int wh = 0; wh < wH; ++wh) { + for (int ww = 0; ww < wW; ++ww) { + int wh_flip = flip ? wH - wh - 1 : wh; + int ww_flip = flip ? wW - ww - 1 : ww; + int ih = ih_base + wh_flip * wt_dilation[0]; + int iw = iw_base + ww_flip * wt_dilation[1]; - for (int c = g * C_per_group; c < (g + 1) * C_per_group; ++c) { - r += static_cast(in_ptr_pt[c * in_stride_C]) * - static_cast( - wt_ptr_pt[(c % C_per_group) * wt_stride_C]); - } // c - } // ww - } // wh + const T* wt_ptr_pt = wt_ptr + wh * wt_stride_H + ww * wt_stride_W; + const T* in_ptr_pt = in_ptr + ih * in_stride_H + iw * in_stride_W; - out_ptr[0] = static_cast(r); - out_ptr += out_stride_O; - wt_ptr += wt_stride_O; - } // o - } // g - }; + for (int c = g * C_per_group; c < (g + 1) * C_per_group; ++c) { + r += static_cast(in_ptr_pt[c * in_stride_C]) * + static_cast( + wt_ptr_pt[(c % C_per_group) * wt_stride_C]); + } // c + } // ww + } // wh - int jump_h = flip ? -wt_dilation[0] : wt_dilation[0]; - int jump_w = flip ? -wt_dilation[1] : wt_dilation[1]; + out_ptr[0] = static_cast(r); + out_ptr += out_stride_O; + wt_ptr += wt_stride_O; + } // o + } // g + }; - int init_h = (flip ? (wH - 1) * wt_dilation[0] : 0); - int init_w = (flip ? (wW - 1) * wt_dilation[1] : 0); + int jump_h = flip ? -wt_dilation[0] : wt_dilation[0]; + int jump_w = flip ? -wt_dilation[1] : wt_dilation[1]; - int f_wgt_jump_h = std::lcm(in_dilation[0], wt_dilation[0]) / wt_dilation[0]; - int f_wgt_jump_w = std::lcm(in_dilation[1], wt_dilation[1]) / wt_dilation[1]; + int init_h = (flip ? (wH - 1) * wt_dilation[0] : 0); + int init_w = (flip ? (wW - 1) * wt_dilation[1] : 0); - int f_out_jump_h = std::lcm(in_dilation[0], wt_strides[0]) / wt_strides[0]; - int f_out_jump_w = std::lcm(in_dilation[1], wt_strides[1]) / wt_strides[1]; + int f_wgt_jump_h = + std::lcm(in_dilation[0], wt_dilation[0]) / wt_dilation[0]; + int f_wgt_jump_w = + std::lcm(in_dilation[1], wt_dilation[1]) / wt_dilation[1]; - std::vector base_h(f_out_jump_h); - std::vector base_w(f_out_jump_w); + int f_out_jump_h = std::lcm(in_dilation[0], wt_strides[0]) / wt_strides[0]; + int f_out_jump_w = std::lcm(in_dilation[1], wt_strides[1]) / wt_strides[1]; - for (int i = 0; i < f_out_jump_h; ++i) { - int ih_loop = i * wt_strides[0] - padding[0] + init_h; + std::vector base_h(f_out_jump_h); + std::vector base_w(f_out_jump_w); - int wh_base = 0; - while (wh_base < wH && ih_loop % in_dilation[0] != 0) { - wh_base++; - ih_loop += jump_h; + for (int i = 0; i < f_out_jump_h; ++i) { + int ih_loop = i * wt_strides[0] - padding[0] + init_h; + + int wh_base = 0; + while (wh_base < wH && ih_loop % in_dilation[0] != 0) { + wh_base++; + ih_loop += jump_h; + } + + base_h[i] = wh_base; } - base_h[i] = wh_base; - } + for (int j = 0; j < f_out_jump_w; ++j) { + int iw_loop = j * wt_strides[1] - padding[1] + init_w; - for (int j = 0; j < f_out_jump_w; ++j) { - int iw_loop = j * wt_strides[1] - padding[1] + init_w; + int ww_base = 0; + while (ww_base < wW && iw_loop % in_dilation[1] != 0) { + ww_base++; + iw_loop += jump_w; + } - int ww_base = 0; - while (ww_base < wW && iw_loop % in_dilation[1] != 0) { - ww_base++; - iw_loop += jump_w; + base_w[j] = ww_base; } - base_w[j] = ww_base; - } + auto pt_conv_all_checks = + [&](const T* in_ptr, const T* wt_ptr, T* out_ptr, int oh, int ow) { + out_ptr += oh * out_stride_H + ow * out_stride_W; - auto pt_conv_all_checks = - [&](const T* in_ptr, const T* wt_ptr, T* out_ptr, int oh, int ow) { - out_ptr += oh * out_stride_H + ow * out_stride_W; + int ih_base = oh * wt_strides[0] - padding[0]; + int iw_base = ow * wt_strides[1] - padding[1]; - int ih_base = oh * wt_strides[0] - padding[0]; - int iw_base = ow * wt_strides[1] - padding[1]; + int wh_base = base_h[oh % f_out_jump_h]; + int ww_base = base_w[ow % f_out_jump_w]; - int wh_base = base_h[oh % f_out_jump_h]; - int ww_base = base_w[ow % f_out_jump_w]; + for (int g = 0; g < groups; ++g) { + for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) { + float r = 0.; - for (int g = 0; g < groups; ++g) { - for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) { - float r = 0.; + for (int wh = wh_base; wh < wH; wh += f_wgt_jump_h) { + for (int ww = ww_base; ww < wW; ww += f_wgt_jump_w) { + int wh_flip = flip ? wH - wh - 1 : wh; + int ww_flip = flip ? wW - ww - 1 : ww; + int ih = ih_base + wh_flip * wt_dilation[0]; + int iw = iw_base + ww_flip * wt_dilation[1]; - for (int wh = wh_base; wh < wH; wh += f_wgt_jump_h) { - for (int ww = ww_base; ww < wW; ww += f_wgt_jump_w) { - int wh_flip = flip ? wH - wh - 1 : wh; - int ww_flip = flip ? wW - ww - 1 : ww; - int ih = ih_base + wh_flip * wt_dilation[0]; - int iw = iw_base + ww_flip * wt_dilation[1]; + if (ih >= 0 && ih < iH && iw >= 0 && iw < iW) { + const T* wt_ptr_pt = + wt_ptr + wh * wt_stride_H + ww * wt_stride_W; - if (ih >= 0 && ih < iH && iw >= 0 && iw < iW) { - const T* wt_ptr_pt = - wt_ptr + wh * wt_stride_H + ww * wt_stride_W; + int ih_dil = !is_idil_one ? (ih / in_dilation[0]) : ih; + int iw_dil = !is_idil_one ? (iw / in_dilation[1]) : iw; - int ih_dil = !is_idil_one ? (ih / in_dilation[0]) : ih; - int iw_dil = !is_idil_one ? (iw / in_dilation[1]) : iw; + const T* in_ptr_pt = + in_ptr + ih_dil * in_stride_H + iw_dil * in_stride_W; - const T* in_ptr_pt = - in_ptr + ih_dil * in_stride_H + iw_dil * in_stride_W; + for (int c = g * C_per_group; c < (g + 1) * C_per_group; + ++c) { + r += static_cast(in_ptr_pt[c * in_stride_C]) * + static_cast( + wt_ptr_pt[(c % C_per_group) * wt_stride_C]); + } // c - for (int c = g * C_per_group; c < (g + 1) * C_per_group; - ++c) { - r += static_cast(in_ptr_pt[c * in_stride_C]) * - static_cast( - wt_ptr_pt[(c % C_per_group) * wt_stride_C]); - } // c + } // ih, iw check + } // ww + } // wh - } // ih, iw check - } // ww - } // wh + out_ptr[0] = static_cast(r); + out_ptr += out_stride_O; + wt_ptr += wt_stride_O; + } // o + } // g + }; - out_ptr[0] = static_cast(r); - out_ptr += out_stride_O; - wt_ptr += wt_stride_O; - } // o - } // g - }; + int oH_border_0 = 0; + int oH_border_1 = + is_idil_one ? ((padding[0] + wt_strides[0] - 1) / wt_strides[0]) : oH; + int oH_border_2 = std::max( + oH_border_1, (iH + padding[0] - wH * wt_dilation[0]) / wt_strides[0]); + int oH_border_3 = oH; - int oH_border_0 = 0; - int oH_border_1 = - is_idil_one ? ((padding[0] + wt_strides[0] - 1) / wt_strides[0]) : oH; - int oH_border_2 = std::max( - oH_border_1, (iH + padding[0] - wH * wt_dilation[0]) / wt_strides[0]); - int oH_border_3 = oH; + int oW_border_0 = 0; + int oW_border_1 = + is_idil_one ? ((padding[1] + wt_strides[1] - 1) / wt_strides[1]) : oW; + int oW_border_2 = std::max( + oW_border_1, (iW + padding[1] - wW * wt_dilation[1]) / wt_strides[1]); + int oW_border_3 = oW; - int oW_border_0 = 0; - int oW_border_1 = - is_idil_one ? ((padding[1] + wt_strides[1] - 1) / wt_strides[1]) : oW; - int oW_border_2 = std::max( - oW_border_1, (iW + padding[1] - wW * wt_dilation[1]) / wt_strides[1]); - int oW_border_3 = oW; + for (int n = 0; n < N; ++n) { + // Case 1: oh might put us out of bounds + for (int oh = oH_border_0; oh < oH_border_1; ++oh) { + for (int ow = 0; ow < oW; ++ow) { + pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow); + } // ow + } // oh - for (int n = 0; n < N; ++n) { - // Case 1: oh might put us out of bounds - for (int oh = oH_border_0; oh < oH_border_1; ++oh) { - for (int ow = 0; ow < oW; ++ow) { - pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow); - } // ow - } // oh + // Case 2: oh in bounds + for (int oh = oH_border_1; oh < oH_border_2; ++oh) { + // Case a: ow might put us out of bounds + for (int ow = oW_border_0; ow < oW_border_1; ++ow) { + pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow); + } // ow - // Case 2: oh in bounds - for (int oh = oH_border_1; oh < oH_border_2; ++oh) { - // Case a: ow might put us out of bounds - for (int ow = oW_border_0; ow < oW_border_1; ++ow) { - pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow); - } // ow + // Case b: ow in bounds + for (int ow = oW_border_1; ow < oW_border_2; ++ow) { + pt_conv_no_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow); + } // ow - // Case b: ow in bounds - for (int ow = oW_border_1; ow < oW_border_2; ++ow) { - pt_conv_no_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow); - } // ow + // Case c: ow might put us out of bounds + for (int ow = oW_border_2; ow < oW_border_3; ++ow) { + pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow); + } // ow - // Case c: ow might put us out of bounds - for (int ow = oW_border_2; ow < oW_border_3; ++ow) { - pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow); - } // ow + } // oh - } // oh + // Case 3: oh might put us out of bounds + for (int oh = oH_border_2; oh < oH_border_3; ++oh) { + for (int ow = 0; ow < oW; ++ow) { + pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow); + } // ow + } // oh - // Case 3: oh might put us out of bounds - for (int oh = oH_border_2; oh < oH_border_3; ++oh) { - for (int ow = 0; ow < oW; ++ow) { - pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow); - } // ow - } // oh + st_in_ptr += in_stride_N; + st_out_ptr += out_stride_N; - st_in_ptr += in_stride_N; - st_out_ptr += out_stride_N; - - } // n + } // n + }); } template @@ -322,187 +355,87 @@ void slow_conv_3D( const std::vector& wt_strides, const std::vector& wt_dilation, const std::vector& in_dilation, - bool flip) { - const T* st_wt_ptr = wt.data(); - const T* st_in_ptr = in.data(); - T* st_out_ptr = out.data(); + bool flip, + Stream stream) { + auto& encoder = cpu::get_command_encoder(stream); + encoder.set_input_array(in); + encoder.set_input_array(wt); + encoder.set_output_array(out); - const int N = in.shape(0); // Batch size, should be the same as out.shape(0) - const int iD = 1 + in_dilation[0] * (in.shape(1) - 1); // Input spatial dim - const int iH = 1 + in_dilation[1] * (in.shape(2) - 1); // Input spatial dim - const int iW = 1 + in_dilation[2] * (in.shape(3) - 1); // Input spatial dim - const int oD = out.shape(1); // Output spatial dim - const int oH = out.shape(2); // Output spatial dim - const int oW = out.shape(3); // Output spatial dim - const int O = wt.shape(0); // Out channels - const int C = wt.shape(4); // In channels - const int wD = wt.shape(1); // Weight spatial dim - const int wH = wt.shape(2); // Weight spatial dim - const int wW = wt.shape(3); // Weight spatial dim + encoder.dispatch([st_wt_ptr = wt.data(), + st_in_ptr = in.data(), + st_out_ptr = out.data(), - const size_t in_stride_N = in.strides()[0]; - const size_t in_stride_D = in.strides()[1]; - const size_t in_stride_H = in.strides()[2]; - const size_t in_stride_W = in.strides()[3]; - const size_t in_stride_C = in.strides()[4]; + N = in.shape( + 0), // Batch size, should be the same as out.shape(0) + iD = 1 + + in_dilation[0] * (in.shape(1) - 1), // Input spatial dim + iH = 1 + + in_dilation[1] * (in.shape(2) - 1), // Input spatial dim + iW = 1 + + in_dilation[2] * (in.shape(3) - 1), // Input spatial dim + oD = out.shape(1), // Output spatial dim + oH = out.shape(2), // Output spatial dim + oW = out.shape(3), // Output spatial dim + O = wt.shape(0), // Out channels + C = wt.shape(4), // In channels + wD = wt.shape(1), // Weight spatial dim + wH = wt.shape(2), // Weight spatial dim + wW = wt.shape(3), // Weight spatial dim - const size_t wt_stride_O = wt.strides()[0]; - const size_t wt_stride_D = wt.strides()[1]; - const size_t wt_stride_H = wt.strides()[2]; - const size_t wt_stride_W = wt.strides()[3]; - const size_t wt_stride_C = wt.strides()[4]; + in_stride_N = in.strides()[0], + in_stride_D = in.strides()[1], + in_stride_H = in.strides()[2], + in_stride_W = in.strides()[3], + in_stride_C = in.strides()[4], - const size_t out_stride_N = out.strides()[0]; - const size_t out_stride_D = out.strides()[1]; - const size_t out_stride_H = out.strides()[2]; - const size_t out_stride_W = out.strides()[3]; - const size_t out_stride_O = out.strides()[4]; + wt_stride_O = wt.strides()[0], + wt_stride_D = wt.strides()[1], + wt_stride_H = wt.strides()[2], + wt_stride_W = wt.strides()[3], + wt_stride_C = wt.strides()[4], - bool is_idil_one = - in_dilation[0] == 1 && in_dilation[1] == 1 && in_dilation[2] == 1; + out_stride_N = out.strides()[0], + out_stride_D = out.strides()[1], + out_stride_H = out.strides()[2], + out_stride_W = out.strides()[3], + out_stride_O = out.strides()[4], + padding, + wt_strides, + wt_dilation, + in_dilation, + flip]() mutable { + bool is_idil_one = + in_dilation[0] == 1 && in_dilation[1] == 1 && in_dilation[2] == 1; - auto pt_conv_no_checks = [&](const T* in_ptr, - const T* wt_ptr, - T* out_ptr, - int od, - int oh, - int ow) { - out_ptr += od * out_stride_D + oh * out_stride_H + ow * out_stride_W; - int id_base = od * wt_strides[0] - padding[0]; - int ih_base = oh * wt_strides[1] - padding[1]; - int iw_base = ow * wt_strides[2] - padding[2]; + auto pt_conv_no_checks = [&](const T* in_ptr, + const T* wt_ptr, + T* out_ptr, + int od, + int oh, + int ow) { + out_ptr += od * out_stride_D + oh * out_stride_H + ow * out_stride_W; + int id_base = od * wt_strides[0] - padding[0]; + int ih_base = oh * wt_strides[1] - padding[1]; + int iw_base = ow * wt_strides[2] - padding[2]; - for (int o = 0; o < O; ++o) { - float r = 0.; + for (int o = 0; o < O; ++o) { + float r = 0.; - for (int wd = 0; wd < wD; ++wd) { - for (int wh = 0; wh < wH; ++wh) { - for (int ww = 0; ww < wW; ++ww) { - int wd_flip = flip ? wD - wd - 1 : wd; - int wh_flip = flip ? wH - wh - 1 : wh; - int ww_flip = flip ? wW - ww - 1 : ww; - int id = id_base + wd_flip * wt_dilation[0]; - int ih = ih_base + wh_flip * wt_dilation[1]; - int iw = iw_base + ww_flip * wt_dilation[2]; + for (int wd = 0; wd < wD; ++wd) { + for (int wh = 0; wh < wH; ++wh) { + for (int ww = 0; ww < wW; ++ww) { + int wd_flip = flip ? wD - wd - 1 : wd; + int wh_flip = flip ? wH - wh - 1 : wh; + int ww_flip = flip ? wW - ww - 1 : ww; + int id = id_base + wd_flip * wt_dilation[0]; + int ih = ih_base + wh_flip * wt_dilation[1]; + int iw = iw_base + ww_flip * wt_dilation[2]; - const T* wt_ptr_pt = - wt_ptr + wd * wt_stride_D + wh * wt_stride_H + ww * wt_stride_W; - const T* in_ptr_pt = - in_ptr + id * in_stride_D + ih * in_stride_H + iw * in_stride_W; - - for (int c = 0; c < C; ++c) { - r += static_cast(in_ptr_pt[0]) * - static_cast(wt_ptr_pt[0]); - in_ptr_pt += in_stride_C; - wt_ptr_pt += wt_stride_C; - } // c - - } // ww - } // wh - } // wd - - out_ptr[0] = static_cast(r); - out_ptr += out_stride_O; - wt_ptr += wt_stride_O; - } // o - }; - - int jump_d = flip ? -wt_dilation[0] : wt_dilation[0]; - int jump_h = flip ? -wt_dilation[1] : wt_dilation[1]; - int jump_w = flip ? -wt_dilation[2] : wt_dilation[2]; - - int init_d = (flip ? (wD - 1) * wt_dilation[0] : 0); - int init_h = (flip ? (wH - 1) * wt_dilation[1] : 0); - int init_w = (flip ? (wW - 1) * wt_dilation[2] : 0); - - int f_wgt_jump_d = std::lcm(in_dilation[0], wt_dilation[0]) / wt_dilation[0]; - int f_wgt_jump_h = std::lcm(in_dilation[1], wt_dilation[1]) / wt_dilation[1]; - int f_wgt_jump_w = std::lcm(in_dilation[2], wt_dilation[2]) / wt_dilation[2]; - - int f_out_jump_d = std::lcm(in_dilation[0], wt_strides[0]) / wt_strides[0]; - int f_out_jump_h = std::lcm(in_dilation[1], wt_strides[1]) / wt_strides[1]; - int f_out_jump_w = std::lcm(in_dilation[2], wt_strides[2]) / wt_strides[2]; - - std::vector base_d(f_out_jump_d); - std::vector base_h(f_out_jump_h); - std::vector base_w(f_out_jump_w); - - for (int i = 0; i < f_out_jump_d; ++i) { - int id_loop = i * wt_strides[0] - padding[0] + init_d; - - int wd_base = 0; - while (wd_base < wD && id_loop % in_dilation[0] != 0) { - wd_base++; - id_loop += jump_d; - } - - base_d[i] = wd_base; - } - - for (int i = 0; i < f_out_jump_h; ++i) { - int ih_loop = i * wt_strides[1] - padding[1] + init_h; - - int wh_base = 0; - while (wh_base < wH && ih_loop % in_dilation[1] != 0) { - wh_base++; - ih_loop += jump_h; - } - - base_h[i] = wh_base; - } - - for (int j = 0; j < f_out_jump_w; ++j) { - int iw_loop = j * wt_strides[2] - padding[2] + init_w; - - int ww_base = 0; - while (ww_base < wW && iw_loop % in_dilation[2] != 0) { - ww_base++; - iw_loop += jump_w; - } - - base_w[j] = ww_base; - } - - auto pt_conv_all_checks = [&](const T* in_ptr, - const T* wt_ptr, - T* out_ptr, - int od, - int oh, - int ow) { - out_ptr += od * out_stride_D + oh * out_stride_H + ow * out_stride_W; - - int id_base = od * wt_strides[0] - padding[0]; - int ih_base = oh * wt_strides[1] - padding[1]; - int iw_base = ow * wt_strides[2] - padding[2]; - - int wd_base = base_d[od % f_out_jump_d]; - int wh_base = base_h[oh % f_out_jump_h]; - int ww_base = base_w[ow % f_out_jump_w]; - - for (int o = 0; o < O; ++o) { - float r = 0.; - - for (int wd = wd_base; wd < wD; wd += f_wgt_jump_d) { - for (int wh = wh_base; wh < wH; wh += f_wgt_jump_h) { - for (int ww = ww_base; ww < wW; ww += f_wgt_jump_w) { - int wd_flip = flip ? wD - wd - 1 : wd; - int wh_flip = flip ? wH - wh - 1 : wh; - int ww_flip = flip ? wW - ww - 1 : ww; - int id = id_base + wd_flip * wt_dilation[0]; - int ih = ih_base + wh_flip * wt_dilation[1]; - int iw = iw_base + ww_flip * wt_dilation[2]; - - if (id >= 0 && id < iD && ih >= 0 && ih < iH && iw >= 0 && - iw < iW) { const T* wt_ptr_pt = wt_ptr + wd * wt_stride_D + wh * wt_stride_H + ww * wt_stride_W; - - int id_dil = !is_idil_one ? (id / in_dilation[0]) : id; - int ih_dil = !is_idil_one ? (ih / in_dilation[1]) : ih; - int iw_dil = !is_idil_one ? (iw / in_dilation[2]) : iw; - - const T* in_ptr_pt = in_ptr + id_dil * in_stride_D + - ih_dil * in_stride_H + iw_dil * in_stride_W; + const T* in_ptr_pt = in_ptr + id * in_stride_D + + ih * in_stride_H + iw * in_stride_W; for (int c = 0; c < C; ++c) { r += static_cast(in_ptr_pt[0]) * @@ -511,96 +444,214 @@ void slow_conv_3D( wt_ptr_pt += wt_stride_C; } // c - } // iD, ih, iw check - } // ww - } // wh - } // wd + } // ww + } // wh + } // wd - out_ptr[0] = static_cast(r); - out_ptr += out_stride_O; - wt_ptr += wt_stride_O; - } // o - }; + out_ptr[0] = static_cast(r); + out_ptr += out_stride_O; + wt_ptr += wt_stride_O; + } // o + }; - int oD_border_0 = 0; - int oD_border_1 = - is_idil_one ? ((padding[0] + wt_strides[0] - 1) / wt_strides[0]) : oD; - int oD_border_2 = std::max( - oD_border_1, (iD + padding[0] - wD * wt_dilation[0]) / wt_strides[0]); - int oD_border_3 = oD; + int jump_d = flip ? -wt_dilation[0] : wt_dilation[0]; + int jump_h = flip ? -wt_dilation[1] : wt_dilation[1]; + int jump_w = flip ? -wt_dilation[2] : wt_dilation[2]; - int oH_border_0 = 0; - int oH_border_1 = - is_idil_one ? ((padding[1] + wt_strides[1] - 1) / wt_strides[1]) : oH; - int oH_border_2 = std::max( - oH_border_1, (iH + padding[1] - wH * wt_dilation[1]) / wt_strides[1]); - int oH_border_3 = oH; + int init_d = (flip ? (wD - 1) * wt_dilation[0] : 0); + int init_h = (flip ? (wH - 1) * wt_dilation[1] : 0); + int init_w = (flip ? (wW - 1) * wt_dilation[2] : 0); - int oW_border_0 = 0; - int oW_border_1 = - is_idil_one ? ((padding[2] + wt_strides[2] - 1) / wt_strides[2]) : oW; - int oW_border_2 = std::max( - oW_border_1, (iW + padding[2] - wW * wt_dilation[2]) / wt_strides[2]); - int oW_border_3 = oW; + int f_wgt_jump_d = + std::lcm(in_dilation[0], wt_dilation[0]) / wt_dilation[0]; + int f_wgt_jump_h = + std::lcm(in_dilation[1], wt_dilation[1]) / wt_dilation[1]; + int f_wgt_jump_w = + std::lcm(in_dilation[2], wt_dilation[2]) / wt_dilation[2]; - for (int n = 0; n < N; ++n) { - // Case 1: od might put us out of bounds - for (int od = oD_border_0; od < oD_border_1; ++od) { - for (int oh = 0; oh < oH; ++oh) { - for (int ow = 0; ow < oW; ++ow) { - pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow); - } // ow - } // oh - } // od + int f_out_jump_d = std::lcm(in_dilation[0], wt_strides[0]) / wt_strides[0]; + int f_out_jump_h = std::lcm(in_dilation[1], wt_strides[1]) / wt_strides[1]; + int f_out_jump_w = std::lcm(in_dilation[2], wt_strides[2]) / wt_strides[2]; - // Case 2: od in bounds - for (int od = oD_border_1; od < oD_border_2; ++od) { - // Case 2.1: oh might put us out of bounds - for (int oh = oH_border_0; oh < oH_border_1; ++oh) { - for (int ow = 0; ow < oW; ++ow) { - pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow); - } // ow - } // oh + std::vector base_d(f_out_jump_d); + std::vector base_h(f_out_jump_h); + std::vector base_w(f_out_jump_w); - // Case 2.2: oh in bounds - for (int oh = oH_border_1; oh < oH_border_2; ++oh) { - // Case 2.2.1: ow might put us out of bounds - for (int ow = oW_border_0; ow < oW_border_1; ++ow) { - pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow); - } // ow + for (int i = 0; i < f_out_jump_d; ++i) { + int id_loop = i * wt_strides[0] - padding[0] + init_d; - // Case 2.2.2: ow in bounds - for (int ow = oW_border_1; ow < oW_border_2; ++ow) { - pt_conv_no_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow); - } // ow + int wd_base = 0; + while (wd_base < wD && id_loop % in_dilation[0] != 0) { + wd_base++; + id_loop += jump_d; + } - // Case 2.2.3: ow might put us out of bounds - for (int ow = oW_border_2; ow < oW_border_3; ++ow) { - pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow); - } // ow - } // oh + base_d[i] = wd_base; + } - // Case 2.3: oh might put us out of bounds - for (int oh = oH_border_2; oh < oH_border_3; ++oh) { - for (int ow = 0; ow < oW; ++ow) { - pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow); - } // ow - } // oh - } // od + for (int i = 0; i < f_out_jump_h; ++i) { + int ih_loop = i * wt_strides[1] - padding[1] + init_h; - // Case 3: od might put us out of bounds - for (int od = oD_border_2; od < oD_border_3; ++od) { - for (int oh = 0; oh < oH; ++oh) { - for (int ow = 0; ow < oW; ++ow) { - pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow); - } // ow - } // oh - } // od + int wh_base = 0; + while (wh_base < wH && ih_loop % in_dilation[1] != 0) { + wh_base++; + ih_loop += jump_h; + } - st_in_ptr += in_stride_N; - st_out_ptr += out_stride_N; + base_h[i] = wh_base; + } - } // n + for (int j = 0; j < f_out_jump_w; ++j) { + int iw_loop = j * wt_strides[2] - padding[2] + init_w; + + int ww_base = 0; + while (ww_base < wW && iw_loop % in_dilation[2] != 0) { + ww_base++; + iw_loop += jump_w; + } + + base_w[j] = ww_base; + } + + auto pt_conv_all_checks = [&](const T* in_ptr, + const T* wt_ptr, + T* out_ptr, + int od, + int oh, + int ow) { + out_ptr += od * out_stride_D + oh * out_stride_H + ow * out_stride_W; + + int id_base = od * wt_strides[0] - padding[0]; + int ih_base = oh * wt_strides[1] - padding[1]; + int iw_base = ow * wt_strides[2] - padding[2]; + + int wd_base = base_d[od % f_out_jump_d]; + int wh_base = base_h[oh % f_out_jump_h]; + int ww_base = base_w[ow % f_out_jump_w]; + + for (int o = 0; o < O; ++o) { + float r = 0.; + + for (int wd = wd_base; wd < wD; wd += f_wgt_jump_d) { + for (int wh = wh_base; wh < wH; wh += f_wgt_jump_h) { + for (int ww = ww_base; ww < wW; ww += f_wgt_jump_w) { + int wd_flip = flip ? wD - wd - 1 : wd; + int wh_flip = flip ? wH - wh - 1 : wh; + int ww_flip = flip ? wW - ww - 1 : ww; + int id = id_base + wd_flip * wt_dilation[0]; + int ih = ih_base + wh_flip * wt_dilation[1]; + int iw = iw_base + ww_flip * wt_dilation[2]; + + if (id >= 0 && id < iD && ih >= 0 && ih < iH && iw >= 0 && + iw < iW) { + const T* wt_ptr_pt = wt_ptr + wd * wt_stride_D + + wh * wt_stride_H + ww * wt_stride_W; + + int id_dil = !is_idil_one ? (id / in_dilation[0]) : id; + int ih_dil = !is_idil_one ? (ih / in_dilation[1]) : ih; + int iw_dil = !is_idil_one ? (iw / in_dilation[2]) : iw; + + const T* in_ptr_pt = in_ptr + id_dil * in_stride_D + + ih_dil * in_stride_H + iw_dil * in_stride_W; + + for (int c = 0; c < C; ++c) { + r += static_cast(in_ptr_pt[0]) * + static_cast(wt_ptr_pt[0]); + in_ptr_pt += in_stride_C; + wt_ptr_pt += wt_stride_C; + } // c + + } // iD, ih, iw check + } // ww + } // wh + } // wd + + out_ptr[0] = static_cast(r); + out_ptr += out_stride_O; + wt_ptr += wt_stride_O; + } // o + }; + + int oD_border_0 = 0; + int oD_border_1 = + is_idil_one ? ((padding[0] + wt_strides[0] - 1) / wt_strides[0]) : oD; + int oD_border_2 = std::max( + oD_border_1, (iD + padding[0] - wD * wt_dilation[0]) / wt_strides[0]); + int oD_border_3 = oD; + + int oH_border_0 = 0; + int oH_border_1 = + is_idil_one ? ((padding[1] + wt_strides[1] - 1) / wt_strides[1]) : oH; + int oH_border_2 = std::max( + oH_border_1, (iH + padding[1] - wH * wt_dilation[1]) / wt_strides[1]); + int oH_border_3 = oH; + + int oW_border_0 = 0; + int oW_border_1 = + is_idil_one ? ((padding[2] + wt_strides[2] - 1) / wt_strides[2]) : oW; + int oW_border_2 = std::max( + oW_border_1, (iW + padding[2] - wW * wt_dilation[2]) / wt_strides[2]); + int oW_border_3 = oW; + + for (int n = 0; n < N; ++n) { + // Case 1: od might put us out of bounds + for (int od = oD_border_0; od < oD_border_1; ++od) { + for (int oh = 0; oh < oH; ++oh) { + for (int ow = 0; ow < oW; ++ow) { + pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow); + } // ow + } // oh + } // od + + // Case 2: od in bounds + for (int od = oD_border_1; od < oD_border_2; ++od) { + // Case 2.1: oh might put us out of bounds + for (int oh = oH_border_0; oh < oH_border_1; ++oh) { + for (int ow = 0; ow < oW; ++ow) { + pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow); + } // ow + } // oh + + // Case 2.2: oh in bounds + for (int oh = oH_border_1; oh < oH_border_2; ++oh) { + // Case 2.2.1: ow might put us out of bounds + for (int ow = oW_border_0; ow < oW_border_1; ++ow) { + pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow); + } // ow + + // Case 2.2.2: ow in bounds + for (int ow = oW_border_1; ow < oW_border_2; ++ow) { + pt_conv_no_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow); + } // ow + + // Case 2.2.3: ow might put us out of bounds + for (int ow = oW_border_2; ow < oW_border_3; ++ow) { + pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow); + } // ow + } // oh + + // Case 2.3: oh might put us out of bounds + for (int oh = oH_border_2; oh < oH_border_3; ++oh) { + for (int ow = 0; ow < oW; ++ow) { + pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow); + } // ow + } // oh + } // od + + // Case 3: od might put us out of bounds + for (int od = oD_border_2; od < oD_border_3; ++od) { + for (int oh = 0; oh < oH; ++oh) { + for (int ow = 0; ow < oW; ++ow) { + pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow); + } // ow + } // oh + } // od + + st_in_ptr += in_stride_N; + st_out_ptr += out_stride_N; + + } // n + }); } void dispatch_slow_conv_1D( @@ -611,16 +662,41 @@ void dispatch_slow_conv_1D( const std::vector& wt_strides, const std::vector& wt_dilation, const std::vector& in_dilation, - bool flip) { + bool flip, + Stream stream) { if (in.dtype() == float32) { return slow_conv_1D( - in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip); + in, + wt, + out, + padding, + wt_strides, + wt_dilation, + in_dilation, + flip, + stream); } else if (in.dtype() == float16) { return slow_conv_1D( - in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip); + in, + wt, + out, + padding, + wt_strides, + wt_dilation, + in_dilation, + flip, + stream); } else if (in.dtype() == bfloat16) { return slow_conv_1D( - in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip); + in, + wt, + out, + padding, + wt_strides, + wt_dilation, + in_dilation, + flip, + stream); } else { throw std::invalid_argument( "[Convolution::eval] got unsupported data type."); @@ -635,16 +711,41 @@ void dispatch_slow_conv_2D( const std::vector& wt_strides, const std::vector& wt_dilation, const std::vector& in_dilation, - bool flip) { + bool flip, + Stream stream) { if (in.dtype() == float32) { return slow_conv_2D( - in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip); + in, + wt, + out, + padding, + wt_strides, + wt_dilation, + in_dilation, + flip, + stream); } else if (in.dtype() == float16) { return slow_conv_2D( - in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip); + in, + wt, + out, + padding, + wt_strides, + wt_dilation, + in_dilation, + flip, + stream); } else if (in.dtype() == bfloat16) { return slow_conv_2D( - in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip); + in, + wt, + out, + padding, + wt_strides, + wt_dilation, + in_dilation, + flip, + stream); } else { throw std::invalid_argument( "[Convolution::eval] got unsupported data type."); @@ -659,16 +760,41 @@ void dispatch_slow_conv_3D( const std::vector& wt_strides, const std::vector& wt_dilation, const std::vector& in_dilation, - bool flip) { + bool flip, + Stream stream) { if (in.dtype() == float32) { return slow_conv_3D( - in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip); + in, + wt, + out, + padding, + wt_strides, + wt_dilation, + in_dilation, + flip, + stream); } else if (in.dtype() == float16) { return slow_conv_3D( - in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip); + in, + wt, + out, + padding, + wt_strides, + wt_dilation, + in_dilation, + flip, + stream); } else if (in.dtype() == bfloat16) { return slow_conv_3D( - in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip); + in, + wt, + out, + padding, + wt_strides, + wt_dilation, + in_dilation, + flip, + stream); } else { throw std::invalid_argument( "[Convolution::eval] got unsupported data type."); @@ -680,17 +806,11 @@ void dispatch_slow_conv_3D( /////////////////////////////////////////////////////////////////////////////// template -void flip_spatial_dims_inplace(array& wt) { - T* x = wt.data(); - size_t out_channels = wt.shape(0); - size_t in_channels = wt.shape(-1); - - // Calculate the total size of the spatial dimensions - int spatial_size = 1; - for (int d = 1; d < wt.ndim() - 1; ++d) { - spatial_size *= wt.shape(d); - } - +void flip_spatial_dims_inplace( + T* x, + size_t in_channels, + size_t out_channels, + size_t spatial_size) { for (size_t i = 0; i < out_channels; i++) { T* top = x + i * spatial_size * in_channels; T* bottom = @@ -711,7 +831,8 @@ void explicit_gemm_conv_1D_cpu( array out, const std::vector& padding, const std::vector& wt_strides, - const std::vector& wt_dilation) { + const std::vector& wt_dilation, + Stream stream) { const int N = in.shape(0); // Batch size, should be the same as out.shape(0) const int iH = in.shape(1); // Input spatial dim const int C = in.shape(2); // Input channels @@ -724,13 +845,16 @@ void explicit_gemm_conv_1D_cpu( const int O_per_group = O / groups; auto conv_dtype = float32; + auto& encoder = cpu::get_command_encoder(stream); // Pad input Shape padded_shape = {N, iH + 2 * padding[0], C}; array in_padded(padded_shape, conv_dtype, nullptr, {}); // Fill with zeros - copy(array(0, conv_dtype), in_padded, CopyType::Scalar); + std::vector temps; + temps.push_back(array(0, conv_dtype)); + copy(temps.back(), in_padded, CopyType::Scalar, stream); // Pick input slice from padded size_t data_offset = padding[0] * in_padded.strides()[1]; @@ -741,9 +865,9 @@ void explicit_gemm_conv_1D_cpu( in_padded.flags(), in_padded_slice.size(), data_offset); - // Copy input values into the slice - copy_inplace(in, in_padded_slice, CopyType::GeneralGeneral); + copy_inplace(in, in_padded_slice, CopyType::GeneralGeneral, stream); + temps.push_back(in_padded_slice); // Make strided view Shape strided_shape = {N, oH, wH, C}; @@ -767,7 +891,8 @@ void explicit_gemm_conv_1D_cpu( // Materialize strided view Shape strided_reshape = {N * oH, wH * C}; array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {}); - copy(in_strided_view, in_strided, CopyType::General); + copy(in_strided_view, in_strided, CopyType::General, stream); + temps.push_back(in_strided); // Check wt dtype and prepare auto gemm_wt = wt; @@ -784,43 +909,62 @@ void explicit_gemm_conv_1D_cpu( wt.size(), 0); gemm_wt = array(wt_transpose.shape(), float32, nullptr, {}); - copy(wt_transpose, gemm_wt, CopyType::General); + copy(wt_transpose, gemm_wt, CopyType::General, stream); + temps.push_back(gemm_wt); } else if (wt.dtype() != float32 || !wt.flags().row_contiguous) { auto ctype = wt.flags().row_contiguous ? CopyType::Vector : CopyType::General; gemm_wt = array(wt.shape(), float32, nullptr, {}); - copy(wt, gemm_wt, ctype); + copy(wt, gemm_wt, ctype, stream); + temps.push_back(gemm_wt); } if (out.dtype() != float32) { gemm_out = array(out.shape(), float32, nullptr, {}); gemm_out.set_data(allocator::malloc_or_wait(gemm_out.nbytes())); + temps.push_back(gemm_out); } - for (int g = 0; g < groups; ++g) { - // Perform gemm - cblas_sgemm( - CblasRowMajor, - CblasNoTrans, // no trans A - CblasTrans, // transB - strided_reshape[0], // M - O_per_group, // N - C_per_group * wH, // K - 1.0f, // alpha - in_strided.data() + g * C_per_group * wH, // A - wH * C, // lda - gemm_wt.data() + g * O_per_group * C_per_group * wH, // B - wH * C_per_group, // ldb - 0.0f, // beta - gemm_out.data() + g * O_per_group, // C - O // ldc - ); + encoder.set_input_array(in_strided); + encoder.set_input_array(gemm_wt); + encoder.set_output_array(gemm_out); - // Copy results if needed - if (out.dtype() != float32) { - copy(gemm_out, out, CopyType::Vector); + encoder.dispatch([in_strided_ptr = in_strided.data(), + gemm_wt_ptr = gemm_wt.data(), + gemm_out_ptr = gemm_out.data(), + groups, + strided_reshape = strided_reshape[0], + O, + C, + wH, + O_per_group, + C_per_group]() { + for (int g = 0; g < groups; ++g) { + // Perform gemm + cblas_sgemm( + CblasRowMajor, + CblasNoTrans, // no trans A + CblasTrans, // transB + strided_reshape, // M + O_per_group, // N + C_per_group * wH, // K + 1.0f, // alpha + in_strided_ptr + g * C_per_group * wH, // A + wH * C, // lda + gemm_wt_ptr + g * O_per_group * C_per_group * wH, // B + wH * C_per_group, // ldb + 0.0f, // beta + gemm_out_ptr + g * O_per_group, // C + O // ldc + ); } + }); + + // Copy results if needed + if (out.dtype() != float32) { + copy_inplace(gemm_out, out, CopyType::Vector, stream); } + encoder.add_temporaries(std::move(temps)); } void explicit_gemm_conv_2D_cpu( @@ -829,7 +973,8 @@ void explicit_gemm_conv_2D_cpu( array out, const std::vector& padding, const std::vector& wt_strides, - const std::vector& wt_dilation) { + const std::vector& wt_dilation, + Stream stream) { const int N = in.shape(0); // Batch size, should be the same as out.shape(0) const int iH = in.shape(1); // Input spatial dim const int iW = in.shape(2); // Input spatial dim @@ -841,13 +986,16 @@ void explicit_gemm_conv_2D_cpu( const int wW = wt.shape(2); // Weight spatial dim auto conv_dtype = out.dtype(); + auto& encoder = cpu::get_command_encoder(stream); // Pad input Shape padded_shape = {N, iH + 2 * padding[0], iW + 2 * padding[1], C}; array in_padded(padded_shape, conv_dtype, nullptr, {}); // Fill with zeros - copy(array(0, conv_dtype), in_padded, CopyType::Scalar); + std::vector temps; + temps.push_back(array(0, conv_dtype)); + copy(temps.back(), in_padded, CopyType::Scalar, stream); // Pick input slice from padded size_t data_offset = @@ -859,9 +1007,10 @@ void explicit_gemm_conv_2D_cpu( in_padded.flags(), in_padded_slice.size(), data_offset); + temps.push_back(in_padded_slice); // Copy input values into the slice - copy_inplace(in, in_padded_slice, CopyType::GeneralGeneral); + copy_inplace(in, in_padded_slice, CopyType::GeneralGeneral, stream); // Make strided view Shape strided_shape = {N, oH, oW, wH, wW, C}; @@ -882,7 +1031,8 @@ void explicit_gemm_conv_2D_cpu( // Materialize strided view Shape strided_reshape = {N * oH * oW, wH * wW * C}; array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {}); - copy(in_strided_view, in_strided, CopyType::General); + copy(in_strided_view, in_strided, CopyType::General, stream); + temps.push_back(in_strided); // Check wt dtype and prepare auto gemm_wt = wt; @@ -892,36 +1042,49 @@ void explicit_gemm_conv_2D_cpu( auto ctype = wt.flags().row_contiguous ? CopyType::Vector : CopyType::General; gemm_wt = array(wt.shape(), float32, nullptr, {}); - copy(wt, gemm_wt, ctype); + copy(wt, gemm_wt, ctype, stream); + temps.push_back(gemm_wt); } if (out.dtype() != float32) { gemm_out = array(out.shape(), float32, nullptr, {}); gemm_out.set_data(allocator::malloc_or_wait(gemm_out.nbytes())); + temps.push_back(gemm_out); } - // Perform gemm - cblas_sgemm( - CblasRowMajor, - CblasNoTrans, // no trans A - CblasTrans, // transB - strided_reshape[0], // M - O, // N - strided_reshape[1], // K - 1.0f, // alpha - in_strided.data(), - strided_reshape[1], // lda - gemm_wt.data(), - strided_reshape[1], // ldb - 0.0f, // beta - gemm_out.data(), - O // ldc - ); + encoder.set_input_array(in_strided); + encoder.set_input_array(gemm_wt); + encoder.set_output_array(gemm_out); + + encoder.dispatch([in_strided_ptr = in_strided.data(), + gemm_wt_ptr = gemm_wt.data(), + gemm_out_ptr = gemm_out.data(), + strided_reshape = std::move(strided_reshape), + O]() { + // Perform gemm + cblas_sgemm( + CblasRowMajor, + CblasNoTrans, // no trans A + CblasTrans, // transB + strided_reshape[0], // M + O, // N + strided_reshape[1], // K + 1.0f, // alpha + in_strided_ptr, + strided_reshape[1], // lda + gemm_wt_ptr, + strided_reshape[1], // ldb + 0.0f, // beta + gemm_out_ptr, + O // ldc + ); + }); // Copy results if needed if (out.dtype() != float32) { - copy(gemm_out, out, CopyType::Vector); + copy_inplace(gemm_out, out, CopyType::Vector, stream); } + encoder.add_temporaries(std::move(temps)); } void explicit_gemm_conv_ND_cpu( @@ -931,7 +1094,8 @@ void explicit_gemm_conv_ND_cpu( const std::vector& padding, const std::vector& wt_strides, const std::vector& wt_dilation, - const bool flip) { + const bool flip, + Stream stream) { const int N = in.shape(0); // Batch size, should be the same as out.shape(0) const auto iDim = Shape(in.shape().begin() + 1, in.shape().end() - 1); // Input spatial dim @@ -944,6 +1108,8 @@ void explicit_gemm_conv_ND_cpu( auto conv_dtype = float32; + auto& encoder = cpu::get_command_encoder(stream); + // Pad input Shape padded_shape(in.shape().size()); padded_shape.front() = N; @@ -954,7 +1120,8 @@ void explicit_gemm_conv_ND_cpu( array in_padded(padded_shape, conv_dtype, nullptr, {}); // Fill with zeros - copy(array(0, conv_dtype), in_padded, CopyType::Scalar); + std::vector temps = {array(0, conv_dtype)}; + copy(temps.back(), in_padded, CopyType::Scalar, stream); // Pick input slice from padded size_t data_offset = 0; @@ -970,7 +1137,8 @@ void explicit_gemm_conv_ND_cpu( data_offset); // Copy input values into the slice - copy_inplace(in, in_padded_slice, CopyType::GeneralGeneral); + copy_inplace(in, in_padded_slice, CopyType::GeneralGeneral, stream); + temps.push_back(in_padded_slice); // Make strided view Shape strided_shape(oDim.size() + wDim.size() + 2); @@ -1008,7 +1176,8 @@ void explicit_gemm_conv_ND_cpu( } array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {}); - copy(in_strided_view, in_strided, CopyType::General); + copy(in_strided_view, in_strided, CopyType::General, stream); + temps.push_back(in_strided); // Check wt dtype and prepare auto gemm_wt = wt; @@ -1018,44 +1187,70 @@ void explicit_gemm_conv_ND_cpu( auto ctype = wt.flags().row_contiguous ? CopyType::Vector : CopyType::General; gemm_wt = array(wt.shape(), float32, nullptr, {}); - copy(wt, gemm_wt, ctype); + copy(wt, gemm_wt, ctype, stream); + temps.push_back(gemm_wt); } if (flip) { auto gemm_wt_ = array(gemm_wt.shape(), float32, nullptr, {}); - copy(gemm_wt, gemm_wt_, CopyType::Vector); + copy(gemm_wt, gemm_wt_, CopyType::Vector, stream); + temps.push_back(gemm_wt_); - flip_spatial_dims_inplace(gemm_wt_); + // Calculate the total size of the spatial dimensions + int spatial_size = 1; + for (int d = 1; d < gemm_wt.ndim() - 1; ++d) { + spatial_size *= gemm_wt.shape(d); + } + encoder.set_output_array(gemm_wt_); + encoder.dispatch([gemm_wt_ptr = gemm_wt_.data(), + out_channels = gemm_wt.shape(0), + in_channels = gemm_wt.shape(-1), + spatial_size]() { + flip_spatial_dims_inplace( + gemm_wt_ptr, in_channels, out_channels, spatial_size); + }); gemm_wt = gemm_wt_; } if (out.dtype() != float32) { gemm_out = array(out.shape(), float32, nullptr, {}); gemm_out.set_data(allocator::malloc_or_wait(gemm_out.nbytes())); + temps.push_back(gemm_out); } - // Perform gemm - cblas_sgemm( - CblasRowMajor, - CblasNoTrans, // no trans A - CblasTrans, // transB - strided_reshape[0], // M - O, // N - strided_reshape[1], // K - 1.0f, // alpha - in_strided.data(), - strided_reshape[1], // lda - gemm_wt.data(), - strided_reshape[1], // ldb - 0.0f, // beta - gemm_out.data(), - O // ldc - ); + encoder.set_input_array(in_strided); + encoder.set_input_array(gemm_wt); + encoder.set_output_array(gemm_out); + + encoder.dispatch([in_strided_ptr = in_strided.data(), + gemm_wt_ptr = gemm_wt.data(), + gemm_out_ptr = gemm_out.data(), + strided_reshape = std::move(strided_reshape), + O]() { + // Perform gemm + cblas_sgemm( + CblasRowMajor, + CblasNoTrans, // no trans A + CblasTrans, // transB + strided_reshape[0], // M + O, // N + strided_reshape[1], // K + 1.0f, // alpha + in_strided_ptr, + strided_reshape[1], // lda + gemm_wt_ptr, + strided_reshape[1], // ldb + 0.0f, // beta + gemm_out_ptr, + O // ldc + ); + }); // Copy results if needed if (out.dtype() != float32) { - copy(gemm_out, out, CopyType::Vector); + copy_inplace(gemm_out, out, CopyType::Vector, stream); } + encoder.add_temporaries(std::move(temps)); } /////////////////////////////////////////////////////////////////////////////// @@ -1070,19 +1265,20 @@ void conv_1D_cpu( const std::vector& wt_strides, const std::vector& wt_dilation, const std::vector& in_dilation, - bool flip) { + bool flip, + Stream stream) { const int groups = in.shape().back() / wt.shape().back(); if (wt_dilation[0] == 1 && in_dilation[0] == 1 && !flip) { return explicit_gemm_conv_1D_cpu( - in, wt, out, padding, wt_strides, wt_dilation); + in, wt, out, padding, wt_strides, wt_dilation, stream); } if (wt_dilation[0] == 1 && in_dilation[0] == 1 && groups == 1) { return explicit_gemm_conv_ND_cpu( - in, wt, out, padding, wt_strides, wt_dilation, flip); + in, wt, out, padding, wt_strides, wt_dilation, flip, stream); } return dispatch_slow_conv_1D( - in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip); + in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip, stream); } void conv_2D_cpu( @@ -1093,16 +1289,17 @@ void conv_2D_cpu( const std::vector& wt_strides, const std::vector& wt_dilation, const std::vector& in_dilation, - bool flip) { + bool flip, + Stream stream) { const int groups = in.shape().back() / wt.shape().back(); if (wt_dilation[0] == 1 && wt_dilation[1] == 1 && in_dilation[0] == 1 && in_dilation[1] == 1 && groups == 1) { return explicit_gemm_conv_ND_cpu( - in, wt, out, padding, wt_strides, wt_dilation, flip); + in, wt, out, padding, wt_strides, wt_dilation, flip, stream); } return dispatch_slow_conv_2D( - in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip); + in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip, stream); } void conv_3D_cpu( @@ -1113,17 +1310,18 @@ void conv_3D_cpu( const std::vector& wt_strides, const std::vector& wt_dilation, const std::vector& in_dilation, - bool flip) { + bool flip, + Stream stream) { const int groups = in.shape().back() / wt.shape().back(); if (wt_dilation[0] == 1 && wt_dilation[1] == 1 && wt_dilation[2] == 1 && in_dilation[0] == 1 && in_dilation[1] == 1 && in_dilation[2] == 1 && groups == 1) { return explicit_gemm_conv_ND_cpu( - in, wt, out, padding, wt_strides, wt_dilation, flip); + in, wt, out, padding, wt_strides, wt_dilation, flip, stream); } return dispatch_slow_conv_3D( - in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip); + in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip, stream); } } // namespace @@ -1144,7 +1342,8 @@ void Convolution::eval_cpu(const std::vector& inputs, array& out) { kernel_strides_, kernel_dilation_, input_dilation_, - flip_); + flip_, + stream()); } // 2D convolution else if (in.ndim() == (2 + 2)) { @@ -1156,7 +1355,8 @@ void Convolution::eval_cpu(const std::vector& inputs, array& out) { kernel_strides_, kernel_dilation_, input_dilation_, - flip_); + flip_, + stream()); } // 1D convolution else if (in.ndim() == (1 + 2)) { @@ -1168,7 +1368,8 @@ void Convolution::eval_cpu(const std::vector& inputs, array& out) { kernel_strides_, kernel_dilation_, input_dilation_, - flip_); + flip_, + stream()); } // Throw error else { diff --git a/mlx/backend/cpu/copy.cpp b/mlx/backend/cpu/copy.cpp index 6c14d0fa4..bb2b3a227 100644 --- a/mlx/backend/cpu/copy.cpp +++ b/mlx/backend/cpu/copy.cpp @@ -5,6 +5,7 @@ #include "mlx/allocator.h" #include "mlx/backend/common/utils.h" #include "mlx/backend/cpu/copy.h" +#include "mlx/backend/cpu/encoder.h" #include "mlx/backend/cpu/simd/simd.h" namespace mlx::core { @@ -12,20 +13,29 @@ namespace mlx::core { namespace { template -void copy_single(const array& src, array& dst) { - auto val = static_cast(src.data()[0]); +void copy_single(const array& src, array& dst, Stream stream) { + auto src_ptr = src.data(); auto dst_ptr = dst.data(); - for (int i = 0; i < dst.size(); ++i) { - dst_ptr[i] = val; - } + auto& encoder = cpu::get_command_encoder(stream); + encoder.set_input_array(src); + encoder.set_output_array(dst); + encoder.dispatch([src_ptr, dst_ptr, size = dst.size()]() { + auto val = static_cast(src_ptr[0]); + std::fill_n(dst_ptr, size, val); + }); } template -void copy_vector(const array& src, array& dst) { +void copy_vector(const array& src, array& dst, Stream stream) { auto src_ptr = src.data(); auto dst_ptr = dst.data(); size_t size = src.data_size(); - std::copy(src_ptr, src_ptr + src.data_size(), dst_ptr); + auto& encoder = cpu::get_command_encoder(stream); + encoder.set_input_array(src); + encoder.set_output_array(dst); + encoder.dispatch([src_ptr, dst_ptr, size = src.data_size()]() { + std::copy(src_ptr, src_ptr + size, dst_ptr); + }); } template @@ -56,151 +66,220 @@ template void copy_general_general( const array& src, array& dst, + Stream stream, const Shape& data_shape, const Strides& i_strides, const Strides& o_strides, int64_t i_offset, - int64_t o_offset) { - if (data_shape.empty()) { - auto val = static_cast(*(src.data() + i_offset)); - auto dst_ptr = dst.data() + o_offset; - *dst_ptr = val; - return; - } - auto [shape, strides] = - collapse_contiguous_dims(data_shape, {i_strides, o_strides}); + int64_t o_offset, + const std::optional& dynamic_i_offset, + const std::optional& dynamic_o_offset) { auto src_ptr = src.data() + i_offset; auto dst_ptr = dst.data() + o_offset; - int ndim = shape.size(); - if (ndim == 1) { - copy_dims( - src_ptr, dst_ptr, shape, strides[0], strides[1], 0); - return; - } else if (ndim == 2) { - copy_dims( - src_ptr, dst_ptr, shape, strides[0], strides[1], 0); - return; - } else if (ndim == 3) { - copy_dims( - src_ptr, dst_ptr, shape, strides[0], strides[1], 0); - return; - } - ContiguousIterator in(shape, strides[0], ndim - 3); - ContiguousIterator out(shape, strides[1], ndim - 3); - auto stride = std::accumulate( - shape.end() - 3, shape.end(), 1, std::multiplies()); - for (int64_t elem = 0; elem < src.size(); elem += stride) { - copy_dims( - src_ptr + in.loc, - dst_ptr + out.loc, - shape, - strides[0], - strides[1], - ndim - 3); - in.step(); - out.step(); - } + auto i_offset_ptr = + dynamic_i_offset ? dynamic_i_offset->data() : nullptr; + auto o_offset_ptr = + dynamic_o_offset ? dynamic_o_offset->data() : nullptr; + + auto& encoder = cpu::get_command_encoder(stream); + encoder.set_input_array(src); + encoder.set_output_array(dst); + encoder.dispatch([src_ptr, + dst_ptr, + size = src.size(), + data_shape = data_shape, + i_strides = i_strides, + o_strides = o_strides, + i_offset_ptr, + o_offset_ptr]() mutable { + if (data_shape.empty()) { + auto val = static_cast(*src_ptr); + *dst_ptr = val; + return; + } + auto [shape, strides] = + collapse_contiguous_dims(data_shape, {i_strides, o_strides}); + + int ndim = shape.size(); + if (ndim < 3) { + if (i_offset_ptr) { + src_ptr += i_offset_ptr[0]; + } + if (o_offset_ptr) { + dst_ptr += o_offset_ptr[0]; + } + + if (ndim == 1) { + copy_dims( + src_ptr, dst_ptr, shape, strides[0], strides[1], 0); + } else if (ndim == 2) { + copy_dims( + src_ptr, dst_ptr, shape, strides[0], strides[1], 0); + } else if (ndim == 3) { + copy_dims( + src_ptr, dst_ptr, shape, strides[0], strides[1], 0); + } + return; + } + if (i_offset_ptr) { + src_ptr += i_offset_ptr[0]; + } + if (o_offset_ptr) { + dst_ptr += o_offset_ptr[0]; + } + + ContiguousIterator in(shape, strides[0], ndim - 3); + ContiguousIterator out(shape, strides[1], ndim - 3); + auto stride = std::accumulate( + shape.end() - 3, shape.end(), 1, std::multiplies()); + for (int64_t elem = 0; elem < size; elem += stride) { + copy_dims( + src_ptr + in.loc, + dst_ptr + out.loc, + shape, + strides[0], + strides[1], + ndim - 3); + in.step(); + out.step(); + } + }); } template -inline void copy_general_general(const array& src, array& dst) { +inline void copy_general_general(const array& src, array& dst, Stream stream) { copy_general_general( - src, dst, src.shape(), src.strides(), dst.strides(), 0, 0); + src, + dst, + stream, + src.shape(), + src.strides(), + dst.strides(), + 0, + 0, + std::nullopt, + std::nullopt); } template void copy_general( const array& src, array& dst, + Stream stream, const Shape& data_shape, const Strides& i_strides, const Strides&, int64_t i_offset, - int64_t o_offset) { + int64_t o_offset, + const std::optional& dynamic_i_offset, + const std::optional& dynamic_o_offset) { copy_general_general( src, dst, + stream, data_shape, i_strides, make_contiguous_strides(data_shape), i_offset, - o_offset); + o_offset, + dynamic_i_offset, + dynamic_o_offset); } template -inline void copy_general(const array& src, array& dst) { +inline void copy_general(const array& src, array& dst, Stream stream) { copy_general_general( src, dst, + stream, src.shape(), src.strides(), make_contiguous_strides(src.shape()), 0, - 0); + 0, + std::nullopt, + std::nullopt); } template -void copy(const array& src, array& dst, CopyType ctype, Args&&... args) { +void copy( + const array& src, + array& dst, + CopyType ctype, + Stream stream, + Args&&... args) { switch (ctype) { case CopyType::Scalar: - copy_single(src, dst); + copy_single(src, dst, stream); return; case CopyType::Vector: - copy_vector(src, dst); + copy_vector(src, dst, stream); return; case CopyType::General: - copy_general(src, dst, std::forward(args)...); + copy_general(src, dst, stream, std::forward(args)...); return; case CopyType::GeneralGeneral: - copy_general_general(src, dst, std::forward(args)...); + copy_general_general( + src, dst, stream, std::forward(args)...); return; } } template -void copy(const array& src, array& dst, CopyType ctype, Args&&... args) { +void copy( + const array& src, + array& dst, + CopyType ctype, + Stream stream, + Args&&... args) { switch (dst.dtype()) { case bool_: - copy(src, dst, ctype, std::forward(args)...); + copy(src, dst, ctype, stream, std::forward(args)...); break; case uint8: - copy(src, dst, ctype, std::forward(args)...); + copy(src, dst, ctype, stream, std::forward(args)...); break; case uint16: - copy(src, dst, ctype, std::forward(args)...); + copy( + src, dst, ctype, stream, std::forward(args)...); break; case uint32: - copy(src, dst, ctype, std::forward(args)...); + copy( + src, dst, ctype, stream, std::forward(args)...); break; case uint64: - copy(src, dst, ctype, std::forward(args)...); + copy( + src, dst, ctype, stream, std::forward(args)...); break; case int8: - copy(src, dst, ctype, std::forward(args)...); + copy(src, dst, ctype, stream, std::forward(args)...); break; case int16: - copy(src, dst, ctype, std::forward(args)...); + copy(src, dst, ctype, stream, std::forward(args)...); break; case int32: - copy(src, dst, ctype, std::forward(args)...); + copy(src, dst, ctype, stream, std::forward(args)...); break; case int64: - copy(src, dst, ctype, std::forward(args)...); + copy(src, dst, ctype, stream, std::forward(args)...); break; case float16: - copy(src, dst, ctype, std::forward(args)...); + copy( + src, dst, ctype, stream, std::forward(args)...); break; case float32: - copy(src, dst, ctype, std::forward(args)...); + copy(src, dst, ctype, stream, std::forward(args)...); break; case float64: - copy(src, dst, ctype, std::forward(args)...); + copy(src, dst, ctype, stream, std::forward(args)...); break; case bfloat16: - copy(src, dst, ctype, std::forward(args)...); + copy( + src, dst, ctype, stream, std::forward(args)...); break; case complex64: - copy(src, dst, ctype, std::forward(args)...); + copy( + src, dst, ctype, stream, std::forward(args)...); break; } } @@ -210,84 +289,71 @@ inline void copy_inplace_dispatch( const array& src, array& dst, CopyType ctype, + Stream stream, Args&&... args) { switch (src.dtype()) { case bool_: - copy(src, dst, ctype, std::forward(args)...); + copy(src, dst, ctype, stream, std::forward(args)...); break; case uint8: - copy(src, dst, ctype, std::forward(args)...); + copy(src, dst, ctype, stream, std::forward(args)...); break; case uint16: - copy(src, dst, ctype, std::forward(args)...); + copy(src, dst, ctype, stream, std::forward(args)...); break; case uint32: - copy(src, dst, ctype, std::forward(args)...); + copy(src, dst, ctype, stream, std::forward(args)...); break; case uint64: - copy(src, dst, ctype, std::forward(args)...); + copy(src, dst, ctype, stream, std::forward(args)...); break; case int8: - copy(src, dst, ctype, std::forward(args)...); + copy(src, dst, ctype, stream, std::forward(args)...); break; case int16: - copy(src, dst, ctype, std::forward(args)...); + copy(src, dst, ctype, stream, std::forward(args)...); break; case int32: - copy(src, dst, ctype, std::forward(args)...); + copy(src, dst, ctype, stream, std::forward(args)...); break; case int64: - copy(src, dst, ctype, std::forward(args)...); + copy(src, dst, ctype, stream, std::forward(args)...); break; case float16: - copy(src, dst, ctype, std::forward(args)...); + copy(src, dst, ctype, stream, std::forward(args)...); break; case float32: - copy(src, dst, ctype, std::forward(args)...); + copy(src, dst, ctype, stream, std::forward(args)...); break; case float64: - copy(src, dst, ctype, std::forward(args)...); + copy(src, dst, ctype, stream, std::forward(args)...); break; case bfloat16: - copy(src, dst, ctype, std::forward(args)...); + copy(src, dst, ctype, stream, std::forward(args)...); break; case complex64: - copy(src, dst, ctype, std::forward(args)...); + copy(src, dst, ctype, stream, std::forward(args)...); break; } } } // namespace -void copy_inplace(const array& src, array& dst, CopyType ctype) { - copy_inplace_dispatch(src, dst, ctype); +void copy_inplace(const array& src, array& dst, CopyType ctype, Stream stream) { + copy_inplace_dispatch(src, dst, ctype, stream); } -void copy(const array& src, array& dst, CopyType ctype) { - // Allocate the output - switch (ctype) { - case CopyType::Vector: - if (src.is_donatable() && src.itemsize() == dst.itemsize()) { - dst.copy_shared_buffer(src); - } else { - auto size = src.data_size(); - dst.set_data( - allocator::malloc_or_wait(size * dst.itemsize()), - size, - src.strides(), - src.flags()); - } - break; - case CopyType::Scalar: - case CopyType::General: - case CopyType::GeneralGeneral: - dst.set_data(allocator::malloc_or_wait(dst.nbytes())); - break; +void copy(const array& src, array& dst, CopyType ctype, Stream stream) { + bool donated = set_copy_output_data(src, dst, ctype); + if (donated && src.dtype() == dst.dtype()) { + // If the output has the same type as the input then there is nothing to + // copy, just use the buffer. + return; } if (ctype == CopyType::GeneralGeneral) { ctype = CopyType::General; } - copy_inplace(src, dst, ctype); + copy_inplace(src, dst, ctype, stream); } void copy_inplace( @@ -298,7 +364,10 @@ void copy_inplace( const Strides& o_strides, int64_t i_offset, int64_t o_offset, - CopyType ctype) { + CopyType ctype, + Stream stream, + const std::optional& dynamic_i_offset, /* = std::nullopt */ + const std::optional& dynamic_o_offset /* = std::nullopt */) { switch (ctype) { case CopyType::General: case CopyType::GeneralGeneral: @@ -306,15 +375,18 @@ void copy_inplace( src, dst, ctype, + stream, data_shape, i_strides, o_strides, i_offset, - o_offset); + o_offset, + dynamic_i_offset, + dynamic_o_offset); break; case CopyType::Scalar: case CopyType::Vector: - copy_inplace_dispatch(src, dst, ctype); + copy_inplace_dispatch(src, dst, ctype, stream); } } diff --git a/mlx/backend/cpu/copy.h b/mlx/backend/cpu/copy.h index 1e8dc2530..e930da465 100644 --- a/mlx/backend/cpu/copy.h +++ b/mlx/backend/cpu/copy.h @@ -2,14 +2,16 @@ #pragma once +#include + #include "mlx/array.h" #include "mlx/backend/common/copy.h" #include "mlx/backend/common/utils.h" namespace mlx::core { -void copy(const array& src, array& dst, CopyType ctype); -void copy_inplace(const array& src, array& dst, CopyType ctype); +void copy(const array& src, array& dst, CopyType ctype, Stream stream); +void copy_inplace(const array& src, array& dst, CopyType ctype, Stream stream); void copy_inplace( const array& src, @@ -19,6 +21,9 @@ void copy_inplace( const Strides& o_strides, int64_t i_offset, int64_t o_offset, - CopyType ctype); + CopyType ctype, + Stream stream, + const std::optional& dynamic_i_offset = std::nullopt, + const std::optional& dynamic_o_offset = std::nullopt); } // namespace mlx::core diff --git a/mlx/backend/cpu/distributed.cpp b/mlx/backend/cpu/distributed.cpp new file mode 100644 index 000000000..1d6d9de06 --- /dev/null +++ b/mlx/backend/cpu/distributed.cpp @@ -0,0 +1,94 @@ +// Copyright © 2024 Apple Inc. + +#include + +#include "mlx/allocator.h" +#include "mlx/backend/cpu/copy.h" +#include "mlx/backend/cpu/encoder.h" +#include "mlx/distributed/primitives.h" + +namespace mlx::core::distributed { + +std::pair ensure_row_contiguous(const array& arr, Stream stream) { + if (arr.flags().row_contiguous) { + return {arr, false}; + } else { + array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); + copy(arr, arr_copy, CopyType::General, stream); + return {arr_copy, true}; + } +}; + +void AllReduce::eval_cpu( + const std::vector& inputs, + std::vector& outputs) { + assert(inputs.size() == 1); + assert(outputs.size() == 1); + + auto donate_or_copy = [s = stream()](const array& in, array& out) { + if (in.flags().row_contiguous) { + if (in.is_donatable()) { + out.copy_shared_buffer(in); + } else { + out.set_data(allocator::malloc_or_wait(out.nbytes())); + } + return in; + } else { + array arr_copy(in.shape(), in.dtype(), nullptr, {}); + copy(in, arr_copy, CopyType::General, s); + out.copy_shared_buffer(arr_copy); + return arr_copy; + } + }; + + auto in = donate_or_copy(inputs[0], outputs[0]); + switch (reduce_type_) { + case Sum: + distributed::detail::all_sum(group(), in, outputs[0], stream()); + break; + default: + throw std::runtime_error("Only all reduce sum is supported for now"); + } +} + +void AllGather::eval_cpu( + const std::vector& inputs, + std::vector& outputs) { + assert(inputs.size() == 1); + assert(outputs.size() == 1); + + auto [in, copied] = ensure_row_contiguous(inputs[0], stream()); + outputs[0].set_data(allocator::malloc_or_wait(outputs[0].nbytes())); + distributed::detail::all_gather(group(), in, outputs[0], stream()); + if (copied) { + auto& enc = cpu::get_command_encoder(stream()); + enc.add_temporary(in); + } +} + +void Send::eval_cpu( + const std::vector& inputs, + std::vector& outputs) { + assert(inputs.size() == 1); + assert(outputs.size() == 1); + + auto [in, copied] = ensure_row_contiguous(inputs[0], stream()); + distributed::detail::send(group(), in, dst_, stream()); + outputs[0].copy_shared_buffer(inputs[0]); + if (copied) { + auto& enc = cpu::get_command_encoder(stream()); + enc.add_temporary(in); + } +} + +void Recv::eval_cpu( + const std::vector& inputs, + std::vector& outputs) { + assert(inputs.size() == 0); + assert(outputs.size() == 1); + + outputs[0].set_data(allocator::malloc_or_wait(outputs[0].nbytes())); + distributed::detail::recv(group(), outputs[0], src_, stream()); +} + +} // namespace mlx::core::distributed diff --git a/mlx/backend/cpu/eigh.cpp b/mlx/backend/cpu/eigh.cpp index c9ec2875f..348a27199 100644 --- a/mlx/backend/cpu/eigh.cpp +++ b/mlx/backend/cpu/eigh.cpp @@ -3,6 +3,7 @@ #include "mlx/allocator.h" #include "mlx/array.h" #include "mlx/backend/cpu/copy.h" +#include "mlx/backend/cpu/encoder.h" #include "mlx/backend/cpu/lapack.h" #include "mlx/linalg.h" #include "mlx/primitives.h" @@ -16,59 +17,72 @@ void eigh_impl( array& vectors, array& values, const std::string& uplo, - bool compute_eigenvectors) { + bool compute_eigenvectors, + Stream stream) { auto vec_ptr = vectors.data(); auto eig_ptr = values.data(); - char jobz = compute_eigenvectors ? 'V' : 'N'; - auto N = vectors.shape(-1); - // Work query - int lwork = -1; - int liwork = -1; - int info; - { - T work; - int iwork; - syevd( - &jobz, - uplo.c_str(), - &N, - nullptr, - &N, - nullptr, - &work, - &lwork, - &iwork, - &liwork, - &info); - lwork = static_cast(work); - liwork = iwork; - } - - auto work_buf = array::Data{allocator::malloc_or_wait(sizeof(T) * lwork)}; - auto iwork_buf = array::Data{allocator::malloc_or_wait(sizeof(int) * liwork)}; - for (size_t i = 0; i < vectors.size() / (N * N); ++i) { - syevd( - &jobz, - uplo.c_str(), - &N, - vec_ptr, - &N, - eig_ptr, - static_cast(work_buf.buffer.raw_ptr()), - &lwork, - static_cast(iwork_buf.buffer.raw_ptr()), - &liwork, - &info); - vec_ptr += N * N; - eig_ptr += N; - if (info != 0) { - std::stringstream msg; - msg << "[Eigh::eval_cpu] Eigenvalue decomposition failed with error code " - << info; - throw std::runtime_error(msg.str()); + auto& encoder = cpu::get_command_encoder(stream); + encoder.set_output_array(vectors); + encoder.set_output_array(values); + encoder.dispatch([vec_ptr, + eig_ptr, + jobz, + uplo = uplo[0], + N = vectors.shape(-1), + size = vectors.size()]() mutable { + // Work query + int lwork = -1; + int liwork = -1; + int info; + { + T work; + int iwork; + syevd( + &jobz, + &uplo, + &N, + nullptr, + &N, + nullptr, + &work, + &lwork, + &iwork, + &liwork, + &info); + lwork = static_cast(work); + liwork = iwork; } + + auto work_buf = array::Data{allocator::malloc_or_wait(sizeof(T) * lwork)}; + auto iwork_buf = + array::Data{allocator::malloc_or_wait(sizeof(int) * liwork)}; + for (size_t i = 0; i < size / (N * N); ++i) { + syevd( + &jobz, + &uplo, + &N, + vec_ptr, + &N, + eig_ptr, + static_cast(work_buf.buffer.raw_ptr()), + &lwork, + static_cast(iwork_buf.buffer.raw_ptr()), + &liwork, + &info); + vec_ptr += N * N; + eig_ptr += N; + if (info != 0) { + std::stringstream msg; + msg << "[Eigh::eval_cpu] Eigenvalue decomposition failed with error code " + << info; + throw std::runtime_error(msg.str()); + } + } + }); + if (!compute_eigenvectors) { + encoder.add_temporary(vectors); } } @@ -89,7 +103,8 @@ void Eigh::eval_cpu( copy( a, vectors, - a.flags().row_contiguous ? CopyType::Vector : CopyType::General); + a.flags().row_contiguous ? CopyType::Vector : CopyType::General, + stream()); if (compute_eigenvectors_) { // Set the strides and flags so the eigenvectors @@ -107,14 +122,15 @@ void Eigh::eval_cpu( flags.col_contiguous = true; } } - vectors.move_shared_buffer(vectors, strides, flags, vectors.data_size()); + vectors.copy_shared_buffer(vectors, strides, flags, vectors.data_size()); } switch (a.dtype()) { case float32: - eigh_impl(vectors, values, uplo_, compute_eigenvectors_); + eigh_impl(vectors, values, uplo_, compute_eigenvectors_, stream()); break; case float64: - eigh_impl(vectors, values, uplo_, compute_eigenvectors_); + eigh_impl( + vectors, values, uplo_, compute_eigenvectors_, stream()); break; default: throw std::runtime_error( diff --git a/mlx/backend/cpu/encoder.cpp b/mlx/backend/cpu/encoder.cpp new file mode 100644 index 000000000..f3fedefc3 --- /dev/null +++ b/mlx/backend/cpu/encoder.cpp @@ -0,0 +1,16 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cpu/encoder.h" + +namespace mlx::core::cpu { + +CommandEncoder& get_command_encoder(Stream stream) { + static std::unordered_map encoder_map; + auto it = encoder_map.find(stream.index); + if (it == encoder_map.end()) { + it = encoder_map.emplace(stream.index, stream).first; + } + return it->second; +} + +} // namespace mlx::core::cpu diff --git a/mlx/backend/cpu/encoder.h b/mlx/backend/cpu/encoder.h new file mode 100644 index 000000000..aae64fb5e --- /dev/null +++ b/mlx/backend/cpu/encoder.h @@ -0,0 +1,53 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include + +#include "mlx/array.h" +#include "mlx/scheduler.h" + +namespace mlx::core::cpu { + +struct CommandEncoder { + CommandEncoder(Stream stream) : stream_(stream) {} + + CommandEncoder(const CommandEncoder&) = delete; + CommandEncoder& operator=(const CommandEncoder&) = delete; + CommandEncoder(CommandEncoder&&) = delete; + CommandEncoder& operator=(CommandEncoder&&) = delete; + + void set_input_array(const array& a) {} + void set_output_array(array& a) {} + + // Hold onto a temporary until any already scheduled tasks which use it as + // an input are complete. + void add_temporary(array arr) { + temporaries_.push_back(std::move(arr)); + } + + void add_temporaries(std::vector arrays) { + temporaries_.insert( + temporaries_.end(), + std::make_move_iterator(arrays.begin()), + std::make_move_iterator(arrays.end())); + } + + std::vector& temporaries() { + return temporaries_; + } + + template + void dispatch(F&& f, Args&&... args) { + auto task = std::bind(std::forward(f), std::forward(args)...); + scheduler::enqueue(stream_, std::move(task)); + } + + private: + Stream stream_; + std::vector temporaries_; +}; + +CommandEncoder& get_command_encoder(Stream stream); + +} // namespace mlx::core::cpu diff --git a/mlx/backend/cpu/eval.cpp b/mlx/backend/cpu/eval.cpp new file mode 100644 index 000000000..04811e737 --- /dev/null +++ b/mlx/backend/cpu/eval.cpp @@ -0,0 +1,44 @@ +// Copyright © 2025 Apple Inc. +#include "mlx/backend/cpu/eval.h" +#include "mlx/backend/cpu/encoder.h" +#include "mlx/primitives.h" +#include "mlx/scheduler.h" +#include "mlx/utils.h" + +namespace mlx::core::cpu { + +void eval(array& arr) { + auto s = arr.primitive().stream(); + + auto outputs = arr.outputs(); + { + // If the array is a tracer hold a reference + // to its inputs so they don't get donated + std::vector inputs; + if (arr.is_tracer()) { + inputs = arr.inputs(); + } + arr.primitive().eval_cpu(arr.inputs(), outputs); + } + + std::unordered_set> buffers; + for (auto& in : arr.inputs()) { + buffers.insert(in.data_shared_ptr()); + } + for (auto& s : arr.siblings()) { + buffers.insert(s.data_shared_ptr()); + } + // Remove the output if it was donated to by an input + if (auto it = buffers.find(arr.data_shared_ptr()); it != buffers.end()) { + buffers.erase(it); + } + auto& encoder = cpu::get_command_encoder(s); + scheduler::notify_new_task(s); + encoder.dispatch([s, + buffers = std::move(buffers), + temps = std::move(encoder.temporaries())]() { + scheduler::notify_task_completion(s); + }); +} + +} // namespace mlx::core::cpu diff --git a/mlx/backend/cpu/eval.h b/mlx/backend/cpu/eval.h new file mode 100644 index 000000000..20156d61b --- /dev/null +++ b/mlx/backend/cpu/eval.h @@ -0,0 +1,12 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/array.h" +#include "mlx/stream.h" + +namespace mlx::core::cpu { + +void eval(array& arr); + +} // namespace mlx::core::cpu diff --git a/mlx/backend/cpu/fft.cpp b/mlx/backend/cpu/fft.cpp index 52bf80655..90227575b 100644 --- a/mlx/backend/cpu/fft.cpp +++ b/mlx/backend/cpu/fft.cpp @@ -4,6 +4,7 @@ #include "mlx/3rdparty/pocketfft.h" #include "mlx/allocator.h" +#include "mlx/backend/cpu/encoder.h" #include "mlx/primitives.h" namespace mlx::core { @@ -38,46 +39,78 @@ void FFT::eval_cpu(const std::vector& inputs, array& out) { }); scale /= nelem; } + + auto& encoder = cpu::get_command_encoder(stream()); + encoder.set_input_array(in); + encoder.set_output_array(out); + if (in.dtype() == complex64 && out.dtype() == complex64) { auto in_ptr = reinterpret_cast*>(in.data()); auto out_ptr = reinterpret_cast*>(out.data()); - pocketfft::c2c( - shape, - strides_in, - strides_out, - axes_, - !inverse_, - in_ptr, - out_ptr, - scale); + encoder.dispatch([shape = std::move(shape), + strides_in = std::move(strides_in), + strides_out = std::move(strides_out), + axes = axes_, + inverse = inverse_, + in_ptr, + out_ptr, + scale]() { + pocketfft::c2c( + shape, + strides_in, + strides_out, + axes, + !inverse, + in_ptr, + out_ptr, + scale); + }); } else if (in.dtype() == float32 && out.dtype() == complex64) { auto in_ptr = in.data(); auto out_ptr = reinterpret_cast*>(out.data()); - pocketfft::r2c( - shape, - strides_in, - strides_out, - axes_, - !inverse_, - in_ptr, - out_ptr, - scale); + encoder.dispatch([shape = std::move(shape), + strides_in = std::move(strides_in), + strides_out = std::move(strides_out), + axes = axes_, + inverse = inverse_, + in_ptr, + out_ptr, + scale]() { + pocketfft::r2c( + shape, + strides_in, + strides_out, + axes, + !inverse, + in_ptr, + out_ptr, + scale); + }); } else if (in.dtype() == complex64 && out.dtype() == float32) { auto in_ptr = reinterpret_cast*>(in.data()); auto out_ptr = out.data(); - pocketfft::c2r( - shape, - strides_in, - strides_out, - axes_, - !inverse_, - in_ptr, - out_ptr, - scale); + encoder.dispatch([shape = std::move(shape), + strides_in = std::move(strides_in), + strides_out = std::move(strides_out), + axes = axes_, + inverse = inverse_, + in_ptr, + out_ptr, + scale]() { + pocketfft::c2r( + shape, + strides_in, + strides_out, + axes, + !inverse, + in_ptr, + out_ptr, + scale); + }); } else { throw std::runtime_error( "[FFT] Received unexpected input and output type combination."); diff --git a/mlx/backend/cpu/gemm.h b/mlx/backend/cpu/gemm.h index 008c29157..d665cb91f 100644 --- a/mlx/backend/cpu/gemm.h +++ b/mlx/backend/cpu/gemm.h @@ -7,14 +7,20 @@ namespace mlx::core { template void matmul( - const array& a, - const array& b, - array& out, + const T* a, + const T* b, + T* out, bool a_transposed, bool b_transposed, size_t lda, size_t ldb, + size_t ldc, float alpha, - float beta); + float beta, + size_t batch_size, + const Shape& a_shape, + const Strides& a_strides, + const Shape& b_shape, + const Strides& b_strides); } // namespace mlx::core diff --git a/mlx/backend/cpu/gemms/bnns.cpp b/mlx/backend/cpu/gemms/bnns.cpp index a9d09380c..0ba7a9b8c 100644 --- a/mlx/backend/cpu/gemms/bnns.cpp +++ b/mlx/backend/cpu/gemms/bnns.cpp @@ -9,39 +9,46 @@ namespace mlx::core { -BNNSDataType to_bnns_dtype(Dtype mlx_dtype) { - uint32_t size_bits = size_of(mlx_dtype) * 8; - switch (kindof(mlx_dtype)) { - case Dtype::Kind::b: - return BNNSDataTypeBoolean; - case Dtype::Kind::u: - return BNNSDataType(BNNSDataTypeUIntBit | size_bits); - case Dtype::Kind::i: - return BNNSDataType(BNNSDataTypeIntBit | size_bits); - case Dtype::Kind::f: - return BNNSDataType(BNNSDataTypeFloatBit | size_bits); - case Dtype::Kind::V: - return BNNSDataTypeBFloat16; - case Dtype::Kind::c: - throw std::invalid_argument("BNNS does not support complex types"); - } +template +constexpr BNNSDataType to_bnns_dtype(); + +template <> +constexpr BNNSDataType to_bnns_dtype() { + return BNNSDataType(BNNSDataTypeFloatBit | 32); +} +template <> +constexpr BNNSDataType to_bnns_dtype() { + return BNNSDataType(BNNSDataTypeFloatBit | 16); } +template <> +constexpr BNNSDataType to_bnns_dtype() { + return BNNSDataTypeBFloat16; +} + +template void matmul_bnns( - const array& a, - const array& b, - array& out, + const T* a, + const T* b, + T* out, bool a_transposed, bool b_transposed, size_t lda, size_t ldb, + size_t ldc, float alpha, - float beta) { - size_t M = a.shape(-2); - size_t N = b.shape(-1); - size_t K = a.shape(-1); + float beta, + size_t batch_size, + const Shape& a_shape, + const Strides& a_strides, + const Shape& b_shape, + const Strides& b_strides) { + auto ndim = a_shape.size(); + size_t M = a_shape[ndim - 2]; + size_t N = b_shape[ndim - 1]; + size_t K = a_shape[ndim - 1]; - BNNSDataType bnns_dtype = to_bnns_dtype(out.dtype()); + BNNSDataType bnns_dtype = to_bnns_dtype(); #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wdeprecated-declarations" @@ -115,14 +122,14 @@ void matmul_bnns( auto bnns_filter = BNNSFilterCreateLayerBroadcastMatMul(&gemm_params, nullptr); - for (int i = 0; i < (a.size() / (M * K)); ++i) { + for (int i = 0; i < batch_size; ++i) { BNNSFilterApplyTwoInput( bnns_filter, - a.data() + - elem_to_loc(M * K * i, a.shape(), a.strides()) * a.itemsize(), - b.data() + - elem_to_loc(K * N * i, b.shape(), b.strides()) * b.itemsize(), - out.data() + M * N * i * out.itemsize()); + reinterpret_cast( + a + elem_to_loc(M * K * i, a_shape, a_strides)), + reinterpret_cast( + b + elem_to_loc(K * N * i, b_shape, b_strides)), + reinterpret_cast(out + M * N * i)); } BNNSFilterDestroy(bnns_filter); @@ -131,30 +138,72 @@ void matmul_bnns( template <> void matmul( - const array& a, - const array& b, - array& out, + const float16_t* a, + const float16_t* b, + float16_t* out, bool a_transposed, bool b_transposed, size_t lda, size_t ldb, + size_t ldc, float alpha, - float beta) { - matmul_bnns(a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta); + float beta, + size_t batch_size, + const Shape& a_shape, + const Strides& a_strides, + const Shape& b_shape, + const Strides& b_strides) { + matmul_bnns( + a, + b, + out, + a_transposed, + b_transposed, + lda, + ldb, + ldc, + alpha, + beta, + batch_size, + a_shape, + a_strides, + b_shape, + b_strides); } template <> void matmul( - const array& a, - const array& b, - array& out, + const bfloat16_t* a, + const bfloat16_t* b, + bfloat16_t* out, bool a_transposed, bool b_transposed, size_t lda, size_t ldb, + size_t ldc, float alpha, - float beta) { - matmul_bnns(a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta); + float beta, + size_t batch_size, + const Shape& a_shape, + const Strides& a_strides, + const Shape& b_shape, + const Strides& b_strides) { + matmul_bnns( + a, + b, + out, + a_transposed, + b_transposed, + lda, + ldb, + ldc, + alpha, + beta, + batch_size, + a_shape, + a_strides, + b_shape, + b_strides); } } // namespace mlx::core diff --git a/mlx/backend/cpu/gemms/cblas.cpp b/mlx/backend/cpu/gemms/cblas.cpp index 912098a1a..13eaeca91 100644 --- a/mlx/backend/cpu/gemms/cblas.cpp +++ b/mlx/backend/cpu/gemms/cblas.cpp @@ -8,20 +8,27 @@ namespace mlx::core { template <> void matmul( - const array& a, - const array& b, - array& out, + const float* a, + const float* b, + float* out, bool a_transposed, bool b_transposed, size_t lda, size_t ldb, + size_t ldc, float alpha, - float beta) { - size_t M = a.shape(-2); - size_t N = b.shape(-1); - size_t K = a.shape(-1); + float beta, + size_t batch_size, + const Shape& a_shape, + const Strides& a_strides, + const Shape& b_shape, + const Strides& b_strides) { + auto ndim = a_shape.size(); + size_t M = a_shape[ndim - 2]; + size_t N = b_shape[ndim - 1]; + size_t K = a_shape[ndim - 1]; - for (int i = 0; i < (a.size() / (M * K)); ++i) { + for (int i = 0; i < batch_size; ++i) { cblas_sgemm( CblasRowMajor, a_transposed ? CblasTrans : CblasNoTrans, // transA @@ -29,34 +36,40 @@ void matmul( M, N, K, - alpha, // alpha - a.data() + elem_to_loc(M * K * i, a.shape(), a.strides()), + alpha, + a + elem_to_loc(M * K * i, a_shape, a_strides), lda, - b.data() + elem_to_loc(K * N * i, b.shape(), b.strides()), + b + elem_to_loc(K * N * i, b_shape, b_strides), ldb, - beta, // beta - out.data() + M * N * i, - out.shape(-1) // ldc - ); + beta, + out + M * N * i, + ldc); } } template <> void matmul( - const array& a, - const array& b, - array& out, + const double* a, + const double* b, + double* out, bool a_transposed, bool b_transposed, size_t lda, size_t ldb, + size_t ldc, float alpha, - float beta) { - size_t M = a.shape(-2); - size_t N = b.shape(-1); - size_t K = a.shape(-1); + float beta, + size_t batch_size, + const Shape& a_shape, + const Strides& a_strides, + const Shape& b_shape, + const Strides& b_strides) { + auto ndim = a_shape.size(); + size_t M = a_shape[ndim - 2]; + size_t N = b_shape[ndim - 1]; + size_t K = a_shape[ndim - 1]; - for (int i = 0; i < (a.size() / (M * K)); ++i) { + for (int i = 0; i < batch_size; ++i) { cblas_dgemm( CblasRowMajor, a_transposed ? CblasTrans : CblasNoTrans, // transA @@ -64,15 +77,14 @@ void matmul( M, N, K, - alpha, // alpha - a.data() + elem_to_loc(M * K * i, a.shape(), a.strides()), + alpha, + a + elem_to_loc(M * K * i, a_shape, a_strides), lda, - b.data() + elem_to_loc(K * N * i, b.shape(), b.strides()), + b + elem_to_loc(K * N * i, b_shape, b_strides), ldb, - beta, // beta - out.data() + M * N * i, - out.shape(-1) // ldc - ); + beta, + out + M * N * i, + ldc); } } diff --git a/mlx/backend/cpu/gemms/no_bf16.cpp b/mlx/backend/cpu/gemms/no_bf16.cpp index bf470779d..157c07f46 100644 --- a/mlx/backend/cpu/gemms/no_bf16.cpp +++ b/mlx/backend/cpu/gemms/no_bf16.cpp @@ -6,15 +6,21 @@ namespace mlx::core { template <> void matmul( - const array&, - const array&, - array&, + const bfloat16_t*, + const bfloat16_t*, + bfloat16_t*, bool, bool, size_t, size_t, + size_t, float, - float) { + float, + size_t, + const Shape&, + const Strides&, + const Shape&, + const Strides&) { throw std::runtime_error("[Matmul::eval_cpu] bfloat16 not supported."); } diff --git a/mlx/backend/cpu/gemms/no_fp16.cpp b/mlx/backend/cpu/gemms/no_fp16.cpp index 7b39d4b30..3f3f41cc5 100644 --- a/mlx/backend/cpu/gemms/no_fp16.cpp +++ b/mlx/backend/cpu/gemms/no_fp16.cpp @@ -6,15 +6,21 @@ namespace mlx::core { template <> void matmul( - const array&, - const array&, - array&, + const float16_t*, + const float16_t*, + float16_t*, bool, bool, size_t, size_t, + size_t, float, - float) { + float, + size_t, + const Shape&, + const Strides&, + const Shape&, + const Strides&) { throw std::runtime_error("[Matmul::eval_cpu] float16 not supported."); } diff --git a/mlx/backend/cpu/hadamard.cpp b/mlx/backend/cpu/hadamard.cpp index eaeac83db..b0734ad79 100644 --- a/mlx/backend/cpu/hadamard.cpp +++ b/mlx/backend/cpu/hadamard.cpp @@ -4,16 +4,17 @@ #include "mlx/backend/common/hadamard.h" #include "mlx/backend/cpu/copy.h" +#include "mlx/backend/cpu/encoder.h" #include "mlx/primitives.h" namespace mlx::core { // n = 2^k component template -void hadamard_n(array& out, int n, int m, float scale) { - for (int b = 0; b < out.size() / n; b++) { +void hadamard_n(T* out, int n, int m, float scale, size_t size) { + for (int b = 0; b < size / n; b++) { size_t loc = b * n; - T* data_ptr = out.data() + loc; + T* data_ptr = out + loc; int h = 1; int n_over_2 = n / 2; while (h < n) { @@ -36,7 +37,7 @@ void hadamard_n(array& out, int n, int m, float scale) { // m component template -void hadamard_m(array& out, int n, int m, float scale) { +void hadamard_m(T* out, int n, int m, float scale, size_t size) { auto h_matrices = hadamard_matrices(); auto& matrix = h_matrices[m]; auto start = 1; @@ -51,9 +52,9 @@ void hadamard_m(array& out, int n, int m, float scale) { end = matrix.find('\n', start); } - for (int b = 0; b < out.size() / m / n; b++) { + for (int b = 0; b < size / m / n; b++) { size_t loc = b * n * m; - T* data_ptr = out.data() + loc; + T* data_ptr = out + loc; for (int i = 0; i < n; i++) { std::vector out(m); for (int j = 0; j < m; j++) { @@ -74,12 +75,17 @@ void hadamard_m(array& out, int n, int m, float scale) { } template -void hadamard(array& out, int n, int m, float scale) { - float n_scale = m > 1 ? 1.0 : scale; - hadamard_n(out, n, m, n_scale); - if (m > 1) { - hadamard_m(out, n, m, scale); - } +void hadamard(array& out, int n, int m, float scale, Stream stream) { + auto& encoder = cpu::get_command_encoder(stream); + encoder.set_output_array(out); + auto out_ptr = out.data(); + encoder.dispatch([out_ptr, size = out.size(), n, m, scale]() { + float n_scale = m > 1 ? 1.0 : scale; + hadamard_n(out_ptr, n, m, n_scale, size); + if (m > 1) { + hadamard_m(out_ptr, n, m, scale, size); + } + }); } void Hadamard::eval_cpu(const std::vector& inputs, array& out) { @@ -87,18 +93,26 @@ void Hadamard::eval_cpu(const std::vector& inputs, array& out) { auto& in = inputs[0]; // Copy input to output - copy(in, out, CopyType::General); + if (in.flags().row_contiguous && in.is_donatable()) { + out.copy_shared_buffer(in); + } else { + copy( + in, + out, + in.flags().row_contiguous ? CopyType::Vector : CopyType::General, + stream()); + } int axis = out.ndim() - 1; auto [n, m] = decompose_hadamard(out.shape(axis)); switch (in.dtype()) { case float32: - return hadamard(out, n, m, scale_); + return hadamard(out, n, m, scale_, stream()); case float16: - return hadamard(out, n, m, scale_); + return hadamard(out, n, m, scale_, stream()); case bfloat16: - return hadamard(out, n, m, scale_); + return hadamard(out, n, m, scale_, stream()); default: throw std::invalid_argument("[hadamard] Unsupported type."); } diff --git a/mlx/backend/cpu/indexing.cpp b/mlx/backend/cpu/indexing.cpp index 3a0b63f34..0f4c082cc 100644 --- a/mlx/backend/cpu/indexing.cpp +++ b/mlx/backend/cpu/indexing.cpp @@ -8,6 +8,7 @@ #include "mlx/backend/common/utils.h" #include "mlx/backend/cpu/copy.h" +#include "mlx/backend/cpu/encoder.h" namespace mlx::core { @@ -27,7 +28,8 @@ void gather( const std::vector& inds, array& out, const std::vector& axes, - const Shape& slice_sizes) { + const Shape& slice_sizes, + Stream stream) { // If the array is row contiguous then we can do a contiguous copy given // two conditions on the slice size: // - Any number of leading ones in the slice sizes are allowed @@ -73,38 +75,60 @@ void gather( size_t ind_size = slice_size == 0 ? 0 : out.size() / slice_size; const T* src_ptr = src.data(); T* dst_ptr = out.data(); - size_t out_idx = 0; std::vector its(inds.begin(), inds.end()); ContiguousIterator src_it; if (!can_copy && src.ndim() > 0) { src_it = ContiguousIterator(slice_sizes, src.strides(), src.ndim()); } - for (int idx = 0; idx < ind_size; idx++) { - size_t src_idx = 0; - for (int ii = 0; ii < inds.size(); ++ii) { - auto ax = axes[ii]; - auto idx_loc = its[ii].loc; - its[ii].step(); - auto idx_val = - offset_neg_idx(inds[ii].data()[idx_loc], src.shape(ax)); - src_idx += (idx_val * src.strides()[ax]); - } - if (slice_size == 1) { - dst_ptr[out_idx++] = src_ptr[src_idx]; - } else if (can_copy) { - std::copy( - src_ptr + src_idx, src_ptr + src_idx + slice_size, dst_ptr + out_idx); - out_idx += slice_size; - } else { - for (int jj = 0; jj < slice_size; jj++) { - dst_ptr[out_idx++] = src_ptr[src_idx + src_it.loc]; - src_it.step(); - } - src_it.reset(); - } + std::vector ind_ptrs; + auto& encoder = cpu::get_command_encoder(stream); + encoder.set_input_array(src); + for (auto& idx : inds) { + ind_ptrs.push_back(idx.data()); + encoder.set_input_array(idx); } + encoder.set_output_array(out); + encoder.dispatch([src_ptr, + dst_ptr, + ind_ptrs = std::move(ind_ptrs), + axes, + ind_size, + slice_size, + src_shape = src.shape(), + src_strides = src.strides(), + src_it = std::move(src_it), + its = std::move(its), + can_copy]() mutable { + size_t out_idx = 0; + for (int idx = 0; idx < ind_size; idx++) { + size_t src_idx = 0; + for (int ii = 0; ii < ind_ptrs.size(); ++ii) { + auto ax = axes[ii]; + auto idx_loc = its[ii].loc; + its[ii].step(); + auto idx_val = offset_neg_idx(ind_ptrs[ii][idx_loc], src_shape[ax]); + src_idx += (idx_val * src_strides[ax]); + } + + if (slice_size == 1) { + dst_ptr[out_idx++] = src_ptr[src_idx]; + } else if (can_copy) { + std::copy( + src_ptr + src_idx, + src_ptr + src_idx + slice_size, + dst_ptr + out_idx); + out_idx += slice_size; + } else { + for (int jj = 0; jj < slice_size; jj++) { + dst_ptr[out_idx++] = src_ptr[src_idx + src_it.loc]; + src_it.step(); + } + src_it.reset(); + } + } + }); } template @@ -113,49 +137,50 @@ void dispatch_gather( const std::vector& inds, array& out, const std::vector& axes, - const Shape& size) { + const Shape& size, + Stream stream) { switch (out.dtype()) { case bool_: - gather(src, inds, out, axes, size); + gather(src, inds, out, axes, size, stream); break; case uint8: - gather(src, inds, out, axes, size); + gather(src, inds, out, axes, size, stream); break; case uint16: - gather(src, inds, out, axes, size); + gather(src, inds, out, axes, size, stream); break; case uint32: - gather(src, inds, out, axes, size); + gather(src, inds, out, axes, size, stream); break; case uint64: - gather(src, inds, out, axes, size); + gather(src, inds, out, axes, size, stream); break; case int8: - gather(src, inds, out, axes, size); + gather(src, inds, out, axes, size, stream); break; case int16: - gather(src, inds, out, axes, size); + gather(src, inds, out, axes, size, stream); break; case int32: - gather(src, inds, out, axes, size); + gather(src, inds, out, axes, size, stream); break; case int64: - gather(src, inds, out, axes, size); + gather(src, inds, out, axes, size, stream); break; case float16: - gather(src, inds, out, axes, size); + gather(src, inds, out, axes, size, stream); break; case float32: - gather(src, inds, out, axes, size); + gather(src, inds, out, axes, size, stream); break; case float64: - gather(src, inds, out, axes, size); + gather(src, inds, out, axes, size, stream); break; case bfloat16: - gather(src, inds, out, axes, size); + gather(src, inds, out, axes, size, stream); break; case complex64: - gather(src, inds, out, axes, size); + gather(src, inds, out, axes, size, stream); break; } } @@ -167,34 +192,34 @@ void Gather::eval_cpu(const std::vector& inputs, array& out) { std::vector inds(inputs.begin() + 1, inputs.end()); if (inds.empty()) { - dispatch_gather(src, inds, out, axes_, slice_sizes_); + dispatch_gather(src, inds, out, axes_, slice_sizes_, stream()); return; } switch (inds[0].dtype()) { case uint8: - dispatch_gather(src, inds, out, axes_, slice_sizes_); + dispatch_gather(src, inds, out, axes_, slice_sizes_, stream()); break; case uint16: - dispatch_gather(src, inds, out, axes_, slice_sizes_); + dispatch_gather(src, inds, out, axes_, slice_sizes_, stream()); break; case uint32: - dispatch_gather(src, inds, out, axes_, slice_sizes_); + dispatch_gather(src, inds, out, axes_, slice_sizes_, stream()); break; case uint64: - dispatch_gather(src, inds, out, axes_, slice_sizes_); + dispatch_gather(src, inds, out, axes_, slice_sizes_, stream()); break; case int8: - dispatch_gather(src, inds, out, axes_, slice_sizes_); + dispatch_gather(src, inds, out, axes_, slice_sizes_, stream()); break; case int16: - dispatch_gather(src, inds, out, axes_, slice_sizes_); + dispatch_gather(src, inds, out, axes_, slice_sizes_, stream()); break; case int32: - dispatch_gather(src, inds, out, axes_, slice_sizes_); + dispatch_gather(src, inds, out, axes_, slice_sizes_, stream()); break; case int64: - dispatch_gather(src, inds, out, axes_, slice_sizes_); + dispatch_gather(src, inds, out, axes_, slice_sizes_, stream()); break; default: throw std::runtime_error( @@ -207,7 +232,8 @@ void gather_axis( const array& src, const array& ind, array& out, - const int axis) { + const int axis, + Stream stream) { auto strides = ind.strides(); strides.erase(strides.begin() + axis); auto shape = ind.shape(); @@ -235,20 +261,39 @@ void gather_axis( for (int i = axis + 1; i < ind.ndim(); ++i) { size_post *= ind.shape(i); } - size_t stride_pre = size_post * ind_ax_size; - for (size_t i = 0; i < size_pre; i++) { - for (size_t k = 0; k < size_post; k++) { - for (int j = 0; j < ind_ax_size; ++j) { - auto ind_val = offset_neg_idx( - ind_ptr[ind_it.loc + j * ind_ax_stride], src_ax_size); - dst_ptr[k + j * dst_ax_stride] = - src_ptr[src_it.loc + ind_val * src_ax_stride]; + + auto& encoder = cpu::get_command_encoder(stream); + encoder.set_input_array(src); + encoder.set_input_array(ind); + encoder.set_output_array(out); + + encoder.dispatch([ind_ptr, + src_ptr, + dst_ptr, + size_pre, + size_post, + ind_ax_size, + src_ax_size, + ind_ax_stride, + src_ax_stride, + dst_ax_stride, + ind_it = std::move(ind_it), + src_it = std::move(src_it)]() mutable { + size_t stride_pre = size_post * ind_ax_size; + for (size_t i = 0; i < size_pre; i++) { + for (size_t k = 0; k < size_post; k++) { + for (int j = 0; j < ind_ax_size; ++j) { + auto ind_val = offset_neg_idx( + ind_ptr[ind_it.loc + j * ind_ax_stride], src_ax_size); + dst_ptr[k + j * dst_ax_stride] = + src_ptr[src_it.loc + ind_val * src_ax_stride]; + } + ind_it.step(); + src_it.step(); } - ind_it.step(); - src_it.step(); + dst_ptr += stride_pre; } - dst_ptr += stride_pre; - } + }); } template @@ -256,49 +301,50 @@ void dispatch_gather_axis( const array& src, const array& inds, array& out, - const int axis) { + const int axis, + Stream stream) { switch (out.dtype()) { case bool_: - gather_axis(src, inds, out, axis); + gather_axis(src, inds, out, axis, stream); break; case uint8: - gather_axis(src, inds, out, axis); + gather_axis(src, inds, out, axis, stream); break; case uint16: - gather_axis(src, inds, out, axis); + gather_axis(src, inds, out, axis, stream); break; case uint32: - gather_axis(src, inds, out, axis); + gather_axis(src, inds, out, axis, stream); break; case uint64: - gather_axis(src, inds, out, axis); + gather_axis(src, inds, out, axis, stream); break; case int8: - gather_axis(src, inds, out, axis); + gather_axis(src, inds, out, axis, stream); break; case int16: - gather_axis(src, inds, out, axis); + gather_axis(src, inds, out, axis, stream); break; case int32: - gather_axis(src, inds, out, axis); + gather_axis(src, inds, out, axis, stream); break; case int64: - gather_axis(src, inds, out, axis); + gather_axis(src, inds, out, axis, stream); break; case float16: - gather_axis(src, inds, out, axis); + gather_axis(src, inds, out, axis, stream); break; case float32: - gather_axis(src, inds, out, axis); + gather_axis(src, inds, out, axis, stream); break; case float64: - gather_axis(src, inds, out, axis); + gather_axis(src, inds, out, axis, stream); break; case bfloat16: - gather_axis(src, inds, out, axis); + gather_axis(src, inds, out, axis, stream); break; case complex64: - gather_axis(src, inds, out, axis); + gather_axis(src, inds, out, axis, stream); break; } } @@ -309,28 +355,28 @@ void GatherAxis::eval_cpu(const std::vector& inputs, array& out) { auto& inds = inputs[1]; switch (inds.dtype()) { case uint8: - dispatch_gather_axis(src, inds, out, axis_); + dispatch_gather_axis(src, inds, out, axis_, stream()); break; case uint16: - dispatch_gather_axis(src, inds, out, axis_); + dispatch_gather_axis(src, inds, out, axis_, stream()); break; case uint32: - dispatch_gather_axis(src, inds, out, axis_); + dispatch_gather_axis(src, inds, out, axis_, stream()); break; case uint64: - dispatch_gather_axis(src, inds, out, axis_); + dispatch_gather_axis(src, inds, out, axis_, stream()); break; case int8: - dispatch_gather_axis(src, inds, out, axis_); + dispatch_gather_axis(src, inds, out, axis_, stream()); break; case int16: - dispatch_gather_axis(src, inds, out, axis_); + dispatch_gather_axis(src, inds, out, axis_, stream()); break; case int32: - dispatch_gather_axis(src, inds, out, axis_); + dispatch_gather_axis(src, inds, out, axis_, stream()); break; case int64: - dispatch_gather_axis(src, inds, out, axis_); + dispatch_gather_axis(src, inds, out, axis_, stream()); break; default: throw std::runtime_error( @@ -345,7 +391,8 @@ void scatter( array& out, const std::vector& inds, const std::vector& axes, - const OpT& op) { + const OpT& op, + Stream stream) { int nind = inds.size(); auto inds_ndim = updates.ndim() - out.ndim(); size_t n_updates = nind ? inds[0].size() : 1; @@ -361,26 +408,45 @@ void scatter( ContiguousIterator update_it(updates); ContiguousIterator out_it(update_shape, out.strides(), out.ndim()); - for (int i = 0; i < n_updates; ++i) { - size_t out_offset = 0; - for (int j = 0; j < nind; ++j) { - auto ax = axes[j]; - auto idx_loc = its[j].loc; - its[j].step(); - auto idx_val = - offset_neg_idx(inds[j].data()[idx_loc], out.shape(ax)); - out_offset += (idx_val * out.strides()[ax]); - } - update_it.seek(i * update_size); - for (int j = 0; j < update_size; ++j) { - op(updates.data()[update_it.loc], - out.data() + out_offset + out_it.loc); - update_it.step(); - out_it.step(); - } - out_it.reset(); - update_it.reset(); + std::vector ind_ptrs; + auto& encoder = cpu::get_command_encoder(stream); + encoder.set_input_array(updates); + for (auto& idx : inds) { + ind_ptrs.push_back(idx.data()); + encoder.set_input_array(idx); } + encoder.set_output_array(out); + encoder.dispatch([out_ptr = out.data(), + upd_ptr = updates.data(), + ind_ptrs = std::move(ind_ptrs), + axes, + n_updates, + update_size, + op = std::move(op), + out_shape = out.shape(), + out_strides = out.strides(), + out_it = std::move(out_it), + update_it = std::move(update_it), + its = std::move(its)]() mutable { + for (int i = 0; i < n_updates; ++i) { + size_t out_offset = 0; + for (int j = 0; j < ind_ptrs.size(); ++j) { + auto ax = axes[j]; + auto idx_loc = its[j].loc; + its[j].step(); + auto idx_val = offset_neg_idx(ind_ptrs[j][idx_loc], out_shape[ax]); + out_offset += (idx_val * out_strides[ax]); + } + update_it.seek(i * update_size); + for (int j = 0; j < update_size; ++j) { + op(upd_ptr[update_it.loc], out_ptr + out_offset + out_it.loc); + update_it.step(); + out_it.step(); + } + out_it.reset(); + update_it.reset(); + } + }); } template @@ -389,29 +455,53 @@ void dispatch_scatter_inds( const std::vector& indices, const array& updates, const std::vector& axes, - Scatter::ReduceType rtype) { + Scatter::ReduceType rtype, + Stream stream) { switch (rtype) { case Scatter::None: scatter( - updates, out, indices, axes, [](auto x, auto* y) { (*y) = x; }); + updates, + out, + indices, + axes, + [](auto x, auto* y) { (*y) = x; }, + stream); break; case Scatter::Sum: scatter( - updates, out, indices, axes, [](auto x, auto* y) { (*y) += x; }); + updates, + out, + indices, + axes, + [](auto x, auto* y) { (*y) += x; }, + stream); break; case Scatter::Prod: scatter( - updates, out, indices, axes, [](auto x, auto* y) { (*y) *= x; }); + updates, + out, + indices, + axes, + [](auto x, auto* y) { (*y) *= x; }, + stream); break; case Scatter::Max: - scatter(updates, out, indices, axes, [](auto x, auto* y) { - (*y) = (*y > x) ? *y : x; - }); + scatter( + updates, + out, + indices, + axes, + [](auto x, auto* y) { (*y) = (*y > x) ? *y : x; }, + stream); break; case Scatter::Min: - scatter(updates, out, indices, axes, [](auto x, auto* y) { - (*y) = (*y < x) ? *y : x; - }); + scatter( + updates, + out, + indices, + axes, + [](auto x, auto* y) { (*y) = (*y < x) ? *y : x; }, + stream); break; } } @@ -422,36 +512,46 @@ void dispatch_scatter( const std::vector& inds, const array& updates, const std::vector& axes, - Scatter::ReduceType rtype) { + Scatter::ReduceType rtype, + Stream stream) { if (inds.empty()) { - dispatch_scatter_inds(out, inds, updates, axes, rtype); + dispatch_scatter_inds( + out, inds, updates, axes, rtype, stream); return; } switch (inds[0].dtype()) { case uint8: - dispatch_scatter_inds(out, inds, updates, axes, rtype); + dispatch_scatter_inds( + out, inds, updates, axes, rtype, stream); break; case uint16: - dispatch_scatter_inds(out, inds, updates, axes, rtype); + dispatch_scatter_inds( + out, inds, updates, axes, rtype, stream); break; case uint32: - dispatch_scatter_inds(out, inds, updates, axes, rtype); + dispatch_scatter_inds( + out, inds, updates, axes, rtype, stream); break; case uint64: - dispatch_scatter_inds(out, inds, updates, axes, rtype); + dispatch_scatter_inds( + out, inds, updates, axes, rtype, stream); break; case int8: - dispatch_scatter_inds(out, inds, updates, axes, rtype); + dispatch_scatter_inds( + out, inds, updates, axes, rtype, stream); break; case int16: - dispatch_scatter_inds(out, inds, updates, axes, rtype); + dispatch_scatter_inds( + out, inds, updates, axes, rtype, stream); break; case int32: - dispatch_scatter_inds(out, inds, updates, axes, rtype); + dispatch_scatter_inds( + out, inds, updates, axes, rtype, stream); break; case int64: - dispatch_scatter_inds(out, inds, updates, axes, rtype); + dispatch_scatter_inds( + out, inds, updates, axes, rtype, stream); break; default: throw std::runtime_error( @@ -469,50 +569,63 @@ void Scatter::eval_cpu(const std::vector& inputs, array& out) { // Copy src into out (copy allocates memory for out) auto ctype = src.flags().row_contiguous ? CopyType::Vector : CopyType::General; - copy(src, out, ctype); + copy(src, out, ctype, stream()); switch (src.dtype()) { case bool_: - dispatch_scatter(out, inds, updates, axes_, reduce_type_); + dispatch_scatter(out, inds, updates, axes_, reduce_type_, stream()); break; case uint8: - dispatch_scatter(out, inds, updates, axes_, reduce_type_); + dispatch_scatter( + out, inds, updates, axes_, reduce_type_, stream()); break; case uint16: - dispatch_scatter(out, inds, updates, axes_, reduce_type_); + dispatch_scatter( + out, inds, updates, axes_, reduce_type_, stream()); break; case uint32: - dispatch_scatter(out, inds, updates, axes_, reduce_type_); + dispatch_scatter( + out, inds, updates, axes_, reduce_type_, stream()); break; case uint64: - dispatch_scatter(out, inds, updates, axes_, reduce_type_); + dispatch_scatter( + out, inds, updates, axes_, reduce_type_, stream()); break; case int8: - dispatch_scatter(out, inds, updates, axes_, reduce_type_); + dispatch_scatter( + out, inds, updates, axes_, reduce_type_, stream()); break; case int16: - dispatch_scatter(out, inds, updates, axes_, reduce_type_); + dispatch_scatter( + out, inds, updates, axes_, reduce_type_, stream()); break; case int32: - dispatch_scatter(out, inds, updates, axes_, reduce_type_); + dispatch_scatter( + out, inds, updates, axes_, reduce_type_, stream()); break; case int64: - dispatch_scatter(out, inds, updates, axes_, reduce_type_); + dispatch_scatter( + out, inds, updates, axes_, reduce_type_, stream()); break; case float16: - dispatch_scatter(out, inds, updates, axes_, reduce_type_); + dispatch_scatter( + out, inds, updates, axes_, reduce_type_, stream()); break; case float32: - dispatch_scatter(out, inds, updates, axes_, reduce_type_); + dispatch_scatter( + out, inds, updates, axes_, reduce_type_, stream()); break; case float64: - dispatch_scatter(out, inds, updates, axes_, reduce_type_); + dispatch_scatter( + out, inds, updates, axes_, reduce_type_, stream()); break; case bfloat16: - dispatch_scatter(out, inds, updates, axes_, reduce_type_); + dispatch_scatter( + out, inds, updates, axes_, reduce_type_, stream()); break; case complex64: - dispatch_scatter(out, inds, updates, axes_, reduce_type_); + dispatch_scatter( + out, inds, updates, axes_, reduce_type_, stream()); break; } } @@ -523,7 +636,8 @@ void scatter_axis( const array idx, const array& upd, int axis, - const OpT& op) { + const OpT& op, + Stream stream) { auto strides = idx.strides(); strides.erase(strides.begin() + axis); auto shape = idx.shape(); @@ -543,6 +657,11 @@ void scatter_axis( auto idx_ax_size = idx.shape(axis); auto dst_ax_size = out.shape(axis); + auto& encoder = cpu::get_command_encoder(stream); + encoder.set_input_array(idx); + encoder.set_input_array(upd); + encoder.set_output_array(out); + size_t size_pre = 1; size_t size_post = 1; for (int i = 0; i < axis; ++i) { @@ -551,20 +670,34 @@ void scatter_axis( for (int i = axis + 1; i < idx.ndim(); ++i) { size_post *= idx.shape(i); } - size_t stride_pre = size_post * dst_ax_size; - for (size_t i = 0; i < size_pre; i++) { - for (size_t k = 0; k < size_post; k++) { - for (int j = 0; j < idx_ax_size; ++j) { - auto ind_val = offset_neg_idx( - idx_ptr[idx_it.loc + j * idx_ax_stride], dst_ax_size); - op(upd_ptr[upd_it.loc + j * upd_ax_stride], - dst_ptr + k + ind_val * dst_ax_stride); + encoder.dispatch([idx_ptr, + upd_ptr, + dst_ptr, + size_pre, + size_post, + idx_ax_size, + dst_ax_size, + idx_ax_stride, + upd_ax_stride, + dst_ax_stride, + idx_it = std::move(idx_it), + upd_it = std::move(upd_it), + op = std::move(op)]() mutable { + size_t stride_pre = size_post * dst_ax_size; + for (size_t i = 0; i < size_pre; i++) { + for (size_t k = 0; k < size_post; k++) { + for (int j = 0; j < idx_ax_size; ++j) { + auto ind_val = offset_neg_idx( + idx_ptr[idx_it.loc + j * idx_ax_stride], dst_ax_size); + op(upd_ptr[upd_it.loc + j * upd_ax_stride], + dst_ptr + k + ind_val * dst_ax_stride); + } + idx_it.step(); + upd_it.step(); } - idx_it.step(); - upd_it.step(); + dst_ptr += stride_pre; } - dst_ptr += stride_pre; - } + }); } template @@ -573,15 +706,16 @@ void dispatch_scatter_axis_op( const array& idx, const array& updates, int axis, - ScatterAxis::ReduceType rtype) { + ScatterAxis::ReduceType rtype, + Stream stream) { switch (rtype) { case ScatterAxis::None: scatter_axis( - out, idx, updates, axis, [](auto x, auto* y) { (*y) = x; }); + out, idx, updates, axis, [](auto x, auto* y) { (*y) = x; }, stream); break; case ScatterAxis::Sum: scatter_axis( - out, idx, updates, axis, [](auto x, auto* y) { (*y) += x; }); + out, idx, updates, axis, [](auto x, auto* y) { (*y) += x; }, stream); break; } } @@ -592,31 +726,40 @@ void dispatch_scatter_axis( const array& idx, const array& updates, int axis, - ScatterAxis::ReduceType rtype) { + ScatterAxis::ReduceType rtype, + Stream stream) { switch (idx.dtype()) { case uint8: - dispatch_scatter_axis_op(out, idx, updates, axis, rtype); + dispatch_scatter_axis_op( + out, idx, updates, axis, rtype, stream); break; case uint16: - dispatch_scatter_axis_op(out, idx, updates, axis, rtype); + dispatch_scatter_axis_op( + out, idx, updates, axis, rtype, stream); break; case uint32: - dispatch_scatter_axis_op(out, idx, updates, axis, rtype); + dispatch_scatter_axis_op( + out, idx, updates, axis, rtype, stream); break; case uint64: - dispatch_scatter_axis_op(out, idx, updates, axis, rtype); + dispatch_scatter_axis_op( + out, idx, updates, axis, rtype, stream); break; case int8: - dispatch_scatter_axis_op(out, idx, updates, axis, rtype); + dispatch_scatter_axis_op( + out, idx, updates, axis, rtype, stream); break; case int16: - dispatch_scatter_axis_op(out, idx, updates, axis, rtype); + dispatch_scatter_axis_op( + out, idx, updates, axis, rtype, stream); break; case int32: - dispatch_scatter_axis_op(out, idx, updates, axis, rtype); + dispatch_scatter_axis_op( + out, idx, updates, axis, rtype, stream); break; case int64: - dispatch_scatter_axis_op(out, idx, updates, axis, rtype); + dispatch_scatter_axis_op( + out, idx, updates, axis, rtype, stream); break; default: throw std::runtime_error( @@ -634,51 +777,64 @@ void ScatterAxis::eval_cpu(const std::vector& inputs, array& out) { // Copy src into out (copy allocates memory for out) auto ctype = src.flags().row_contiguous ? CopyType::Vector : CopyType::General; - copy(src, out, ctype); + copy(src, out, ctype, stream()); switch (src.dtype()) { case bool_: - dispatch_scatter_axis(out, idx, updates, axis_, reduce_type_); + dispatch_scatter_axis( + out, idx, updates, axis_, reduce_type_, stream()); break; case uint8: - dispatch_scatter_axis(out, idx, updates, axis_, reduce_type_); + dispatch_scatter_axis( + out, idx, updates, axis_, reduce_type_, stream()); break; case uint16: - dispatch_scatter_axis(out, idx, updates, axis_, reduce_type_); + dispatch_scatter_axis( + out, idx, updates, axis_, reduce_type_, stream()); break; case uint32: - dispatch_scatter_axis(out, idx, updates, axis_, reduce_type_); + dispatch_scatter_axis( + out, idx, updates, axis_, reduce_type_, stream()); break; case uint64: - dispatch_scatter_axis(out, idx, updates, axis_, reduce_type_); + dispatch_scatter_axis( + out, idx, updates, axis_, reduce_type_, stream()); break; case int8: - dispatch_scatter_axis(out, idx, updates, axis_, reduce_type_); + dispatch_scatter_axis( + out, idx, updates, axis_, reduce_type_, stream()); break; case int16: - dispatch_scatter_axis(out, idx, updates, axis_, reduce_type_); + dispatch_scatter_axis( + out, idx, updates, axis_, reduce_type_, stream()); break; case int32: - dispatch_scatter_axis(out, idx, updates, axis_, reduce_type_); + dispatch_scatter_axis( + out, idx, updates, axis_, reduce_type_, stream()); break; case int64: - dispatch_scatter_axis(out, idx, updates, axis_, reduce_type_); + dispatch_scatter_axis( + out, idx, updates, axis_, reduce_type_, stream()); break; case float16: - dispatch_scatter_axis(out, idx, updates, axis_, reduce_type_); + dispatch_scatter_axis( + out, idx, updates, axis_, reduce_type_, stream()); break; case float32: - dispatch_scatter_axis(out, idx, updates, axis_, reduce_type_); + dispatch_scatter_axis( + out, idx, updates, axis_, reduce_type_, stream()); break; case float64: - dispatch_scatter_axis(out, idx, updates, axis_, reduce_type_); + dispatch_scatter_axis( + out, idx, updates, axis_, reduce_type_, stream()); break; case bfloat16: - dispatch_scatter_axis(out, idx, updates, axis_, reduce_type_); + dispatch_scatter_axis( + out, idx, updates, axis_, reduce_type_, stream()); break; case complex64: dispatch_scatter_axis( - out, idx, updates, axis_, reduce_type_); + out, idx, updates, axis_, reduce_type_, stream()); break; } } diff --git a/mlx/backend/cpu/inverse.cpp b/mlx/backend/cpu/inverse.cpp index aba038218..9d9b71497 100644 --- a/mlx/backend/cpu/inverse.cpp +++ b/mlx/backend/cpu/inverse.cpp @@ -2,20 +2,21 @@ #include "mlx/allocator.h" #include "mlx/backend/cpu/copy.h" +#include "mlx/backend/cpu/encoder.h" #include "mlx/backend/cpu/lapack.h" #include "mlx/primitives.h" namespace mlx::core { template -void general_inv(array& inv, int N, int i) { +void general_inv(T* inv, int N) { int info; auto ipiv = array::Data{allocator::malloc_or_wait(sizeof(int) * N)}; // Compute LU factorization. getrf( /* m = */ &N, /* n = */ &N, - /* a = */ inv.data() + N * N * i, + /* a = */ inv, /* lda = */ &N, /* ipiv = */ static_cast(ipiv.buffer.raw_ptr()), /* info = */ &info); @@ -53,7 +54,7 @@ void general_inv(array& inv, int N, int i) { // Compute inverse. getri( /* m = */ &N, - /* a = */ inv.data() + N * N * i, + /* a = */ inv, /* lda = */ &N, /* ipiv = */ static_cast(ipiv.buffer.raw_ptr()), /* work = */ static_cast(scratch.buffer.raw_ptr()), @@ -68,29 +69,28 @@ void general_inv(array& inv, int N, int i) { } template -void tri_inv(array& inv, int N, int i, bool upper) { +void tri_inv(T* inv, int N, bool upper) { const char uplo = upper ? 'L' : 'U'; const char diag = 'N'; - T* data = inv.data() + N * N * i; int info; trtri( /* uplo = */ &uplo, /* diag = */ &diag, /* N = */ &N, - /* a = */ data, + /* a = */ inv, /* lda = */ &N, /* info = */ &info); // zero out the other triangle if (upper) { for (int i = 0; i < N; i++) { - std::fill(data, data + i, 0.0f); - data += N; + std::fill(inv, inv + i, 0.0f); + inv += N; } } else { for (int i = 0; i < N; i++) { - std::fill(data + i + 1, data + N, 0.0f); - data += N; + std::fill(inv + i + 1, inv + N, 0.0f); + inv += N; } } @@ -103,34 +103,53 @@ void tri_inv(array& inv, int N, int i, bool upper) { } template -void inverse_impl(const array& a, array& inv, bool tri, bool upper) { +void inverse_impl( + const array& a, + array& inv, + bool tri, + bool upper, + Stream stream) { // Lapack uses the column-major convention. We take advantage of the following // identity to avoid transposing (see // https://math.stackexchange.com/a/340234): // (A⁻¹)ᵀ = (Aᵀ)⁻¹ // The inverse is computed in place, so just copy the input to the output. - copy(a, inv, a.flags().row_contiguous ? CopyType::Vector : CopyType::General); + copy( + a, + inv, + a.flags().row_contiguous ? CopyType::Vector : CopyType::General, + stream); const int N = a.shape(-1); const size_t num_matrices = a.size() / (N * N); - for (int i = 0; i < num_matrices; i++) { - if (tri) { - tri_inv(inv, N, i, upper); - } else { - general_inv(inv, N, i); - } + auto& encoder = cpu::get_command_encoder(stream); + encoder.set_output_array(inv); + + auto inv_ptr = inv.data(); + if (tri) { + encoder.dispatch([inv_ptr, N, num_matrices, upper]() { + for (int i = 0; i < num_matrices; i++) { + tri_inv(inv_ptr + N * N * i, N, upper); + } + }); + } else { + encoder.dispatch([inv_ptr, N, num_matrices]() { + for (int i = 0; i < num_matrices; i++) { + general_inv(inv_ptr + N * N * i, N); + } + }); } } void Inverse::eval_cpu(const std::vector& inputs, array& output) { switch (inputs[0].dtype()) { case float32: - inverse_impl(inputs[0], output, tri_, upper_); + inverse_impl(inputs[0], output, tri_, upper_, stream()); break; case float64: - inverse_impl(inputs[0], output, tri_, upper_); + inverse_impl(inputs[0], output, tri_, upper_, stream()); break; default: throw std::runtime_error( diff --git a/mlx/backend/cpu/luf.cpp b/mlx/backend/cpu/luf.cpp index 87de97c72..d85de146c 100644 --- a/mlx/backend/cpu/luf.cpp +++ b/mlx/backend/cpu/luf.cpp @@ -4,15 +4,22 @@ #include "mlx/allocator.h" #include "mlx/backend/cpu/copy.h" +#include "mlx/backend/cpu/encoder.h" #include "mlx/backend/cpu/lapack.h" #include "mlx/primitives.h" namespace mlx::core { template -void luf_impl(const array& a, array& lu, array& pivots, array& row_indices) { +void luf_impl( + const array& a, + array& lu, + array& pivots, + array& row_indices, + Stream stream) { int M = a.shape(-2); int N = a.shape(-1); + int K = std::min(M, N); // Copy a into lu and make it col contiguous auto ndim = lu.ndim(); @@ -26,57 +33,72 @@ void luf_impl(const array& a, array& lu, array& pivots, array& row_indices) { lu.set_data( allocator::malloc_or_wait(lu.nbytes()), lu.nbytes(), strides, flags); copy_inplace( - a, lu, a.shape(), a.strides(), strides, 0, 0, CopyType::GeneralGeneral); + a, + lu, + a.shape(), + a.strides(), + strides, + 0, + 0, + CopyType::GeneralGeneral, + stream); auto a_ptr = lu.data(); - pivots.set_data(allocator::malloc_or_wait(pivots.nbytes())); row_indices.set_data(allocator::malloc_or_wait(row_indices.nbytes())); auto pivots_ptr = pivots.data(); auto row_indices_ptr = row_indices.data(); - - int info; size_t num_matrices = a.size() / (M * N); - for (size_t i = 0; i < num_matrices; ++i) { - // Compute LU factorization of A - getrf( - /* m */ &M, - /* n */ &N, - /* a */ a_ptr, - /* lda */ &M, - /* ipiv */ reinterpret_cast(pivots_ptr), - /* info */ &info); + auto& encoder = cpu::get_command_encoder(stream); + encoder.set_input_array(a); + encoder.set_output_array(lu); + encoder.set_output_array(pivots); + encoder.set_output_array(row_indices); - if (info != 0) { - std::stringstream ss; - ss << "[LUF::eval_cpu] sgetrf_ failed with code " << info - << ((info > 0) ? " because matrix is singular" - : " because argument had an illegal value"); - throw std::runtime_error(ss.str()); - } + encoder.dispatch( + [a_ptr, pivots_ptr, row_indices_ptr, num_matrices, M, N, K]() mutable { + int info; + for (size_t i = 0; i < num_matrices; ++i) { + // Compute LU factorization of A + getrf( + /* m */ &M, + /* n */ &N, + /* a */ a_ptr, + /* lda */ &M, + /* ipiv */ reinterpret_cast(pivots_ptr), + /* info */ &info); - // Subtract 1 to get 0-based index - int j = 0; - for (; j < pivots.shape(-1); ++j) { - pivots_ptr[j]--; - row_indices_ptr[j] = j; - } - for (; j < row_indices.shape(-1); ++j) { - row_indices_ptr[j] = j; - } - for (int j = pivots.shape(-1) - 1; j >= 0; --j) { - auto piv = pivots_ptr[j]; - auto t1 = row_indices_ptr[piv]; - auto t2 = row_indices_ptr[j]; - row_indices_ptr[j] = t1; - row_indices_ptr[piv] = t2; - } + if (info != 0) { + std::stringstream ss; + ss << "[LUF::eval_cpu] sgetrf_ failed with code " << info + << ((info > 0) ? " because matrix is singular" + : " because argument had an illegal value"); + throw std::runtime_error(ss.str()); + } - // Advance pointers to the next matrix - a_ptr += M * N; - pivots_ptr += pivots.shape(-1); - row_indices_ptr += pivots.shape(-1); - } + // Subtract 1 to get 0-based index + int j = 0; + for (; j < K; ++j) { + pivots_ptr[j]--; + row_indices_ptr[j] = j; + } + for (; j < M; ++j) { + row_indices_ptr[j] = j; + } + for (int j = K - 1; j >= 0; --j) { + auto piv = pivots_ptr[j]; + auto t1 = row_indices_ptr[piv]; + auto t2 = row_indices_ptr[j]; + row_indices_ptr[j] = t1; + row_indices_ptr[piv] = t2; + } + + // Advance pointers to the next matrix + a_ptr += M * N; + pivots_ptr += K; + row_indices_ptr += M; + } + }); } void LUF::eval_cpu( @@ -85,10 +107,10 @@ void LUF::eval_cpu( assert(inputs.size() == 1); switch (inputs[0].dtype()) { case float32: - luf_impl(inputs[0], outputs[0], outputs[1], outputs[2]); + luf_impl(inputs[0], outputs[0], outputs[1], outputs[2], stream()); break; case float64: - luf_impl(inputs[0], outputs[0], outputs[1], outputs[2]); + luf_impl(inputs[0], outputs[0], outputs[1], outputs[2], stream()); break; default: throw std::runtime_error( diff --git a/mlx/backend/cpu/masked_mm.cpp b/mlx/backend/cpu/masked_mm.cpp index 5c9753e88..75b8a5b3d 100644 --- a/mlx/backend/cpu/masked_mm.cpp +++ b/mlx/backend/cpu/masked_mm.cpp @@ -5,6 +5,7 @@ #include "mlx/array.h" #include "mlx/backend/common/utils.h" #include "mlx/backend/cpu/copy.h" +#include "mlx/backend/cpu/encoder.h" #include "mlx/backend/cpu/lapack.h" #include "mlx/primitives.h" @@ -64,36 +65,36 @@ void BlockMaskedMM::eval_cpu(const std::vector& inputs, array& out) { auto& b_pre = inputs[1]; auto check_transpose = - [](const array& arr, bool do_copy, bool expand_all = false) { + [s = stream()](const array& arr, bool do_copy, bool expand_all = false) { auto stx = arr.strides()[arr.ndim() - 2]; auto sty = arr.strides()[arr.ndim() - 1]; if (!expand_all && stx == arr.shape(-1) && sty == 1) { if (do_copy) { array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); - copy(arr, arr_copy, CopyType::Vector); - return std::make_tuple(false, stx, arr_copy); + copy(arr, arr_copy, CopyType::Vector, s); + return std::make_tuple(false, stx, arr_copy, true); } - return std::make_tuple(false, stx, arr); + return std::make_tuple(false, stx, arr, false); } else if (!expand_all && stx == 1 && sty == arr.shape(-2)) { if (do_copy) { array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); - copy(arr, arr_copy, CopyType::Vector); - return std::make_tuple(true, sty, arr_copy); + copy(arr, arr_copy, CopyType::Vector, s); + return std::make_tuple(true, sty, arr_copy, true); } - return std::make_tuple(true, sty, arr); + return std::make_tuple(true, sty, arr, false); } else { array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); - copy(arr, arr_copy, CopyType::General); + copy(arr, arr_copy, CopyType::General, s); int64_t stx = arr.shape(-1); - return std::make_tuple(false, stx, arr_copy); + return std::make_tuple(false, stx, arr_copy, true); } }; bool has_op_mask = inputs.size() > 3; bool has_out_mask = inputs.size() == 3 || inputs.size() == 5; - auto [a_transposed, lda, a] = + auto [a_transposed, lda, a, a_copied] = check_transpose(a_pre, has_op_mask, inputs.back().dtype() != bool_); - auto [b_transposed, ldb, b] = + auto [b_transposed, ldb, b, b_copied] = check_transpose(b_pre, has_op_mask, inputs.back().dtype() != bool_); size_t M = a.shape(-2); @@ -104,31 +105,39 @@ void BlockMaskedMM::eval_cpu(const std::vector& inputs, array& out) { return; } + auto& encoder = cpu::get_command_encoder(stream()); if (K == 0) { - std::memset(static_cast(out.data()), 0, out.nbytes()); + encoder.set_output_array(out); + encoder.dispatch([out_ptr = out.data(), nbytes = out.nbytes()]() { + std::memset(out_ptr, 0, nbytes); + }); return; } - auto mask_array = [](const array& mask, + auto mask_array = [](const void* mask, float* data, int block_size, int batch_idx, int X, int Y, size_t X_data_str, - size_t Y_data_str) { + size_t Y_data_str, + const Shape& mask_shape, + const Strides& mask_strides, + bool is_bool) { + auto ndim = mask_shape.size(); auto mask_offset = elem_to_loc( - mask.shape(-1) * mask.shape(-2) * batch_idx, - mask.shape(), - mask.strides()); + mask_shape[ndim - 1] * mask_shape[ndim - 2] * batch_idx, + mask_shape, + mask_strides); - auto X_mask_str = mask.strides()[mask.ndim() - 2]; - auto Y_mask_str = mask.strides()[mask.ndim() - 1]; + auto X_mask_str = mask_strides[ndim - 2]; + auto Y_mask_str = mask_strides[ndim - 1]; - if (mask.dtype() == bool_) { + if (is_bool) { return mask_matrix( data, - mask.data(), + static_cast(mask), block_size, X, Y, @@ -140,7 +149,7 @@ void BlockMaskedMM::eval_cpu(const std::vector& inputs, array& out) { } else { return mask_matrix( data, - mask.data(), + static_cast(mask), block_size, X, Y, @@ -152,61 +161,155 @@ void BlockMaskedMM::eval_cpu(const std::vector& inputs, array& out) { } }; - for (int i = 0; i < (out.size() / (M * size_t(N))); ++i) { - // Adjust pointer - float* ai = - a.data() + elem_to_loc(M * K * i, a.shape(), a.strides()); - float* bi = - b.data() + elem_to_loc(K * N * i, b.shape(), b.strides()); - float* ci = out.data() + M * N * i; + encoder.set_input_array(a); + encoder.set_input_array(b); + const void* a_mask_ptr; + const void* b_mask_ptr; + const void* out_mask_ptr; + Shape a_mask_shape; + Shape b_mask_shape; + Shape out_mask_shape; + Strides a_mask_strides; + Strides b_mask_strides; + Strides out_mask_strides; + bool a_mask_bool; + bool b_mask_bool; + bool out_mask_bool; + if (has_op_mask) { + auto& a_mask = inputs[inputs.size() - 2]; + auto& b_mask = inputs[inputs.size() - 1]; + a_mask_ptr = a_mask.data(); + b_mask_ptr = b_mask.data(); + a_mask_shape = a_mask.shape(); + b_mask_shape = b_mask.shape(); + a_mask_strides = a_mask.strides(); + b_mask_strides = b_mask.strides(); + a_mask_bool = (a_mask.dtype() == bool_); + b_mask_bool = (b_mask.dtype() == bool_); + encoder.set_input_array(a_mask); + encoder.set_input_array(b_mask); + } + if (has_out_mask) { + auto& out_mask = inputs[2]; + out_mask_ptr = out_mask.data(); + out_mask_bool = (out_mask.dtype() == bool_); + encoder.set_input_array(out_mask); + out_mask_shape = out_mask.shape(); + out_mask_strides = out_mask.strides(); + } + encoder.set_output_array(out); + auto a_ptr = a.data(); + auto b_ptr = b.data(); + auto out_ptr = out.data(); + size_t num_matrices = out.size() / (M * size_t(N)); + auto ldc = out.shape(-1); - // Zero out blocks in a and b if needed - if (has_op_mask) { - auto& a_mask = inputs[inputs.size() - 2]; - mask_array( - a_mask, - ai, - block_size_, - i, + encoder.dispatch([a_ptr, + b_ptr, + out_ptr, + a_mask_ptr, + b_mask_ptr, + out_mask_ptr, + has_op_mask, + has_out_mask, + block_size = block_size_, + num_matrices, + M, + N, + K, + a_transposed = a_transposed, + b_transposed = b_transposed, + lda = lda, + ldb = ldb, + ldc, + a_shape = a.shape(), + a_strides = a.strides(), + b_shape = b.shape(), + b_strides = b.strides(), + a_mask_shape = std::move(a_mask_shape), + b_mask_shape = std::move(b_mask_shape), + out_mask_shape = std::move(out_mask_shape), + a_mask_strides = std::move(a_mask_strides), + b_mask_strides = std::move(b_mask_strides), + out_mask_strides = std::move(out_mask_strides), + mask_array, + a_mask_bool, + b_mask_bool, + out_mask_bool]() { + for (int i = 0; i < num_matrices; ++i) { + // Adjust pointer + float* ai = a_ptr + elem_to_loc(M * K * i, a_shape, a_strides); + float* bi = b_ptr + elem_to_loc(K * N * i, b_shape, b_strides); + float* ci = out_ptr + M * N * i; + + // Zero out blocks in a and b if needed + if (has_op_mask) { + mask_array( + a_mask_ptr, + ai, + block_size, + i, + M, + K, + a_transposed ? 1 : lda, + a_transposed ? lda : 1, + a_mask_shape, + a_mask_strides, + a_mask_bool); + + mask_array( + b_mask_ptr, + bi, + block_size, + i, + K, + N, + b_transposed ? 1 : ldb, + b_transposed ? ldb : 1, + b_mask_shape, + b_mask_strides, + b_mask_bool); + } + + // Do matmul + cblas_sgemm( + CblasRowMajor, + a_transposed ? CblasTrans : CblasNoTrans, // transA + b_transposed ? CblasTrans : CblasNoTrans, // transB M, - K, - a_transposed ? 1 : lda, - a_transposed ? lda : 1); - - auto& b_mask = inputs[inputs.size() - 1]; - mask_array( - b_mask, - bi, - block_size_, - i, - K, N, - b_transposed ? 1 : ldb, - b_transposed ? ldb : 1); - } + K, + 1.0, // alpha + ai, + lda, + bi, + ldb, + 0.0, // beta + ci, + ldc); - // Do matmul - cblas_sgemm( - CblasRowMajor, - a_transposed ? CblasTrans : CblasNoTrans, // transA - b_transposed ? CblasTrans : CblasNoTrans, // transB - M, - N, - K, - 1.0, // alpha - ai, - lda, - bi, - ldb, - 0.0, // beta - ci, - out.shape(-1) // ldc - ); - - // Zero out blocks in out - if (has_out_mask) { - mask_array(inputs[2], ci, block_size_, i, M, N, N, 1); + // Zero out blocks in out + if (has_out_mask) { + mask_array( + out_mask_ptr, + ci, + block_size, + i, + M, + N, + N, + 1, + out_mask_shape, + out_mask_strides, + out_mask_bool); + } } + }); + if (a_copied) { + encoder.add_temporary(a); + } + if (b_copied) { + encoder.add_temporary(b); } } @@ -220,7 +323,8 @@ void GatherMM::eval_cpu(const std::vector& inputs, array& out) { auto& a_pre = inputs[0]; auto& b_pre = inputs[1]; - auto check_transpose = [](const array& arr) { + std::vector temps; + auto check_transpose = [s = stream(), &temps](const array& arr) { auto stx = arr.strides()[arr.ndim() - 2]; auto sty = arr.strides()[arr.ndim() - 1]; if (stx == arr.shape(-1) && sty == 1) { @@ -228,10 +332,10 @@ void GatherMM::eval_cpu(const std::vector& inputs, array& out) { } else if (stx == 1 && sty == arr.shape(-2)) { return std::make_tuple(true, sty, arr); } else { - array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); - copy(arr, arr_copy, CopyType::General); + temps.push_back(array(arr.shape(), arr.dtype(), nullptr, {})); + copy(arr, temps.back(), CopyType::General, s); int64_t stx = arr.shape(-1); - return std::make_tuple(false, stx, arr_copy); + return std::make_tuple(false, stx, temps.back()); } }; @@ -246,8 +350,12 @@ void GatherMM::eval_cpu(const std::vector& inputs, array& out) { return; } + auto& encoder = cpu::get_command_encoder(stream()); if (K == 0) { - std::memset(static_cast(out.data()), 0, out.nbytes()); + encoder.set_output_array(out); + encoder.dispatch([out_ptr = out.data(), size = out.size()]() { + std::fill_n(out_ptr, size, 0); + }); return; } @@ -272,29 +380,61 @@ void GatherMM::eval_cpu(const std::vector& inputs, array& out) { const uint32_t* lhs_indices_ptr = lhs_indices.data(); const uint32_t* rhs_indices_ptr = rhs_indices.data(); + encoder.set_input_array(a); + encoder.set_input_array(b); + encoder.set_input_array(lhs_indices); + encoder.set_input_array(rhs_indices); + encoder.set_output_array(out); + auto ldc = out.shape(-1); - for (int i = 0; i < batch_size_out; i++) { - // Get index - uint32_t indx_A = lhs_indices_ptr[elem_to_loc(i, lhs_indices)]; - uint32_t indx_B = rhs_indices_ptr[elem_to_loc(i, rhs_indices)]; + encoder.dispatch([a_ptr = a.data(), + b_ptr = b.data(), + out_ptr = out.data(), + M, + N, + K, + lda = lda, + ldb = ldb, + a_transposed = a_transposed, + b_transposed = b_transposed, + ldc, + lhs_indices_ptr, + rhs_indices_ptr, + lhs_indices_shape = lhs_indices.shape(), + lhs_indices_strides = lhs_indices.strides(), + rhs_indices_shape = rhs_indices.shape(), + rhs_indices_strides = rhs_indices.strides(), + batch_size_out, + matrix_stride_out, + batch_shape_A = std::move(batch_shape_A), + batch_shape_B = std::move(batch_shape_B), + batch_strides_A = std::move(batch_strides_A), + batch_strides_B = std::move(batch_strides_B)]() { + for (int i = 0; i < batch_size_out; i++) { + // Get index + uint32_t indx_A = lhs_indices_ptr[elem_to_loc( + i, lhs_indices_shape, lhs_indices_strides)]; + uint32_t indx_B = rhs_indices_ptr[elem_to_loc( + i, rhs_indices_shape, rhs_indices_strides)]; - cblas_sgemm( - CblasRowMajor, - a_transposed ? CblasTrans : CblasNoTrans, // transA - b_transposed ? CblasTrans : CblasNoTrans, // transB - M, - N, - K, - 1.0f, // alpha - a.data() + elem_to_loc(indx_A, batch_shape_A, batch_strides_A), - lda, - b.data() + elem_to_loc(indx_B, batch_shape_B, batch_strides_B), - ldb, - 0.0f, // beta - out.data() + matrix_stride_out * i, - out.shape(-1) // ldc - ); - } + cblas_sgemm( + CblasRowMajor, + a_transposed ? CblasTrans : CblasNoTrans, // transA + b_transposed ? CblasTrans : CblasNoTrans, // transB + M, + N, + K, + 1.0f, // alpha + a_ptr + elem_to_loc(indx_A, batch_shape_A, batch_strides_A), + lda, + b_ptr + elem_to_loc(indx_B, batch_shape_B, batch_strides_B), + ldb, + 0.0f, // beta + out_ptr + matrix_stride_out * i, + ldc); + } + }); + encoder.add_temporaries(std::move(temps)); } } // namespace mlx::core diff --git a/mlx/backend/cpu/matmul.cpp b/mlx/backend/cpu/matmul.cpp index 0712ea2d3..082531894 100644 --- a/mlx/backend/cpu/matmul.cpp +++ b/mlx/backend/cpu/matmul.cpp @@ -3,18 +3,76 @@ #include #include "mlx/array.h" #include "mlx/backend/cpu/copy.h" +#include "mlx/backend/cpu/encoder.h" #include "mlx/backend/cpu/gemm.h" #include "mlx/primitives.h" namespace mlx::core { +template +void matmul_dispatch( + const array& a, + const array& b, + array& out, + bool a_transposed, + bool b_transposed, + size_t lda, + size_t ldb, + float alpha, + float beta, + Stream stream) { + const T* a_ptr = a.data(); + const T* b_ptr = b.data(); + T* out_ptr = out.data(); + size_t ldc = out.shape(-1); + size_t batch_size = a.size() / (a.shape(-2) * a.shape(-1)); + auto& encoder = cpu::get_command_encoder(stream); + encoder.set_input_array(a); + encoder.set_input_array(b); + encoder.set_output_array(out); + encoder.dispatch([a_ptr, + b_ptr, + out_ptr, + a_transposed, + b_transposed, + lda, + ldb, + ldc, + alpha, + beta, + batch_size, + a_shape = a.shape(), + a_strides = a.strides(), + b_shape = b.shape(), + b_strides = b.strides()]() { + matmul( + a_ptr, + b_ptr, + out_ptr, + a_transposed, + b_transposed, + lda, + ldb, + ldc, + alpha, + beta, + batch_size, + a_shape, + a_strides, + b_shape, + b_strides); + }); +} + void matmul_general( const array& a_pre, const array& b_pre, array& out, + Stream stream, float alpha = 1.0f, float beta = 0.0f) { - auto check_transpose = [](const array& arr) { + std::vector temps; + auto check_transpose = [stream, &temps](const array& arr) { auto stx = arr.strides()[arr.ndim() - 2]; auto sty = arr.strides()[arr.ndim() - 1]; if (stx == arr.shape(-1) && sty == 1) { @@ -22,10 +80,10 @@ void matmul_general( } else if (stx == 1 && sty == arr.shape(-2)) { return std::make_tuple(true, sty, arr); } else { - array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); - copy(arr, arr_copy, CopyType::General); + temps.push_back(array(arr.shape(), arr.dtype(), nullptr, {})); + copy(arr, temps.back(), CopyType::General, stream); stx = arr.shape(-1); - return std::make_tuple(false, stx, arr_copy); + return std::make_tuple(false, stx, temps.back()); } }; @@ -39,28 +97,34 @@ void matmul_general( } if (out.dtype() == float32) { - matmul(a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta); + matmul_dispatch( + a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta, stream); } else if (out.dtype() == float16) { - matmul( - a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta); + matmul_dispatch( + a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta, stream); } else if (out.dtype() == bfloat16) { - matmul( - a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta); + matmul_dispatch( + a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta, stream); } else if (out.dtype() == float64) { - matmul( - a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta); + matmul_dispatch( + a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta, stream); } else { throw std::runtime_error("[Matmul::eval_cpu] Invalid type."); } + cpu::get_command_encoder(stream).add_temporaries(std::move(temps)); } void Matmul::eval_cpu(const std::vector& inputs, array& out) { out.set_data(allocator::malloc_or_wait(out.nbytes())); if (inputs[0].shape(-1) == 0) { - std::memset(out.data(), 0, out.nbytes()); + auto& encoder = cpu::get_command_encoder(stream()); + encoder.set_output_array(out); + encoder.dispatch([out_ptr = out.data(), nbytes = out.nbytes()]() { + std::memset(out_ptr, 0, nbytes); + }); return; } - return matmul_general(inputs[0], inputs[1], out); + matmul_general(inputs[0], inputs[1], out, stream()); } void AddMM::eval_cpu(const std::vector& inputs, array& out) { @@ -74,9 +138,9 @@ void AddMM::eval_cpu(const std::vector& inputs, array& out) { CopyType ctype = c.data_size() == 1 ? CopyType::Scalar : (c.flags().row_contiguous ? CopyType::Vector : CopyType::General); - copy(c, out, ctype); + copy(c, out, ctype, stream()); - return matmul_general(inputs[0], inputs[1], out, alpha_, beta_); + matmul_general(inputs[0], inputs[1], out, stream(), alpha_, beta_); } } // namespace mlx::core diff --git a/mlx/backend/cpu/primitives.cpp b/mlx/backend/cpu/primitives.cpp index b5d9d7ef3..a561f1de8 100644 --- a/mlx/backend/cpu/primitives.cpp +++ b/mlx/backend/cpu/primitives.cpp @@ -7,11 +7,11 @@ #include #include "mlx/allocator.h" -#include "mlx/backend/common/load.h" #include "mlx/backend/common/slicing.h" #include "mlx/backend/common/utils.h" #include "mlx/backend/cpu/arange.h" #include "mlx/backend/cpu/copy.h" +#include "mlx/backend/cpu/encoder.h" #include "mlx/backend/cpu/threefry.h" #include "mlx/primitives.h" #include "mlx/utils.h" @@ -22,39 +22,58 @@ void reshape(const array& in, array& out) { auto [copy_necessary, out_strides] = prepare_reshape(in, out); if (copy_necessary) { out.set_data(allocator::malloc_or_wait(out.nbytes())); - copy_inplace(in, out, CopyType::General); + copy_inplace(in, out, CopyType::General, out.primitive().stream()); } else { shared_buffer_reshape(in, out_strides, out); } } -int64_t compute_dynamic_offset( +static std::pair compute_dynamic_offset( const array& indices, const Strides& strides, - const std::vector& axes) { - auto compute_offset = [&strides, &axes](const auto* indices) { - int64_t offset = 0; - for (int i = 0; i < axes.size(); ++i) { - offset += indices[i] * strides[axes[i]]; - } - return offset; - }; + const std::vector& axes, + Stream stream) { + array offset({1}, int64, nullptr, {}); + bool donate = indices.is_donatable() && + (indices.data_size() * indices.itemsize()) >= offset.itemsize(); + if (donate) { + offset.copy_shared_buffer(indices); + } else { + offset.set_data(allocator::malloc_or_wait(offset.itemsize())); + } + + auto& encoder = cpu::get_command_encoder(stream); + encoder.set_input_array(indices); + encoder.set_output_array(offset); + auto compute_offset = + [strides, axes, offset = offset.data()](const auto* indices) { + int64_t offset_ = 0; + for (int i = 0; i < axes.size(); ++i) { + offset_ += indices[i] * strides[axes[i]]; + } + offset[0] = offset_; + }; switch (indices.dtype()) { case int8: case uint8: - return compute_offset(indices.data()); + encoder.dispatch(compute_offset, indices.data()); + break; case int16: case uint16: - return compute_offset(indices.data()); + encoder.dispatch(compute_offset, indices.data()); + break; case int32: case uint32: - return compute_offset(indices.data()); + encoder.dispatch(compute_offset, indices.data()); + break; case int64: case uint64: - return compute_offset(indices.data()); + encoder.dispatch(compute_offset, indices.data()); + break; default: throw std::runtime_error("Invalid indices type."); } + return {offset, donate}; } void AsStrided::eval_cpu(const std::vector& inputs, array& out) { @@ -104,14 +123,59 @@ void Transpose::eval_cpu(const std::vector& inputs, array& out) { } void Arange::eval_cpu(const std::vector& inputs, array& out) { - arange(inputs, out, start_, step_); + assert(inputs.size() == 0); + out.set_data(allocator::malloc_or_wait(out.nbytes())); + switch (out.dtype()) { + case bool_: + throw std::runtime_error("Bool type unsupported for arange."); + break; + case uint8: + arange(start_, start_ + step_, out, out.size(), stream()); + break; + case uint16: + arange(start_, start_ + step_, out, out.size(), stream()); + break; + case uint32: + arange(start_, start_ + step_, out, out.size(), stream()); + break; + case uint64: + arange(start_, start_ + step_, out, out.size(), stream()); + break; + case int8: + arange(start_, start_ + step_, out, out.size(), stream()); + break; + case int16: + arange(start_, start_ + step_, out, out.size(), stream()); + break; + case int32: + arange(start_, start_ + step_, out, out.size(), stream()); + break; + case int64: + arange(start_, start_ + step_, out, out.size(), stream()); + break; + case float16: + arange(start_, start_ + step_, out, out.size(), stream()); + break; + case float32: + arange(start_, start_ + step_, out, out.size(), stream()); + break; + case float64: + arange(start_, start_ + step_, out, out.size(), stream()); + break; + case bfloat16: + arange(start_, start_ + step_, out, out.size(), stream()); + break; + case complex64: + arange(start_, start_ + step_, out, out.size(), stream()); + break; + } } void AsType::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& in = inputs[0]; CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General; - copy(in, out, ctype); + copy(in, out, ctype, stream()); } void Concatenate::eval_cpu(const std::vector& inputs, array& out) { @@ -134,7 +198,7 @@ void Concatenate::eval_cpu(const std::vector& inputs, array& out) { size_t data_offset = strides[axis_] * sizes[i]; out_slice.copy_shared_buffer( out, strides, flags, out_slice.size(), data_offset); - copy_inplace(inputs[i], out_slice, CopyType::GeneralGeneral); + copy_inplace(inputs[i], out_slice, CopyType::GeneralGeneral, stream()); } } @@ -145,7 +209,7 @@ void Contiguous::eval_cpu(const std::vector& inputs, array& out) { (allow_col_major_ && in.flags().col_contiguous)) { out.copy_shared_buffer(in); } else { - copy(in, out, CopyType::General); + copy(in, out, CopyType::General, stream()); } } @@ -169,14 +233,7 @@ void Full::eval_cpu(const std::vector& inputs, array& out) { } else { ctype = CopyType::General; } - copy(in, out, ctype); -} - -void Load::eval_cpu(const std::vector& inputs, array& out) { - assert(inputs.size() == 0); - out.set_data(allocator::malloc_or_wait(out.nbytes())); - - load(out, offset_, reader_, swap_endianness_); + copy(in, out, ctype, stream()); } void Pad::eval_cpu(const std::vector& inputs, array& out) { @@ -192,7 +249,7 @@ void Pad::eval_cpu(const std::vector& inputs, array& out) { assert(val.dtype() == in.dtype() && in.dtype() == out.dtype()); // Fill output with val - copy(val, out, CopyType::Scalar); + copy(val, out, CopyType::Scalar, stream()); // Find offset for start of input values size_t data_offset = 0; @@ -207,7 +264,7 @@ void Pad::eval_cpu(const std::vector& inputs, array& out) { out, out.strides(), out.flags(), out_slice.size(), data_offset); // Copy input values into the slice - copy_inplace(in, out_slice, CopyType::GeneralGeneral); + copy_inplace(in, out_slice, CopyType::GeneralGeneral, stream()); } void RandomBits::eval_cpu(const std::vector& inputs, array& out) { @@ -223,39 +280,49 @@ void RandomBits::eval_cpu(const std::vector& inputs, array& out) { auto kptr = inputs[0].data(); auto cptr = out.data(); - size_t out_skip = (bytes_per_key + 4 - 1) / 4; - auto half_size = out_skip / 2; - bool even = out_skip % 2 == 0; - for (int i = 0; i < num_keys; ++i, cptr += bytes_per_key) { - auto ptr = reinterpret_cast(cptr); - // Get ith key - auto kidx = 2 * i; - auto k1_elem = elem_to_loc(kidx, keys.shape(), keys.strides()); - auto k2_elem = elem_to_loc(kidx + 1, keys.shape(), keys.strides()); - auto key = std::make_pair(kptr[k1_elem], kptr[k2_elem]); + auto& encoder = cpu::get_command_encoder(stream()); + encoder.set_input_array(inputs[0]); + encoder.set_output_array(out); + encoder.dispatch([kptr, + cptr, + bytes_per_key, + num_keys, + kshape = keys.shape(), + kstrides = keys.strides()]() mutable { + size_t out_skip = (bytes_per_key + 4 - 1) / 4; + auto half_size = out_skip / 2; + bool even = out_skip % 2 == 0; + for (int i = 0; i < num_keys; ++i, cptr += bytes_per_key) { + auto ptr = reinterpret_cast(cptr); + // Get ith key + auto kidx = 2 * i; + auto k1_elem = elem_to_loc(kidx, kshape, kstrides); + auto k2_elem = elem_to_loc(kidx + 1, kshape, kstrides); + auto key = std::make_pair(kptr[k1_elem], kptr[k2_elem]); - std::pair count{0, half_size + !even}; - for (; count.first + 1 < half_size; count.first++, count.second++) { - std::tie(ptr[count.first], ptr[count.second]) = - random::threefry2x32_hash(key, count); - } - if (count.first < half_size) { - auto rb = random::threefry2x32_hash(key, count); - ptr[count.first++] = rb.first; - if (bytes_per_key % 4 > 0) { - std::copy( - reinterpret_cast(&rb.second), - reinterpret_cast(&rb.second) + bytes_per_key % 4, - cptr + 4 * count.second); - } else { - ptr[count.second] = rb.second; + std::pair count{0, half_size + !even}; + for (; count.first + 1 < half_size; count.first++, count.second++) { + std::tie(ptr[count.first], ptr[count.second]) = + random::threefry2x32_hash(key, count); + } + if (count.first < half_size) { + auto rb = random::threefry2x32_hash(key, count); + ptr[count.first++] = rb.first; + if (bytes_per_key % 4 > 0) { + std::copy( + reinterpret_cast(&rb.second), + reinterpret_cast(&rb.second) + bytes_per_key % 4, + cptr + 4 * count.second); + } else { + ptr[count.second] = rb.second; + } + } + if (!even) { + count.second = 0; + ptr[half_size] = random::threefry2x32_hash(key, count).first; } } - if (!even) { - count.second = 0; - ptr[half_size] = random::threefry2x32_hash(key, count).first; - } - } + }); } void Reshape::eval_cpu(const std::vector& inputs, array& out) { @@ -269,16 +336,23 @@ void DynamicSlice::eval_cpu(const std::vector& inputs, array& out) { } auto& in = inputs[0]; out.set_data(allocator::malloc_or_wait(out.nbytes())); - auto i_offset = compute_dynamic_offset(inputs[1], in.strides(), axes_); + auto [in_offset, donated] = + compute_dynamic_offset(inputs[1], in.strides(), axes_, stream()); copy_inplace( /* const array& src = */ in, /* array& dst = */ out, /* const Shape& data_shape = */ out.shape(), /* const Strides& i_strides = */ in.strides(), /* const Strides& o_strides = */ out.strides(), - /* int64_t i_offset = */ i_offset, + /* int64_t i_offset = */ 0, /* int64_t o_offset = */ 0, - /* CopyType ctype = */ CopyType::GeneralGeneral); + /* CopyType ctype = */ CopyType::GeneralGeneral, + stream(), + /* const std::optional& dynamic_i_offset = */ in_offset, + /* const std::optional& dynamic_o_offset = */ std::nullopt); + if (!donated) { + cpu::get_command_encoder(stream()).add_temporary(std::move(in_offset)); + } } void DynamicSliceUpdate::eval_cpu( @@ -296,9 +370,10 @@ void DynamicSliceUpdate::eval_cpu( auto ctype = in.flags().contiguous && in.size() == in.data_size() ? CopyType::Vector : CopyType::General; - copy(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype); + copy(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, stream()); - auto o_offset = compute_dynamic_offset(inputs[2], out.strides(), axes_); + auto [out_offset, donated] = + compute_dynamic_offset(inputs[2], out.strides(), axes_, stream()); copy_inplace( /* const array& src = */ upd, /* array& dst = */ out, @@ -306,8 +381,14 @@ void DynamicSliceUpdate::eval_cpu( /* const std::vector& i_strides = */ upd.strides(), /* const std::vector& o_strides = */ out.strides(), /* int64_t i_offset = */ 0, - /* int64_t o_offset = */ o_offset, - /* CopyType ctype = */ CopyType::GeneralGeneral); + /* int64_t o_offset = */ 0, + /* CopyType ctype = */ CopyType::GeneralGeneral, + stream(), + /* const std::optional& dynamic_i_offset = */ std::nullopt, + /* const std::optional& dynamic_o_offset = */ out_offset); + if (!donated) { + cpu::get_command_encoder(stream()).add_temporary(std::move(out_offset)); + } } void SliceUpdate::eval_cpu(const std::vector& inputs, array& out) { @@ -329,7 +410,7 @@ void SliceUpdate::eval_cpu(const std::vector& inputs, array& out) { auto ctype = in.flags().contiguous && in.size() == in.data_size() ? CopyType::Vector : CopyType::General; - copy(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype); + copy(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, stream()); // Calculate out strides, initial offset and if copy needs to be made auto [data_offset, out_strides] = @@ -344,7 +425,8 @@ void SliceUpdate::eval_cpu(const std::vector& inputs, array& out) { /* const std::vector& o_strides = */ out_strides, /* int64_t i_offset = */ 0, /* int64_t o_offset = */ data_offset, - /* CopyType ctype = */ CopyType::GeneralGeneral); + /* CopyType ctype = */ CopyType::GeneralGeneral, + stream()); } void View::eval_cpu(const std::vector& inputs, array& out) { @@ -372,9 +454,9 @@ void View::eval_cpu(const std::vector& inputs, array& out) { if (in.dtype() == bool_) { auto in_tmp = array(in.shape(), uint8, nullptr, {}); in_tmp.copy_shared_buffer(in); - copy_inplace(in_tmp, tmp, CopyType::General); + copy_inplace(in_tmp, tmp, CopyType::General, stream()); } else { - copy_inplace(in, tmp, CopyType::General); + copy_inplace(in, tmp, CopyType::General, stream()); } auto flags = out.flags(); @@ -382,7 +464,7 @@ void View::eval_cpu(const std::vector& inputs, array& out) { flags.row_contiguous = true; auto max_dim = std::max_element(out.shape().begin(), out.shape().end()); flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim; - out.move_shared_buffer(tmp, out.strides(), flags, out.size()); + out.copy_shared_buffer(tmp, out.strides(), flags, out.size()); } } diff --git a/mlx/backend/cpu/qrf.cpp b/mlx/backend/cpu/qrf.cpp index 537b63358..8c6f94140 100644 --- a/mlx/backend/cpu/qrf.cpp +++ b/mlx/backend/cpu/qrf.cpp @@ -2,20 +2,18 @@ #include "mlx/allocator.h" #include "mlx/backend/cpu/copy.h" +#include "mlx/backend/cpu/encoder.h" #include "mlx/backend/cpu/lapack.h" #include "mlx/primitives.h" namespace mlx::core { template -void qrf_impl(const array& a, array& q, array& r) { +void qrf_impl(const array& a, array& q, array& r, Stream stream) { const int M = a.shape(-2); const int N = a.shape(-1); const int lda = M; size_t num_matrices = a.size() / (M * N); - int num_reflectors = std::min(M, N); - auto tau = - allocator::malloc_or_wait(sizeof(T) * num_matrices * num_reflectors); // Copy A to inplace input and make it col-contiguous array in(a.shape(), a.dtype(), nullptr, {}); @@ -29,93 +27,107 @@ void qrf_impl(const array& a, array& q, array& r) { strides[in.ndim() - 1] = M; in.set_data( allocator::malloc_or_wait(in.nbytes()), in.nbytes(), strides, flags); - copy_inplace(a, in, CopyType::GeneralGeneral); - - T optimal_work; - int lwork = -1; - int info; - - // Compute workspace size - geqrf(&M, &N, nullptr, &lda, nullptr, &optimal_work, &lwork, &info); - - // Update workspace size - lwork = optimal_work; - auto work = allocator::malloc_or_wait(sizeof(T) * lwork); - - // Loop over matrices - for (int i = 0; i < num_matrices; ++i) { - // Solve - geqrf( - &M, - &N, - in.data() + M * N * i, - &lda, - static_cast(tau.raw_ptr()) + num_reflectors * i, - static_cast(work.raw_ptr()), - &lwork, - &info); - } - allocator::free(work); - + copy_inplace(a, in, CopyType::GeneralGeneral, stream); + auto& encoder = cpu::get_command_encoder(stream); + q.set_data(allocator::malloc_or_wait(q.nbytes())); r.set_data(allocator::malloc_or_wait(r.nbytes())); - for (int i = 0; i < num_matrices; ++i) { - /// num_reflectors x N - for (int j = 0; j < r.shape(-2); ++j) { - for (int k = 0; k < j; ++k) { - r.data()[i * N * num_reflectors + j * N + k] = 0; - } - for (int k = j; k < r.shape(-1); ++k) { - r.data()[i * N * num_reflectors + j * N + k] = - in.data()[i * N * M + j + k * M]; + auto in_ptr = in.data(); + auto r_ptr = r.data(); + auto q_ptr = q.data(); + + encoder.set_input_array(in); + encoder.set_output_array(q); + encoder.set_output_array(r); + encoder.dispatch([in_ptr, q_ptr, r_ptr, M, N, lda, num_matrices]() { + int num_reflectors = std::min(M, N); + auto tau = + allocator::malloc_or_wait(sizeof(T) * num_matrices * num_reflectors); + + T optimal_work; + int lwork = -1; + int info; + + // Compute workspace size + geqrf(&M, &N, nullptr, &lda, nullptr, &optimal_work, &lwork, &info); + + // Update workspace size + lwork = optimal_work; + auto work = allocator::malloc_or_wait(sizeof(T) * lwork); + + // Loop over matrices + for (int i = 0; i < num_matrices; ++i) { + // Solve + geqrf( + &M, + &N, + in_ptr + M * N * i, + &lda, + static_cast(tau.raw_ptr()) + num_reflectors * i, + static_cast(work.raw_ptr()), + &lwork, + &info); + } + allocator::free(work); + + for (int i = 0; i < num_matrices; ++i) { + /// num_reflectors x N + for (int j = 0; j < num_reflectors; ++j) { + for (int k = 0; k < j; ++k) { + r_ptr[i * N * num_reflectors + j * N + k] = 0; + } + for (int k = j; k < N; ++k) { + r_ptr[i * N * num_reflectors + j * N + k] = + in_ptr[i * N * M + j + k * M]; + } } } - } - // Get work size - lwork = -1; - orgqr( - &M, - &num_reflectors, - &num_reflectors, - nullptr, - &lda, - nullptr, - &optimal_work, - &lwork, - &info); - lwork = optimal_work; - work = allocator::malloc_or_wait(sizeof(T) * lwork); - - // Loop over matrices - for (int i = 0; i < num_matrices; ++i) { - // Compute Q + // Get work size + lwork = -1; orgqr( &M, &num_reflectors, &num_reflectors, - in.data() + M * N * i, + nullptr, &lda, - static_cast(tau.raw_ptr()) + num_reflectors * i, - static_cast(work.raw_ptr()), + nullptr, + &optimal_work, &lwork, &info); - } + lwork = optimal_work; + work = allocator::malloc_or_wait(sizeof(T) * lwork); - q.set_data(allocator::malloc_or_wait(q.nbytes())); - for (int i = 0; i < num_matrices; ++i) { - // M x num_reflectors - for (int j = 0; j < q.shape(-2); ++j) { - for (int k = 0; k < q.shape(-1); ++k) { - q.data()[i * M * num_reflectors + j * num_reflectors + k] = - in.data()[i * N * M + j + k * M]; + // Loop over matrices + for (int i = 0; i < num_matrices; ++i) { + // Compute Q + orgqr( + &M, + &num_reflectors, + &num_reflectors, + in_ptr + M * N * i, + &lda, + static_cast(tau.raw_ptr()) + num_reflectors * i, + static_cast(work.raw_ptr()), + &lwork, + &info); + } + + for (int i = 0; i < num_matrices; ++i) { + // M x num_reflectors + for (int j = 0; j < M; ++j) { + for (int k = 0; k < num_reflectors; ++k) { + q_ptr[i * M * num_reflectors + j * num_reflectors + k] = + in_ptr[i * N * M + j + k * M]; + } } } - } - // Cleanup - allocator::free(work); - allocator::free(tau); + // Cleanup + allocator::free(work); + allocator::free(tau); + }); + encoder.add_temporary(in); } void QRF::eval_cpu( @@ -123,10 +135,10 @@ void QRF::eval_cpu( std::vector& outputs) { switch (inputs[0].dtype()) { case float32: - qrf_impl(inputs[0], outputs[0], outputs[1]); + qrf_impl(inputs[0], outputs[0], outputs[1], stream()); break; case float64: - qrf_impl(inputs[0], outputs[0], outputs[1]); + qrf_impl(inputs[0], outputs[0], outputs[1], stream()); break; default: throw std::runtime_error( diff --git a/mlx/backend/cpu/quantized.cpp b/mlx/backend/cpu/quantized.cpp index da7788971..526f0b716 100644 --- a/mlx/backend/cpu/quantized.cpp +++ b/mlx/backend/cpu/quantized.cpp @@ -3,6 +3,7 @@ #include #include "mlx/backend/cpu/copy.h" +#include "mlx/backend/cpu/encoder.h" #include "mlx/backend/cpu/simd/simd.h" #include "mlx/fast_primitives.h" #include "mlx/primitives.h" @@ -316,6 +317,76 @@ void _qmm_dispatch_typed( } } +template +void _qmm_dispatch_typed( + array& out, + const array& x, + const array& w, + const array& scales, + const array& biases, + int bits, + int group_size, + bool transposed_w, + Stream stream) { + int K = x.shape(-1); + int M = x.ndim() > 1 ? x.shape(-2) : 1; + int N = out.shape(-1); + int w_els = w.ndim() > 2 ? w.shape(-1) * w.shape(-2) : 0; + int g_els = w.ndim() > 2 ? scales.shape(-1) * scales.shape(-2) : 0; + int batch_size = x.size() / (K * M); + + auto& encoder = cpu::get_command_encoder(stream); + encoder.set_input_array(x); + encoder.set_input_array(w); + encoder.set_input_array(scales); + encoder.set_input_array(biases); + encoder.set_output_array(out); + + auto out_ptr = out.data(); + auto x_ptr = x.data(); + auto w_ptr = w.data(); + auto scales_ptr = scales.data(); + auto biases_ptr = biases.data(); + + encoder.dispatch([out_ptr, + x_ptr, + w_ptr, + scales_ptr, + biases_ptr, + x_shape = x.shape(), + x_strides = x.strides(), + w_shape = w.shape(), + w_strides = w.strides(), + scales_shape = scales.shape(), + scales_strides = scales.strides(), + biases_shape = biases.shape(), + biases_strides = biases.strides(), + w_els, + g_els, + batch_size, + M, + N, + K, + bits, + group_size, + transposed_w] { + for (int i = 0; i < batch_size; i++) { + _qmm_dispatch_typed( + out_ptr + i * M * N, + x_ptr + elem_to_loc(i * M * K, x_shape, x_strides), + w_ptr + elem_to_loc(i * w_els, w_shape, w_strides), + scales_ptr + elem_to_loc(i * g_els, scales_shape, scales_strides), + biases_ptr + elem_to_loc(i * g_els, biases_shape, biases_strides), + M, + N, + K, + bits, + group_size, + transposed_w); + } + }); +} + void _qmm_dispatch( array& out, const array& x, @@ -324,64 +395,111 @@ void _qmm_dispatch( const array& biases, int bits, int group_size, - bool transposed_w) { + bool transposed_w, + Stream stream) { + switch (x.dtype()) { + case float32: + _qmm_dispatch_typed( + out, x, w, scales, biases, bits, group_size, transposed_w, stream); + break; + case float16: + _qmm_dispatch_typed( + out, x, w, scales, biases, bits, group_size, transposed_w, stream); + break; + case bfloat16: + _qmm_dispatch_typed( + out, x, w, scales, biases, bits, group_size, transposed_w, stream); + break; + default: + throw std::invalid_argument( + "[quantized_matmul] only floating types are supported"); + } +} + +template +void _bs_qmm_dispatch_typed( + array& out, + const array& x, + const array& w, + const array& scales, + const array& biases, + const array& lhs_indices, + const array& rhs_indices, + int bits, + int group_size, + bool transposed_w, + Stream stream) { int K = x.shape(-1); - int M = x.ndim() > 1 ? x.shape(-2) : 1; + int M = x.shape(-2); int N = out.shape(-1); - int w_els = w.ndim() > 2 ? w.shape(-1) * w.shape(-2) : 0; - int g_els = w.ndim() > 2 ? scales.shape(-1) * scales.shape(-2) : 0; + int w_els = w.shape(-1) * w.shape(-2); + int g_els = scales.shape(-1) * scales.shape(-2); - int batch_size = x.size() / (K * M); - for (int i = 0; i < batch_size; i++) { - switch (x.dtype()) { - case float32: - _qmm_dispatch_typed( - out.data() + i * M * N, - x.data() + elem_to_loc(i * M * K, x), - w.data() + elem_to_loc(i * w_els, w), - scales.data() + elem_to_loc(i * g_els, scales), - biases.data() + elem_to_loc(i * g_els, biases), - M, - N, - K, - bits, - group_size, - transposed_w); - break; - case float16: - _qmm_dispatch_typed( - out.data() + i * M * N, - x.data() + elem_to_loc(i * M * K, x), - w.data() + elem_to_loc(i * w_els, w), - scales.data() + elem_to_loc(i * g_els, scales), - biases.data() + elem_to_loc(i * g_els, biases), - M, - N, - K, - bits, - group_size, - transposed_w); - break; - case bfloat16: - _qmm_dispatch_typed( - out.data() + i * M * N, - x.data() + elem_to_loc(i * M * K, x), - w.data() + elem_to_loc(i * w_els, w), - scales.data() + elem_to_loc(i * g_els, scales), - biases.data() + elem_to_loc(i * g_els, biases), - M, - N, - K, - bits, - group_size, - transposed_w); - break; - default: - throw std::invalid_argument( - "[quantized_matmul] only floating types are supported"); + auto& encoder = cpu::get_command_encoder(stream); + encoder.set_input_array(x); + encoder.set_input_array(w); + encoder.set_input_array(scales); + encoder.set_input_array(biases); + encoder.set_input_array(lhs_indices); + encoder.set_input_array(rhs_indices); + encoder.set_output_array(out); + + auto out_ptr = out.data(); + auto x_ptr = x.data(); + auto w_ptr = w.data(); + auto scales_ptr = scales.data(); + auto biases_ptr = biases.data(); + auto lhs_indices_ptr = lhs_indices.data(); + auto rhs_indices_ptr = rhs_indices.data(); + + encoder.dispatch([out_ptr, + x_ptr, + w_ptr, + scales_ptr, + biases_ptr, + lhs_indices_ptr, + rhs_indices_ptr, + x_shape = x.shape(), + x_strides = x.strides(), + w_shape = w.shape(), + w_strides = w.strides(), + scales_shape = scales.shape(), + scales_strides = scales.strides(), + biases_shape = biases.shape(), + biases_strides = biases.strides(), + lhs_indices_shape = lhs_indices.shape(), + lhs_indices_strides = lhs_indices.strides(), + rhs_indices_shape = rhs_indices.shape(), + rhs_indices_strides = rhs_indices.strides(), + w_els, + g_els, + indices_size = lhs_indices.size(), + M, + N, + K, + bits, + group_size, + transposed_w]() { + for (int i = 0; i < indices_size; i++) { + int x_idx = lhs_indices_ptr[elem_to_loc( + i, lhs_indices_shape, lhs_indices_strides)]; + int w_idx = rhs_indices_ptr[elem_to_loc( + i, rhs_indices_shape, rhs_indices_strides)]; + _qmm_dispatch_typed( + out_ptr + i * M * N, + x_ptr + elem_to_loc(x_idx * M * K, x_shape, x_strides), + w_ptr + elem_to_loc(w_idx * w_els, w_shape, w_strides), + scales_ptr + elem_to_loc(w_idx * g_els, scales_shape, scales_strides), + biases_ptr + elem_to_loc(w_idx * g_els, biases_shape, biases_strides), + M, + N, + K, + bits, + group_size, + transposed_w); } - } + }); } void _bs_qmm_dispatch( @@ -394,68 +512,54 @@ void _bs_qmm_dispatch( const array& rhs_indices, int bits, int group_size, - bool transposed_w) { - int K = x.shape(-1); - int M = x.shape(-2); - int N = out.shape(-1); - - int w_els = w.shape(-1) * w.shape(-2); - int g_els = scales.shape(-1) * scales.shape(-2); - - const uint32_t* lhs_indices_data = lhs_indices.data(); - const uint32_t* rhs_indices_data = rhs_indices.data(); - - for (int i = 0; i < lhs_indices.size(); i++) { - int x_idx = lhs_indices_data[elem_to_loc(i, lhs_indices)]; - int w_idx = rhs_indices_data[elem_to_loc(i, rhs_indices)]; - - switch (x.dtype()) { - case float32: - _qmm_dispatch_typed( - out.data() + i * M * N, - x.data() + elem_to_loc(x_idx * M * K, x), - w.data() + elem_to_loc(w_idx * w_els, w), - scales.data() + elem_to_loc(w_idx * g_els, scales), - biases.data() + elem_to_loc(w_idx * g_els, biases), - M, - N, - K, - bits, - group_size, - transposed_w); - break; - case float16: - _qmm_dispatch_typed( - out.data() + i * M * N, - x.data() + elem_to_loc(x_idx * M * K, x), - w.data() + elem_to_loc(w_idx * w_els, w), - scales.data() + elem_to_loc(w_idx * g_els, scales), - biases.data() + elem_to_loc(w_idx * g_els, biases), - M, - N, - K, - bits, - group_size, - transposed_w); - break; - case bfloat16: - _qmm_dispatch_typed( - out.data() + i * M * N, - x.data() + elem_to_loc(x_idx * M * K, x), - w.data() + elem_to_loc(w_idx * w_els, w), - scales.data() + elem_to_loc(w_idx * g_els, scales), - biases.data() + elem_to_loc(w_idx * g_els, biases), - M, - N, - K, - bits, - group_size, - transposed_w); - break; - default: - throw std::invalid_argument( - "[quantized_matmul] only floating types are supported"); - } + bool transposed_w, + Stream stream) { + switch (x.dtype()) { + case float32: + _bs_qmm_dispatch_typed( + out, + x, + w, + scales, + biases, + lhs_indices, + rhs_indices, + bits, + group_size, + transposed_w, + stream); + break; + case float16: + _bs_qmm_dispatch_typed( + out, + x, + w, + scales, + biases, + lhs_indices, + rhs_indices, + bits, + group_size, + transposed_w, + stream); + break; + case bfloat16: + _bs_qmm_dispatch_typed( + out, + x, + w, + scales, + biases, + lhs_indices, + rhs_indices, + bits, + group_size, + transposed_w, + stream); + break; + default: + throw std::invalid_argument( + "[quantized_matmul] only floating types are supported"); } } @@ -469,13 +573,14 @@ void QuantizedMatmul::eval_cpu(const std::vector& inputs, array& out) { auto& scales_pre = inputs[2]; auto& biases_pre = inputs[3]; - auto ensure_row_contiguous = [](const array& arr) { + std::vector temps; + auto ensure_row_contiguous = [s = stream(), &temps](const array& arr) { if (arr.flags().row_contiguous) { return arr; } else { - array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); - copy(arr, arr_copy, CopyType::General); - return arr_copy; + temps.push_back(array(arr.shape(), arr.dtype(), nullptr, {})); + copy(arr, temps.back(), CopyType::General, s); + return temps.back(); } }; @@ -485,7 +590,10 @@ void QuantizedMatmul::eval_cpu(const std::vector& inputs, array& out) { auto biases = ensure_row_contiguous(biases_pre); out.set_data(allocator::malloc_or_wait(out.nbytes())); - _qmm_dispatch(out, x, w, scales, biases, group_size_, bits_, transpose_); + _qmm_dispatch( + out, x, w, scales, biases, group_size_, bits_, transpose_, stream()); + auto& enc = cpu::get_command_encoder(stream()); + enc.add_temporaries(std::move(temps)); } void GatherQMM::eval_cpu(const std::vector& inputs, array& out) { @@ -498,15 +606,17 @@ void GatherQMM::eval_cpu(const std::vector& inputs, array& out) { auto& lhs_indices = inputs[4]; auto& rhs_indices = inputs[5]; - auto ensure_row_contiguous_last_dims = [](const array& arr) { + std::vector temps; + auto ensure_row_contiguous_last_dims = [s = stream(), + &temps](const array& arr) { auto stride_0 = arr.strides()[arr.ndim() - 2]; auto stride_1 = arr.strides()[arr.ndim() - 1]; if (stride_0 == arr.shape(-1) && stride_1 == 1) { return arr; } else { - array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); - copy(arr, arr_copy, CopyType::General); - return arr_copy; + temps.push_back(array(arr.shape(), arr.dtype(), nullptr, {})); + copy(arr, temps.back(), CopyType::General, s); + return temps.back(); } }; @@ -526,31 +636,30 @@ void GatherQMM::eval_cpu(const std::vector& inputs, array& out) { rhs_indices, group_size_, bits_, - transpose_); + transpose_, + stream()); + auto& enc = cpu::get_command_encoder(stream()); + enc.add_temporaries(std::move(temps)); } template void quantize( - const array& w_, - array& out_, - array& scales_, - array& biases_, + const T* w, + U* out, + T* scales, + T* biases, int bits, - int group_size) { - const T* w = w_.data(); - - auto out = out_.data(); - T* scales = scales_.data(); - T* biases = biases_.data(); - + int group_size, + size_t w_size) { float n_bins = (1 << bits) - 1; float eps = 1e-7; + bool power_of_2_bits = is_power_of_2(bits); int el_per_int = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits; // For 3/6 bits we read 3 uint8s at a time instead of 1 uint32 int bytes_per_pack = power_of_2_bits ? 1 : 3; int int_per_group = group_size * bytes_per_pack / el_per_int; - size_t n_groups = w_.size() / group_size; + size_t n_groups = w_size / group_size; for (size_t i = 0; i < n_groups; ++i) { size_t w_idx = i * group_size; @@ -593,20 +702,50 @@ void quantize( } } +template +void dispatch_quantize( + const array& w, + array& out, + array& scales, + array& biases, + int bits, + int group_size, + Stream stream) { + auto w_ptr = w.data(); + auto out_ptr = out.data(); + auto scales_ptr = scales.data(); + auto biases_ptr = biases.data(); + auto& encoder = cpu::get_command_encoder(stream); + encoder.set_input_array(w); + encoder.set_input_array(scales); + encoder.set_input_array(biases); + encoder.set_output_array(out); + encoder.dispatch([w_ptr, + out_ptr, + scales_ptr, + biases_ptr, + bits, + group_size, + w_size = w.size()]() { + quantize( + w_ptr, out_ptr, scales_ptr, biases_ptr, bits, group_size, w_size); + }); +} + void fast::AffineQuantize::eval_cpu( const std::vector& inputs, std::vector& outputs) { - auto ensure_row_contiguous = [](const array& arr) { + auto ensure_row_contiguous = [s = stream()](const array& arr) { if (arr.flags().row_contiguous) { - return arr; + return std::make_pair(arr, false); } else { array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); - copy(arr, arr_copy, CopyType::General); - return arr_copy; + copy(arr, arr_copy, CopyType::General, s); + return std::make_pair(arr_copy, true); } }; - auto w = ensure_row_contiguous(inputs[0]); + auto [w, copied] = ensure_row_contiguous(inputs[0]); auto& out = outputs[0]; out.set_data(allocator::malloc_or_wait(out.nbytes())); @@ -616,27 +755,35 @@ void fast::AffineQuantize::eval_cpu( biases.set_data(allocator::malloc_or_wait(biases.nbytes())); if (w.dtype() == float16) { if (is_power_of_2(bits_)) { - quantize(w, out, scales, biases, bits_, group_size_); + dispatch_quantize( + w, out, scales, biases, bits_, group_size_, stream()); } else { - quantize(w, out, scales, biases, bits_, group_size_); + dispatch_quantize( + w, out, scales, biases, bits_, group_size_, stream()); } } else if (w.dtype() == bfloat16) { if (is_power_of_2(bits_)) { - quantize( - w, out, scales, biases, bits_, group_size_); + dispatch_quantize( + w, out, scales, biases, bits_, group_size_, stream()); } else { - quantize(w, out, scales, biases, bits_, group_size_); + dispatch_quantize( + w, out, scales, biases, bits_, group_size_, stream()); } } else if (w.dtype() == float32) { if (is_power_of_2(bits_)) { - quantize(w, out, scales, biases, bits_, group_size_); + dispatch_quantize( + w, out, scales, biases, bits_, group_size_, stream()); } else { - quantize(w, out, scales, biases, bits_, group_size_); + dispatch_quantize( + w, out, scales, biases, bits_, group_size_, stream()); } } else { throw std::runtime_error( "[fast::AffineQuantize::eval_cpu] Only supports floating point inputs"); } + if (copied) { + cpu::get_command_encoder(stream()).add_temporary(w); + } } } // namespace mlx::core diff --git a/mlx/backend/cpu/reduce.cpp b/mlx/backend/cpu/reduce.cpp index 1c93caaf1..424894cfd 100644 --- a/mlx/backend/cpu/reduce.cpp +++ b/mlx/backend/cpu/reduce.cpp @@ -5,6 +5,7 @@ #include #include "mlx/backend/common/reduce.h" +#include "mlx/backend/cpu/encoder.h" #include "mlx/backend/cpu/simd/simd.h" #include "mlx/primitives.h" @@ -140,25 +141,33 @@ void reduction_op( array& out, const std::vector& axes, U init, - Op op) { + Stream stream) { out.set_data(allocator::malloc_or_wait(out.nbytes())); ReductionPlan plan = get_reduction_plan(x, axes); + auto& encoder = cpu::get_command_encoder(stream); + encoder.set_input_array(x); + encoder.set_output_array(out); + + auto in_ptr = x.data(); + auto out_ptr = out.data(); if (plan.type == ContiguousAllReduce) { - U* out_ptr = out.data(); - *out_ptr = init; - contiguous_reduce(x.data(), out_ptr, x.size(), op, init); + encoder.dispatch([in_ptr, out_ptr, init, size = x.size()]() { + *out_ptr = init; + contiguous_reduce(in_ptr, out_ptr, size, Op{}, init); + }); return; } if (plan.type == ContiguousReduce && plan.shape.size() == 1) { int reduction_size = plan.shape[0]; - const T* x_ptr = x.data(); - U* out_ptr = out.data(); - for (int i = 0; i < out.size(); i++, out_ptr++, x_ptr += reduction_size) { - *out_ptr = init; - contiguous_reduce(x_ptr, out_ptr, reduction_size, op, init); - } + encoder.dispatch( + [in_ptr, out_ptr, init, reduction_size, size = out.size()]() mutable { + for (int i = 0; i < size; i++, out_ptr++, in_ptr += reduction_size) { + *out_ptr = init; + contiguous_reduce(in_ptr, out_ptr, reduction_size, Op{}, init); + } + }); return; } @@ -166,34 +175,43 @@ void reduction_op( int reduction_size = plan.shape.back(); plan.shape.pop_back(); plan.strides.pop_back(); - const T* x_ptr = x.data(); - U* out_ptr = out.data(); // Unrolling the following loop (and implementing it in order for // ContiguousReduce) should hold extra performance boost. auto [shape, strides] = shapes_without_reduction_axes(x, axes); - if (plan.shape.size() == 0) { - for (int i = 0; i < out.size(); i++, out_ptr++) { - int offset = elem_to_loc(i, shape, strides); - *out_ptr = init; - contiguous_reduce(x_ptr + offset, out_ptr, reduction_size, op, init); + + encoder.dispatch([in_ptr, + out_ptr, + init, + reduction_size, + size = out.size(), + plan = std::move(plan), + shape = std::move(shape), + strides = std::move(strides)]() mutable { + if (plan.shape.size() == 0) { + for (int i = 0; i < size; i++, out_ptr++) { + int offset = elem_to_loc(i, shape, strides); + *out_ptr = init; + contiguous_reduce( + in_ptr + offset, out_ptr, reduction_size, Op{}, init); + } + } else { + for (int i = 0; i < size; i++, out_ptr++) { + int offset = elem_to_loc(i, shape, strides); + *out_ptr = init; + nd_loop( + [&](int extra_offset) { + contiguous_reduce( + in_ptr + offset + extra_offset, + out_ptr, + reduction_size, + Op{}, + init); + }, + plan.shape, + plan.strides); + } } - } else { - for (int i = 0; i < out.size(); i++, out_ptr++) { - int offset = elem_to_loc(i, shape, strides); - *out_ptr = init; - nd_loop( - [&](int extra_offset) { - contiguous_reduce( - x_ptr + offset + extra_offset, - out_ptr, - reduction_size, - op, - init); - }, - plan.shape, - plan.strides); - } - } + }); return; } @@ -202,14 +220,20 @@ void reduction_op( size_t reduction_stride = plan.strides.back(); plan.shape.pop_back(); plan.strides.pop_back(); - const T* x_ptr = x.data(); - U* out_ptr = out.data(); - for (int i = 0; i < out.size(); i += reduction_stride) { - std::fill_n(out_ptr, reduction_stride, init); - strided_reduce(x_ptr, out_ptr, reduction_size, reduction_stride, op); - x_ptr += reduction_stride * reduction_size; - out_ptr += reduction_stride; - } + + encoder.dispatch([in_ptr, + out_ptr, + init, + reduction_size, + reduction_stride, + size = out.size()]() mutable { + for (int i = 0; i < size; i += reduction_stride) { + std::fill_n(out_ptr, reduction_stride, init); + strided_reduce(in_ptr, out_ptr, reduction_size, reduction_stride, Op{}); + in_ptr += reduction_stride * reduction_size; + out_ptr += reduction_stride; + } + }); return; } @@ -219,53 +243,69 @@ void reduction_op( size_t reduction_stride = plan.strides.back(); plan.shape.pop_back(); plan.strides.pop_back(); - const T* x_ptr = x.data(); - U* out_ptr = out.data(); auto [shape, strides] = shapes_without_reduction_axes(x, axes); - if (plan.shape.size() == 0) { - for (int i = 0; i < out.size(); i += reduction_stride) { - int offset = elem_to_loc(i, shape, strides); - std::fill_n(out_ptr, reduction_stride, init); - strided_reduce( - x_ptr + offset, out_ptr, reduction_size, reduction_stride, op); - out_ptr += reduction_stride; + + encoder.dispatch([in_ptr, + out_ptr, + init, + reduction_size, + reduction_stride, + size = out.size(), + plan = std::move(plan), + shape = std::move(shape), + strides = std::move(strides)]() mutable { + if (plan.shape.size() == 0) { + for (int i = 0; i < size; i += reduction_stride) { + int offset = elem_to_loc(i, shape, strides); + std::fill_n(out_ptr, reduction_stride, init); + strided_reduce( + in_ptr + offset, out_ptr, reduction_size, reduction_stride, Op{}); + out_ptr += reduction_stride; + } + } else { + for (int i = 0; i < size; i += reduction_stride) { + int offset = elem_to_loc(i, shape, strides); + std::fill_n(out_ptr, reduction_stride, init); + nd_loop( + [&](int extra_offset) { + strided_reduce( + in_ptr + offset + extra_offset, + out_ptr, + reduction_size, + reduction_stride, + Op{}); + }, + plan.shape, + plan.strides); + out_ptr += reduction_stride; + } } - } else { - for (int i = 0; i < out.size(); i += reduction_stride) { - int offset = elem_to_loc(i, shape, strides); - std::fill_n(out_ptr, reduction_stride, init); - nd_loop( - [&](int extra_offset) { - strided_reduce( - x_ptr + offset + extra_offset, - out_ptr, - reduction_size, - reduction_stride, - op); - }, - plan.shape, - plan.strides); - out_ptr += reduction_stride; - } - } + }); return; } if (plan.type == GeneralReduce) { - const T* x_ptr = x.data(); - U* out_ptr = out.data(); auto [shape, strides] = shapes_without_reduction_axes(x, axes); - for (int i = 0; i < out.size(); i++, out_ptr++) { - int offset = elem_to_loc(i, shape, strides); - U val = init; - nd_loop( - [&](int extra_offset) { - val = op(val, *(x_ptr + offset + extra_offset)); - }, - plan.shape, - plan.strides); - *out_ptr = val; - } + + encoder.dispatch([in_ptr, + out_ptr, + init, + size = out.size(), + plan = std::move(plan), + shape = std::move(shape), + strides = std::move(strides)]() mutable { + for (int i = 0; i < size; i++, out_ptr++) { + int offset = elem_to_loc(i, shape, strides); + U val = init; + nd_loop( + [&](int extra_offset) { + val = Op{}(val, *(in_ptr + offset + extra_offset)); + }, + plan.shape, + plan.strides); + *out_ptr = val; + } + }); } } @@ -394,11 +434,12 @@ void reduce_dispatch_and_or( const array& in, array& out, Reduce::ReduceType rtype, - const std::vector& axes) { + const std::vector& axes, + Stream stream) { if (rtype == Reduce::And) { - reduction_op(in, out, axes, true, AndReduce()); + reduction_op(in, out, axes, true, stream); } else { - reduction_op(in, out, axes, false, OrReduce()); + reduction_op(in, out, axes, false, stream); } } @@ -407,18 +448,19 @@ void reduce_dispatch_sum_prod( const array& in, array& out, Reduce::ReduceType rtype, - const std::vector& axes) { + const std::vector& axes, + Stream stream) { if (rtype == Reduce::Sum) { if constexpr (std::is_integral_v && sizeof(InT) <= 4) { - reduction_op(in, out, axes, 0, SumReduce()); + reduction_op(in, out, axes, 0, stream); } else { - reduction_op(in, out, axes, 0, SumReduce()); + reduction_op(in, out, axes, 0, stream); } } else { if constexpr (std::is_integral_v && sizeof(InT) <= 4) { - reduction_op(in, out, axes, 1, ProdReduce()); + reduction_op(in, out, axes, 1, stream); } else { - reduction_op(in, out, axes, 1, ProdReduce()); + reduction_op(in, out, axes, 1, stream); } } } @@ -428,13 +470,14 @@ void reduce_dispatch_min_max( const array& in, array& out, Reduce::ReduceType rtype, - const std::vector& axes) { + const std::vector& axes, + Stream stream) { if (rtype == Reduce::Max) { auto init = Limits::min; - reduction_op(in, out, axes, init, MaxReduce()); + reduction_op(in, out, axes, init, stream); } else { auto init = Limits::max; - reduction_op(in, out, axes, init, MinReduce()); + reduction_op(in, out, axes, init, stream); } } @@ -448,24 +491,28 @@ void Reduce::eval_cpu(const std::vector& inputs, array& out) { case bool_: case uint8: case int8: - reduce_dispatch_and_or(in, out, reduce_type_, axes_); + reduce_dispatch_and_or( + in, out, reduce_type_, axes_, stream()); break; case int16: case uint16: case float16: case bfloat16: - reduce_dispatch_and_or(in, out, reduce_type_, axes_); + reduce_dispatch_and_or( + in, out, reduce_type_, axes_, stream()); break; case uint32: case int32: case float32: - reduce_dispatch_and_or(in, out, reduce_type_, axes_); + reduce_dispatch_and_or( + in, out, reduce_type_, axes_, stream()); break; case uint64: case int64: case float64: case complex64: - reduce_dispatch_and_or(in, out, reduce_type_, axes_); + reduce_dispatch_and_or( + in, out, reduce_type_, axes_, stream()); break; } break; @@ -476,34 +523,43 @@ void Reduce::eval_cpu(const std::vector& inputs, array& out) { case bool_: case uint8: case int8: - reduce_dispatch_sum_prod(in, out, reduce_type_, axes_); + reduce_dispatch_sum_prod( + in, out, reduce_type_, axes_, stream()); break; case int16: case uint16: - reduce_dispatch_sum_prod(in, out, reduce_type_, axes_); + reduce_dispatch_sum_prod( + in, out, reduce_type_, axes_, stream()); break; case int32: case uint32: - reduce_dispatch_sum_prod(in, out, reduce_type_, axes_); + reduce_dispatch_sum_prod( + in, out, reduce_type_, axes_, stream()); break; case int64: case uint64: - reduce_dispatch_sum_prod(in, out, reduce_type_, axes_); + reduce_dispatch_sum_prod( + in, out, reduce_type_, axes_, stream()); break; case float16: - reduce_dispatch_sum_prod(in, out, reduce_type_, axes_); + reduce_dispatch_sum_prod( + in, out, reduce_type_, axes_, stream()); break; case bfloat16: - reduce_dispatch_sum_prod(in, out, reduce_type_, axes_); + reduce_dispatch_sum_prod( + in, out, reduce_type_, axes_, stream()); break; case float32: - reduce_dispatch_sum_prod(in, out, reduce_type_, axes_); + reduce_dispatch_sum_prod( + in, out, reduce_type_, axes_, stream()); break; case float64: - reduce_dispatch_sum_prod(in, out, reduce_type_, axes_); + reduce_dispatch_sum_prod( + in, out, reduce_type_, axes_, stream()); break; case complex64: - reduce_dispatch_sum_prod(in, out, reduce_type_, axes_); + reduce_dispatch_sum_prod( + in, out, reduce_type_, axes_, stream()); break; } break; @@ -512,46 +568,59 @@ void Reduce::eval_cpu(const std::vector& inputs, array& out) { case Reduce::Min: { switch (in.dtype()) { case bool_: - reduce_dispatch_min_max(in, out, reduce_type_, axes_); + reduce_dispatch_min_max(in, out, reduce_type_, axes_, stream()); break; case uint8: - reduce_dispatch_min_max(in, out, reduce_type_, axes_); + reduce_dispatch_min_max( + in, out, reduce_type_, axes_, stream()); break; case uint16: - reduce_dispatch_min_max(in, out, reduce_type_, axes_); + reduce_dispatch_min_max( + in, out, reduce_type_, axes_, stream()); break; case uint32: - reduce_dispatch_min_max(in, out, reduce_type_, axes_); + reduce_dispatch_min_max( + in, out, reduce_type_, axes_, stream()); break; case uint64: - reduce_dispatch_min_max(in, out, reduce_type_, axes_); + reduce_dispatch_min_max( + in, out, reduce_type_, axes_, stream()); break; case int8: - reduce_dispatch_min_max(in, out, reduce_type_, axes_); + reduce_dispatch_min_max( + in, out, reduce_type_, axes_, stream()); break; case int16: - reduce_dispatch_min_max(in, out, reduce_type_, axes_); + reduce_dispatch_min_max( + in, out, reduce_type_, axes_, stream()); break; case int32: - reduce_dispatch_min_max(in, out, reduce_type_, axes_); + reduce_dispatch_min_max( + in, out, reduce_type_, axes_, stream()); break; case int64: - reduce_dispatch_min_max(in, out, reduce_type_, axes_); + reduce_dispatch_min_max( + in, out, reduce_type_, axes_, stream()); break; case float16: - reduce_dispatch_min_max(in, out, reduce_type_, axes_); + reduce_dispatch_min_max( + in, out, reduce_type_, axes_, stream()); break; case float32: - reduce_dispatch_min_max(in, out, reduce_type_, axes_); + reduce_dispatch_min_max( + in, out, reduce_type_, axes_, stream()); break; case float64: - reduce_dispatch_min_max(in, out, reduce_type_, axes_); + reduce_dispatch_min_max( + in, out, reduce_type_, axes_, stream()); break; case bfloat16: - reduce_dispatch_min_max(in, out, reduce_type_, axes_); + reduce_dispatch_min_max( + in, out, reduce_type_, axes_, stream()); break; case complex64: - reduce_dispatch_min_max(in, out, reduce_type_, axes_); + reduce_dispatch_min_max( + in, out, reduce_type_, axes_, stream()); break; } break; diff --git a/mlx/backend/cpu/scan.cpp b/mlx/backend/cpu/scan.cpp index 28c929beb..93f438718 100644 --- a/mlx/backend/cpu/scan.cpp +++ b/mlx/backend/cpu/scan.cpp @@ -4,6 +4,7 @@ #include "mlx/backend/common/utils.h" #include "mlx/backend/cpu/copy.h" +#include "mlx/backend/cpu/encoder.h" #include "mlx/backend/cpu/simd/simd.h" #include "mlx/primitives.h" @@ -153,37 +154,44 @@ void strided_scan( template void scan_op( - const array& input, - array& output, + const array& in, + array& out, int axis, bool reverse, bool inclusive, const Op& op, - U init) { - output.set_data(allocator::malloc_or_wait(output.nbytes())); + U init, + Stream stream) { + auto& encoder = cpu::get_command_encoder(stream); + encoder.set_input_array(in); + encoder.set_output_array(out); - if (input.flags().row_contiguous) { - if (input.strides()[axis] == 1) { - contiguous_scan( - input.data(), - output.data(), - input.size() / input.shape(axis), - input.shape(axis), - reverse, - inclusive, - op, - init); + if (in.flags().row_contiguous) { + if (in.strides()[axis] == 1) { + encoder.dispatch([in_ptr = in.data(), + out_ptr = out.data(), + count = in.size() / in.shape(axis), + stride = in.shape(axis), + reverse, + inclusive, + op = std::move(op), + init]() { + contiguous_scan( + in_ptr, out_ptr, count, stride, reverse, inclusive, op, init); + }); } else { - strided_scan( - input.data(), - output.data(), - input.size() / input.shape(axis) / input.strides()[axis], - input.shape(axis), - input.strides()[axis], - reverse, - inclusive, - op, - init); + encoder.dispatch([in_ptr = in.data(), + out_ptr = out.data(), + count = in.size() / in.shape(axis) / in.strides()[axis], + size = in.shape(axis), + stride = in.strides()[axis], + reverse, + inclusive, + op = std::move(op), + init]() { + strided_scan( + in_ptr, out_ptr, count, size, stride, reverse, inclusive, op, init); + }); } } else { throw std::runtime_error("Scan op supports only contiguous inputs"); @@ -193,38 +201,39 @@ void scan_op( template void scan_dispatch( Scan::ReduceType rtype, - const array& input, - array& output, + const array& in, + array& out, int axis, bool reverse, - bool inclusive) { + bool inclusive, + Stream stream) { switch (rtype) { case Scan::Sum: { auto op = [](U y, T x) { return y + x; }; auto init = static_cast(0); - scan_op(input, output, axis, reverse, inclusive, op, init); + scan_op(in, out, axis, reverse, inclusive, op, init, stream); break; } case Scan::Prod: { auto op = [](U y, T x) { return y * x; }; auto init = static_cast(1); - scan_op(input, output, axis, reverse, inclusive, op, init); + scan_op(in, out, axis, reverse, inclusive, op, init, stream); break; } case Scan::Min: { auto op = [](U y, T x) { return x < y ? x : y; }; - auto init = (issubdtype(input.dtype(), floating)) + auto init = (issubdtype(in.dtype(), floating)) ? static_cast(std::numeric_limits::infinity()) : std::numeric_limits::max(); - scan_op(input, output, axis, reverse, inclusive, op, init); + scan_op(in, out, axis, reverse, inclusive, op, init, stream); break; } case Scan::Max: { auto op = [](U y, T x) { return x < y ? y : x; }; - auto init = (issubdtype(input.dtype(), floating)) + auto init = (issubdtype(in.dtype(), floating)) ? static_cast(-std::numeric_limits::infinity()) : std::numeric_limits::min(); - scan_op(input, output, axis, reverse, inclusive, op, init); + scan_op(in, out, axis, reverse, inclusive, op, init, stream); break; } } @@ -237,11 +246,14 @@ void Scan::eval_cpu(const std::vector& inputs, array& out) { // Ensure contiguity auto in = inputs[0]; + bool copied = false; if (!in.flags().row_contiguous) { array arr_copy(in.shape(), in.dtype(), nullptr, {}); - copy(in, arr_copy, CopyType::General); + copy(in, arr_copy, CopyType::General, stream()); in = arr_copy; + copied = true; } + out.set_data(allocator::malloc_or_wait(out.nbytes())); switch (in.dtype()) { case bool_: { @@ -252,65 +264,68 @@ void Scan::eval_cpu(const std::vector& inputs, array& out) { // floats perhaps we should add the full all-to-all dispatch. if (reduce_type_ == Scan::Sum && out.dtype() == int32) { scan_dispatch( - reduce_type_, in, out, axis_, reverse_, inclusive_); + reduce_type_, in, out, axis_, reverse_, inclusive_, stream()); } else { scan_dispatch( - reduce_type_, in, out, axis_, reverse_, inclusive_); + reduce_type_, in, out, axis_, reverse_, inclusive_, stream()); } break; } case uint8: scan_dispatch( - reduce_type_, in, out, axis_, reverse_, inclusive_); + reduce_type_, in, out, axis_, reverse_, inclusive_, stream()); break; case uint16: scan_dispatch( - reduce_type_, in, out, axis_, reverse_, inclusive_); + reduce_type_, in, out, axis_, reverse_, inclusive_, stream()); break; case uint32: scan_dispatch( - reduce_type_, in, out, axis_, reverse_, inclusive_); + reduce_type_, in, out, axis_, reverse_, inclusive_, stream()); break; case uint64: scan_dispatch( - reduce_type_, in, out, axis_, reverse_, inclusive_); + reduce_type_, in, out, axis_, reverse_, inclusive_, stream()); break; case int8: scan_dispatch( - reduce_type_, in, out, axis_, reverse_, inclusive_); + reduce_type_, in, out, axis_, reverse_, inclusive_, stream()); break; case int16: scan_dispatch( - reduce_type_, in, out, axis_, reverse_, inclusive_); + reduce_type_, in, out, axis_, reverse_, inclusive_, stream()); break; case int32: scan_dispatch( - reduce_type_, in, out, axis_, reverse_, inclusive_); + reduce_type_, in, out, axis_, reverse_, inclusive_, stream()); break; case int64: scan_dispatch( - reduce_type_, in, out, axis_, reverse_, inclusive_); + reduce_type_, in, out, axis_, reverse_, inclusive_, stream()); break; case float16: scan_dispatch( - reduce_type_, in, out, axis_, reverse_, inclusive_); + reduce_type_, in, out, axis_, reverse_, inclusive_, stream()); break; case float32: scan_dispatch( - reduce_type_, in, out, axis_, reverse_, inclusive_); + reduce_type_, in, out, axis_, reverse_, inclusive_, stream()); break; case float64: scan_dispatch( - reduce_type_, in, out, axis_, reverse_, inclusive_); + reduce_type_, in, out, axis_, reverse_, inclusive_, stream()); break; case bfloat16: scan_dispatch( - reduce_type_, in, out, axis_, reverse_, inclusive_); + reduce_type_, in, out, axis_, reverse_, inclusive_, stream()); break; case complex64: throw std::runtime_error("Scan ops do not support complex types yet"); break; } + if (copied) { + cpu::get_command_encoder(stream()).add_temporary(std::move(in)); + } } } // namespace mlx::core diff --git a/mlx/backend/cpu/softmax.cpp b/mlx/backend/cpu/softmax.cpp index 43b5a55e2..591bad020 100644 --- a/mlx/backend/cpu/softmax.cpp +++ b/mlx/backend/cpu/softmax.cpp @@ -4,6 +4,7 @@ #include #include "mlx/backend/cpu/copy.h" +#include "mlx/backend/cpu/encoder.h" #include "mlx/backend/cpu/simd/simd.h" #include "mlx/primitives.h" #include "mlx/types/limits.h" @@ -15,92 +16,100 @@ namespace { using namespace mlx::core::simd; template -void softmax(const array& in, array& out) { - constexpr bool same_t = std::is_same_v; - constexpr int N = std::min(max_size, max_size); +void softmax(const array& in, array& out, Stream stream) { + auto& encoder = cpu::get_command_encoder(stream); + encoder.set_input_array(in); + encoder.set_output_array(out); const T* in_ptr = in.data(); T* out_ptr = out.data(); + int M = in.shape().back(); int L = in.data_size() / M; - const T* current_in_ptr; - T* current_out_ptr; - for (int i = 0; i < L; i++, in_ptr += M, out_ptr += M) { - // Find the maximum - current_in_ptr = in_ptr; - Simd vmaximum(-numeric_limits::infinity()); - size_t s = M; - while (s >= N) { - Simd vals = load(current_in_ptr); - vmaximum = maximum(vals, vmaximum); - current_in_ptr += N; - s -= N; - } + encoder.dispatch([in_ptr, out_ptr, M, L]() mutable { + constexpr bool same_t = std::is_same_v; + constexpr int N = std::min(max_size, max_size); - AccT maximum = max(vmaximum); - while (s-- > 0) { - maximum = std::max(maximum, static_cast(*current_in_ptr)); - current_in_ptr++; - } + const T* current_in_ptr; + T* current_out_ptr; - // Compute the normalizer and the exponentials - Simd vnormalizer(0.0); - current_out_ptr = out_ptr; - current_in_ptr = in_ptr; - s = M; - while (s >= N) { - Simd vexp = load(current_in_ptr); - vexp = exp(vexp - maximum); - if constexpr (same_t) { - store(current_out_ptr, vexp); - } - vnormalizer = vnormalizer + vexp; - current_in_ptr += N; - current_out_ptr += N; - s -= N; - } - AccT normalizer = sum(vnormalizer); - while (s-- > 0) { - AccT _exp = std::exp(*current_in_ptr - maximum); - if constexpr (same_t) { - *current_out_ptr = _exp; - } - normalizer += _exp; - current_in_ptr++; - current_out_ptr++; - } - normalizer = 1 / normalizer; - - // Normalize - current_out_ptr = out_ptr; - current_in_ptr = in_ptr; - s = M; - while (s >= N) { - if constexpr (same_t) { - store( - current_out_ptr, - Simd(load(current_out_ptr) * normalizer)); - } else { - Simd vexp = load(current_in_ptr); - vexp = exp(vexp - maximum) * normalizer; - store(current_out_ptr, Simd(vexp)); + for (int i = 0; i < L; i++, in_ptr += M, out_ptr += M) { + // Find the maximum + current_in_ptr = in_ptr; + Simd vmaximum(-numeric_limits::infinity()); + size_t s = M; + while (s >= N) { + Simd vals = load(current_in_ptr); + vmaximum = maximum(vals, vmaximum); current_in_ptr += N; + s -= N; } - current_out_ptr += N; - s -= N; - } - while (s-- > 0) { - if constexpr (same_t) { - *current_out_ptr *= normalizer; - } else { - AccT _exp = std::exp(*current_in_ptr - maximum); - *current_out_ptr = static_cast(_exp * normalizer); + + AccT maximum = max(vmaximum); + while (s-- > 0) { + maximum = std::max(maximum, static_cast(*current_in_ptr)); current_in_ptr++; } - current_out_ptr++; + + // Compute the normalizer and the exponentials + Simd vnormalizer(0.0); + current_out_ptr = out_ptr; + current_in_ptr = in_ptr; + s = M; + while (s >= N) { + Simd vexp = load(current_in_ptr); + vexp = exp(vexp - maximum); + if constexpr (same_t) { + store(current_out_ptr, vexp); + } + vnormalizer = vnormalizer + vexp; + current_in_ptr += N; + current_out_ptr += N; + s -= N; + } + AccT normalizer = sum(vnormalizer); + while (s-- > 0) { + AccT _exp = std::exp(*current_in_ptr - maximum); + if constexpr (same_t) { + *current_out_ptr = _exp; + } + normalizer += _exp; + current_in_ptr++; + current_out_ptr++; + } + normalizer = 1 / normalizer; + + // Normalize + current_out_ptr = out_ptr; + current_in_ptr = in_ptr; + s = M; + while (s >= N) { + if constexpr (same_t) { + store( + current_out_ptr, + Simd(load(current_out_ptr) * normalizer)); + } else { + Simd vexp = load(current_in_ptr); + vexp = exp(vexp - maximum) * normalizer; + store(current_out_ptr, Simd(vexp)); + current_in_ptr += N; + } + current_out_ptr += N; + s -= N; + } + while (s-- > 0) { + if constexpr (same_t) { + *current_out_ptr *= normalizer; + } else { + AccT _exp = std::exp(*current_in_ptr - maximum); + *current_out_ptr = static_cast(_exp * normalizer); + current_in_ptr++; + } + current_out_ptr++; + } } - } + }); } } // namespace @@ -109,30 +118,32 @@ void Softmax::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); // Make sure that the last dimension is contiguous - auto check_input = [](array x) { + auto set_output = [s = stream(), &out](const array& x) { bool no_copy = x.strides()[x.ndim() - 1] == 1; if (x.ndim() > 1) { auto s = x.strides()[x.ndim() - 2]; no_copy &= (s == 0 || s == x.shape().back()); } if (no_copy) { + if (x.is_donatable()) { + out.copy_shared_buffer(x); + } else { + out.set_data( + allocator::malloc_or_wait(x.data_size() * x.itemsize()), + x.data_size(), + x.strides(), + x.flags()); + } return x; } else { array x_copy(x.shape(), x.dtype(), nullptr, {}); - copy(x, x_copy, CopyType::General); + copy(x, x_copy, CopyType::General, s); + out.copy_shared_buffer(x_copy); return x_copy; } }; - array in = check_input(std::move(inputs[0])); - if (in.is_donatable()) { - out.copy_shared_buffer(in); - } else { - out.set_data( - allocator::malloc_or_wait(in.data_size() * in.itemsize()), - in.data_size(), - in.strides(), - in.flags()); - } + + auto in = set_output(inputs[0]); switch (in.dtype()) { case bool_: @@ -148,24 +159,24 @@ void Softmax::eval_cpu(const std::vector& inputs, array& out) { "Softmax is defined only for floating point types"); break; case float32: - softmax(in, out); + softmax(in, out, stream()); break; case float16: if (precise_) { - softmax(in, out); + softmax(in, out, stream()); } else { - softmax(in, out); + softmax(in, out, stream()); } break; case bfloat16: if (precise_) { - softmax(in, out); + softmax(in, out, stream()); } else { - softmax(in, out); + softmax(in, out, stream()); } break; case float64: - softmax(in, out); + softmax(in, out, stream()); break; case complex64: throw std::invalid_argument( diff --git a/mlx/backend/cpu/sort.cpp b/mlx/backend/cpu/sort.cpp index a6db5dcef..f66e7362a 100644 --- a/mlx/backend/cpu/sort.cpp +++ b/mlx/backend/cpu/sort.cpp @@ -7,6 +7,7 @@ #include "mlx/backend/common/utils.h" #include "mlx/backend/cpu/copy.h" +#include "mlx/backend/cpu/encoder.h" #include "mlx/primitives.h" @@ -103,11 +104,11 @@ struct StridedIterator { T* ptr_; }; -template -void sort(const array& in, array& out, int axis) { +template +void sort(const array& in, array& out, int axis, Stream stream) { // Copy input to output CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General; - copy(in, out, ctype); + copy(in, out, ctype, stream); // Get axis, shape and stride info axis = axis < 0 ? axis + in.ndim() : axis; @@ -126,19 +127,27 @@ void sort(const array& in, array& out, int axis) { // Perform sorting in place ContiguousIterator src_it( remaining_shape, remaining_strides, remaining_shape.size()); - for (int i = 0; i < n_rows; i++) { - T* data_ptr = out.data() + src_it.loc; + auto& encoder = cpu::get_command_encoder(stream); + encoder.set_output_array(out); + encoder.dispatch([out_ptr = out.data(), + src_it = std::move(src_it), + n_rows, + axis_size, + axis_stride]() mutable { + for (int i = 0; i < n_rows; i++) { + T* data_ptr = out_ptr + src_it.loc; - StridedIterator st(data_ptr, axis_stride, 0); - StridedIterator ed(data_ptr, axis_stride, axis_size); + StridedIterator st(data_ptr, axis_stride, 0); + StridedIterator ed(data_ptr, axis_stride, axis_size); - std::stable_sort(st, ed); - src_it.step(); - } + std::stable_sort(st, ed); + src_it.step(); + } + }); } template -void argsort(const array& in, array& out, int axis) { +void argsort(const array& in, array& out, int axis, Stream stream) { // Allocate output out.set_data(allocator::malloc_or_wait(out.nbytes())); @@ -167,35 +176,48 @@ void argsort(const array& in, array& out, int axis) { in_remaining_shape, in_remaining_strides, in_remaining_shape.size()); ContiguousIterator out_it( out_remaining_shape, out_remaining_strides, out_remaining_shape.size()); - for (int i = 0; i < n_rows; i++) { - const T* data_ptr = in.data() + in_it.loc; - IdxT* idx_ptr = out.data() + out_it.loc; - in_it.step(); - out_it.step(); + auto& encoder = cpu::get_command_encoder(stream); + encoder.set_input_array(in); + encoder.set_input_array(out); + encoder.dispatch([in_ptr = in.data(), + out_ptr = out.data(), + in_it = std::move(in_it), + out_it = std::move(out_it), + n_rows, + axis_size, + in_stride, + out_stride]() mutable { + for (int i = 0; i < n_rows; i++) { + const T* data_ptr = in_ptr + in_it.loc; + IdxT* idx_ptr = out_ptr + out_it.loc; - StridedIterator st_(idx_ptr, out_stride, 0); - StridedIterator ed_(idx_ptr, out_stride, axis_size); + in_it.step(); + out_it.step(); - // Initialize with iota - std::iota(st_, ed_, IdxT(0)); + StridedIterator st_(idx_ptr, out_stride, 0); + StridedIterator ed_(idx_ptr, out_stride, axis_size); - // Sort according to vals - StridedIterator st(idx_ptr, out_stride, 0); - StridedIterator ed(idx_ptr, out_stride, axis_size); + // Initialize with iota + std::iota(st_, ed_, IdxT(0)); - std::stable_sort(st, ed, [data_ptr, in_stride](IdxT a, IdxT b) { - auto v1 = data_ptr[a * in_stride]; - auto v2 = data_ptr[b * in_stride]; - return v1 < v2 || (v1 == v2 && a < b); - }); - } + // Sort according to vals + StridedIterator st(idx_ptr, out_stride, 0); + StridedIterator ed(idx_ptr, out_stride, axis_size); + + std::stable_sort(st, ed, [data_ptr, in_stride](IdxT a, IdxT b) { + auto v1 = data_ptr[a * in_stride]; + auto v2 = data_ptr[b * in_stride]; + return v1 < v2 || (v1 == v2 && a < b); + }); + } + }); } -template -void partition(const array& in, array& out, int axis, int kth) { +template +void partition(const array& in, array& out, int axis, int kth, Stream stream) { // Copy input to output CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General; - copy(in, out, ctype); + copy(in, out, ctype, stream); // Get axis, shape and stride info axis = axis < 0 ? axis + in.ndim() : axis; @@ -216,20 +238,34 @@ void partition(const array& in, array& out, int axis, int kth) { // Perform partition in place ContiguousIterator src_it( remaining_shape, remaining_strides, remaining_shape.size()); - for (int i = 0; i < n_rows; i++) { - T* data_ptr = out.data() + src_it.loc; - src_it.step(); + auto& encoder = cpu::get_command_encoder(stream); + encoder.set_output_array(out); + encoder.dispatch([out_ptr = out.data(), + src_it = std::move(src_it), + n_rows, + axis_size, + axis_stride, + kth]() mutable { + for (int i = 0; i < n_rows; i++) { + T* data_ptr = out_ptr + src_it.loc; + src_it.step(); - StridedIterator st(data_ptr, axis_stride, 0); - StridedIterator md(data_ptr, axis_stride, kth); - StridedIterator ed(data_ptr, axis_stride, axis_size); + StridedIterator st(data_ptr, axis_stride, 0); + StridedIterator md(data_ptr, axis_stride, kth); + StridedIterator ed(data_ptr, axis_stride, axis_size); - std::nth_element(st, md, ed); - } + std::nth_element(st, md, ed); + } + }); } template -void argpartition(const array& in, array& out, int axis, int kth) { +void argpartition( + const array& in, + array& out, + int axis, + int kth, + Stream stream) { // Allocate output out.set_data(allocator::malloc_or_wait(out.nbytes())); @@ -260,29 +296,43 @@ void argpartition(const array& in, array& out, int axis, int kth) { in_remaining_shape, in_remaining_strides, in_remaining_shape.size()); ContiguousIterator out_it( out_remaining_shape, out_remaining_strides, out_remaining_shape.size()); - for (int i = 0; i < n_rows; i++) { - const T* data_ptr = in.data() + in_it.loc; - IdxT* idx_ptr = out.data() + out_it.loc; - in_it.step(); - out_it.step(); - StridedIterator st_(idx_ptr, out_stride, 0); - StridedIterator ed_(idx_ptr, out_stride, axis_size); + auto& encoder = cpu::get_command_encoder(stream); + encoder.set_input_array(in); + encoder.set_input_array(out); + encoder.dispatch([in_ptr = in.data(), + out_ptr = out.data(), + in_it = std::move(in_it), + out_it = std::move(out_it), + n_rows, + axis_size, + in_stride, + out_stride, + kth]() mutable { + for (int i = 0; i < n_rows; i++) { + const T* data_ptr = in_ptr + in_it.loc; + IdxT* idx_ptr = out_ptr + out_it.loc; + in_it.step(); + out_it.step(); - // Initialize with iota - std::iota(st_, ed_, IdxT(0)); + StridedIterator st_(idx_ptr, out_stride, 0); + StridedIterator ed_(idx_ptr, out_stride, axis_size); - // Sort according to vals - StridedIterator st(idx_ptr, out_stride, 0); - StridedIterator md(idx_ptr, out_stride, kth); - StridedIterator ed(idx_ptr, out_stride, axis_size); + // Initialize with iota + std::iota(st_, ed_, IdxT(0)); - std::nth_element(st, md, ed, [data_ptr, in_stride](IdxT a, IdxT b) { - auto v1 = data_ptr[a * in_stride]; - auto v2 = data_ptr[b * in_stride]; - return v1 < v2 || (v1 == v2 && a < b); - }); - } + // Sort according to vals + StridedIterator st(idx_ptr, out_stride, 0); + StridedIterator md(idx_ptr, out_stride, kth); + StridedIterator ed(idx_ptr, out_stride, axis_size); + + std::nth_element(st, md, ed, [data_ptr, in_stride](IdxT a, IdxT b) { + auto v1 = data_ptr[a * in_stride]; + auto v2 = data_ptr[b * in_stride]; + return v1 < v2 || (v1 == v2 && a < b); + }); + } + }); } } // namespace @@ -293,33 +343,33 @@ void ArgSort::eval_cpu(const std::vector& inputs, array& out) { switch (in.dtype()) { case bool_: - return argsort(in, out, axis_); + return argsort(in, out, axis_, stream()); case uint8: - return argsort(in, out, axis_); + return argsort(in, out, axis_, stream()); case uint16: - return argsort(in, out, axis_); + return argsort(in, out, axis_, stream()); case uint32: - return argsort(in, out, axis_); + return argsort(in, out, axis_, stream()); case uint64: - return argsort(in, out, axis_); + return argsort(in, out, axis_, stream()); case int8: - return argsort(in, out, axis_); + return argsort(in, out, axis_, stream()); case int16: - return argsort(in, out, axis_); + return argsort(in, out, axis_, stream()); case int32: - return argsort(in, out, axis_); + return argsort(in, out, axis_, stream()); case int64: - return argsort(in, out, axis_); + return argsort(in, out, axis_, stream()); case float32: - return argsort(in, out, axis_); + return argsort(in, out, axis_, stream()); case float64: - return argsort(in, out, axis_); + return argsort(in, out, axis_, stream()); case float16: - return argsort(in, out, axis_); + return argsort(in, out, axis_, stream()); case bfloat16: - return argsort(in, out, axis_); + return argsort(in, out, axis_, stream()); case complex64: - return argsort(in, out, axis_); + return argsort(in, out, axis_, stream()); } } @@ -329,33 +379,33 @@ void Sort::eval_cpu(const std::vector& inputs, array& out) { switch (in.dtype()) { case bool_: - return sort(in, out, axis_); + return sort(in, out, axis_, stream()); case uint8: - return sort(in, out, axis_); + return sort(in, out, axis_, stream()); case uint16: - return sort(in, out, axis_); + return sort(in, out, axis_, stream()); case uint32: - return sort(in, out, axis_); + return sort(in, out, axis_, stream()); case uint64: - return sort(in, out, axis_); + return sort(in, out, axis_, stream()); case int8: - return sort(in, out, axis_); + return sort(in, out, axis_, stream()); case int16: - return sort(in, out, axis_); + return sort(in, out, axis_, stream()); case int32: - return sort(in, out, axis_); + return sort(in, out, axis_, stream()); case int64: - return sort(in, out, axis_); + return sort(in, out, axis_, stream()); case float32: - return sort(in, out, axis_); + return sort(in, out, axis_, stream()); case float64: - return sort(in, out, axis_); + return sort(in, out, axis_, stream()); case float16: - return sort(in, out, axis_); + return sort(in, out, axis_, stream()); case bfloat16: - return sort(in, out, axis_); + return sort(in, out, axis_, stream()); case complex64: - return sort(in, out, axis_); + return sort(in, out, axis_, stream()); } } @@ -365,33 +415,33 @@ void ArgPartition::eval_cpu(const std::vector& inputs, array& out) { switch (in.dtype()) { case bool_: - return argpartition(in, out, axis_, kth_); + return argpartition(in, out, axis_, kth_, stream()); case uint8: - return argpartition(in, out, axis_, kth_); + return argpartition(in, out, axis_, kth_, stream()); case uint16: - return argpartition(in, out, axis_, kth_); + return argpartition(in, out, axis_, kth_, stream()); case uint32: - return argpartition(in, out, axis_, kth_); + return argpartition(in, out, axis_, kth_, stream()); case uint64: - return argpartition(in, out, axis_, kth_); + return argpartition(in, out, axis_, kth_, stream()); case int8: - return argpartition(in, out, axis_, kth_); + return argpartition(in, out, axis_, kth_, stream()); case int16: - return argpartition(in, out, axis_, kth_); + return argpartition(in, out, axis_, kth_, stream()); case int32: - return argpartition(in, out, axis_, kth_); + return argpartition(in, out, axis_, kth_, stream()); case int64: - return argpartition(in, out, axis_, kth_); + return argpartition(in, out, axis_, kth_, stream()); case float32: - return argpartition(in, out, axis_, kth_); + return argpartition(in, out, axis_, kth_, stream()); case float64: - return argpartition(in, out, axis_, kth_); + return argpartition(in, out, axis_, kth_, stream()); case float16: - return argpartition(in, out, axis_, kth_); + return argpartition(in, out, axis_, kth_, stream()); case bfloat16: - return argpartition(in, out, axis_, kth_); + return argpartition(in, out, axis_, kth_, stream()); case complex64: - return argpartition(in, out, axis_, kth_); + return argpartition(in, out, axis_, kth_, stream()); } } @@ -401,33 +451,33 @@ void Partition::eval_cpu(const std::vector& inputs, array& out) { switch (in.dtype()) { case bool_: - return partition(in, out, axis_, kth_); + return partition(in, out, axis_, kth_, stream()); case uint8: - return partition(in, out, axis_, kth_); + return partition(in, out, axis_, kth_, stream()); case uint16: - return partition(in, out, axis_, kth_); + return partition(in, out, axis_, kth_, stream()); case uint32: - return partition(in, out, axis_, kth_); + return partition(in, out, axis_, kth_, stream()); case uint64: - return partition(in, out, axis_, kth_); + return partition(in, out, axis_, kth_, stream()); case int8: - return partition(in, out, axis_, kth_); + return partition(in, out, axis_, kth_, stream()); case int16: - return partition(in, out, axis_, kth_); + return partition(in, out, axis_, kth_, stream()); case int32: - return partition(in, out, axis_, kth_); + return partition(in, out, axis_, kth_, stream()); case int64: - return partition(in, out, axis_, kth_); + return partition(in, out, axis_, kth_, stream()); case float32: - return partition(in, out, axis_, kth_); + return partition(in, out, axis_, kth_, stream()); case float64: - return partition(in, out, axis_, kth_); + return partition(in, out, axis_, kth_, stream()); case float16: - return partition(in, out, axis_, kth_); + return partition(in, out, axis_, kth_, stream()); case bfloat16: - return partition(in, out, axis_, kth_); + return partition(in, out, axis_, kth_, stream()); case complex64: - return partition(in, out, axis_, kth_); + return partition(in, out, axis_, kth_, stream()); } } diff --git a/mlx/backend/cpu/svd.cpp b/mlx/backend/cpu/svd.cpp index 88b127bce..86745e545 100644 --- a/mlx/backend/cpu/svd.cpp +++ b/mlx/backend/cpu/svd.cpp @@ -2,13 +2,18 @@ #include "mlx/allocator.h" #include "mlx/backend/cpu/copy.h" +#include "mlx/backend/cpu/encoder.h" #include "mlx/backend/cpu/lapack.h" #include "mlx/primitives.h" namespace mlx::core { template -void svd_impl(const array& a, T* u_data, T* s_data, T* vt_data) { +void svd_impl( + const array& a, + std::vector& outputs, + bool compute_uv, + Stream stream) { // Lapack uses the column-major convention. To avoid having to transpose // the input and then transpose the outputs, we swap the indices/sizes of the // matrices and take advantage of the following identity (see @@ -22,118 +27,24 @@ void svd_impl(const array& a, T* u_data, T* s_data, T* vt_data) { const int N = a.shape(-1); const int K = std::min(M, N); - // A of shape M x N. The leading dimension is N since lapack receives Aᵀ. - const int lda = N; - // U of shape M x M. (N x N in lapack). - const int ldu = N; - // Vᵀ of shape N x N. (M x M in lapack). - const int ldvt = M; - size_t num_matrices = a.size() / (M * N); // lapack clobbers the input, so we have to make a copy. array in(a.shape(), a.dtype(), nullptr, {}); - copy(a, in, a.flags().row_contiguous ? CopyType::Vector : CopyType::General); + copy( + a, + in, + a.flags().row_contiguous ? CopyType::Vector : CopyType::General, + stream); - auto job_u = (u_data && vt_data) ? "V" : "N"; - auto job_vt = (u_data && vt_data) ? "V" : "N"; - static constexpr auto range = "A"; + // Allocate outputs. + auto& encoder = cpu::get_command_encoder(stream); + encoder.set_input_array(a); + auto in_ptr = in.data(); + T* u_ptr; + T* s_ptr; + T* vt_ptr; - // Will contain the number of singular values after the call has returned. - int ns = 0; - T workspace_dimension = 0; - - // Will contain the indices of eigenvectors that failed to converge (not used - // here but required by lapack). - auto iwork = array::Data{allocator::malloc_or_wait(sizeof(int) * 12 * K)}; - - static const int lwork_query = -1; - - static const int ignored_int = 0; - static const T ignored_float = 0; - static T ignored_output = 0; - - int info; - - // Compute workspace size. - gesvdx( - /* jobu = */ job_u, - /* jobvt = */ job_vt, - /* range = */ range, - // M and N are swapped since lapack expects column-major. - /* m = */ &N, - /* n = */ &M, - /* a = */ nullptr, - /* lda = */ &lda, - /* vl = */ &ignored_float, - /* vu = */ &ignored_float, - /* il = */ &ignored_int, - /* iu = */ &ignored_int, - /* ns = */ &ns, - /* s = */ nullptr, - /* u = */ nullptr, - /* ldu = */ &ldu, - /* vt = */ nullptr, - /* ldvt = */ &ldvt, - /* work = */ &workspace_dimension, - /* lwork = */ &lwork_query, - /* iwork = */ static_cast(iwork.buffer.raw_ptr()), - /* info = */ &info); - - if (info != 0) { - std::stringstream ss; - ss << "[SVD::eval_cpu] workspace calculation failed with code " << info; - throw std::runtime_error(ss.str()); - } - - const int lwork = workspace_dimension; - auto scratch = array::Data{allocator::malloc_or_wait(sizeof(T) * lwork)}; - - // Loop over matrices. - for (int i = 0; i < num_matrices; i++) { - gesvdx( - /* jobu = */ job_u, - /* jobvt = */ job_vt, - /* range = */ range, - // M and N are swapped since lapack expects column-major. - /* m = */ &N, - /* n = */ &M, - /* a = */ in.data() + M * N * i, - /* lda = */ &lda, - /* vl = */ &ignored_float, - /* vu = */ &ignored_float, - /* il = */ &ignored_int, - /* iu = */ &ignored_int, - /* ns = */ &ns, - /* s = */ s_data + K * i, - // According to the identity above, lapack will write Vᵀᵀ as U. - /* u = */ vt_data ? vt_data + N * N * i : &ignored_output, - /* ldu = */ &ldu, - // According to the identity above, lapack will write Uᵀ as Vᵀ. - /* vt = */ u_data ? u_data + M * M * i : &ignored_output, - /* ldvt = */ &ldvt, - /* work = */ static_cast(scratch.buffer.raw_ptr()), - /* lwork = */ &lwork, - /* iwork = */ static_cast(iwork.buffer.raw_ptr()), - /* info = */ &info); - - if (info != 0) { - std::stringstream ss; - ss << "[SVD::eval_cpu] failed with code " << info; - throw std::runtime_error(ss.str()); - } - - if (ns != K) { - std::stringstream ss; - ss << "[SVD::eval_cpu] expected " << K << " singular values, but " << ns - << " were computed."; - throw std::runtime_error(ss.str()); - } - } -} - -template -void compute_svd(const array& a, bool compute_uv, std::vector& outputs) { if (compute_uv) { array& u = outputs[0]; array& s = outputs[1]; @@ -143,25 +54,147 @@ void compute_svd(const array& a, bool compute_uv, std::vector& outputs) { s.set_data(allocator::malloc_or_wait(s.nbytes())); vt.set_data(allocator::malloc_or_wait(vt.nbytes())); - svd_impl(a, u.data(), s.data(), vt.data()); + encoder.set_output_array(u); + encoder.set_output_array(s); + encoder.set_output_array(vt); + + s_ptr = s.data(); + u_ptr = u.data(); + vt_ptr = vt.data(); } else { array& s = outputs[0]; s.set_data(allocator::malloc_or_wait(s.nbytes())); - svd_impl(a, nullptr, s.data(), nullptr); + encoder.set_output_array(s); + + s_ptr = s.data(); + u_ptr = nullptr; + vt_ptr = nullptr; } + + encoder.dispatch([in_ptr, u_ptr, s_ptr, vt_ptr, M, N, K, num_matrices]() { + // A of shape M x N. The leading dimension is N since lapack receives Aᵀ. + const int lda = N; + // U of shape M x M. (N x N in lapack). + const int ldu = N; + // Vᵀ of shape N x N. (M x M in lapack). + const int ldvt = M; + + auto job_u = (u_ptr) ? "V" : "N"; + auto job_vt = (u_ptr) ? "V" : "N"; + static constexpr auto range = "A"; + + // Will contain the number of singular values after the call has returned. + int ns = 0; + T workspace_dimension = 0; + + // Will contain the indices of eigenvectors that failed to converge (not + // used here but required by lapack). + auto iwork = array::Data{allocator::malloc_or_wait(sizeof(int) * 12 * K)}; + + static const int lwork_query = -1; + + static const int ignored_int = 0; + static const T ignored_float = 0; + + int info; + + // Compute workspace size. + gesvdx( + /* jobu = */ job_u, + /* jobvt = */ job_vt, + /* range = */ range, + // M and N are swapped since lapack expects column-major. + /* m = */ &N, + /* n = */ &M, + /* a = */ nullptr, + /* lda = */ &lda, + /* vl = */ &ignored_float, + /* vu = */ &ignored_float, + /* il = */ &ignored_int, + /* iu = */ &ignored_int, + /* ns = */ &ns, + /* s = */ nullptr, + /* u = */ nullptr, + /* ldu = */ &ldu, + /* vt = */ nullptr, + /* ldvt = */ &ldvt, + /* work = */ &workspace_dimension, + /* lwork = */ &lwork_query, + /* iwork = */ static_cast(iwork.buffer.raw_ptr()), + /* info = */ &info); + + if (info != 0) { + std::stringstream ss; + ss << "[SVD::eval_cpu] workspace calculation failed with code " << info; + throw std::runtime_error(ss.str()); + } + + const int lwork = workspace_dimension; + auto scratch = array::Data{allocator::malloc_or_wait(sizeof(T) * lwork)}; + + // Loop over matrices. + for (int i = 0; i < num_matrices; i++) { + gesvdx( + /* jobu = */ job_u, + /* jobvt = */ job_vt, + /* range = */ range, + // M and N are swapped since lapack expects column-major. + /* m = */ &N, + /* n = */ &M, + /* a = */ in_ptr + M * N * i, + /* lda = */ &lda, + /* vl = */ &ignored_float, + /* vu = */ &ignored_float, + /* il = */ &ignored_int, + /* iu = */ &ignored_int, + /* ns = */ &ns, + /* s = */ s_ptr + K * i, + // According to the identity above, lapack will write Vᵀᵀ as U. + /* u = */ vt_ptr ? vt_ptr + N * N * i : nullptr, + /* ldu = */ &ldu, + // According to the identity above, lapack will write Uᵀ as Vᵀ. + /* vt = */ u_ptr ? u_ptr + M * M * i : nullptr, + /* ldvt = */ &ldvt, + /* work = */ static_cast(scratch.buffer.raw_ptr()), + /* lwork = */ &lwork, + /* iwork = */ static_cast(iwork.buffer.raw_ptr()), + /* info = */ &info); + + if (info != 0) { + std::stringstream ss; + ss << "svd_impl: sgesvdx_ failed with code " << info; + throw std::runtime_error(ss.str()); + } + + if (ns != K) { + std::stringstream ss; + ss << "svd_impl: expected " << K << " singular values, but " << ns + << " were computed."; + throw std::runtime_error(ss.str()); + } + } + }); + encoder.add_temporary(in); } +template +void compute_svd( + const array& a, + bool compute_uv, + std::vector& outputs, + Stream stream) {} + void SVD::eval_cpu( const std::vector& inputs, std::vector& outputs) { switch (inputs[0].dtype()) { case float32: - compute_svd(inputs[0], compute_uv_, outputs); + svd_impl(inputs[0], outputs, compute_uv_, stream()); break; case float64: - compute_svd(inputs[0], compute_uv_, outputs); + svd_impl(inputs[0], outputs, compute_uv_, stream()); break; default: throw std::runtime_error( diff --git a/mlx/backend/cpu/ternary.h b/mlx/backend/cpu/ternary.h index 87c27c86c..7b89c7d74 100644 --- a/mlx/backend/cpu/ternary.h +++ b/mlx/backend/cpu/ternary.h @@ -5,6 +5,8 @@ #include "mlx/array.h" #include "mlx/backend/common/ternary.h" #include "mlx/backend/common/utils.h" +#include "mlx/backend/cpu/encoder.h" +#include "mlx/primitives.h" namespace mlx::core { @@ -53,22 +55,18 @@ void ternary_op_dims( template void ternary_op_dispatch_dims( - const array& a, - const array& b, - const array& c, - array& out, - Op op) { - auto [shape, strides] = collapse_contiguous_dims( - a.shape(), {a.strides(), b.strides(), c.strides(), out.strides()}); + const T1* a_ptr, + const T2* b_ptr, + const T3* c_ptr, + U* out_ptr, + Op op, + size_t size, + Shape& shape, + std::vector& strides) { const auto& a_strides = strides[0]; const auto& b_strides = strides[1]; const auto& c_strides = strides[2]; const auto& out_strides = strides[3]; - - const T1* a_ptr = a.data(); - const T2* b_ptr = b.data(); - const T3* c_ptr = c.data(); - U* out_ptr = out.data(); int ndim = shape.size(); switch (ndim) { case 1: @@ -105,7 +103,7 @@ void ternary_op_dispatch_dims( ContiguousIterator b_it(shape, b_strides, ndim - 2); ContiguousIterator c_it(shape, c_strides, ndim - 2); auto stride = out_strides[ndim - 3]; - for (size_t elem = 0; elem < a.size(); elem += stride) { + for (size_t elem = 0; elem < size; elem += stride) { ternary_op_dims( a_ptr + a_it.loc, b_ptr + b_it.loc, @@ -134,23 +132,53 @@ void ternary_op( TernaryOpType topt = get_ternary_op_type(a, b, c); set_ternary_op_output_data(a, b, c, out, topt); - // The full computation is scalar-scalar-scalar so we call the base op once. + auto& encoder = cpu::get_command_encoder(out.primitive().stream()); + encoder.set_input_array(a); + encoder.set_input_array(b); + encoder.set_input_array(c); + encoder.set_output_array(out); + + const T1* a_ptr = a.data(); + const T2* b_ptr = b.data(); + const T3* c_ptr = c.data(); + U* out_ptr = out.data(); + if (topt == TernaryOpType::ScalarScalarScalar) { - *(out.data()) = op(*a.data(), *b.data(), *c.data()); + encoder.dispatch( + [a_ptr, b_ptr, c_ptr, out_ptr, op = std::move(op)]() mutable { + *out_ptr = op(*a_ptr, *b_ptr, *c_ptr); + }); } else if (topt == TernaryOpType::VectorVectorVector) { - const T1* a_ptr = a.data(); - const T2* b_ptr = b.data(); - const T3* c_ptr = c.data(); - U* out_ptr = out.data(); - for (size_t i = 0; i < out.size(); ++i) { - *out_ptr = op(*a_ptr, *b_ptr, *c_ptr); - a_ptr++; - b_ptr++; - c_ptr++; - out_ptr++; - } + encoder.dispatch([a_ptr, + b_ptr, + c_ptr, + out_ptr, + op = std::move(op), + size = out.size()]() mutable { + for (size_t i = 0; i < size; ++i) { + *out_ptr = op(*a_ptr, *b_ptr, *c_ptr); + a_ptr++; + b_ptr++; + c_ptr++; + out_ptr++; + } + }); } else { - ternary_op_dispatch_dims(a, b, c, out, op); + auto [shape, strides] = collapse_contiguous_dims( + a.shape(), {a.strides(), b.strides(), c.strides(), out.strides()}); + encoder.dispatch( + + [a_ptr, + b_ptr, + c_ptr, + out_ptr, + op = std::move(op), + size = out.size(), + shape = std::move(shape), + strides = std::move(strides)]() mutable { + ternary_op_dispatch_dims( + a_ptr, b_ptr, c_ptr, out_ptr, op, size, shape, strides); + }); } } diff --git a/mlx/backend/cpu/unary.h b/mlx/backend/cpu/unary.h index d5c30526b..4769ecd06 100644 --- a/mlx/backend/cpu/unary.h +++ b/mlx/backend/cpu/unary.h @@ -5,67 +5,83 @@ #include "mlx/allocator.h" #include "mlx/array.h" #include "mlx/backend/common/utils.h" +#include "mlx/backend/cpu/encoder.h" #include "mlx/backend/cpu/simd/simd.h" +#include "mlx/primitives.h" #include "mlx/utils.h" namespace mlx::core { void set_unary_output_data(const array& in, array& out) { - if (is_donatable(in, out)) { - out.copy_shared_buffer(in); + if (in.flags().contiguous) { + if (is_donatable(in, out)) { + out.copy_shared_buffer(in); + } else { + auto size = in.data_size(); + out.set_data( + allocator::malloc_or_wait(size * out.itemsize()), + size, + in.strides(), + in.flags()); + } } else { - auto size = in.data_size(); - out.set_data( - allocator::malloc_or_wait(size * out.itemsize()), - size, - in.strides(), - in.flags()); + out.set_data(allocator::malloc_or_wait(out.nbytes())); } } template -void unary_op(const T* a, U* out, Op op, size_t shape, size_t stride) { +void unary_op(const T* a, U* out, size_t shape, size_t stride) { for (size_t i = 0; i < shape; i += 1) { - out[i] = op(*a); + out[i] = Op{}(*a); a += stride; } } template -void unary_op(const array& a, array& out, Op op) { - const T* a_ptr = a.data(); - if (a.flags().contiguous) { - set_unary_output_data(a, out); - U* dst = out.data(); - constexpr int N = simd::max_size; - size_t size = a.data_size(); - while (size >= N) { - simd::store(dst, op(simd::load(a_ptr))); - size -= N; - a_ptr += N; - dst += N; +void unary_op(const array& a, array& out, Op) { + set_unary_output_data(a, out); + const T* src = a.data(); + U* dst = out.data(); + auto& encoder = cpu::get_command_encoder(out.primitive().stream()); + encoder.set_input_array(a); + encoder.set_output_array(out); + + encoder.dispatch([src, + dst, + contig = a.flags().contiguous, + data_size = a.data_size(), + size = a.size(), + shapes = a.shape(), + strides = a.strides()]() mutable { + auto ndim = shapes.size(); + if (contig) { + constexpr int N = simd::max_size; + while (data_size >= N) { + simd::store(dst, Op{}(simd::load(src))); + data_size -= N; + src += N; + dst += N; + } + while (data_size > 0) { + *dst = Op{}(*src); + data_size--; + dst++; + src++; + } + } else { + size_t shape = ndim > 0 ? shapes.back() : 1; + size_t stride = ndim > 0 ? strides.back() : 1; + if (ndim <= 1) { + unary_op(src, dst, shape, stride); + return; + } + auto it = ContiguousIterator(shapes, strides, ndim - 1); + for (size_t elem = 0; elem < size; elem += shape) { + unary_op(src + it.loc, dst + elem, shape, stride); + it.step(); + } } - while (size > 0) { - *dst = op(*a_ptr); - size--; - dst++; - a_ptr++; - } - } else { - out.set_data(allocator::malloc_or_wait(out.nbytes())); - U* dst = out.data(); - size_t shape = a.ndim() > 0 ? a.shape(-1) : 1; - size_t stride = a.ndim() > 0 ? a.strides(-1) : 1; - if (a.ndim() <= 1) { - unary_op(a_ptr, dst, op, shape, stride); - return; - } - ContiguousIterator it(a.shape(), a.strides(), a.ndim() - 1); - for (size_t elem = 0; elem < a.size(); elem += shape) { - unary_op(a_ptr + it.loc, dst + elem, op, shape, stride); - it.step(); - } - } + }); } template diff --git a/mlx/backend/metal/binary.cpp b/mlx/backend/metal/binary.cpp index c7ebc9635..f80f8c3e4 100644 --- a/mlx/backend/metal/binary.cpp +++ b/mlx/backend/metal/binary.cpp @@ -102,16 +102,9 @@ void binary_op_gpu_inplace( auto& compute_encoder = d.get_command_encoder(s.index); compute_encoder.set_compute_pipeline_state(kernel); - // - If a is donated it goes to the first output - // - If b is donated it goes to the first output if a was not donated - // otherwise it goes to the second output. - // - If there is only one output only one of a and b will be donated. - bool donate_a = a.data_shared_ptr() == nullptr; - bool donate_b = b.data_shared_ptr() == nullptr; int arg_idx = 0; - compute_encoder.set_input_array(donate_a ? outputs[0] : a, arg_idx++); - compute_encoder.set_input_array( - donate_b ? (donate_a ? outputs[1] : outputs[0]) : b, arg_idx++); + compute_encoder.set_input_array(a, arg_idx++); + compute_encoder.set_input_array(b, arg_idx++); compute_encoder.set_output_array(outputs[0], arg_idx++); if (outputs.size() == 2) { compute_encoder.set_output_array(outputs[1], arg_idx++); @@ -164,8 +157,8 @@ void binary_op_gpu( auto& a = inputs[0]; auto& b = inputs[1]; auto bopt = get_binary_op_type(a, b); - set_binary_op_output_data(a, b, outputs[0], bopt, true); - set_binary_op_output_data(a, b, outputs[1], bopt, true); + set_binary_op_output_data(a, b, outputs[0], bopt); + set_binary_op_output_data(a, b, outputs[1], bopt); binary_op_gpu_inplace(inputs, outputs, op, s); } @@ -195,7 +188,7 @@ void binary_op_gpu( auto& a = inputs[0]; auto& b = inputs[1]; auto bopt = get_binary_op_type(a, b); - set_binary_op_output_data(a, b, out, bopt, true); + set_binary_op_output_data(a, b, out, bopt); binary_op_gpu_inplace(inputs, out, op, s); } diff --git a/mlx/backend/metal/compiled.cpp b/mlx/backend/metal/compiled.cpp index e8e5e77e8..154273233 100644 --- a/mlx/backend/metal/compiled.cpp +++ b/mlx/backend/metal/compiled.cpp @@ -457,7 +457,7 @@ void Compiled::eval_gpu( } compiled_allocate_outputs( - inputs, outputs, inputs_, constant_ids_, contiguous, true); + inputs, outputs, inputs_, constant_ids_, contiguous); // Put the outputs in for (auto& x : outputs) { diff --git a/mlx/backend/metal/copy.cpp b/mlx/backend/metal/copy.cpp index 88d63ba72..1750a6908 100644 --- a/mlx/backend/metal/copy.cpp +++ b/mlx/backend/metal/copy.cpp @@ -14,25 +14,11 @@ namespace mlx::core { constexpr int MAX_COPY_SPECIALIZED_DIMS = 3; void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) { - if (ctype == CopyType::Vector) { - // If the input is donateable, we are doing a vector copy and the types - // have the same size, then the input buffer can hold the output. - if (in.is_donatable() && in.itemsize() == out.itemsize()) { - out.move_shared_buffer(in); - // If the output has the same type as the input then there is nothing to - // copy, just use the buffer. - if (in.dtype() == out.dtype()) { - return; - } - } else { - out.set_data( - allocator::malloc_or_wait(in.data_size() * out.itemsize()), - in.data_size(), - in.strides(), - in.flags()); - } - } else { - out.set_data(allocator::malloc_or_wait(out.nbytes())); + bool donated = set_copy_output_data(in, out, ctype); + if (donated && in.dtype() == out.dtype()) { + // If the output has the same type as the input then there is nothing to + // copy, just use the buffer. + return; } if (ctype == CopyType::GeneralGeneral) { ctype = CopyType::General; diff --git a/mlx/backend/metal/device.cpp b/mlx/backend/metal/device.cpp index 06681c458..e28989a5c 100644 --- a/mlx/backend/metal/device.cpp +++ b/mlx/backend/metal/device.cpp @@ -168,9 +168,10 @@ void CommandEncoder::set_output_array( register_output_array(a); } -void CommandEncoder::register_output_array(array& a) { +void CommandEncoder::register_output_array(const array& a) { all_outputs_.insert(a.buffer().ptr()); - auto buf = static_cast(a.buffer().ptr()); + + auto buf = static_cast(const_cast(a.buffer().ptr())); if (concurrent_) { concurrent_outputs_.insert(buf); } else { diff --git a/mlx/backend/metal/device.h b/mlx/backend/metal/device.h index 36d8747fd..1fe7cf76f 100644 --- a/mlx/backend/metal/device.h +++ b/mlx/backend/metal/device.h @@ -62,7 +62,7 @@ struct CommandEncoder { void set_input_array(const array& a, int idx, int64_t offset = 0); void set_output_array(array& a, int idx, int64_t offset = 0); - void register_output_array(array& a); + void register_output_array(const array& a); void dispatch_threadgroups(MTL::Size grid_dims, MTL::Size group_dims); void dispatch_threads(MTL::Size grid_dims, MTL::Size group_dims); void maybeInsertBarrier(); diff --git a/mlx/backend/metal/distributed.cpp b/mlx/backend/metal/distributed.cpp index 9ca727ef4..82e8fff7d 100644 --- a/mlx/backend/metal/distributed.cpp +++ b/mlx/backend/metal/distributed.cpp @@ -4,149 +4,30 @@ #include "mlx/allocator.h" #include "mlx/backend/common/utils.h" +#include "mlx/backend/metal/copy.h" #include "mlx/backend/metal/device.h" -#include "mlx/backend/metal/event.h" -#include "mlx/backend/metal/fence.h" #include "mlx/backend/metal/utils.h" #include "mlx/distributed/ops.h" #include "mlx/distributed/primitives.h" +#include "mlx/fence.h" #include "mlx/scheduler.h" namespace mlx::core::distributed { -void signal_and_wait(const Event& e_signal, const Event& e_wait) { - if (e_signal.valid()) { - encode_signal(e_signal); - } - encode_wait(e_wait); +void AllReduce::eval_gpu(const std::vector&, std::vector&) { + throw std::runtime_error("[AllReduce::eval_gpu] has no GPU implementation."); } -void AllReduce::eval_gpu( - const std::vector& inputs, - std::vector& outputs) { - assert(inputs.size() == 1); - assert(outputs.size() == 1); - - auto& in = inputs[0]; - - Fence f{stream()}; - - if (in.event().valid()) { - f.update_gpu(in); - } - - auto& out = outputs[0]; - if (in.is_donatable()) { - out.move_shared_buffer(in); - } else { - out.set_data(allocator::malloc_or_wait(out.nbytes())); - } - f.wait_gpu(out); - - auto task = [in = in, - out = unsafe_weak_copy(out), - f = std::move(f), - reduce_type = reduce_type_, - group = group()]() mutable { - if (in.event().valid()) { - f.wait(); - } - switch (reduce_type) { - case Sum: - distributed::detail::all_sum( - group, in.data_shared_ptr() == nullptr ? out : in, out); - break; - default: - throw std::runtime_error("Only all reduce sum is supported for now"); - } - f.update(); - }; - scheduler::enqueue(detail::communication_stream(), std::move(task)); +void AllGather::eval_gpu(const std::vector&, std::vector&) { + throw std::runtime_error("[AllGather::eval_gpu] has no GPU implementation."); } -void AllGather::eval_gpu( - const std::vector& inputs, - std::vector& outputs) { - assert(inputs.size() == 1); - assert(outputs.size() == 1); - auto& in = inputs[0]; - auto& out = outputs[0]; - - out.set_data(allocator::malloc_or_wait(out.nbytes())); - - Fence f{stream()}; - - if (in.event().valid()) { - f.update_gpu(in); - } - f.wait_gpu(out); - - auto task = [in = in, - out = unsafe_weak_copy(out), - f = std::move(f), - group = group()]() mutable { - if (in.event().valid()) { - f.wait(); - } - distributed::detail::all_gather(group, in, out); - f.update(); - }; - scheduler::enqueue(detail::communication_stream(), std::move(task)); +void Send::eval_gpu(const std::vector&, std::vector&) { + throw std::runtime_error("[Send::eval_gpu] has no GPU implementation."); } -void Send::eval_gpu( - const std::vector& inputs, - std::vector& outputs) { - assert(inputs.size() == 1); - assert(outputs.size() == 1); - - auto& in = inputs[0]; - - // Encode a signal event for the input - Fence f{stream()}; - if (in.event().valid()) { - f.update_gpu(in); - } - - auto& out = outputs[0]; - move_or_copy(in, out); - - // Schedule an async send on the comm stream - auto task = [in = in, - out = unsafe_weak_copy(out), - f = std::move(f), - group = group(), - dst = dst_]() mutable { - if (in.event().valid()) { - f.wait(); - } - distributed::detail::send(group, out, dst); - }; - scheduler::enqueue(detail::communication_stream(), std::move(task)); -} - -void Recv::eval_gpu( - const std::vector& inputs, - std::vector& outputs) { - assert(inputs.size() == 0); - assert(outputs.size() == 1); - - auto& out = outputs[0]; - - out.set_data(allocator::malloc_or_wait(out.nbytes())); - - Fence f{stream()}; - f.wait_gpu(out); - - // Schedule an async recv on the comm stream - auto task = [out = unsafe_weak_copy(out), - f = std::move(f), - group = group(), - src = src_]() mutable { - distributed::detail::recv(group, out, src); - f.update(); - }; - scheduler::enqueue(detail::communication_stream(), std::move(task)); +void Recv::eval_gpu(const std::vector&, std::vector&) { + throw std::runtime_error("[Recv::eval_gpu] has no GPU implementation."); } } // namespace mlx::core::distributed diff --git a/mlx/backend/metal/event.cpp b/mlx/backend/metal/event.cpp index e4b78985e..58d1521c9 100644 --- a/mlx/backend/metal/event.cpp +++ b/mlx/backend/metal/event.cpp @@ -3,52 +3,58 @@ #include "mlx/event.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/metal_impl.h" +#include "mlx/scheduler.h" namespace mlx::core { -void encode_wait(Event e) { - auto& d = metal::device(e.stream().device); - d.end_encoding(e.stream().index); - auto command_buffer = d.get_command_buffer(e.stream().index); - command_buffer->encodeWait( - static_cast(e.raw_event().get()), e.value()); - command_buffer->addCompletedHandler( - [e = std::move(e)](MTL::CommandBuffer* cbuf) {}); -} - -void encode_signal(Event e) { - auto& d = metal::device(e.stream().device); - d.end_encoding(e.stream().index); - auto command_buffer = d.get_command_buffer(e.stream().index); - command_buffer->encodeSignalEvent( - static_cast(e.raw_event().get()), e.value()); - command_buffer->addCompletedHandler( - [e = std::move(e)](MTL::CommandBuffer* cbuf) {}); -} - -Event::Event(const Stream& stream) : stream_(stream) { +Event::Event(Stream stream) : stream_(stream) { auto dtor = [](void* ptr) { auto p = metal::new_scoped_memory_pool(); static_cast(ptr)->release(); }; auto p = metal::new_scoped_memory_pool(); event_ = std::shared_ptr( - metal::device(stream.device).mtl_device()->newSharedEvent(), dtor); + metal::device(Device::gpu).mtl_device()->newSharedEvent(), dtor); } void Event::wait() { - if (!static_cast(raw_event().get()) + if (!static_cast(event_.get()) ->waitUntilSignaledValue(value(), -1)) { throw std::runtime_error("[Event::wait] Timed out"); } } void Event::signal() { - static_cast(raw_event().get())->setSignaledValue(value()); + static_cast(event_.get())->setSignaledValue(value()); +} + +void Event::wait(Stream stream) { + if (stream.device == Device::cpu) { + scheduler::enqueue(stream, [*this]() mutable { wait(); }); + } else { + auto& d = metal::device(stream.device); + d.end_encoding(stream.index); + auto command_buffer = d.get_command_buffer(stream.index); + command_buffer->encodeWait(static_cast(event_.get()), value()); + command_buffer->addCompletedHandler([*this](MTL::CommandBuffer*) {}); + } +} + +void Event::signal(Stream stream) { + if (stream.device == Device::cpu) { + scheduler::enqueue(stream, [*this]() mutable { signal(); }); + } else { + auto& d = metal::device(stream.device); + d.end_encoding(stream.index); + auto command_buffer = d.get_command_buffer(stream.index); + command_buffer->encodeSignalEvent( + static_cast(event_.get()), value()); + command_buffer->addCompletedHandler([*this](MTL::CommandBuffer*) {}); + } } bool Event::is_signaled() const { - return static_cast(raw_event().get())->signaledValue() >= + return static_cast(event_.get())->signaledValue() >= value(); } diff --git a/mlx/backend/metal/event.h b/mlx/backend/metal/event.h deleted file mode 100644 index b7f29e73d..000000000 --- a/mlx/backend/metal/event.h +++ /dev/null @@ -1,10 +0,0 @@ -// Copyright © 2024 Apple Inc. -#pragma once - -namespace mlx::core { - -void encode_wait(Event e); - -void encode_signal(Event e); - -} // namespace mlx::core diff --git a/mlx/backend/metal/fence.cpp b/mlx/backend/metal/fence.cpp index 54ce8ea8d..f5502231a 100644 --- a/mlx/backend/metal/fence.cpp +++ b/mlx/backend/metal/fence.cpp @@ -1,49 +1,81 @@ // Copyright © 2024 Apple Inc. - -#include "mlx/backend/metal/fence.h" +#include "mlx/fence.h" #include "mlx/backend/metal/device.h" -#include "mlx/backend/metal/event.h" #include "mlx/backend/metal/metal_impl.h" +#include "mlx/scheduler.h" #include "mlx/utils.h" namespace mlx::core { -Fence::Fence(const Stream& stream) : stream_(stream) { - auto d = metal::device(stream.device).mtl_device(); - if (!d->supportsFamily(MTL::GPUFamilyMetal3)) { - use_fast_ = false; - } else if (__builtin_available(macOS 15, iOS 18, *)) { - use_fast_ = env::metal_fast_synch(); - } - if (!use_fast_) { - // Wraps Metal SharedEvent - auto dtor = [](void* ptr) { +struct FenceImpl { + FenceImpl() { + auto d = metal::device(Device::gpu).mtl_device(); + if (!d->supportsFamily(MTL::GPUFamilyMetal3)) { + use_fast = false; + } else if (__builtin_available(macOS 15, iOS 18, *)) { + use_fast = env::metal_fast_synch(); + } + + if (!use_fast) { auto p = metal::new_scoped_memory_pool(); - static_cast(ptr)->release(); - }; - auto p = metal::new_scoped_memory_pool(); - fence_ = std::shared_ptr( - metal::device(stream.device).mtl_device()->newSharedEvent(), dtor); - } else { - auto dtor = [](void* buf) { - allocator::free(static_cast(buf)); - }; - auto buf = allocator::malloc_or_wait(sizeof(uint32_t)).ptr(); - fence_ = std::shared_ptr(buf, dtor); - cpu_value()[0] = 0; + fence = static_cast(d->newSharedEvent()); + } else { + auto buf = allocator::malloc_or_wait(sizeof(uint32_t)).ptr(); + fence = static_cast(buf); + cpu_value()[0] = 0; + } } + + ~FenceImpl() { + if (!use_fast) { + // Wraps Metal SharedEvent + auto p = metal::new_scoped_memory_pool(); + static_cast(fence)->release(); + } else { + allocator::free(static_cast(fence)); + } + } + bool use_fast{false}; + uint32_t count{0}; + void* fence; + + std::atomic_uint* cpu_value() { + return static_cast( + static_cast(fence)->contents()); + } +}; + +Fence::Fence(Stream) { + auto dtor = [](void* ptr) { delete static_cast(ptr); }; + fence_ = std::shared_ptr(new FenceImpl{}, dtor); } -void Fence::wait_gpu(array& x) { - gpu_count_++; - auto& d = metal::device(stream_.device); - auto idx = stream_.index; +void Fence::wait(Stream stream, const array& x) { + auto& f = *static_cast(fence_.get()); - if (!use_fast_) { + if (stream.device == Device::cpu) { + scheduler::enqueue(stream, [fence_ = fence_, count = f.count]() mutable { + auto& f = *static_cast(fence_.get()); + if (!f.use_fast) { + if (!static_cast(f.fence)->waitUntilSignaledValue( + count, -1)) { + throw std::runtime_error("[Fence::wait] Timed out"); + } + return; + } + while (f.cpu_value()[0] < count) { + } + }); + return; + } + + auto& d = metal::device(stream.device); + auto idx = stream.index; + + if (!f.use_fast) { d.end_encoding(idx); auto command_buffer = d.get_command_buffer(idx); - command_buffer->encodeWait( - static_cast(fence_.get()), gpu_count_); + command_buffer->encodeWait(static_cast(f.fence), f.count); command_buffer->addCompletedHandler( [fence_ = fence_](MTL::CommandBuffer* cbuf) {}); return; @@ -51,7 +83,7 @@ void Fence::wait_gpu(array& x) { auto& compute_encoder = d.get_command_encoder(idx); - // Register the output to ensure that no kernels which depends on the + // Register outputs to ensure that no kernels which depends on the // output starts before this one is done compute_encoder.register_output_array(x); @@ -59,36 +91,49 @@ void Fence::wait_gpu(array& x) { MTL::Size kernel_dims = MTL::Size(1, 1, 1); compute_encoder.set_compute_pipeline_state(kernel); - auto buf = static_cast(fence_.get()); + auto buf = static_cast(f.fence); compute_encoder.set_buffer(buf, 0); - compute_encoder.set_bytes(gpu_count_, 1); + compute_encoder.set_bytes(f.count, 1); compute_encoder.dispatch_threads(kernel_dims, kernel_dims); d.get_command_buffer(idx)->addCompletedHandler( - [fence = fence_](MTL::CommandBuffer* cbuf) {}); + [fence_ = fence_](MTL::CommandBuffer* cbuf) {}); } -void Fence::update_gpu(const array& x) { - gpu_count_++; - auto& d = metal::device(stream_.device); - auto idx = stream_.index; +void Fence::update(Stream stream, const array& x) { + auto& f = *static_cast(fence_.get()); + f.count++; - if (!use_fast_) { + if (stream.device == Device::cpu) { + scheduler::enqueue(stream, [fence_ = fence_, count = f.count]() mutable { + auto& f = *static_cast(fence_.get()); + if (!f.use_fast) { + static_cast(f.fence)->setSignaledValue(count); + return; + } + + f.cpu_value()[0] = count; + }); + return; + } + + auto& d = metal::device(stream.device); + auto idx = stream.index; + if (!f.use_fast) { d.end_encoding(idx); auto command_buffer = d.get_command_buffer(idx); command_buffer->encodeSignalEvent( - static_cast(fence_.get()), gpu_count_); + static_cast(f.fence), f.count); command_buffer->addCompletedHandler( [fence_ = fence_](MTL::CommandBuffer* cbuf) {}); return; } - // Launch input visibility kernel + // Launch input visibility kernels auto& compute_encoder = d.get_command_encoder(idx); auto kernel = d.get_kernel("input_coherent"); uint32_t nthreads = (x.data_size() * x.itemsize() + sizeof(uint32_t) - 1) / sizeof(uint32_t); - MTL::Size group_dims = MTL::Size(1024, 1, 1); MTL::Size grid_dims = MTL::Size((nthreads + 1024 - 1) / 1024, 1, 1); compute_encoder.set_compute_pipeline_state(kernel); @@ -96,7 +141,7 @@ void Fence::update_gpu(const array& x) { compute_encoder.set_bytes(nthreads, 1); compute_encoder.dispatch_threadgroups(group_dims, grid_dims); - // Barrier on previous kernel + // Barrier on previous kernels compute_encoder.barrier(); // Launch value update kernel @@ -104,42 +149,13 @@ void Fence::update_gpu(const array& x) { MTL::Size kernel_dims = MTL::Size(1, 1, 1); compute_encoder.set_compute_pipeline_state(kernel); - auto buf = static_cast(fence_.get()); + auto buf = static_cast(f.fence); compute_encoder.set_buffer(buf, 0); - compute_encoder.set_bytes(gpu_count_, 1); + compute_encoder.set_bytes(f.count, 1); compute_encoder.dispatch_threads(kernel_dims, kernel_dims); d.get_command_buffer(idx)->addCompletedHandler( - [fence = fence_](MTL::CommandBuffer* cbuf) {}); -} - -void Fence::wait() { - cpu_count_++; - if (!use_fast_) { - if (!static_cast(fence_.get()) - ->waitUntilSignaledValue(cpu_count_, -1)) { - throw std::runtime_error("[Event::wait] Timed out"); - } - return; - } - while (cpu_value()[0] < cpu_count_) { - } -} - -void Fence::update() { - cpu_count_++; - - if (!use_fast_) { - static_cast(fence_.get())->setSignaledValue(cpu_count_); - return; - } - - cpu_value()[0] = cpu_count_; -} - -std::atomic_uint* Fence::cpu_value() { - return static_cast( - static_cast(fence_.get())->contents()); + [fence_ = fence_](MTL::CommandBuffer* cbuf) {}); } } // namespace mlx::core diff --git a/mlx/backend/metal/fence.h b/mlx/backend/metal/fence.h deleted file mode 100644 index 28e04e643..000000000 --- a/mlx/backend/metal/fence.h +++ /dev/null @@ -1,40 +0,0 @@ -// Copyright © 2024 Apple Inc. - -#include "mlx/backend/metal/device.h" - -namespace mlx::core { - -/* A fence to be used for synchronizing work between the CPU and GPU - * - * Calls to `update_gpu` should be paired with calls to `wait`. This ensures - * that the array passed to `update_gpu` is computed and visible to the CPU - * after the call to `wait` returns. - * - * Calls to `update` should be paired with calls to `wait_gpu`. This ensures - * that the array passed to `wait_gpu` will not be read by the GPU until the CPU - * has called `update`. - * - * The fence supports slow (default) and fast mode. Fast mode requires setting - * the environment variable `MLX_METAL_FAST_SYNCH=1`. Fast mode also requires - * Metal 3.2+ (macOS 15+, iOS 18+). - */ -class Fence { - public: - Fence(const Stream& stream); - - void update_gpu(const array& x); - void wait_gpu(array& x); - - void wait(); - void update(); - - private: - Stream stream_; - std::shared_ptr fence_; - uint32_t cpu_count_{0}; - uint32_t gpu_count_{0}; - bool use_fast_; - std::atomic_uint* cpu_value(); -}; - -} // namespace mlx::core diff --git a/mlx/backend/metal/hadamard.cpp b/mlx/backend/metal/hadamard.cpp index 7dcc761af..7b711e28b 100644 --- a/mlx/backend/metal/hadamard.cpp +++ b/mlx/backend/metal/hadamard.cpp @@ -82,7 +82,7 @@ void Hadamard::eval_gpu(const std::vector& inputs, array& out) { const array& in_contiguous = check_input(in); if (in_contiguous.is_donatable()) { - out.move_shared_buffer(in_contiguous); + out.copy_shared_buffer(in_contiguous); } else { out.set_data(allocator::malloc_or_wait(out.nbytes())); } diff --git a/mlx/backend/metal/metal.cpp b/mlx/backend/metal/metal.cpp index 4933a0855..ba13b4b59 100644 --- a/mlx/backend/metal/metal.cpp +++ b/mlx/backend/metal/metal.cpp @@ -2,7 +2,6 @@ #include #include "mlx/backend/metal/device.h" -#include "mlx/backend/metal/event.h" #include "mlx/backend/metal/utils.h" #include "mlx/primitives.h" #include "mlx/scheduler.h" @@ -23,88 +22,78 @@ inline void check_error(MTL::CommandBuffer* cbuf) { } } -std::function make_task(array arr, bool signal) { - auto task = [arr = std::move(arr), signal]() mutable { - auto pool = new_scoped_memory_pool(); - auto s = arr.primitive().stream(); - auto& d = metal::device(s.device); - auto command_buffer = d.get_command_buffer(s.index); +void eval(array& arr) { + auto pool = new_scoped_memory_pool(); + auto s = arr.primitive().stream(); + auto& d = metal::device(s.device); + auto command_buffer = d.get_command_buffer(s.index); - for (auto& input : arr.inputs()) { - if (input.event().valid() && - input.event().stream() != arr.primitive().stream()) { - input.event().wait(); - } + auto outputs = arr.outputs(); + { + // If the array is a tracer hold a reference + // to its inputs so they don't get donated + std::vector inputs; + if (arr.is_tracer()) { + inputs = arr.inputs(); } - auto outputs = arr.outputs(); - { - // If the array is a tracer hold a reference - // to its inputs so they don't get donated - std::vector inputs; - if (arr.is_tracer()) { - inputs = arr.inputs(); - } + debug_set_primitive_buffer_label(command_buffer, arr.primitive()); + arr.primitive().eval_gpu(arr.inputs(), outputs); + } + std::unordered_set> buffers; + for (auto& in : arr.inputs()) { + buffers.insert(in.data_shared_ptr()); + } + for (auto& s : arr.siblings()) { + buffers.insert(s.data_shared_ptr()); + } + // Remove the output if it was donated to by an input + if (auto it = buffers.find(arr.data_shared_ptr()); it != buffers.end()) { + buffers.erase(it); + } - debug_set_primitive_buffer_label(command_buffer, arr.primitive()); - try { - arr.primitive().eval_gpu(arr.inputs(), outputs); - } catch (const std::exception& error) { - abort_with_exception(error); - } - } - std::vector> buffers; - for (auto& in : arr.inputs()) { - buffers.push_back(in.data_shared_ptr()); - } - for (auto& s : arr.siblings()) { - buffers.push_back(s.data_shared_ptr()); - } - if (!arr.is_tracer()) { - arr.detach(); - } - for (auto& out : outputs) { - out.set_status(array::Status::evaluated); - } - - if (signal || d.command_buffer_needs_commit(s.index)) { - if (signal) { - encode_signal(arr.event()); - } - d.end_encoding(s.index); - scheduler::notify_new_task(s); - command_buffer->addCompletedHandler( - [s, buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) { - scheduler::notify_task_completion(s); - check_error(cbuf); - }); - d.commit_command_buffer(s.index); - d.get_command_buffer(s.index); - } else { - command_buffer->addCompletedHandler( - [s, buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) { - check_error(cbuf); - }); - } - }; - return task; + if (d.command_buffer_needs_commit(s.index)) { + d.end_encoding(s.index); + scheduler::notify_new_task(s); + command_buffer->addCompletedHandler( + [s, buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) { + scheduler::notify_task_completion(s); + check_error(cbuf); + }); + d.commit_command_buffer(s.index); + d.get_command_buffer(s.index); + } else { + command_buffer->addCompletedHandler( + [s, buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) { + check_error(cbuf); + }); + } } -std::function make_synchronize_task( - Stream s, - std::shared_ptr> p) { - return [s, p = std::move(p)]() { - auto pool = new_scoped_memory_pool(); - auto& d = metal::device(s.device); - auto cb = d.get_command_buffer(s.index); - cb->retain(); - d.end_encoding(s.index); - d.commit_command_buffer(s.index); - cb->waitUntilCompleted(); - check_error(cb); - cb->release(); - p->set_value(); - }; +void finalize(Stream s) { + auto pool = new_scoped_memory_pool(); + auto& d = metal::device(s.device); + auto cb = d.get_command_buffer(s.index); + d.end_encoding(s.index); + scheduler::notify_new_task(s); + cb->addCompletedHandler([s](MTL::CommandBuffer* cbuf) { + scheduler::notify_task_completion(s); + check_error(cbuf); + }); + d.commit_command_buffer(s.index); + d.get_command_buffer(s.index); +} + +void synchronize(Stream s) { + auto pool = new_scoped_memory_pool(); + auto& d = metal::device(s.device); + auto cb = d.get_command_buffer(s.index); + cb->retain(); + d.end_encoding(s.index); + d.commit_command_buffer(s.index); + cb->waitUntilCompleted(); + check_error(cb); + cb->release(); } void start_capture(std::string path, id object) { diff --git a/mlx/backend/metal/metal_impl.h b/mlx/backend/metal/metal_impl.h index 5c0a14357..9ca8d2f80 100644 --- a/mlx/backend/metal/metal_impl.h +++ b/mlx/backend/metal/metal_impl.h @@ -14,10 +14,8 @@ void new_stream(Stream stream); std::unique_ptr> new_scoped_memory_pool(); -std::function make_task(array arr, bool signal); - -std::function make_synchronize_task( - Stream s, - std::shared_ptr> p); +void eval(array& arr); +void finalize(Stream s); +void synchronize(Stream s); } // namespace mlx::core::metal diff --git a/mlx/backend/metal/normalization.cpp b/mlx/backend/metal/normalization.cpp index 3f5216392..b9a82b876 100644 --- a/mlx/backend/metal/normalization.cpp +++ b/mlx/backend/metal/normalization.cpp @@ -18,33 +18,33 @@ void RMSNorm::eval_gpu( auto& out = outputs[0]; // Make sure that the last dimension is contiguous - std::vector copies; - auto check_input = [&copies, &s](const array& x) -> const array& { + auto set_output = [&s, &out](const array& x) { bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1; if (no_copy && x.ndim() > 1) { auto s = x.strides()[x.ndim() - 2]; no_copy &= (s == 0 || s == x.shape().back()); } if (no_copy) { + if (x.is_donatable()) { + out.copy_shared_buffer(x); + } else { + out.set_data( + allocator::malloc_or_wait(x.data_size() * x.itemsize()), + x.data_size(), + x.strides(), + x.flags()); + } return x; } else { - copies.push_back(array(x.shape(), x.dtype(), nullptr, {})); - copy_gpu(x, copies.back(), CopyType::General, s); - return copies.back(); + auto x_copy = array(x.shape(), x.dtype(), nullptr, {}); + copy_gpu(x, x_copy, CopyType::General, s); + out.copy_shared_buffer(x_copy); + return x_copy; } }; - const array& x = check_input(inputs[0]); - const array& w = inputs[1]; - if (x.is_donatable()) { - out.move_shared_buffer(x); - } else { - out.set_data( - allocator::malloc_or_wait(x.data_size() * x.itemsize()), - x.data_size(), - x.strides(), - x.flags()); - } + const array x = set_output(inputs[0]); + const array& w = inputs[1]; auto axis_size = static_cast(x.shape().back()); int n_rows = x.data_size() / axis_size; @@ -88,8 +88,6 @@ void RMSNorm::eval_gpu( compute_encoder.set_bytes(w_stride, 5); compute_encoder.dispatch_threads(grid_dims, group_dims); } - - d.add_temporaries(std::move(copies), s.index); } void RMSNormVJP::eval_gpu( @@ -101,20 +99,22 @@ void RMSNormVJP::eval_gpu( // Ensure row contiguity. We could relax this step by checking that the array // is contiguous (no broadcasts or holes) and that the input strides are the // same as the cotangent strides but for now this is simpler. - auto check_input = [&d, &s](const array& x) -> array { + auto check_input = [&d, &s](const array& x) -> std::pair { if (x.flags().row_contiguous) { - return x; + return {x, false}; } array x_copy(x.shape(), x.dtype(), nullptr, {}); copy_gpu(x, x_copy, CopyType::General, s); - d.add_temporary(x_copy, s.index); - - return x_copy; + return {x_copy, true}; }; - const array& x = check_input(inputs[0]); + bool donate_x = inputs[0].is_donatable(); + bool donate_g = inputs[2].is_donatable(); + auto [x, copied] = check_input(inputs[0]); + donate_x |= copied; const array& w = inputs[1]; - const array& g = check_input(inputs[2]); + auto [g, g_copied] = check_input(inputs[2]); + donate_g |= g_copied; array& gx = outputs[0]; array& gw = outputs[1]; @@ -122,17 +122,18 @@ void RMSNormVJP::eval_gpu( bool has_w = w.ndim() != 0; // Allocate space for the outputs - bool x_in_gx = false; bool g_in_gx = false; if (x.is_donatable()) { - gx.move_shared_buffer(x); - x_in_gx = true; + gx.copy_shared_buffer(x); } else if (g.is_donatable()) { - gx.move_shared_buffer(g); + gx.copy_shared_buffer(g); g_in_gx = true; } else { gx.set_data(allocator::malloc_or_wait(gx.nbytes())); } + if (g_copied && !g_in_gx) { + d.add_temporary(g, s.index); + } auto axis_size = static_cast(x.shape().back()); int n_rows = x.data_size() / axis_size; @@ -141,11 +142,9 @@ void RMSNormVJP::eval_gpu( // gradients before they are accumulated. array gw_temp = (has_w) ? array({n_rows, x.shape().back()}, gw.dtype(), nullptr, {}) : w; - bool g_in_gw = false; if (has_w) { - if (!g_in_gx && g.is_donatable()) { - gw_temp.move_shared_buffer(g); - g_in_gw = true; + if (!g_in_gx && donate_g) { + gw_temp.copy_shared_buffer(g); } else { gw_temp.set_data(allocator::malloc_or_wait(gw_temp.nbytes())); d.add_temporary(gw_temp, s.index); @@ -189,9 +188,9 @@ void RMSNormVJP::eval_gpu( uint32_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0; compute_encoder.set_compute_pipeline_state(kernel); - compute_encoder.set_input_array(x_in_gx ? gx : x, 0); + compute_encoder.set_input_array(x, 0); compute_encoder.set_input_array(w, 1); - compute_encoder.set_input_array(g_in_gx ? gx : (g_in_gw ? gw_temp : g), 2); + compute_encoder.set_input_array(g, 2); compute_encoder.set_output_array(gx, 3); compute_encoder.set_output_array(gw_temp, 4); compute_encoder.set_bytes(eps_, 5); @@ -216,35 +215,35 @@ void LayerNorm::eval_gpu( auto& out = outputs[0]; // Make sure that the last dimension is contiguous - std::vector copies; - auto check_input = [&copies, &s](const array& x) -> const array& { + auto set_output = [&s, &out](const array& x) { bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1; if (no_copy && x.ndim() > 1) { auto s = x.strides()[x.ndim() - 2]; no_copy &= (s == 0 || s == x.shape().back()); } if (no_copy) { + if (x.is_donatable()) { + out.copy_shared_buffer(x); + } else { + out.set_data( + allocator::malloc_or_wait(x.data_size() * x.itemsize()), + x.data_size(), + x.strides(), + x.flags()); + } return x; } else { - copies.push_back(array(x.shape(), x.dtype(), nullptr, {})); - copy_gpu(x, copies.back(), CopyType::General, s); - return copies.back(); + auto x_copy = array(x.shape(), x.dtype(), nullptr, {}); + copy_gpu(x, x_copy, CopyType::General, s); + out.copy_shared_buffer(x_copy); + return x_copy; } }; - const array& x = check_input(inputs[0]); + + const array x = set_output(inputs[0]); const array& w = inputs[1]; const array& b = inputs[2]; - if (x.is_donatable()) { - out.move_shared_buffer(x); - } else { - out.set_data( - allocator::malloc_or_wait(x.data_size() * x.itemsize()), - x.data_size(), - x.strides(), - x.flags()); - } - auto axis_size = static_cast(x.shape().back()); int n_rows = x.data_size() / axis_size; @@ -290,8 +289,6 @@ void LayerNorm::eval_gpu( compute_encoder.set_bytes(b_stride, 7); compute_encoder.dispatch_threads(grid_dims, group_dims); } - - d.add_temporaries(std::move(copies), s.index); } void LayerNormVJP::eval_gpu( @@ -303,21 +300,22 @@ void LayerNormVJP::eval_gpu( // Ensure row contiguity. We could relax this step by checking that the array // is contiguous (no broadcasts or holes) and that the input strides are the // same as the cotangent strides but for now this is simpler. - auto check_input = [&d, &s](const array& x) -> array { + auto check_input = [&s](const array& x) -> std::pair { if (x.flags().row_contiguous) { - return x; + return {x, false}; } - array x_copy(x.shape(), x.dtype(), nullptr, {}); copy_gpu(x, x_copy, CopyType::General, s); - d.add_temporary(x_copy, s.index); - - return x_copy; + return {x_copy, true}; }; - const array& x = check_input(inputs[0]); + bool donate_x = inputs[0].is_donatable(); + bool donate_g = inputs[3].is_donatable(); + auto [x, copied] = check_input(inputs[0]); + donate_x |= copied; const array& w = inputs[1]; const array& b = inputs[2]; - const array& g = check_input(inputs[3]); + auto [g, g_copied] = check_input(inputs[3]); + donate_g |= g_copied; array& gx = outputs[0]; array& gw = outputs[1]; array& gb = outputs[2]; @@ -326,17 +324,18 @@ void LayerNormVJP::eval_gpu( bool has_w = w.ndim() != 0; // Allocate space for the outputs - bool x_in_gx = false; bool g_in_gx = false; - if (x.is_donatable()) { - gx.move_shared_buffer(x); - x_in_gx = true; - } else if (g.is_donatable()) { - gx.move_shared_buffer(g); + if (donate_x) { + gx.copy_shared_buffer(x); + } else if (donate_g) { + gx.copy_shared_buffer(g); g_in_gx = true; } else { gx.set_data(allocator::malloc_or_wait(gx.nbytes())); } + if (g_copied && !g_in_gx) { + d.add_temporary(g, s.index); + } auto axis_size = static_cast(x.shape().back()); int n_rows = x.data_size() / axis_size; @@ -345,15 +344,13 @@ void LayerNormVJP::eval_gpu( // gradient accumulators. array gw_temp = (has_w) ? array({n_rows, x.shape().back()}, gw.dtype(), nullptr, {}) : w; - bool g_in_gw = false; if (has_w) { - if (!g_in_gx && g.is_donatable()) { - gw_temp.move_shared_buffer(g); - g_in_gw = true; + if (!g_in_gx && donate_g) { + gw_temp.copy_shared_buffer(g); } else { gw_temp.set_data(allocator::malloc_or_wait(gw_temp.nbytes())); + d.add_temporary(gw_temp, s.index); } - d.add_temporary(gw_temp, s.index); } gw.set_data(allocator::malloc_or_wait(gw.nbytes())); gb.set_data(allocator::malloc_or_wait(gb.nbytes())); @@ -364,14 +361,7 @@ void LayerNormVJP::eval_gpu( ReductionPlan plan( ReductionOpType::ContiguousStridedReduce, {n_rows}, {axis_size}); strided_reduce_general_dispatch( - g_in_gx ? gx : (g_in_gw ? gw_temp : g), - gb, - "sum", - plan, - {0}, - compute_encoder, - d, - s); + g, gb, "sum", plan, {0}, compute_encoder, d, s); } const int simd_size = 32; @@ -409,9 +399,9 @@ void LayerNormVJP::eval_gpu( uint32_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0; compute_encoder.set_compute_pipeline_state(kernel); - compute_encoder.set_input_array(x_in_gx ? gx : x, 0); + compute_encoder.set_input_array(x, 0); compute_encoder.set_input_array(w, 1); - compute_encoder.set_input_array(g_in_gx ? gx : (g_in_gw ? gw_temp : g), 2); + compute_encoder.set_input_array(g, 2); compute_encoder.set_output_array(gx, 3); compute_encoder.set_output_array(gw_temp, 4); compute_encoder.set_bytes(eps_, 5); diff --git a/mlx/backend/metal/primitives.cpp b/mlx/backend/metal/primitives.cpp index df8638012..76f7e3946 100644 --- a/mlx/backend/metal/primitives.cpp +++ b/mlx/backend/metal/primitives.cpp @@ -5,12 +5,10 @@ #include #include "mlx/backend/common/compiled.h" -#include "mlx/backend/common/load.h" #include "mlx/backend/common/slicing.h" #include "mlx/backend/common/utils.h" #include "mlx/backend/metal/copy.h" #include "mlx/backend/metal/device.h" -#include "mlx/backend/metal/event.h" #include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/slicing.h" #include "mlx/backend/metal/utils.h" @@ -45,7 +43,8 @@ void reshape(const array& in, array& out, Stream s) { shared_buffer_reshape(in, out_strides, out); } } -array compute_dynamic_offset( + +static array compute_dynamic_offset( const array& indices, const Strides& strides, const std::vector& axes, @@ -57,7 +56,7 @@ array compute_dynamic_offset( bool donate = indices.is_donatable() && (indices.data_size() * indices.itemsize()) >= offset.itemsize(); if (donate) { - offset.move_shared_buffer(indices); + offset.copy_shared_buffer(indices); } else { offset.set_data(allocator::malloc_or_wait(offset.itemsize())); } @@ -88,7 +87,7 @@ array compute_dynamic_offset( auto& compute_encoder = d.get_command_encoder(s.index); compute_encoder.set_compute_pipeline_state(kernel); - compute_encoder.set_input_array(donate ? offset : indices, 0); + compute_encoder.set_input_array(indices, 0); compute_encoder.set_output_array(offset, 1); compute_encoder.set_vector_bytes(strides, 2); compute_encoder.set_vector_bytes(axes, 3); @@ -254,7 +253,7 @@ void Contiguous::eval_gpu(const std::vector& inputs, array& out) { auto& in = inputs[0]; if (in.flags().row_contiguous || (allow_col_major_ && in.flags().col_contiguous)) { - move_or_copy(in, out); + out.copy_shared_buffer(in); } else { copy_gpu(in, out, CopyType::General); } @@ -302,31 +301,7 @@ void Unflatten::eval_gpu(const std::vector& inputs, array& out) { } void Load::eval_gpu(const std::vector& inputs, array& out) { - out.set_data(allocator::malloc_or_wait(out.nbytes())); - auto read_task = [out = unsafe_weak_copy(out), - offset = offset_, - reader = reader_, - swap_endianness = swap_endianness_]() mutable { - load(out, offset, reader, swap_endianness); - }; - - // Limit the size that the command buffer will wait on to avoid timing out - // on the event (<4 seconds). - if (out.nbytes() > (1 << 28)) { - read_task(); - return; - } - - auto fut = io::thread_pool().enqueue(std::move(read_task)).share(); - - auto e = Event(stream()); - e.set_value(1); - encode_wait(e); - auto signal_task = [e = std::move(e), fut = std::move(fut)]() mutable { - fut.wait(); - e.signal(); - }; - scheduler::enqueue(io_stream(), std::move(signal_task)); + throw std::runtime_error("[Load::eval_gpu] Not implemented."); } void NumberOfElements::eval_gpu(const std::vector& inputs, array& out) { @@ -452,7 +427,7 @@ void DynamicSliceUpdate::eval_gpu( auto& start_indices = inputs[2]; if (upd.size() == 0) { - move_or_copy(in, out); + out.copy_shared_buffer(in); return; } @@ -491,7 +466,7 @@ void SliceUpdate::eval_gpu(const std::vector& inputs, array& out) { auto& upd = inputs[1]; if (upd.size() == 0) { - move_or_copy(in, out); + out.copy_shared_buffer(in); return; } @@ -575,8 +550,8 @@ void View::eval_gpu(const std::vector& inputs, array& out) { strides[i] *= ibytes; strides[i] /= obytes; } - move_or_copy( - in, out, strides, in.flags(), in.data_size() * ibytes / obytes); + out.copy_shared_buffer( + in, strides, in.flags(), in.data_size() * ibytes / obytes); } else { auto tmp = array(in.shape(), in.dtype(), nullptr, {}); tmp.set_data(allocator::malloc_or_wait(tmp.nbytes())); @@ -587,7 +562,7 @@ void View::eval_gpu(const std::vector& inputs, array& out) { flags.row_contiguous = true; auto max_dim = std::max_element(out.shape().begin(), out.shape().end()); flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim; - out.move_shared_buffer(tmp, out.strides(), flags, out.size()); + out.copy_shared_buffer(tmp, out.strides(), flags, out.size()); } } diff --git a/mlx/backend/metal/rope.cpp b/mlx/backend/metal/rope.cpp index c6da9278a..ce3384c54 100644 --- a/mlx/backend/metal/rope.cpp +++ b/mlx/backend/metal/rope.cpp @@ -41,7 +41,7 @@ void RoPE::eval_gpu( } else if (in.flags().row_contiguous) { if (in.is_donatable()) { donated = true; - out.move_shared_buffer(in); + out.copy_shared_buffer(in); } else { out.set_data(allocator::malloc_or_wait(out.nbytes())); } diff --git a/mlx/backend/metal/scaled_dot_product_attention.cpp b/mlx/backend/metal/scaled_dot_product_attention.cpp index a349fd031..938d7afa6 100644 --- a/mlx/backend/metal/scaled_dot_product_attention.cpp +++ b/mlx/backend/metal/scaled_dot_product_attention.cpp @@ -152,7 +152,7 @@ void sdpa_vector( compute_encoder.set_compute_pipeline_state(kernel); // Set its arguments - compute_encoder.set_input_array(q.data_shared_ptr() == nullptr ? out : q, 0); + compute_encoder.set_input_array(q, 0); compute_encoder.set_input_array(k, 1); compute_encoder.set_input_array(v, 2); compute_encoder.set_output_array(out, 3); @@ -242,7 +242,7 @@ void sdpa_vector_2pass( compute_encoder.set_compute_pipeline_state(kernel); // Set its arguments - compute_encoder.set_input_array(q.data_shared_ptr() == nullptr ? out : q, 0); + compute_encoder.set_input_array(q, 0); compute_encoder.set_input_array(k, 1); compute_encoder.set_input_array(v, 2); compute_encoder.set_output_array(intermediate, 3); @@ -357,7 +357,7 @@ void ScaledDotProductAttention::eval_gpu( // Donate the query if possible if (q.is_donatable() && (q.shape(2) == 1 || !q.flags().row_contiguous) && q.size() == o.size()) { - o.move_shared_buffer(q); + o.copy_shared_buffer(q); } else { if (o.shape(2) == 1) { o.set_data(allocator::malloc_or_wait(o.nbytes())); diff --git a/mlx/backend/metal/scan.cpp b/mlx/backend/metal/scan.cpp index b31fdf4f9..dd123662f 100644 --- a/mlx/backend/metal/scan.cpp +++ b/mlx/backend/metal/scan.cpp @@ -21,7 +21,7 @@ void Scan::eval_gpu(const std::vector& inputs, array& out) { auto in = inputs[0]; if (in.flags().contiguous && in.strides()[axis_] != 0) { if (donate && in.itemsize() == out.itemsize()) { - out.move_shared_buffer(in); + out.copy_shared_buffer(in); } else { out.set_data( allocator::malloc_or_wait(in.data_size() * out.itemsize()), @@ -33,7 +33,7 @@ void Scan::eval_gpu(const std::vector& inputs, array& out) { array arr_copy(in.shape(), in.dtype(), nullptr, {}); copy_gpu(in, arr_copy, CopyType::General, s); in = std::move(arr_copy); - out.move_shared_buffer(in); + out.copy_shared_buffer(in); } bool contiguous = in.strides()[axis_] == 1; @@ -68,8 +68,7 @@ void Scan::eval_gpu(const std::vector& inputs, array& out) { if (contiguous) { auto& compute_encoder = d.get_command_encoder(s.index); compute_encoder.set_compute_pipeline_state(kernel); - compute_encoder.set_input_array( - in.data_shared_ptr() == nullptr ? out : in, 0); + compute_encoder.set_input_array(in, 0); compute_encoder.set_output_array(out, 1); size_t size = in.shape(axis_); compute_encoder.set_bytes(size, 2); diff --git a/mlx/backend/metal/softmax.cpp b/mlx/backend/metal/softmax.cpp index 732c5d98c..854f31f4b 100644 --- a/mlx/backend/metal/softmax.cpp +++ b/mlx/backend/metal/softmax.cpp @@ -22,31 +22,32 @@ void Softmax::eval_gpu(const std::vector& inputs, array& out) { auto& d = metal::device(s.device); // Make sure that the last dimension is contiguous - std::vector copies; - auto check_input = [&copies, &s](const array& x) -> const array& { + auto set_output = [&s, &out](const array& x) { bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1; if (no_copy && x.ndim() > 1) { auto s = x.strides()[x.ndim() - 2]; no_copy &= (s == 0 || s == x.shape().back()); } if (no_copy) { + if (x.is_donatable()) { + out.copy_shared_buffer(x); + } else { + out.set_data( + allocator::malloc_or_wait(x.data_size() * x.itemsize()), + x.data_size(), + x.strides(), + x.flags()); + } return x; } else { - copies.push_back(array(x.shape(), x.dtype(), nullptr, {})); - copy_gpu(x, copies.back(), CopyType::General, s); - return copies.back(); + auto x_copy = array(x.shape(), x.dtype(), nullptr, {}); + copy_gpu(x, x_copy, CopyType::General, s); + out.copy_shared_buffer(x_copy); + return x_copy; } }; - const array& in = check_input(inputs[0]); - if (in.is_donatable()) { - out.move_shared_buffer(in); - } else { - out.set_data( - allocator::malloc_or_wait(in.data_size() * in.itemsize()), - in.data_size(), - in.strides(), - in.flags()); - } + + const array in = set_output(inputs[0]); int axis_size = in.shape().back(); int n_rows = in.data_size() / axis_size; @@ -82,14 +83,11 @@ void Softmax::eval_gpu(const std::vector& inputs, array& out) { } compute_encoder.set_compute_pipeline_state(kernel); - compute_encoder.set_input_array( - in.data_shared_ptr() == nullptr ? out : in, 0); + compute_encoder.set_input_array(in, 0); compute_encoder.set_output_array(out, 1); compute_encoder.set_bytes(axis_size, 2); compute_encoder.dispatch_threads(grid_dims, group_dims); } - - d.add_temporaries(std::move(copies), s.index); } } // namespace mlx::core diff --git a/mlx/backend/metal/ternary.cpp b/mlx/backend/metal/ternary.cpp index 6c55ca02d..36bfd3e2b 100644 --- a/mlx/backend/metal/ternary.cpp +++ b/mlx/backend/metal/ternary.cpp @@ -71,12 +71,9 @@ void ternary_op_gpu_inplace( auto& compute_encoder = d.get_command_encoder(s.index); compute_encoder.set_compute_pipeline_state(kernel); - bool donate_a = a.data_shared_ptr() == nullptr; - bool donate_b = b.data_shared_ptr() == nullptr; - bool donate_c = c.data_shared_ptr() == nullptr; - compute_encoder.set_input_array(donate_a ? out : a, 0); - compute_encoder.set_input_array(donate_b ? out : b, 1); - compute_encoder.set_input_array(donate_c ? out : c, 2); + compute_encoder.set_input_array(a, 0); + compute_encoder.set_input_array(b, 1); + compute_encoder.set_input_array(c, 2); compute_encoder.set_output_array(out, 3); auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); @@ -129,7 +126,7 @@ void ternary_op_gpu( auto& b = inputs[1]; auto& c = inputs[2]; TernaryOpType topt = get_ternary_op_type(a, b, c); - set_ternary_op_output_data(a, b, c, out, topt, true /* donate_with_move */); + set_ternary_op_output_data(a, b, c, out, topt); ternary_op_gpu_inplace(inputs, out, op, s); } diff --git a/mlx/backend/metal/unary.cpp b/mlx/backend/metal/unary.cpp index 2652ecbfa..741c7d70a 100644 --- a/mlx/backend/metal/unary.cpp +++ b/mlx/backend/metal/unary.cpp @@ -57,8 +57,7 @@ void unary_op_gpu_inplace( auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); auto& compute_encoder = d.get_command_encoder(s.index); compute_encoder.set_compute_pipeline_state(kernel); - compute_encoder.set_input_array( - in.data_shared_ptr() == nullptr ? out : in, 0); + compute_encoder.set_input_array(in, 0); compute_encoder.set_output_array(out, 1); if (!contig) { // Launch up to 3D grid of threads @@ -95,7 +94,7 @@ void unary_op_gpu( bool contig = in.flags().contiguous; if (contig) { if (in.is_donatable() && in.itemsize() == out.itemsize()) { - out.move_shared_buffer(in); + out.copy_shared_buffer(in); } else { out.set_data( allocator::malloc_or_wait(in.data_size() * out.itemsize()), @@ -169,7 +168,7 @@ void Round::eval_gpu(const std::vector& inputs, array& out) { unary_op_gpu(inputs, out, get_primitive_string(this)); } else { // No-op integer types - move_or_copy(in, out); + out.copy_shared_buffer(in); } } diff --git a/mlx/backend/metal/utils.h b/mlx/backend/metal/utils.h index 09415e999..cc56bab32 100644 --- a/mlx/backend/metal/utils.h +++ b/mlx/backend/metal/utils.h @@ -69,19 +69,4 @@ void concatenate(std::string& acc, T first, Args... args) { concatenate(acc, args...); } -/** - * Get a new array that refers to the same data but has a non-owning pointer to - * them. - */ -inline array unsafe_weak_copy(const array& x) { - return array( - x.buffer(), - x.shape(), - x.dtype(), - x.strides(), - x.data_size(), - x.flags(), - [](auto b) {}); -} - } // namespace mlx::core diff --git a/mlx/backend/no_cpu/CMakeLists.txt b/mlx/backend/no_cpu/CMakeLists.txt index ac17d8059..e1524ec63 100644 --- a/mlx/backend/no_cpu/CMakeLists.txt +++ b/mlx/backend/no_cpu/CMakeLists.txt @@ -1,2 +1,6 @@ -target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp) +target_sources( + mlx + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/../cpu/eval.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/../cpu/encoder.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp) diff --git a/mlx/backend/no_cpu/primitives.cpp b/mlx/backend/no_cpu/primitives.cpp index 3733af28d..2f1ae566f 100644 --- a/mlx/backend/no_cpu/primitives.cpp +++ b/mlx/backend/no_cpu/primitives.cpp @@ -1,6 +1,7 @@ // Copyright © 2024 Apple Inc. #include "mlx/primitives.h" +#include "mlx/distributed/primitives.h" #include "mlx/fast_primitives.h" #define NO_CPU_MULTI(func) \ @@ -75,7 +76,6 @@ NO_CPU(Hadamard) NO_CPU(Imag) NO_CPU(Less) NO_CPU(LessEqual) -NO_CPU(Load) NO_CPU(Log) NO_CPU(Log1p) NO_CPU(LogicalNot) @@ -129,4 +129,11 @@ namespace fast { NO_CPU_MULTI(AffineQuantize) } // namespace fast +namespace distributed { +NO_CPU_MULTI(AllReduce) +NO_CPU_MULTI(AllGather) +NO_CPU_MULTI(Send) +NO_CPU_MULTI(Recv) +} // namespace distributed + } // namespace mlx::core diff --git a/mlx/backend/no_metal/CMakeLists.txt b/mlx/backend/no_metal/CMakeLists.txt index f31619e69..962ceecb7 100644 --- a/mlx/backend/no_metal/CMakeLists.txt +++ b/mlx/backend/no_metal/CMakeLists.txt @@ -2,5 +2,6 @@ target_sources( mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp ${CMAKE_CURRENT_SOURCE_DIR}/event.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp ${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp) diff --git a/mlx/backend/no_metal/event.cpp b/mlx/backend/no_metal/event.cpp index a41b4bb6e..692be20d2 100644 --- a/mlx/backend/no_metal/event.cpp +++ b/mlx/backend/no_metal/event.cpp @@ -1,6 +1,7 @@ // Copyright © 2024 Apple Inc. #include "mlx/event.h" +#include "mlx/scheduler.h" #include #include @@ -13,13 +14,13 @@ struct EventCounter { std::condition_variable cv; }; -Event::Event(const Stream& stream) : stream_(stream) { +Event::Event(Stream stream) : stream_(stream) { auto dtor = [](void* ptr) { delete static_cast(ptr); }; event_ = std::shared_ptr(new EventCounter{}, dtor); } void Event::wait() { - auto ec = static_cast(raw_event().get()); + auto ec = static_cast(event_.get()); std::unique_lock lk(ec->mtx); if (ec->value >= value()) { return; @@ -28,7 +29,7 @@ void Event::wait() { } void Event::signal() { - auto ec = static_cast(raw_event().get()); + auto ec = static_cast(event_.get()); { std::lock_guard lk(ec->mtx); ec->value = value(); @@ -36,11 +37,19 @@ void Event::signal() { ec->cv.notify_all(); } +void Event::wait(Stream stream) { + scheduler::enqueue(stream, [*this]() mutable { wait(); }); +} + +void Event::signal(Stream stream) { + scheduler::enqueue(stream, [*this]() mutable { signal(); }); +} + bool Event::is_signaled() const { - auto ec = static_cast(raw_event().get()); + auto ec = static_cast(event_.get()); { std::lock_guard lk(ec->mtx); - return (ec->value > value()); + return (ec->value >= value()); } } } // namespace mlx::core diff --git a/mlx/backend/no_metal/fence.cpp b/mlx/backend/no_metal/fence.cpp new file mode 100644 index 000000000..af22108a8 --- /dev/null +++ b/mlx/backend/no_metal/fence.cpp @@ -0,0 +1,54 @@ +// Copyright © 2024 Apple Inc. + +#include +#include + +#include "mlx/fence.h" +#include "mlx/scheduler.h" + +namespace mlx::core { + +struct FenceImpl { + uint32_t count{0}; + uint32_t value{0}; + std::mutex mtx; + std::condition_variable cv; +}; + +Fence::Fence(Stream) { + auto dtor = [](void* ptr) { delete static_cast(ptr); }; + fence_ = std::shared_ptr(new FenceImpl{}, dtor); +} + +void Fence::wait(Stream stream, const array&) { + auto& f = *static_cast(fence_.get()); + if (stream.device == Device::cpu) { + scheduler::enqueue(stream, [count = f.count, fence_ = fence_]() mutable { + auto& f = *static_cast(fence_.get()); + std::unique_lock lk(f.mtx); + if (f.value >= count) { + return; + } + f.cv.wait(lk, [&f, count] { return f.value >= count; }); + }); + } else { + throw std::runtime_error("[Fence::wait] Invalid stream."); + } +} + +void Fence::update(Stream stream, const array&) { + auto& f = *static_cast(fence_.get()); + f.count++; + if (stream.device == Device::cpu) { + scheduler::enqueue(stream, [count = f.count, fence_ = fence_]() mutable { + auto& f = *static_cast(fence_.get()); + std::unique_lock lk(f.mtx); + f.value = count; + f.cv.notify_all(); + }); + } else { + throw std::runtime_error("[Fence::update] Invalid stream."); + } +} + +} // namespace mlx::core diff --git a/mlx/backend/no_metal/metal.cpp b/mlx/backend/no_metal/metal.cpp index 9ae9800a2..359e63b5d 100644 --- a/mlx/backend/no_metal/metal.cpp +++ b/mlx/backend/no_metal/metal.cpp @@ -16,17 +16,19 @@ std::unique_ptr> new_scoped_memory_pool() { return nullptr; } -std::function make_task(array, bool) { +void eval(array&) { throw std::runtime_error( - "[metal::make_task] Cannot make GPU task without metal backend"); + "[metal::eval] Cannot eval on GPU without metal backend"); } -std::function make_synchronize_task( - Stream, - std::shared_ptr>) { +void finalize(Stream) { throw std::runtime_error( - "[metal::make_synchronize_task] Cannot synchronize GPU" - " without metal backend"); + "[metal::finalize] Cannot finalize GPU without metal backend"); +} + +void synchronize(Stream) { + throw std::runtime_error( + "[metal::synchronize] Cannot synchronize GPU without metal backend"); } // No-ops when Metal is not available. diff --git a/mlx/distributed/distributed.cpp b/mlx/distributed/distributed.cpp index cb6edd1e8..33f1d1320 100644 --- a/mlx/distributed/distributed.cpp +++ b/mlx/distributed/distributed.cpp @@ -6,31 +6,25 @@ #include "mlx/distributed/distributed_impl.h" #include "mlx/distributed/mpi/mpi.h" #include "mlx/distributed/ring/ring.h" -#include "mlx/scheduler.h" namespace mlx::core::distributed { namespace detail { -Stream communication_stream() { - static Stream comm_stream = new_stream(Device::cpu); - return comm_stream; +void all_sum(Group group, const array& input, array& output, Stream stream) { + group.raw_group()->all_sum(input, output, stream); } -void all_sum(Group group, const array& input, array& output) { - group.raw_group()->all_sum(input, output); +void all_gather(Group group, const array& input, array& output, Stream stream) { + group.raw_group()->all_gather(input, output, stream); } -void all_gather(Group group, const array& input, array& output) { - group.raw_group()->all_gather(input, output); +void send(Group group, const array& input, int dst, Stream stream) { + group.raw_group()->send(input, dst, stream); } -void send(Group group, const array& input, int dst) { - group.raw_group()->send(input, dst); -} - -void recv(Group group, array& out, int src) { - group.raw_group()->recv(out, src); +void recv(Group group, array& out, int src, Stream stream) { + group.raw_group()->recv(out, src, stream); } class EmptyGroup : public GroupImpl { @@ -47,19 +41,19 @@ class EmptyGroup : public GroupImpl { throw std::runtime_error("Cannot split the distributed group further."); } - void all_sum(const array& input, array& output) override { + void all_sum(const array&, array&, Stream) override { throw std::runtime_error( "Communication not implemented in an empty distributed group."); } - void all_gather(const array& input, array& output) override { + void all_gather(const array&, array&, Stream) override { throw std::runtime_error( "Communication not implemented in an empty distributed group."); } - void send(const array& input, int dst) override { + void send(const array&, int, Stream) override { throw std::runtime_error( "Communication not implemented in an empty distributed group."); } - void recv(array& out, int src) override { + void recv(array&, int, Stream) override { throw std::runtime_error( "Communication not implemented in an empty distributed group."); } @@ -122,10 +116,6 @@ Group init(bool strict /* = false */, const std::string& bk /* = "any" */) { backends.insert({"any", group}); } backends.insert({std::move(bk_), group}); - - // Ensure the communication stream is alive before - // the graph is evaluated - detail::communication_stream(); return Group(group); } diff --git a/mlx/distributed/distributed_impl.h b/mlx/distributed/distributed_impl.h index bf5fc13b5..7c06068b2 100644 --- a/mlx/distributed/distributed_impl.h +++ b/mlx/distributed/distributed_impl.h @@ -17,25 +17,22 @@ class GroupImpl { virtual int size() = 0; virtual std::shared_ptr split(int color, int key = -1) = 0; - virtual void all_sum(const array& input, array& output) = 0; - virtual void all_gather(const array& input, array& output) = 0; - virtual void send(const array& input, int dst) = 0; - virtual void recv(array& out, int src) = 0; + virtual void all_sum(const array& input, array& output, Stream stream) = 0; + virtual void all_gather(const array& input, array& output, Stream stream) = 0; + virtual void send(const array& input, int dst, Stream stream) = 0; + virtual void recv(array& out, int src, Stream stream) = 0; }; -/* Return the communication stream. */ -Stream communication_stream(); - /* Perform an all reduce sum operation */ -void all_sum(Group group, const array& input, array& output); +void all_sum(Group group, const array& input, array& output, Stream stream); /* Perform an all gather operation */ -void all_gather(Group group, const array& input, array& output); +void all_gather(Group group, const array& input, array& output, Stream stream); /** Send an array to the dst rank */ -void send(Group group, const array& input, int dst); +void send(Group group, const array& input, int dst, Stream stream); /** Recv an array from the src rank */ -void recv(Group group, array& out, int src); +void recv(Group group, array& out, int src, Stream stream); } // namespace mlx::core::distributed::detail diff --git a/mlx/distributed/mpi/mpi.cpp b/mlx/distributed/mpi/mpi.cpp index e532cd771..b9136f701 100644 --- a/mlx/distributed/mpi/mpi.cpp +++ b/mlx/distributed/mpi/mpi.cpp @@ -3,11 +3,10 @@ #include #include -#include "mlx/backend/cpu/copy.h" +#include "mlx/backend/cpu/encoder.h" #include "mlx/distributed/distributed.h" #include "mlx/distributed/distributed_impl.h" #include "mlx/distributed/mpi/mpi.h" -#include "mlx/scheduler.h" #define LOAD_SYMBOL(symbol, variable) \ { \ @@ -25,16 +24,6 @@ using GroupImpl = mlx::core::distributed::detail::GroupImpl; namespace { -array ensure_row_contiguous(const array& arr) { - if (arr.flags().row_contiguous) { - return arr; - } else { - array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); - copy(arr, arr_copy, CopyType::General); - return arr_copy; - } -} - template void simple_sum( void* input, @@ -281,9 +270,12 @@ class MPIGroup : public GroupImpl { return std::make_shared(new_comm, false); } - void all_sum(const array& input_, array& output) override { - array input = ensure_row_contiguous(input_); - mpi().all_reduce( + void all_sum(const array& input, array& output, Stream stream) override { + auto& encoder = cpu::get_command_encoder(stream); + encoder.set_input_array(input); + encoder.set_output_array(output); + encoder.dispatch( + mpi().all_reduce, (input.data() == output.data()) ? MPI_IN_PLACE : input.data(), output.data(), @@ -293,9 +285,12 @@ class MPIGroup : public GroupImpl { comm_); } - void all_gather(const array& input_, array& output) override { - array input = ensure_row_contiguous(input_); - mpi().all_gather( + void all_gather(const array& input, array& output, Stream stream) override { + auto& encoder = cpu::get_command_encoder(stream); + encoder.set_input_array(input); + encoder.set_output_array(output); + encoder.dispatch( + mpi().all_gather, input.data(), input.size(), mpi().datatype(input), @@ -305,22 +300,30 @@ class MPIGroup : public GroupImpl { comm_); } - void send(const array& input_, int dst) override { - array input = ensure_row_contiguous(input_); - mpi().send( - input.data(), input.size(), mpi().datatype(input), dst, 0, comm_); + void send(const array& input, int dst, Stream stream) override { + auto& encoder = cpu::get_command_encoder(stream); + encoder.set_input_array(input); + encoder.dispatch( + mpi().send, + input.data(), + input.size(), + mpi().datatype(input), + dst, + 0, + comm_); } - void recv(array& out, int src) override { - MPI_Status status; - mpi().recv( - out.data(), - out.size(), - mpi().datatype(out), - src, - MPI_ANY_TAG, - comm_, - &status); + void recv(array& out, int src, Stream stream) override { + auto& encoder = cpu::get_command_encoder(stream); + encoder.set_output_array(out); + encoder.dispatch([out_ptr = out.data(), + out_size = out.size(), + out_type = mpi().datatype(out), + src, + comm = comm_]() { + MPI_Status status; + mpi().recv(out_ptr, out_size, out_type, src, MPI_ANY_TAG, comm, &status); + }); } private: diff --git a/mlx/distributed/ops.cpp b/mlx/distributed/ops.cpp index 2af552664..865911ac6 100644 --- a/mlx/distributed/ops.cpp +++ b/mlx/distributed/ops.cpp @@ -28,11 +28,11 @@ array all_sum( if (group.size() == 1) { return x; } - return array( x.shape(), x.dtype(), - std::make_shared(to_stream(s), group, AllReduce::Sum), + std::make_shared( + to_stream(s, Device::cpu), group, AllReduce::Sum), {x}); } @@ -55,7 +55,7 @@ array all_gather( return array( std::move(result_shape), x.dtype(), - std::make_shared(to_stream(s), group), + std::make_shared(to_stream(s, Device::cpu), group), {x}); } @@ -80,7 +80,7 @@ array send( return array( x.shape(), x.dtype(), - std::make_shared(to_stream(s), group, dst), + std::make_shared(to_stream(s, Device::cpu), group, dst), {x}); } @@ -105,7 +105,7 @@ array recv( return array( std::move(shape), std::move(dtype), - std::make_shared(to_stream(s), group, src), + std::make_shared(to_stream(s, Device::cpu), group, src), std::vector{}); } diff --git a/mlx/distributed/primitives.cpp b/mlx/distributed/primitives.cpp index 2a151cc0c..de28788b3 100644 --- a/mlx/distributed/primitives.cpp +++ b/mlx/distributed/primitives.cpp @@ -3,34 +3,12 @@ #include #include "mlx/allocator.h" -#include "mlx/backend/common/utils.h" #include "mlx/distributed/ops.h" #include "mlx/distributed/primitives.h" #include "mlx/ops.h" namespace mlx::core::distributed { -void AllReduce::eval_cpu( - const std::vector& inputs, - std::vector& outputs) { - assert(inputs.size() == 1); - assert(outputs.size() == 1); - - if (inputs[0].is_donatable()) { - outputs[0].copy_shared_buffer(inputs[0]); - } else { - outputs[0].set_data(allocator::malloc_or_wait(outputs[0].nbytes())); - } - - switch (reduce_type_) { - case Sum: - distributed::detail::all_sum(group(), inputs[0], outputs[0]); - break; - default: - throw std::runtime_error("Only all reduce sum is supported for now"); - } -} - std::pair, std::vector> AllReduce::vmap( const std::vector& inputs, const std::vector& axes) { @@ -62,17 +40,6 @@ std::vector AllReduce::vjp( return cotangents; } -void AllGather::eval_cpu( - const std::vector& inputs, - std::vector& outputs) { - assert(inputs.size() == 1); - assert(outputs.size() == 1); - - outputs[0].set_data(allocator::malloc_or_wait(outputs[0].nbytes())); - - distributed::detail::all_gather(group(), inputs[0], outputs[0]); -} - std::pair, std::vector> AllGather::vmap( const std::vector& inputs, const std::vector& axes) { @@ -99,30 +66,10 @@ std::vector AllGather::vjp( return {slice(cotangents[0], starts, stops)}; } -void Send::eval_cpu( - const std::vector& inputs, - std::vector& outputs) { - assert(inputs.size() == 1); - assert(outputs.size() == 1); - - distributed::detail::send(group(), inputs[0], dst_); - move_or_copy(inputs[0], outputs[0]); -} - std::pair, std::vector> Send::vmap( const std::vector& inputs, const std::vector& axes) { return {{send(inputs[0], dst_, group(), stream())}, axes}; } -void Recv::eval_cpu( - const std::vector& inputs, - std::vector& outputs) { - assert(inputs.size() == 0); - assert(outputs.size() == 1); - - outputs[0].set_data(allocator::malloc_or_wait(outputs[0].nbytes())); - distributed::detail::recv(group(), outputs[0], src_); -} - } // namespace mlx::core::distributed diff --git a/mlx/distributed/ring/ring.cpp b/mlx/distributed/ring/ring.cpp index 5bf08200e..d8fe53051 100644 --- a/mlx/distributed/ring/ring.cpp +++ b/mlx/distributed/ring/ring.cpp @@ -18,7 +18,7 @@ #include -#include "mlx/backend/cpu/copy.h" +#include "mlx/backend/cpu/encoder.h" #include "mlx/distributed/distributed.h" #include "mlx/distributed/distributed_impl.h" #include "mlx/threadpool.h" @@ -140,8 +140,8 @@ class SocketThread { } template - std::future send(T* buffer, size_t size) { - return send_impl(reinterpret_cast(buffer), size * sizeof(T)); + std::future send(const T* buffer, size_t size) { + return send_impl(reinterpret_cast(buffer), size * sizeof(T)); } template @@ -160,7 +160,7 @@ class SocketThread { std::promise promise; }; - std::future send_impl(char* buffer, size_t size) { + std::future send_impl(const char* buffer, size_t size) { std::promise send_completed_promise; auto send_completed_future = send_completed_promise.get_future(); if (size == 0) { @@ -170,8 +170,8 @@ class SocketThread { { std::unique_lock lock(queue_mutex_); - sends_.emplace_back( - SocketTask(buffer, size, std::move(send_completed_promise))); + sends_.emplace_back(SocketTask( + const_cast(buffer), size, std::move(send_completed_promise))); } condition_.notify_one(); return send_completed_future; @@ -503,16 +503,6 @@ std::vector make_connections( return sockets; } -array ensure_row_contiguous(const array& arr) { - if (arr.flags().row_contiguous) { - return arr; - } else { - array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); - copy(arr, arr_copy, CopyType::General); - return arr_copy; - } -} - template void sum_inplace(const T* input, T* output, size_t N) { while (N-- > 0) { @@ -613,117 +603,131 @@ class RingGroup : public GroupImpl { return size_; } - void all_sum(const array& input_, array& output) override { - SWITCH_TYPE(output, all_sum(input_, output)); + void all_sum(const array& input_, array& output, Stream stream) override { + SWITCH_TYPE(output, all_sum(input_, output, stream)); } std::shared_ptr split(int color, int key = -1) override { throw std::runtime_error("[ring] Group split not supported."); } - void all_gather(const array& input, array& output) override { + void all_gather(const array& input, array& output, Stream stream) override { throw std::runtime_error("[ring] All gather not supported."); } - void send(const array& input_, int dst) override { - // Make sure that the input is row contiguous - array input = ensure_row_contiguous(input_); - - int right = (rank_ + 1) % size_; - int left = (rank_ + size_ - 1) % size_; - if (dst == right) { - send(sockets_right_, input.data(), input.nbytes()); - } else if (dst == left) { - send(sockets_left_, input.data(), input.nbytes()); - } else { - std::ostringstream msg; - msg << "[ring] Send only supported to direct neighbors " - << "but tried to send to " << dst << " from " << rank_ << std::endl; - throw std::runtime_error(msg.str()); - } + void send(const array& input, int dst, Stream stream) override { + auto& encoder = cpu::get_command_encoder(stream); + encoder.set_input_array(input); + encoder.dispatch( + [input_ptr = input.data(), nbytes = input.nbytes(), dst, this]() { + int right = (rank_ + 1) % size_; + int left = (rank_ + size_ - 1) % size_; + if (dst == right) { + send(sockets_right_, input_ptr, nbytes); + } else if (dst == left) { + send(sockets_left_, input_ptr, nbytes); + } else { + std::ostringstream msg; + msg << "[ring] Send only supported to direct neighbors " + << "but tried to send to " << dst << " from " << rank_ + << std::endl; + throw std::runtime_error(msg.str()); + } + }); } - void recv(array& out, int src) override { - // NOTE: We 'll check the sockets with the opposite order of send so that - // they work even with 2 nodes where left and right is the same - // neighbor. - int right = (rank_ + 1) % size_; - int left = (rank_ + size_ - 1) % size_; - if (src == left) { - recv(sockets_left_, out.data(), out.nbytes()); - } else if (src == right) { - recv(sockets_right_, out.data(), out.nbytes()); - } else { - std::ostringstream msg; - msg << "[ring] Recv only supported from direct neighbors " - << "but tried to recv from " << src << " to " << rank_ << std::endl; - throw std::runtime_error(msg.str()); - } + void recv(array& out, int src, Stream stream) override { + auto& encoder = cpu::get_command_encoder(stream); + encoder.set_output_array(out); + encoder.dispatch( + [out_ptr = out.data(), nbytes = out.nbytes(), src, this]() { + // NOTE: We 'll check the sockets with the opposite order of send so + // that + // they work even with 2 nodes where left and right is the same + // neighbor. + int right = (rank_ + 1) % size_; + int left = (rank_ + size_ - 1) % size_; + if (src == left) { + recv(sockets_left_, out_ptr, nbytes); + } else if (src == right) { + recv(sockets_right_, out_ptr, nbytes); + } else { + std::ostringstream msg; + msg << "[ring] Recv only supported from direct neighbors " + << "but tried to recv from " << src << " to " << rank_ + << std::endl; + throw std::runtime_error(msg.str()); + } + }); } private: template - void all_sum(const array& input_, array& output) { - // Make sure that the input is row contiguous - array input = ensure_row_contiguous(input_); + void all_sum(const array& input, array& output, Stream stream) { + auto in_ptr = input.data(); + auto out_ptr = output.data(); + auto& encoder = cpu::get_command_encoder(stream); + encoder.set_output_array(output); + encoder.dispatch([in_ptr, out_ptr, size = input.size(), this]() { + // If the input data cannot be split into size_ segments then copy it and + // all reduce a local buffer prefilled with 0s. + size_t nbytes = size * sizeof(T); + if (size < size_) { + // TODO: Maybe allocate dynamically so we don't have the constraint + // below? + if (sizeof(T) * size_ > 1024) { + std::ostringstream msg; + msg << "Can't perform the ring all reduce of " << size + << " elements with a ring of size " << size_; + throw std::runtime_error(msg.str()); + } - // If the input data cannot be split into size_ segments then copy it and - // all reduce a local buffer prefilled with 0s. - if (input.size() < size_) { - // TODO: Maybe allocate dynamically so we don't have the constraint - // below? - if (input.itemsize() * size_ > 1024) { - std::ostringstream msg; - msg << "Can't perform the ring all reduce of " << output.size() - << " elements with a ring of size " << size_; - throw std::runtime_error(msg.str()); + char buffer[1024]; + std::memset(buffer, 0, size_ * sizeof(T)); + std::memcpy(buffer, in_ptr, nbytes); + all_sum_impl( + reinterpret_cast(buffers_.data()), + reinterpret_cast(buffer), + size_, + sockets_right_[0], + sockets_left_[0], + -1); + std::memcpy(out_ptr, buffer, nbytes); + return; } - char buffer[1024]; - std::memset(buffer, 0, size_ * input.itemsize()); - std::memcpy(buffer, input.data(), input.nbytes()); - all_sum_impl( - reinterpret_cast(buffers_.data()), - reinterpret_cast(buffer), - size_, - sockets_right_[0], - sockets_left_[0], - -1); - std::memcpy(output.data(), buffer, output.nbytes()); - return; - } + // If not inplace all reduce then copy the input to the output first + if (in_ptr != out_ptr) { + std::memcpy(out_ptr, in_ptr, nbytes); + } - // If not inplace all reduce then copy the input to the output first - if (input.data() != output.data()) { - std::memcpy(output.data(), input.data(), input.nbytes()); - } + // Split the all reduces so that each member has at least 1 buffer to + // send/recv per segment. + constexpr size_t min_send_size = 262144; + size_t n_reduces = std::max( + std::min( + sockets_right_.size() + sockets_left_.size(), + nbytes / (size_ * min_send_size)), + 1UL); + size_t step = ceildiv(size, n_reduces); + std::vector> all_sums; - // Split the all reduces so that each member has at least 1 buffer to - // send/recv per segment. - constexpr size_t min_send_size = 262144; - size_t n_reduces = std::max( - std::min( - sockets_right_.size() + sockets_left_.size(), - output.nbytes() / (size_ * min_send_size)), - 1UL); - size_t step = ceildiv(output.size(), n_reduces); - std::vector> all_sums; - - for (int i = 0; i < n_reduces; i++) { - all_sums.emplace_back(pool_.enqueue(std::bind( - &RingGroup::all_sum_impl, - this, - reinterpret_cast( - buffers_.data() + i * ALL_SUM_SIZE * ALL_SUM_BUFFERS), - output.data() + i * step, - std::min(output.size(), (i + 1) * step) - i * step, - sockets_right_[i / 2], - sockets_left_[i / 2], - (i % 2) ? -1 : 1))); - } - for (auto& f : all_sums) { - f.wait(); - } + for (int i = 0; i < n_reduces; i++) { + all_sums.emplace_back(pool_.enqueue(std::bind( + &RingGroup::all_sum_impl, + this, + reinterpret_cast( + buffers_.data() + i * ALL_SUM_SIZE * ALL_SUM_BUFFERS), + reinterpret_cast(out_ptr) + i * step, + std::min(size, (i + 1) * step) - i * step, + sockets_right_[i / 2], + sockets_left_[i / 2], + (i % 2) ? -1 : 1))); + } + for (auto& f : all_sums) { + f.wait(); + } + }); } template @@ -823,7 +827,8 @@ class RingGroup : public GroupImpl { recvs[b].wait(); } - void send(const std::vector& sockets, char* data, size_t data_size) { + void + send(const std::vector& sockets, const char* data, size_t data_size) { size_t segment_size = std::max(1024UL, ceildiv(data_size, sockets.size())); std::vector> sends; for (int i = 0; i < sockets.size(); i++) { diff --git a/mlx/event.h b/mlx/event.h index e492370a2..7054b60c7 100644 --- a/mlx/event.h +++ b/mlx/event.h @@ -10,9 +10,8 @@ namespace mlx::core { class Event { public: - Event() = default; - - Event(const Stream& steam); + Event() {}; + explicit Event(Stream stream); // Wait for the event to be signaled at its current value void wait(); @@ -20,6 +19,12 @@ class Event { // Signal the event at its current value void signal(); + // Wait in the given stream for the event to be signaled at its current value + void wait(Stream stream); + + // Signal the event at its current value in the given stream + void signal(Stream stream); + // Check if the event has been signaled at its current value bool is_signaled() const; @@ -44,15 +49,11 @@ class Event { return stream_; } - const std::shared_ptr& raw_event() const { - return event_; - } - private: // Default constructed stream should never be used // since the event is not yet valid Stream stream_{0, Device::cpu}; - std::shared_ptr event_; + std::shared_ptr event_{nullptr}; uint64_t value_{0}; }; diff --git a/mlx/export.cpp b/mlx/export.cpp index 361d73f66..4eb3ff99a 100644 --- a/mlx/export.cpp +++ b/mlx/export.cpp @@ -832,7 +832,7 @@ ImportedFunction::ImportedFunction(const std::string& file) std::move(shape), type, std::make_shared( - default_stream(default_device()), is_ptr, offset), + default_stream(Device::cpu), is_ptr, offset), {})); is.seek(offset + tape.back().nbytes()); constants.insert({id, tape.back()}); diff --git a/mlx/fence.h b/mlx/fence.h new file mode 100644 index 000000000..66c2b5f49 --- /dev/null +++ b/mlx/fence.h @@ -0,0 +1,39 @@ +// Copyright © 2024 Apple Inc. + +#include + +#include "mlx/array.h" + +namespace mlx::core { + +/* A fence to be used for synchronizing work between streams. + * + * Calls to `wait` wait in the given stream until all previous calls to update + * are complete on their given stream. + * + * The array passed to `update` is computed and visible after the call to + * `wait` returns. The array passed to `wait` will not be read until all + * previous calls to `update` have completed. + * + * Note, calls to `update` should always from the same thread or explicitly + * synchronized so that they occur in sequence. Calls to `wait` can be on any + * thread. + * + * For the Metal back-end the fence supports slow (default) and fast mode. + * Fast mode requires setting the environment variable + * `MLX_METAL_FAST_SYNCH=1`. Fast mode also requires Metal 3.2+ (macOS 15+, + * iOS 18+). + */ +class Fence { + public: + Fence() {}; + explicit Fence(Stream stream); + + void update(Stream stream, const array& x); + void wait(Stream stream, const array& x); + + private: + std::shared_ptr fence_{nullptr}; +}; + +} // namespace mlx::core diff --git a/mlx/io/CMakeLists.txt b/mlx/io/CMakeLists.txt index 2402ff588..f9ab1c44e 100644 --- a/mlx/io/CMakeLists.txt +++ b/mlx/io/CMakeLists.txt @@ -1,13 +1,6 @@ target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/load.cpp) if(MLX_BUILD_SAFETENSORS) - message(STATUS "Downloading json") - FetchContent_Declare( - json - URL https://github.com/nlohmann/json/releases/download/v3.11.3/json.tar.xz) - FetchContent_MakeAvailable(json) - target_include_directories( - mlx PRIVATE $) target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/safetensors.cpp) else() target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_safetensors.cpp) diff --git a/mlx/io/load.cpp b/mlx/io/load.cpp index 51c5805d9..59d91c007 100644 --- a/mlx/io/load.cpp +++ b/mlx/io/load.cpp @@ -226,6 +226,11 @@ array load(std::shared_ptr in_stream, StreamOrDevice s) { throw std::runtime_error("[load] Failed to open " + in_stream->label()); } + auto stream = to_stream(s, Device::cpu); + if (stream.device != Device::cpu) { + throw std::runtime_error("[load] Must run on a CPU stream."); + } + //////////////////////////////////////////////////////// // Read header and prepare array details @@ -309,7 +314,7 @@ array load(std::shared_ptr in_stream, StreamOrDevice s) { auto loaded_array = array( shape, dtype, - std::make_shared(to_stream(s), in_stream, offset, swap_endianness), + std::make_shared(stream, in_stream, offset, swap_endianness), std::vector{}); if (col_contiguous) { loaded_array = transpose(loaded_array, s); diff --git a/mlx/io/safetensors.cpp b/mlx/io/safetensors.cpp index f7c910372..7ee974186 100644 --- a/mlx/io/safetensors.cpp +++ b/mlx/io/safetensors.cpp @@ -200,6 +200,11 @@ SafetensorsLoad load_safetensors( "[load_safetensors] Failed to open " + in_stream->label()); } + auto stream = to_stream(s, Device::cpu); + if (stream.device != Device::cpu) { + throw std::runtime_error("[load_safetensors] Must run on a CPU stream."); + } + uint64_t jsonHeaderLength = 0; // This is the same limit as in the original Rust Safetensors code. constexpr uint64_t kMaxJsonHeaderLength = 100000000; @@ -236,7 +241,7 @@ SafetensorsLoad load_safetensors( shape, type, std::make_shared( - to_stream(s), in_stream, offset + data_offsets.at(0), false), + stream, in_stream, offset + data_offsets.at(0), false), std::vector{}); if (dtype == ST_F8_E4M3) { loaded_array = f8_e4m3_to_float(loaded_array, bfloat16, s); diff --git a/mlx/primitives.h b/mlx/primitives.h index c2c0576aa..bb0ca8080 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -1231,11 +1231,7 @@ class Load : public UnaryPrimitive { : UnaryPrimitive(stream), reader_(std::move(reader)), offset_(offset), - swap_endianness_(swap_endianness) { - if (stream.device == Device::gpu) { - io_stream(); - } - } + swap_endianness_(swap_endianness) {} void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; @@ -1243,10 +1239,6 @@ class Load : public UnaryPrimitive { DEFINE_PRINT(Load) private: - Stream& io_stream() { - static Stream io_stream = new_stream(Device::cpu); - return io_stream; - }; std::shared_ptr reader_; size_t offset_; bool swap_endianness_; diff --git a/mlx/scheduler.cpp b/mlx/scheduler.cpp index 29cd3db71..e3612f351 100644 --- a/mlx/scheduler.cpp +++ b/mlx/scheduler.cpp @@ -38,14 +38,14 @@ Stream new_stream() { } void synchronize(Stream s) { - auto p = std::make_shared>(); - std::future f = p->get_future(); if (s.device == mlx::core::Device::cpu) { + auto p = std::make_shared>(); + std::future f = p->get_future(); scheduler::enqueue(s, [p = std::move(p)]() { p->set_value(); }); + f.wait(); } else { - scheduler::enqueue(s, metal::make_synchronize_task(s, std::move(p))); + metal::synchronize(s); } - f.wait(); } void synchronize() { diff --git a/mlx/scheduler.h b/mlx/scheduler.h index dcf874783..bf34b38c0 100644 --- a/mlx/scheduler.h +++ b/mlx/scheduler.h @@ -20,16 +20,11 @@ struct StreamThread { std::queue> q; std::condition_variable cond; bool stop; - Stream stream; std::thread thread; - StreamThread(Stream stream) - : stop(false), stream(stream), thread(&StreamThread::thread_fn, this) { - metal::new_stream(stream); - } + StreamThread() : stop(false), thread(&StreamThread::thread_fn, this) {} ~StreamThread() { - synchronize(stream); { std::lock_guard lk(mtx); stop = true; @@ -85,9 +80,14 @@ class Scheduler { Scheduler& operator=(Scheduler&&) = delete; Stream new_stream(const Device& d) { - auto stream = Stream(streams_.size(), d); - streams_.push_back(new StreamThread{stream}); - return stream; + streams_.emplace_back(streams_.size(), d); + if (d == Device::gpu) { + threads_.push_back(nullptr); + metal::new_stream(streams_.back()); + } else { + threads_.push_back(new StreamThread{}); + } + return streams_.back(); } template @@ -97,7 +97,7 @@ class Scheduler { return default_streams_.at(d.type); } Stream get_stream(int index) const { - return streams_.at(index)->stream; + return streams_.at(index); } void set_default_stream(const Stream& s) { @@ -136,13 +136,19 @@ class Scheduler { ~Scheduler() { for (auto s : streams_) { - delete s; + synchronize(s); + } + for (auto t : threads_) { + if (t != nullptr) { + delete t; + } } } private: int n_active_tasks_; - std::vector streams_; + std::vector threads_; + std::vector streams_; std::unordered_map default_streams_; std::condition_variable completion_cv; std::mutex mtx; @@ -150,7 +156,7 @@ class Scheduler { template void Scheduler::enqueue(const Stream& stream, F&& f) { - streams_[stream.index]->enqueue(std::forward(f)); + threads_[stream.index]->enqueue(std::forward(f)); } Scheduler& scheduler(); diff --git a/mlx/transforms.cpp b/mlx/transforms.cpp index d91fee91b..f01082418 100644 --- a/mlx/transforms.cpp +++ b/mlx/transforms.cpp @@ -9,7 +9,9 @@ #include #include +#include "mlx/backend/cpu/eval.h" #include "mlx/backend/metal/metal_impl.h" +#include "mlx/fence.h" #include "mlx/ops.h" #include "mlx/primitives.h" #include "mlx/scheduler.h" @@ -19,6 +21,8 @@ namespace mlx::core { +static constexpr int MAX_ACTIVE_TASKS = 100; + /* This class is only meant to be used in eval * for synchronizing with the main thread. */ class Synchronizer : public Primitive { @@ -43,9 +47,6 @@ int detail::RetainGraph::tracing_counter{0}; array eval_impl(std::vector outputs, bool async) { std::deque tape; - // stream events to use for synchronization - std::unordered_map events; - // Make an effort to choose a good output stream Stream stream = default_stream(default_device()); for (auto& o : outputs) { @@ -55,13 +56,17 @@ array eval_impl(std::vector outputs, bool async) { } } - std::unordered_set needs_signal; + // Map of array id that needs fence and stream it's computed on + std::unordered_map needs_fence; auto synchronizer = array( {}, bool_, std::make_shared(stream), std::move(outputs)); - needs_signal.insert(synchronizer.id()); - // Make an event for the synchronizer stream + // Stream fences for inter-stream synchronization + std::unordered_map fences; + + // Stream events for synchronization after eval + std::unordered_map events; events.emplace(stream.index, Event{stream}); { @@ -78,11 +83,6 @@ array eval_impl(std::vector outputs, bool async) { // Add an input, and continue auto& in = a.inputs()[idx++]; - // Ignore arrays already scheduled - if (in.status() == array::Status::scheduled) { - continue; - } - if (in.status() == array::Status::unscheduled) { if (async && in.is_tracer()) { throw std::invalid_argument( @@ -104,7 +104,7 @@ array eval_impl(std::vector outputs, bool async) { "https://github.com/ml-explore/mlx/issues."); } if (a.primitive().stream() != in.primitive().stream()) { - needs_signal.insert(in.id()); + needs_fence.emplace(in.id(), in.primitive().stream().index); } } @@ -185,62 +185,81 @@ array eval_impl(std::vector outputs, bool async) { auto stream = arr.primitive().stream(); - // Lookup corresponding event and increment counter + // Lookup corresponding event auto e = events.find(stream.index); if (e == events.end()) { e = events.emplace(stream.index, Event{stream}).first; } - e->second.set_value(e->second.value() + 1); + e->second.set_value(1); arr.attach_event(e->second); for (auto& s : arr.siblings()) { s.attach_event(e->second); } - // Set the status of the array and siblings. - arr.set_status(array::Status::scheduled); - for (auto& s : arr.siblings()) { - s.set_status(array::Status::scheduled); + for (auto& in : arr.inputs()) { + if (auto it = needs_fence.find(in.id()); it != needs_fence.end()) { + // Use fence to wait within a single eval + // Get the input array's stream fence and wait on the + // output arrays stream + fences[it->second].wait(stream, in); + } else if (in.event().valid()) { + if (in.event().is_signaled()) { + in.detach_event(); + } else if (in.event().stream() != stream) { + // Use event to wait across async eval + in.event().wait(stream); + } + } } - std::vector> arr_deps; - bool signal = needs_signal.find(arr.id()) != needs_signal.end(); - if (arr.primitive().device() == Device::gpu) { - if (!metal::is_available()) { - throw std::runtime_error("Metal GPU is not available."); - } - scheduler::enqueue(stream, metal::make_task(std::move(arr), signal)); + metal::eval(arr); } else { - auto task = [arr = std::move(arr), stream, signal]() mutable { - for (auto& input : arr.inputs()) { - if (input.event().valid() && - input.event().stream() != arr.primitive().stream()) { - input.event().wait(); - } - } - scheduler::notify_new_task(stream); - auto outputs = arr.outputs(); - try { - arr.primitive().eval_cpu(arr.inputs(), outputs); - } catch (const std::exception& error) { - abort_with_exception(error); - } - if (!arr.is_tracer()) { - arr.detach(); - } - for (auto& out : outputs) { - out.set_status(array::Status::available); - } + cpu::eval(arr); + } - if (signal) { - arr.event().signal(); + if (scheduler::n_active_tasks() > MAX_ACTIVE_TASKS) { + // Commit any open streams + for (auto& [_, e] : events) { + if (e.stream().device == Device::gpu) { + metal::finalize(e.stream()); } + } + scheduler::wait_for_one(); + } - scheduler::notify_task_completion(stream); - }; - scheduler::enqueue(stream, std::move(task)); + auto maybe_update_fence = [&fences, &needs_fence, stream](const array& a) { + if (needs_fence.find(a.id()) != needs_fence.end()) { + auto it = fences.find(stream.index); + if (it == fences.end()) { + it = fences.emplace(stream.index, Fence{stream}).first; + } + it->second.update(stream, a); + } + }; + + arr.set_status(array::Status::evaluated); + // TODO Maybe always want the fence coherent kernel in the same cbuf + // as the other kernels? + maybe_update_fence(arr); + for (auto& sib : arr.siblings()) { + sib.set_status(array::Status::evaluated); + maybe_update_fence(sib); + } + if (!arr.is_tracer()) { + arr.detach(); } } + + // Signal the event in its stream + for (auto& [_, e] : events) { + auto s = e.stream(); + e.signal(s); + if (s.device == Device::gpu) { + metal::finalize(s); + } + } + return synchronizer; } diff --git a/mlx/utils.cpp b/mlx/utils.cpp index 11a9488b3..9168b34c8 100644 --- a/mlx/utils.cpp +++ b/mlx/utils.cpp @@ -20,6 +20,16 @@ Stream to_stream(StreamOrDevice s) { } } +Stream to_stream(StreamOrDevice s, Device default_) { + if (std::holds_alternative(s)) { + return default_stream(default_); + } else if (std::holds_alternative(s)) { + return default_stream(std::get(s)); + } else { + return std::get(s); + } +} + void PrintFormatter::print(std::ostream& os, bool val) { if (capitalize_bool) { os << (val ? "True" : "False"); diff --git a/mlx/utils.h b/mlx/utils.h index 11ac37be8..0b5ce54a1 100644 --- a/mlx/utils.h +++ b/mlx/utils.h @@ -14,6 +14,7 @@ namespace mlx::core { using StreamOrDevice = std::variant; Stream to_stream(StreamOrDevice s); +Stream to_stream(StreamOrDevice s, Device default_); struct StreamContext { public: diff --git a/python/mlx/extension.py b/python/mlx/extension.py index afece6d83..ecf3c52e6 100644 --- a/python/mlx/extension.py +++ b/python/mlx/extension.py @@ -6,7 +6,7 @@ import subprocess import sys from pathlib import Path -from setuptools import Extension, find_namespace_packages, setup +from setuptools import Extension from setuptools.command.build_ext import build_ext import mlx diff --git a/python/mlx/utils.py b/python/mlx/utils.py index 2a3c1e660..b7173deb7 100644 --- a/python/mlx/utils.py +++ b/python/mlx/utils.py @@ -62,7 +62,7 @@ def tree_map_with_path( tree: Any, *rest: Any, is_leaf: Optional[Callable] = None, - path: Any = None, + path: Optional[Any] = None, ) -> Any: """Applies ``fn`` to the path and leaves of the Python tree ``tree`` and returns a new collection with the results. @@ -74,8 +74,9 @@ def tree_map_with_path( fn (callable): The function that processes the leaves of the tree. tree (Any): The main Python tree that will be iterated upon. rest (tuple[Any]): Extra trees to be iterated together with ``tree``. - is_leaf (callable, optional): An optional callable that returns ``True`` + is_leaf (Optional[Callable]): An optional callable that returns ``True`` if the passed object is considered a leaf or ``False`` otherwise. + path (Optional[Any]): Prefix will be added to the result. Returns: A Python tree with the new values returned by ``fn``. diff --git a/python/src/random.cpp b/python/src/random.cpp index 49f3220cb..4b82f5479 100644 --- a/python/src/random.cpp +++ b/python/src/random.cpp @@ -253,7 +253,7 @@ void init_random(nb::module_& parent_module) { The values are sampled with equal probability from the integers in half-open interval ``[low, high)``. The lower and upper bound can be - scalars or arrays and must be roadcastable to ``shape``. + scalars or arrays and must be broadcastable to ``shape``. Args: low (scalar or array): Lower bound of the interval. diff --git a/tests/scheduler_tests.cpp b/tests/scheduler_tests.cpp index d30d01cfc..ad6184f92 100644 --- a/tests/scheduler_tests.cpp +++ b/tests/scheduler_tests.cpp @@ -37,8 +37,8 @@ TEST_CASE("test stream management") { } TEST_CASE("test asynchronous launch") { - auto s1 = default_stream(default_device()); - auto s2 = new_stream(default_device()); + auto s1 = default_stream(Device::cpu); + auto s2 = new_stream(Device::cpu); // Make sure streams execute asynchronously int x = 1; @@ -67,8 +67,8 @@ TEST_CASE("test asynchronous launch") { } TEST_CASE("test stream placement") { - auto s1 = default_stream(default_device()); - auto s2 = new_stream(default_device()); + auto s1 = default_stream(Device::cpu); + auto s2 = new_stream(Device::cpu); { // Wait on stream 1