diff --git a/.gitignore b/.gitignore new file mode 100644 index 000000000..327549e91 --- /dev/null +++ b/.gitignore @@ -0,0 +1,75 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Metal libraries +*.metallib + +# Distribution / packaging +python/mlx/share +python/mlx/include +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# vim +*.swp + +# Ignore build dir +build/ + +# Prerequisites +*.d + +# Compiled Object files +*.slo +*.lo +*.o +*.obj + +# Precompiled Headers +*.gch +*.pch + +# Compiled Dynamic libraries +*.so +*.dylib +*.dll + +# Fortran module files +*.mod +*.smod + +# Compiled Static libraries +*.lai +*.la +*.a +*.lib + +# Executables +*.exe +*.out +*.app + +# VSCode +.vscode/ +.DS_Store diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 000000000..b8babecad --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,9 @@ +repos: +- repo: https://github.com/pre-commit/mirrors-clang-format + rev: v14.0.6 + hooks: + - id: clang-format +- repo: https://github.com/psf/black + rev: 22.10.0 + hooks: + - id: black diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 000000000..ba5d3a6e2 --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,197 @@ +cmake_minimum_required(VERSION 3.24) + +project(mlx LANGUAGES CXX) + +# ----------------------------- Setup ----------------------------- +set(CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake") +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) +set(CMAKE_POSITION_INDEPENDENT_CODE ON) +set(CMAKE_INSTALL_MESSAGE NEVER) + +# ----------------------------- Configuration ----------------------------- +option(MLX_BUILD_TESTS "Build tests for mlx" ON) +option(MLX_BUILD_EXAMPLES "Build examples for mlx" ON) +option(MLX_BUILD_BENCHMARKS "Build benchmarks for mlx" OFF) +option(MLX_BUILD_PYTHON_BINDINGS "Build python bindings for mlx" OFF) +option(MLX_BUILD_METAL "Build metal backend" ON) +option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF) + +if(NOT MLX_VERSION) + set(MLX_VERSION 0.0.1) +endif() + +# ----------------------------- Lib ----------------------------- + +include(FetchContent) +# Avoid warning about DOWNLOAD_EXTRACT_TIMESTAMP in CMake 3.24: +cmake_policy(SET CMP0135 NEW) + +add_library(mlx) + +if (MLX_BUILD_METAL) + find_library(METAL_LIB Metal) + find_library(FOUNDATION_LIB Foundation) + find_library(QUARTZ_LIB QuartzCore) +endif() + +if (MLX_BUILD_METAL AND NOT METAL_LIB) + message(STATUS "Metal not found. Unable to build GPU") +elseif (MLX_BUILD_METAL) + message(STATUS "Building METAL sources") + add_compile_definitions(_METAL_) + + execute_process(COMMAND zsh "-c" "/usr/bin/sw_vers | cut -f2- -d: | sed -n 2p | grep -Eo '[0-9]+.[0-9]+'" + OUTPUT_VARIABLE MACOS_VERSION) + + message(STATUS "Detected macOS version ${MACOS_VERSION}") + if (${MACOS_VERSION} GREATER_EQUAL 14.0) + set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS14_iOS17-beta.zip) + elseif (${MACOS_VERSION} GREATER_EQUAL 13.3) + set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS13.3_iOS16.4.zip) + else() + set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS13_iOS16.zip) + endif() + + FetchContent_Declare( + metal_cpp + URL ${METAL_CPP_URL} + ) + + FetchContent_MakeAvailable(metal_cpp) + target_include_directories( + mlx PUBLIC + $ + $ + ) + target_link_libraries( + mlx + ${METAL_LIB} + ${FOUNDATION_LIB} + ${QUARTZ_LIB}) +endif() + +find_library(ACCELERATE_LIBRARY Accelerate) +if (ACCELERATE_LIBRARY) + message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}") + set(MLX_BUILD_ACCELERATE ON) + target_link_libraries(mlx ${ACCELERATE_LIBRARY}) + add_compile_definitions(ACCELERATE_NEW_LAPACK) +else() + message(STATUS "Accelerate not found, using default backend.") + set(MLX_BUILD_ACCELERATE OFF) + #set(BLA_VENDOR Generic) + find_package(BLAS REQUIRED) + if (NOT BLAS_FOUND) + message(FATAL_ERROR "Must have BLAS installed") + endif() + # TODO find a cleaner way to do this + find_path(BLAS_INCLUDE_DIRS cblas.h + /usr/include + /usr/local/include + $ENV{BLAS_HOME}/include) + message(STATUS ${BLAS_LIBRARIES}) + message(STATUS ${BLAS_INCLUDE_DIRS}) + target_include_directories(mlx PRIVATE ${BLAS_INCLUDE_DIRS}) + target_link_libraries(mlx ${BLAS_LIBRARIES}) +endif() + +add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/mlx) + +target_include_directories( + mlx + PUBLIC + $ + $ +) + +if (MLX_BUILD_PYTHON_BINDINGS) + message(STATUS "Building Python bindings.") + find_package(Python COMPONENTS Interpreter Development) + find_package(pybind11 CONFIG REQUIRED) + add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/python/src) +endif() + +if (MLX_BUILD_TESTS) + include(CTest) + add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/tests) +endif() + +if (MLX_BUILD_EXAMPLES) + add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/examples/cpp) +endif() + +if (MLX_BUILD_BENCHMARKS) + add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/benchmarks/cpp) +endif() + +# ----------------------------- Installation ----------------------------- +include(GNUInstallDirs) + +# Install library +install( + TARGETS mlx + EXPORT MLXTargets + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} + RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} + INCLUDES DESTINATION ${CMAKE_INSTALL_INCLUDEDIR} +) + + +# Install headers +install( + DIRECTORY ${CMAKE_CURRENT_LIST_DIR}/mlx + DESTINATION ${CMAKE_INSTALL_INCLUDEDIR} + COMPONENT headers + FILES_MATCHING PATTERN "*.h" +) + +# Install metal dependencies +if (MLX_BUILD_METAL) + + # Install metal cpp + install( + DIRECTORY ${metal_cpp_SOURCE_DIR}/ + DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/metal_cpp + COMPONENT metal_cpp_source + ) + +endif() + +# Install cmake config +set(MLX_CMAKE_BUILD_CONFIG ${CMAKE_BINARY_DIR}/MLXConfig.cmake) +set(MLX_CMAKE_BUILD_VERSION_CONFIG ${CMAKE_BINARY_DIR}/MLXConfigVersion.cmake) +set(MLX_CMAKE_INSTALL_MODULE_DIR share/cmake/MLX) + +install( + EXPORT MLXTargets + FILE MLXTargets.cmake + DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR} +) + +include(CMakePackageConfigHelpers) + +write_basic_package_version_file( + ${MLX_CMAKE_BUILD_VERSION_CONFIG} + COMPATIBILITY SameMajorVersion + VERSION ${MLX_VERSION} +) + +configure_package_config_file( + ${CMAKE_CURRENT_LIST_DIR}/mlx.pc.in + ${MLX_CMAKE_BUILD_CONFIG} + INSTALL_DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR} + NO_CHECK_REQUIRED_COMPONENTS_MACRO + PATH_VARS CMAKE_INSTALL_LIBDIR CMAKE_INSTALL_INCLUDEDIR MLX_CMAKE_INSTALL_MODULE_DIR +) + +install( + FILES ${MLX_CMAKE_BUILD_CONFIG} ${MLX_CMAKE_BUILD_VERSION_CONFIG} + DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR} +) + +install( + DIRECTORY ${CMAKE_MODULE_PATH}/ + DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR} +) \ No newline at end of file diff --git a/benchmarks/cpp/time_utils.h b/benchmarks/cpp/time_utils.h new file mode 100644 index 000000000..9a4a0778c --- /dev/null +++ b/benchmarks/cpp/time_utils.h @@ -0,0 +1,38 @@ +#pragma once + +#include +#include +#include + +#include "mlx/mlx.h" + +#define milliseconds(x) \ + (std::chrono::duration_cast(x).count() / 1e6) +#define time_now() std::chrono::high_resolution_clock::now() + +#define TIME(FUNC, ...) \ + std::cout << "Timing " << #FUNC << " ... " << std::flush \ + << std::setprecision(5) << time_fn(FUNC, ##__VA_ARGS__) << " msec" \ + << std::endl; + +#define TIMEM(MSG, FUNC, ...) \ + std::cout << "Timing " \ + << "(" << MSG << ") " << #FUNC << " ... " << std::flush \ + << std::setprecision(5) << time_fn(FUNC, ##__VA_ARGS__) << " msec" \ + << std::endl; + +template +double time_fn(F fn, Args... args) { + // warmup + for (int i = 0; i < 5; ++i) { + eval(fn(std::forward(args)...)); + } + + int num_iters = 100; + auto start = time_now(); + for (int i = 0; i < num_iters; i++) { + eval(fn(std::forward(args)...)); + } + auto end = time_now(); + return milliseconds(end - start) / static_cast(num_iters); +} diff --git a/benchmarks/python/batch_matmul_bench.py b/benchmarks/python/batch_matmul_bench.py new file mode 100644 index 000000000..11f0734eb --- /dev/null +++ b/benchmarks/python/batch_matmul_bench.py @@ -0,0 +1,60 @@ +import argparse +import mlx.core as mx + +from time_utils import time_fn + +B = 8 +T = 1024 +D = 512 + + +def time_batch_matmul(): + mx.random.seed(3) + a = mx.random.uniform(shape=(B, T, D)) + b = mx.random.uniform(shape=(D, D)) + c = mx.random.uniform(shape=(B, T, D)) + mx.eval(a, b, c) + + time_fn(mx.matmul, a, b) + + def batch_vjp_first(): + return mx.vjp(mx.matmul, [a, b], [c])[1][0] + + time_fn(batch_vjp_first) + + def batch_vjp_second(): + return mx.vjp(mx.matmul, [a, b], [c])[1][1] + + time_fn(batch_vjp_second) + + +def time_unbatch_matmul(key): + mx.random.seed(3) + a = mx.random.uniform(shape=(B * T, D)) + b = mx.random.uniform(shape=(D, D)) + c = mx.random.uniform(shape=(B * T, D)) + mx.eval(a, b, c) + time_fn(mx.matmul, a, b) + + def unbatch_vjp_first(): + return mx.matmul(c, mx.transpose(b)) + + time_fn(unbatch_vjp_first) + + def unbatch_vjp_second(): + return mx.matmul(mx.transpose(a), c) + + time_fn(unbatch_vjp_second) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("MLX benchmarks.") + parser.add_argument("--gpu", action="store_true", help="Use the Metal back-end.") + args = parser.parse_args() + if args.gpu: + mx.set_default_device(mx.gpu) + else: + mx.set_default_device(mx.cpu) + + time_batch_matmul() + time_unbatch_matmul() diff --git a/benchmarks/python/time_utils.py b/benchmarks/python/time_utils.py new file mode 100644 index 000000000..e067c8098 --- /dev/null +++ b/benchmarks/python/time_utils.py @@ -0,0 +1,20 @@ +import time + +import mlx.core as mx + + +def time_fn(fn, *args, **kwargs): + print(f"Timing {fn.__name__} ...", end=" ") + + # warmup + for _ in range(5): + mx.eval(fn(*args, **kwargs)) + + num_iters = 100 + tic = time.perf_counter() + for _ in range(num_iters): + x = mx.eval(fn(*args, **kwargs)) + toc = time.perf_counter() + + msec = 1e3 * (toc - tic) / num_iters + print(f"{msec:.5f} msec") diff --git a/docs/.clang-format b/docs/.clang-format new file mode 100644 index 000000000..47a38a93f --- /dev/null +++ b/docs/.clang-format @@ -0,0 +1,2 @@ +DisableFormat: true +SortIncludes: Never diff --git a/docs/Makefile b/docs/Makefile new file mode 100644 index 000000000..e5888bc2f --- /dev/null +++ b/docs/Makefile @@ -0,0 +1,18 @@ +# Minimal makefile for Sphinx documentation + +# You can set these variables from the command line. +SPHINXOPTS = +SPHINXBUILD = sphinx-build +SOURCEDIR = src +BUILDDIR = build + +# Put it first so that "make" without argument is like "make help". +help: + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +.PHONY: help Makefile + +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/docs/src/examples/mlp.rst b/docs/src/examples/mlp.rst new file mode 100644 index 000000000..5763eeba0 --- /dev/null +++ b/docs/src/examples/mlp.rst @@ -0,0 +1,131 @@ +.. _mlp: + +Multi-Layer Perceptron +---------------------- + +In this example we'll learn to use ``mlx.nn`` by implementing a simple +multi-layer perceptron to classify MNIST. + +As a first step import the MLX packages we need: + +.. code-block:: python + + import mlx.core as mx + import mlx.nn as nn + import mlx.optimizers as optim + + import numpy as np + + +The model is defined as the ``MLP`` class which inherits from +:class:`mlx.nn.Module`. We follow the standard idiom to make a new module: + +1. Define an ``__init__`` where the parameters and/or submodules are setup. See + the :ref:`Module class docs` for more information on how + :class:`mlx.nn.Module` registers parameters. +2. Define a ``__call__`` where the computation is implemented. + +.. code-block:: python + + class MLP(nn.Module): + def __init__( + self, num_layers: int, input_dim: int, hidden_dim: int, output_dim: int + ): + super().__init__() + layer_sizes = [input_dim] + [hidden_dim] * num_layers + [output_dim] + self.layers = [ + nn.Linear(idim, odim) + for idim, odim in zip(layer_sizes[:-1], layer_sizes[1:]) + ] + + def __call__(self, x): + for l in self.layers[:-1]: + x = mx.maximum(l(x), 0.0) + return self.layers[-1](x) + + +We define the loss function which takes the mean of the per-example cross +entropy loss. The ``mlx.nn.losses`` sub-package has implementations of some +commonly used loss functions. + +.. code-block:: python + + def loss_fn(model, X, y): + return mx.mean(nn.losses.cross_entropy(model(X), y)) + +We also need a function to compute the accuracy of the model on the validation +set: + +.. code-block:: python + + def eval_fn(model, X, y): + return mx.mean(mx.argmax(model(X), axis=1) == y) + +Next, setup the problem parameters and load the data: + +.. code-block:: python + + num_layers = 2 + hidden_dim = 32 + num_classes = 10 + batch_size = 256 + num_epochs = 10 + learning_rate = 1e-1 + + # Load the data + import mnist + train_images, train_labels, test_images, test_labels = map( + mx.array, mnist.mnist() + ) + +Since we're using SGD, we need an iterator which shuffles and constructs +minibatches of examples in the training set: + +.. code-block:: python + + def batch_iterate(batch_size, X, y): + perm = mx.array(np.random.permutation(y.size)) + for s in range(0, y.size, batch_size): + ids = perm[s : s + batch_size] + yield X[ids], y[ids] + + +Finally, we put it all together by instantiating the model, the +:class:`mlx.optimizers.SGD` optimizer, and running the training loop: + +.. code-block:: python + + # Load the model + model = MLP(num_layers, train_images.shape[-1], hidden_dim, num_classes) + mx.eval(model.parameters()) + + # Get a function which gives the loss and gradient of the + # loss with respect to the model's trainable parameters + loss_and_grad_fn = nn.value_and_grad(model, loss_fn) + + # Instantiate the optimizer + optimizer = optim.SGD(learning_rate=learning_rate) + + for e in range(num_epochs): + for X, y in batch_iterate(batch_size, train_images, train_labels): + loss, grads = loss_and_grad_fn(model, X, y) + + # Update the optimizer state and model parameters + # in a single call + optimizer.update(model, grads) + + # Force a graph evaluation + mx.eval(model.parameters(), optimizer.state) + + accuracy = eval_fn(model, test_images, test_labels) + print(f"Epoch {e}: Test accuracy {accuracy.item():.3f}") + + +.. note:: + The :func:`mlx.nn.value_and_grad` function is a convenience function to get + the gradient of a loss with respect to the trainable parameters of a model. + This should not be confused with :func:`mlx.core.value_and_grad`. + +The model should train to a decent accuracy (about 95%) after just a few passes +over the training set. The `full example `_ +is available in the MLX GitHub repo. diff --git a/docs/src/python/array.rst b/docs/src/python/array.rst new file mode 100644 index 000000000..96ddd32b3 --- /dev/null +++ b/docs/src/python/array.rst @@ -0,0 +1,45 @@ +.. _array: + +Array +===== + +.. currentmodule:: mlx.core + +.. autosummary:: + :toctree: _autosummary + + array + array.astype + array.item + array.tolist + array.dtype + array.ndim + array.shape + array.size + Dtype + array.abs + array.all + array.any + array.argmax + array.argmin + array.cos + array.dtype + array.exp + array.log + array.log1p + array.logsumexp + array.max + array.mean + array.min + array.prod + array.reciprocal + array.reshape + array.rsqrt + array.sin + array.split + array.sqrt + array.square + array.sum + array.transpose + array.T + array.var diff --git a/docs/src/python/ops.rst b/docs/src/python/ops.rst new file mode 100644 index 000000000..450588536 --- /dev/null +++ b/docs/src/python/ops.rst @@ -0,0 +1,94 @@ +.. _ops: + +Operations +========== + +.. currentmodule:: mlx.core + +.. autosummary:: + :toctree: _autosummary + + abs + add + all + allclose + any + arange + arccos + arccosh + arcsin + arcsinh + arctan + arctanh + argmax + argmin + argpartition + argsort + array_equal + broadcast_to + concatenate + convolve + conv1d + conv2d + cos + cosh + divide + equal + erf + erfinv + exp + expand_dims + full + greater + greater_equal + less + less_equal + load + log + log2 + log10 + log1p + logaddexp + logical_not + logsumexp + matmul + max + maximum + mean + min + minimum + multiply + negative + ones + ones_like + partition + pad + prod + reciprocal + reshape + rsqrt + save + savez + savez_compressed + sigmoid + sign + sin + sinh + softmax + sort + split + sqrt + square + squeeze + stop_gradient + subtract + sum + take + take_along_axis + tan + tanh + transpose + var + where + zeros + zeros_like diff --git a/examples/cpp/timer.h b/examples/cpp/timer.h new file mode 100644 index 000000000..5a270f00a --- /dev/null +++ b/examples/cpp/timer.h @@ -0,0 +1,18 @@ +#pragma once + +#include + +namespace timer { + +using namespace std::chrono; + +template +inline double seconds(duration x) { + return duration_cast(x).count() / 1e9; +} + +inline auto time() { + return high_resolution_clock::now(); +} + +} // namespace timer diff --git a/examples/extensions/axpby/axpby.cpp b/examples/extensions/axpby/axpby.cpp new file mode 100644 index 000000000..1e6a995bf --- /dev/null +++ b/examples/extensions/axpby/axpby.cpp @@ -0,0 +1,359 @@ +#include +#include +#include + +#include "mlx/backend/common/copy.h" +#include "mlx/backend/common/utils.h" +#include "mlx/utils.h" + +#include "axpby/axpby.h" + +#ifdef ACCELERATE_NEW_LAPACK +#include +#endif + +#ifdef _METAL_ +#include "mlx/backend/metal/device.h" +#include "mlx/backend/metal/utils.h" +#endif + +namespace mlx::core { + +/////////////////////////////////////////////////////////////////////////////// +// Operation Implementation +/////////////////////////////////////////////////////////////////////////////// + +/** + * Scale and sum two vectors elementwise + * z = alpha * x + beta * y + * + * Follow numpy style broadcasting between x and y + * Inputs are upcasted to floats if needed + **/ +array axpby( + const array& x, // Input array x + const array& y, // Input array y + const float alpha, // Scaling factor for x + const float beta, // Scaling factor for y + StreamOrDevice s /* = {} */ // Stream on which to schedule the operation +) { + // Promote dtypes between x and y as needed + auto promoted_dtype = promote_types(x.dtype(), y.dtype()); + + // Upcast to float32 for non-floating point inputs x and y + auto out_dtype = is_floating_point(promoted_dtype) + ? promoted_dtype + : promote_types(promoted_dtype, float32); + + // Cast x and y up to the determined dtype (on the same stream s) + auto x_casted = astype(x, out_dtype, s); + auto y_casted = astype(y, out_dtype, s); + + // Broadcast the shapes of x and y (on the same stream s) + auto broadcasted_inputs = broadcast_arrays({x_casted, y_casted}, s); + auto out_shape = broadcasted_inputs[0].shape(); + + // Construct the array as the output of the Axpby primitive + // with the broadcasted and upcasted arrays as inputs + return array( + /* const std::vector& shape = */ out_shape, + /* Dtype dtype = */ out_dtype, + /* std::unique_ptr primitive = */ + std::make_unique(to_stream(s), alpha, beta), + /* const std::vector& inputs = */ broadcasted_inputs); +} + +/////////////////////////////////////////////////////////////////////////////// +// Primitive Common Backend Implementation +/////////////////////////////////////////////////////////////////////////////// + +template +void axpby_impl( + const array& x, + const array& y, + array& out, + float alpha_, + float beta_) { + // We only allocate memory when we are ready to fill the output + // malloc_or_wait synchronously allocates available memory + // There may be a wait executed here if the allocation is requested + // under memory-pressured conditions + out.set_data(allocator::malloc_or_wait(out.nbytes())); + + // Collect input and output data pointers + const T* x_ptr = x.data(); + const T* y_ptr = y.data(); + T* out_ptr = out.data(); + + // Cast alpha and beta to the relevant types + T alpha = static_cast(alpha_); + T beta = static_cast(beta_); + + // Do the elementwise operation for each output + for (size_t out_idx = 0; out_idx < out.size(); out_idx++) { + // Map linear indices to offsets in x and y + auto x_offset = elem_to_loc(out_idx, x.shape(), x.strides()); + auto y_offset = elem_to_loc(out_idx, y.shape(), y.strides()); + + // We allocate the output to be contiguous and regularly strided + // (defaults to row major) and hence it doesn't need additonal mapping + out_ptr[out_idx] = alpha * x_ptr[x_offset] + beta * y_ptr[y_offset]; + } +} + +/** Fall back implementation for evaluation on CPU */ +void Axpby::eval(const std::vector& inputs, array& out) { + // Check the inputs (registered in the op while contructing the out array) + assert(inputs.size() == 2); + auto& x = inputs[0]; + auto& y = inputs[1]; + + // Dispatch to the correct dtype + if (out.dtype() == float32) { + return axpby_impl(x, y, out, alpha_, beta_); + } else if (out.dtype() == float16) { + return axpby_impl(x, y, out, alpha_, beta_); + } else if (out.dtype() == bfloat16) { + return axpby_impl(x, y, out, alpha_, beta_); + } else if (out.dtype() == complex64) { + return axpby_impl(x, y, out, alpha_, beta_); + } else { + throw std::runtime_error( + "Axpby is only supported for floating point types."); + } +} + +/////////////////////////////////////////////////////////////////////////////// +// Primitive Accelerate Backend Implementation +/////////////////////////////////////////////////////////////////////////////// + +#ifdef ACCELERATE_NEW_LAPACK + +template +void axpby_impl_accelerate( + const array& x, + const array& y, + array& out, + float alpha_, + float beta_) { + // Accelerate library provides catlas_saxpby which does + // Y = (alpha * X) + (beta * Y) in place + // To use it, we first copy the data in y over to the output array + + // This specialization requires both x and y be contiguous in the same mode + // i.e: corresponding linear indices in both point to corresponding elements + // The data in the output array is allocated to match the strides in y + // such that x, y, and out are contiguous in the same mode and + // no transposition is needed + out.set_data( + allocator::malloc_or_wait(y.data_size() * out.itemsize()), + y.data_size(), + y.strides(), + y.flags()); + + // We then copy over the elements using the contiguous vector specialization + copy_inplace(y, out, CopyType::Vector); + + // Get x and y pointers for catlas_saxpby + const T* x_ptr = x.data(); + T* y_ptr = out.data(); + + T alpha = static_cast(alpha_); + T beta = static_cast(beta_); + + // Call the inplace accelerate operator + catlas_saxpby( + /* N = */ out.size(), + /* ALPHA = */ alpha, + /* X = */ x_ptr, + /* INCX = */ 1, + /* BETA = */ beta, + /* Y = */ y_ptr, + /* INCY = */ 1); +} + +/** Evaluate primitive on CPU using accelerate specializations */ +void Axpby::eval_cpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 2); + auto& x = inputs[0]; + auto& y = inputs[1]; + + // Accelerate specialization for contiguous single precision float arrays + if (out.dtype() == float32 && + ((x.flags().row_contiguous && y.flags().row_contiguous) || + (x.flags().col_contiguous && y.flags().col_contiguous))) { + axpby_impl_accelerate(x, y, out, alpha_, beta_); + return; + } + + // Fall back to common backend if specializations are not available + eval(inputs, out); +} + +#else // Accelerate not avaliable + +/** Evaluate primitive on CPU falling back to common backend */ +void Axpby::eval_cpu(const std::vector& inputs, array& out) { + eval(inputs, out); +} + +#endif + +/////////////////////////////////////////////////////////////////////////////// +// Primitive Metal Backend Implementation +/////////////////////////////////////////////////////////////////////////////// + +#ifdef _METAL_ + +/** Evaluate primitive on GPU */ +void Axpby::eval_gpu(const std::vector& inputs, array& out) { + // Prepare inputs + assert(inputs.size() == 2); + auto& x = inputs[0]; + auto& y = inputs[1]; + + // Each primitive carries the stream it should execute on + // and each stream carries its device identifiers + auto& s = stream(); + // We get the needed metal device using the stream + auto& d = metal::device(s.device); + + // Prepare to specialize based on contiguity + bool contiguous_kernel = + (x.flags().row_contiguous && y.flags().row_contiguous) || + (x.flags().col_contiguous && y.flags().col_contiguous); + + // Allocate output memory with strides based on specialization + if (contiguous_kernel) { + out.set_data( + allocator::malloc_or_wait(x.data_size() * out.itemsize()), + x.data_size(), + x.strides(), + x.flags()); + } else { + out.set_data(allocator::malloc_or_wait(out.nbytes())); + } + + // Resolve name of kernel (corresponds to axpby.metal) + std::ostringstream kname; + kname << "axpby_"; + kname << (contiguous_kernel ? "contiguous_" : "general_"); + kname << type_to_name(out); + + // Make sure the metal library is available and look for it + // in the same folder as this executable if needed + d.register_library("mlx_ext", metal::get_colocated_mtllib_path); + + // Make a kernel from this metal library + auto kernel = d.get_kernel(kname.str(), "mlx_ext"); + + // Prepare to encode kernel + auto compute_encoder = d.get_command_encoder(s.index); + compute_encoder->setComputePipelineState(kernel); + + // Kernel parameters are registered with buffer indices corresponding to + // those in the kernel decelaration at axpby.metal + int ndim = out.ndim(); + size_t nelem = out.size(); + + // Encode input arrays to kernel + set_array_buffer(compute_encoder, x, 0); + set_array_buffer(compute_encoder, y, 1); + + // Encode output arrays to kernel + set_array_buffer(compute_encoder, out, 2); + + // Encode alpha and beta + compute_encoder->setBytes(&alpha_, sizeof(float), 3); + compute_encoder->setBytes(&beta_, sizeof(float), 4); + + // Encode shape, strides and ndim if needed + if (!contiguous_kernel) { + compute_encoder->setBytes(x.shape().data(), ndim * sizeof(int), 5); + compute_encoder->setBytes(x.strides().data(), ndim * sizeof(size_t), 6); + compute_encoder->setBytes(y.strides().data(), ndim * sizeof(size_t), 7); + compute_encoder->setBytes(&ndim, sizeof(int), 8); + } + + // We launch 1 thread for each input and make sure that the number of + // threads in any given threadgroup is not higher than the max allowed + size_t tgp_size = std::min(nelem, kernel->maxTotalThreadsPerThreadgroup()); + + // Fix the 3D size of each threadgroup (in terms of threads) + MTL::Size group_dims = MTL::Size(tgp_size, 1, 1); + + // Fix the 3D size of the launch grid (in terms of threads) + MTL::Size grid_dims = MTL::Size(nelem, 1, 1); + + // Launch the grid with the given number of threads divded among + // the given threadgroups + compute_encoder->dispatchThreads(grid_dims, group_dims); +} + +#else // Metal is not available + +/** Fail evaluation on GPU */ +void Axpby::eval_gpu(const std::vector& inputs, array& out) { + throw std::runtime_error("Axpby has no GPU implementation."); +} + +#endif + +/////////////////////////////////////////////////////////////////////////////// +// Primitive Transforms +/////////////////////////////////////////////////////////////////////////////// + +/** The Jacobian-vector product. */ +array Axpby::jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) { + // Forward mode diff that pushes along the tangents + // The jvp transform on the the primitive can built with ops + // that are scheduled on the same stream as the primtive + + // If argnums = {0}, we only push along x in which case the + // jvp is just the tangent scaled by alpha + // Similarly, if argnums = {1}, the jvp is just the tangent + // scaled by beta + if (argnums.size() > 1) { + auto scale = argnums[0] == 0 ? alpha_ : beta_; + auto scale_arr = array(scale, tangents[0].dtype()); + return multiply(scale_arr, tangents[0], stream()); + } + // If, argnums = {0, 1}, we take contributions from both + // which gives us jvp = tangent_x * alpha + tangent_y * beta + else { + return axpby(tangents[0], tangents[1], alpha_, beta_, stream()); + } +} + +/** The vector-Jacobian product. */ +std::vector Axpby::vjp( + const std::vector& primals, + const array& cotan, + const std::vector& argnums) { + // Reverse mode diff + std::vector vjps; + for (auto arg : argnums) { + auto scale = arg == 0 ? alpha_ : beta_; + auto scale_arr = array(scale, cotan.dtype()); + vjps.push_back(multiply(scale_arr, cotan, stream())); + } + return vjps; +} + +/** Vectorize primitve along given axis */ +std::pair Axpby::vmap( + const std::vector& inputs, + const std::vector& axes) { + throw std::runtime_error("Axpby has no vmap implementation."); +} + +/** Equivalence check **/ +bool Axpby::is_equivalent(const Primitive& other) const { + const Axpby& r_other = static_cast(other); + return alpha_ == r_other.alpha_ && beta_ == r_other.beta_; +} + +} // namespace mlx::core \ No newline at end of file diff --git a/examples/extensions/bindings.cpp b/examples/extensions/bindings.cpp new file mode 100644 index 000000000..d2eeca4ac --- /dev/null +++ b/examples/extensions/bindings.cpp @@ -0,0 +1,39 @@ +#include +#include + +#include "axpby/axpby.h" + +namespace py = pybind11; +using namespace py::literals; +using namespace mlx::core; + +PYBIND11_MODULE(mlx_sample_extensions, m) { + m.doc() = "Sample C++ and metal extensions for MLX"; + + m.def( + "axpby", + &axpby, + "x"_a, + "y"_a, + py::pos_only(), + "alpha"_a, + "beta"_a, + py::kw_only(), + "stream"_a = py::none(), + R"pbdoc( + Scale and sum two vectors elementwise + ``z = alpha * x + beta * y`` + + Follows numpy style broadcasting between ``x`` and ``y`` + Inputs are upcasted to floats if needed + + Args: + x (array): Input array. + y (array): Input array. + alpha (float): Scaling factor for ``x``. + beta (float): Scaling factor for ``y``. + + Returns: + array: ``alpha * x + beta * y`` + )pbdoc"); +} \ No newline at end of file diff --git a/mlx/CMakeLists.txt b/mlx/CMakeLists.txt new file mode 100644 index 000000000..bd28537f1 --- /dev/null +++ b/mlx/CMakeLists.txt @@ -0,0 +1,36 @@ +target_sources( + mlx + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/array.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/graph_utils.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/load.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/random.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/scheduler.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/transforms.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/metal.h +) + +add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/common) + +if (MLX_BUILD_ACCELERATE) + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/accelerate) +else() + target_sources( + mlx + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/backend/common/default_primitives.cpp + ) +endif() + +if (MLX_BUILD_METAL) + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/metal) +else() + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_metal) +endif() diff --git a/mlx/array.cpp b/mlx/array.cpp new file mode 100644 index 000000000..20ea2e407 --- /dev/null +++ b/mlx/array.cpp @@ -0,0 +1,143 @@ +#include + +#include "mlx/array.h" +#include "mlx/ops.h" +#include "mlx/primitives.h" +#include "mlx/transforms.h" + +namespace mlx::core { + +namespace { + +std::pair> cum_prod(const std::vector& shape) { + std::vector strides(shape.size()); + size_t cum_prod = 1; + for (int i = shape.size() - 1; i >= 0; --i) { + strides[i] = cum_prod; + cum_prod *= shape[i]; + } + return {cum_prod, strides}; +} + +} // namespace + +array::array(const std::complex& val, Dtype dtype /* = complex64 */) + : array_desc_(std::make_shared(std::vector{}, dtype)) { + auto cval = static_cast(val); + init(&cval); +} + +array::array( + const std::vector& shape, + Dtype dtype, + std::unique_ptr primitive, + const std::vector& inputs) + : array_desc_(std::make_shared( + shape, + dtype, + std::move(primitive), + inputs)) {} + +array::array(std::initializer_list data) + : array_desc_(std::make_shared( + std::vector{static_cast(data.size())}, + float32)) { + init(data.begin()); +} + +/* Build an array from a shared buffer */ +array::array( + allocator::Buffer data, + const std::vector& shape, + Dtype dtype, + deleter_t deleter) + : array_desc_(std::make_shared(shape, dtype)) { + set_data(data, deleter); +} + +void array::detach() { + array_desc_->inputs.clear(); + array_desc_->primitive = nullptr; +} + +void array::eval(bool retain_graph /* = false */) { + mlx::core::eval({*this}, retain_graph); +} + +void array::set_data(allocator::Buffer buffer, deleter_t d) { + array_desc_->data = std::make_shared(buffer, d); + array_desc_->data_ptr = buffer.raw_ptr(); + array_desc_->data_size = size(); + array_desc_->flags.contiguous = true; + array_desc_->flags.row_contiguous = true; + auto max_dim = std::max_element(shape().begin(), shape().end()); + array_desc_->flags.col_contiguous = size() <= 1 || size() == *max_dim; +} + +void array::set_data( + allocator::Buffer buffer, + size_t data_size, + std::vector strides, + Flags flags, + deleter_t d) { + array_desc_->data = std::make_shared(buffer, d); + array_desc_->data_ptr = buffer.raw_ptr(); + array_desc_->data_size = data_size; + array_desc_->strides = std::move(strides); + array_desc_->flags = flags; +} + +void array::copy_shared_buffer( + const array& other, + const std::vector& strides, + Flags flags, + size_t data_size, + size_t offset /* = 0 */) { + array_desc_->data = other.array_desc_->data; + array_desc_->strides = strides; + array_desc_->flags = flags; + array_desc_->data_size = data_size; + auto char_offset = sizeof(char) * itemsize() * offset; + array_desc_->data_ptr = static_cast( + static_cast(other.array_desc_->data_ptr) + char_offset); +} + +void array::copy_shared_buffer(const array& other) { + copy_shared_buffer(other, other.strides(), other.flags(), other.data_size()); +} + +array::ArrayDesc::ArrayDesc(const std::vector& shape, Dtype dtype) + : shape(shape), dtype(dtype) { + std::tie(size, strides) = cum_prod(shape); +} + +array::ArrayDesc::ArrayDesc( + const std::vector& shape, + Dtype dtype, + std::unique_ptr primitive, + const std::vector& inputs) + : shape(shape), + dtype(dtype), + primitive(std::move(primitive)), + inputs(inputs) { + std::tie(size, strides) = cum_prod(shape); + for (auto& in : inputs) { + is_tracer |= in.is_tracer(); + } +} + +// Needed because the Primitive type used in array.h is incomplete and the +// compiler needs to see the call to the desctructor after the type is complete. +array::ArrayDesc::~ArrayDesc() = default; + +array::ArrayIterator::reference array::ArrayIterator::operator*() const { + auto start = std::vector(arr.ndim(), 0); + auto end = arr.shape(); + auto shape = arr.shape(); + shape.erase(shape.begin()); + start[0] = idx; + end[0] = idx + 1; + return reshape(slice(arr, start, end), shape); +}; + +} // namespace mlx::core diff --git a/mlx/backend/accelerate/softmax.cpp b/mlx/backend/accelerate/softmax.cpp new file mode 100644 index 000000000..31cdbd2da --- /dev/null +++ b/mlx/backend/accelerate/softmax.cpp @@ -0,0 +1,323 @@ +#include +#include + +#include +#include +#include + +#include "mlx/backend/common/copy.h" +#include "mlx/primitives.h" + +namespace mlx::core { + +namespace { + +/** + * Compute exp(x) in an optimizer friendly way as follows: + * + * First change the problem to computing 2**y where y = x / ln(2). + * + * Now we will compute 2**y as 2**y1 * 2**y2 where y1 is the integer part + * `ipart` and y2 is fractional part. For the integer part we perform bit + * shifting and for the fractional part we use a polynomial approximation. + * + * The algorithm and constants of the polynomial taken from + * https://github.com/akohlmey/fastermath/blob/master/src/exp.c which took them + * from Cephes math library. + * + * Note: The implementation below is a general fast exp. There could be faster + * implementations for numbers strictly < 0. + */ +inline simd_float16 simd_fast_exp(simd_float16 x) { + x *= 1.442695; // multiply with log_2(e) + simd_float16 ipart, fpart; + simd_int16 epart; + x = simd_clamp(x, -80, 80); + ipart = simd::floor(x + 0.5); + fpart = x - ipart; + + x = 1.535336188319500e-4f; + x = x * fpart + 1.339887440266574e-3f; + x = x * fpart + 9.618437357674640e-3f; + x = x * fpart + 5.550332471162809e-2f; + x = x * fpart + 2.402264791363012e-1f; + x = x * fpart + 6.931472028550421e-1f; + x = x * fpart + 1.000000000000000f; + + // generate 2**ipart in the floating point representation using integer + // bitshifting + epart = (simd_int(ipart) + 127) << 23; + + return (*(simd_float16*)&epart) * x; +} + +/** + * The ARM neon equivalent of the fast exp above. + */ +inline float16x8_t neon_fast_exp(float16x8_t x) { + x = vmulq_f16(x, vdupq_n_f16(1.442695)); // multiply with log_2(e) + x = vmaxq_f16(x, vdupq_n_f16(-14)); // clamp under with -14 + x = vminq_f16(x, vdupq_n_f16(14)); // clamp over with 14 + + float16x8_t ipart = vrndmq_f16(vaddq_f16(x, vdupq_n_f16(0.5))); + float16x8_t fpart = vsubq_f16(x, ipart); + + x = vdupq_n_f16(1.535336188319500e-4f); + x = vfmaq_f16(vdupq_n_f16(1.339887440266574e-3f), x, fpart); + x = vfmaq_f16(vdupq_n_f16(1.339887440266574e-3f), x, fpart); + x = vfmaq_f16(vdupq_n_f16(9.618437357674640e-3f), x, fpart); + x = vfmaq_f16(vdupq_n_f16(5.550332471162809e-2f), x, fpart); + x = vfmaq_f16(vdupq_n_f16(2.402264791363012e-1f), x, fpart); + x = vfmaq_f16(vdupq_n_f16(6.931472028550421e-1f), x, fpart); + x = vfmaq_f16(vdupq_n_f16(1.000000000000000f), x, fpart); + + // generate 2**ipart in the floating point representation using integer + // bitshifting + int16x8_t epart = vcvtq_s16_f16(ipart); + epart = vaddq_s16(epart, vdupq_n_s16(15)); + epart = vshlq_n_s16(epart, 10); + + return vmulq_f16(vreinterpretq_f16_s16(epart), x); +} + +/** + * Implementation of folding maximum for ARM neon. This should possibly be + * refactored out of softmax.cpp at some point. + */ +inline float16_t neon_reduce_max(float16x8_t x) { + float16x4_t y; + y = vpmax_f16(vget_low_f16(x), vget_high_f16(x)); + y = vpmax_f16(y, y); + y = vpmax_f16(y, y); + return vget_lane_f16(y, 0); +} + +/** + * Implementation of folding sum for ARM neon. This should possibly be + * refactored out of softmax.cpp at some point. + */ +inline float16_t neon_reduce_add(float16x8_t x) { + float16x4_t y; + float16x4_t zero = vdup_n_f16(0); + y = vpadd_f16(vget_low_f16(x), vget_high_f16(x)); + y = vpadd_f16(y, zero); + y = vpadd_f16(y, zero); + return vget_lane_f16(y, 0); +} + +template +struct AccelerateSimdOps { + VT init(T a) { + return a; + } + + VT load(const T* a) { + return *(VT*)a; + } + + void store(T* dst, VT x) { + *(VT*)dst = x; + } + + VT max(VT a, VT b) { + return simd_max(a, b); + }; + + VT exp(VT x) { + return simd_fast_exp(x); + } + + VT add(VT a, VT b) { + return a + b; + } + + VT sub(VT a, T b) { + return a - b; + } + + VT mul(VT a, VT b) { + return a * b; + } + + VT mul(VT a, T b) { + return a * b; + } + + T reduce_max(VT x) { + return simd_reduce_max(x); + } + + T reduce_add(VT x) { + return simd_reduce_add(x); + } +}; + +template +struct NeonFp16SimdOps { + VT init(T a) { + return vdupq_n_f16(a); + } + + VT load(const T* a) { + return vld1q_f16(a); + } + + void store(T* dst, VT x) { + vst1q_f16(dst, x); + } + + VT max(VT a, VT b) { + return vmaxq_f16(a, b); + }; + + VT exp(VT x) { + return neon_fast_exp(x); + } + + VT add(VT a, VT b) { + return vaddq_f16(a, b); + } + + VT sub(VT a, T b) { + return vsubq_f16(a, vdupq_n_f16(b)); + } + + VT mul(VT a, VT b) { + return vmulq_f16(a, b); + } + + VT mul(VT a, T b) { + return vmulq_f16(a, vdupq_n_f16(b)); + } + + T reduce_max(VT x) { + return neon_reduce_max(x); + } + + T reduce_add(VT x) { + return neon_reduce_add(x); + } +}; + +template +void softmax(const array& in, array& out) { + Ops ops; + + const T* in_ptr = in.data(); + T* out_ptr = out.data(); + int M = in.shape().back(); + int L = in.data_size() / M; + const T* current_in_ptr; + T* current_out_ptr; + + for (int i = 0; i < L; i++, in_ptr += M, out_ptr += M) { + // Find the maximum + current_in_ptr = in_ptr; + VT vmaximum = ops.init(-std::numeric_limits::infinity()); + size_t s = M; + while (s >= N) { + vmaximum = ops.max(ops.load(current_in_ptr), vmaximum); + current_in_ptr += N; + s -= N; + } + T maximum = ops.reduce_max(vmaximum); + while (s-- > 0) { + maximum = std::max(maximum, *current_in_ptr); + current_in_ptr++; + } + + // Compute the normalizer and the exponentials + VT vnormalizer = ops.init(0.0); + current_out_ptr = out_ptr; + current_in_ptr = in_ptr; + s = M; + while (s >= N) { + VT vexp = ops.exp(ops.sub(*(VT*)current_in_ptr, maximum)); + ops.store(current_out_ptr, vexp); + *(VT*)current_out_ptr = vexp; + vnormalizer = ops.add(vnormalizer, vexp); + current_in_ptr += N; + current_out_ptr += N; + s -= N; + } + T normalizer = ops.reduce_add(vnormalizer); + while (s-- > 0) { + T _exp = std::exp(*current_in_ptr - maximum); + *current_out_ptr = _exp; + normalizer += _exp; + current_in_ptr++; + current_out_ptr++; + } + normalizer = 1 / normalizer; + + // Normalize + current_out_ptr = out_ptr; + s = M; + while (s >= N) { + ops.store(current_out_ptr, ops.mul(*(VT*)current_out_ptr, normalizer)); + current_out_ptr += N; + s -= N; + } + while (s-- > 0) { + *current_out_ptr *= normalizer; + current_out_ptr++; + } + } +} + +} // namespace + +void Softmax::eval_cpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + + // Make sure that the last dimension is contiguous + auto check_input = [](array x) { + if (x.strides()[x.ndim() - 1] == 1) { + return x; + } else { + array x_copy(x.shape(), x.dtype(), nullptr, {}); + copy(x, x_copy, CopyType::General); + return x_copy; + } + }; + array in = check_input(std::move(inputs[0])); + out.set_data( + allocator::malloc_or_wait(in.data_size() * in.itemsize()), + in.data_size(), + in.strides(), + in.flags()); + + switch (in.dtype()) { + case bool_: + case uint8: + case uint16: + case uint32: + case uint64: + case int8: + case int16: + case int32: + case int64: + throw std::invalid_argument( + "Softmax is defined only for floating point types"); + break; + case float32: + softmax, 16>( + in, out); + break; + case float16: + softmax< + float16_t, + float16x8_t, + NeonFp16SimdOps, + 8>(in, out); + break; + case bfloat16: + eval(inputs, out); + break; + case complex64: + eval(inputs, out); + break; + } +} + +} // namespace mlx::core diff --git a/mlx/backend/common/binary.cpp b/mlx/backend/common/binary.cpp new file mode 100644 index 000000000..3e4166a7a --- /dev/null +++ b/mlx/backend/common/binary.cpp @@ -0,0 +1,216 @@ +#include +#include +#include + +#include "mlx/allocator.h" +#include "mlx/backend/common/binary.h" +#include "mlx/primitives.h" +#include "mlx/utils.h" + +namespace mlx::core { + +namespace { + +template +void comparison_op(const array& a, const array& b, array& out, Op op) { + DefaultScalarVector opsv(op); + DefaultVectorScalar opvs(op); + DefaultVectorVector opvv(op); + binary_op(a, b, out, op, opsv, opvs, opvv); +} + +template +void comparison_op(const array& a, const array& b, array& out, Op op) { + switch (a.dtype()) { + case bool_: + comparison_op(a, b, out, op); + break; + case uint8: + comparison_op(a, b, out, op); + break; + case uint16: + comparison_op(a, b, out, op); + break; + case uint32: + comparison_op(a, b, out, op); + break; + case uint64: + comparison_op(a, b, out, op); + break; + case int8: + comparison_op(a, b, out, op); + break; + case int16: + comparison_op(a, b, out, op); + break; + case int32: + comparison_op(a, b, out, op); + break; + case int64: + comparison_op(a, b, out, op); + break; + case float16: + comparison_op(a, b, out, op); + break; + case float32: + comparison_op(a, b, out, op); + break; + case bfloat16: + comparison_op(a, b, out, op); + break; + case complex64: + comparison_op(a, b, out, op); + break; + } +} + +} // namespace + +void Add::eval(const std::vector& inputs, array& out) { + assert(inputs.size() == 2); + auto& a = inputs[0]; + auto& b = inputs[1]; + binary(a, b, out, [](auto x, auto y) { return x + y; }); +} + +void Divide::eval(const std::vector& inputs, array& out) { + assert(inputs.size() == 2); + auto& a = inputs[0]; + auto& b = inputs[1]; + binary(a, b, out, [](auto x, auto y) { return x / y; }); +} + +void Equal::eval(const std::vector& inputs, array& out) { + assert(inputs.size() == 2); + if (equal_nan_) { + comparison_op(inputs[0], inputs[1], out, [](auto x, auto y) { + return x == y || (std::isnan(x) && std::isnan(y)); + }); + } else { + comparison_op( + inputs[0], inputs[1], out, [](auto x, auto y) { return x == y; }); + } +} + +void Greater::eval(const std::vector& inputs, array& out) { + assert(inputs.size() == 2); + comparison_op( + inputs[0], inputs[1], out, [](auto x, auto y) { return x > y; }); +} + +void GreaterEqual::eval(const std::vector& inputs, array& out) { + assert(inputs.size() == 2); + comparison_op( + inputs[0], inputs[1], out, [](auto x, auto y) { return x >= y; }); +} + +void Less::eval(const std::vector& inputs, array& out) { + assert(inputs.size() == 2); + comparison_op( + inputs[0], inputs[1], out, [](auto x, auto y) { return x < y; }); +} + +void LessEqual::eval(const std::vector& inputs, array& out) { + assert(inputs.size() == 2); + comparison_op( + inputs[0], inputs[1], out, [](auto x, auto y) { return x <= y; }); +} + +void LogAddExp::eval(const std::vector& inputs, array& out) { + assert(inputs.size() == 2); + auto& a = inputs[0]; + auto& b = inputs[1]; + auto op = [](auto x, auto y) { + constexpr float inf = std::numeric_limits::infinity(); + auto maxval = (x > y) ? x : y; + auto minval = (x > y) ? y : x; + return (minval == -inf || maxval == inf) + ? maxval + : static_cast( + maxval + std::log1p(std::exp(minval - maxval))); + }; + if (is_floating_point(out.dtype())) { + if (out.dtype() == float32) { + binary_op(a, b, out, op); + } else if (out.dtype() == float16) { + binary_op(a, b, out, op); + } else if (out.dtype() == bfloat16) { + binary_op(a, b, out, op); + } else { + std::ostringstream err; + err << "[logaddexp] Does not support " << out.dtype(); + throw std::invalid_argument(err.str()); + } + } else { + throw std::invalid_argument( + "[logaddexp] Cannot compute logaddexp for arrays with" + " non floating point type."); + } +} + +void Maximum::eval(const std::vector& inputs, array& out) { + assert(inputs.size() == 2); + auto& a = inputs[0]; + auto& b = inputs[1]; + binary(a, b, out, [](auto x, auto y) { return (x > y) ? x : y; }); +} + +void Minimum::eval(const std::vector& inputs, array& out) { + assert(inputs.size() == 2); + auto& a = inputs[0]; + auto& b = inputs[1]; + binary(a, b, out, [](auto x, auto y) { return (x < y) ? x : y; }); +} + +void Multiply::eval(const std::vector& inputs, array& out) { + assert(inputs.size() == 2); + auto& a = inputs[0]; + auto& b = inputs[1]; + binary(a, b, out, [](auto x, auto y) { return x * y; }); +} + +void NotEqual::eval(const std::vector& inputs, array& out) { + assert(inputs.size() == 2); + comparison_op( + inputs[0], inputs[1], out, [](auto x, auto y) { return x != y; }); +} + +struct PowerFn { + template + std::enable_if_t, T> operator()(T base, T exp) { + return std::pow(base, exp); + } + + template + std::enable_if_t, T> operator()(T base, T exp) { + if (exp < 0) { + throw std::invalid_argument( + "Integers cannot be raise to negative powers"); + } + T res = 1; + while (exp) { + if (exp & 1) { + res *= base; + } + exp >>= 1; + base *= base; + } + return res; + } +}; + +void Power::eval(const std::vector& inputs, array& out) { + assert(inputs.size() == 2); + auto& a = inputs[0]; + auto& b = inputs[1]; + binary(a, b, out, PowerFn{}); +} + +void Subtract::eval(const std::vector& inputs, array& out) { + assert(inputs.size() == 2); + auto& a = inputs[0]; + auto& b = inputs[1]; + binary(a, b, out, [](auto x, auto y) { return x - y; }); +} + +} // namespace mlx::core diff --git a/mlx/backend/common/binary.h b/mlx/backend/common/binary.h new file mode 100644 index 000000000..744ab5202 --- /dev/null +++ b/mlx/backend/common/binary.h @@ -0,0 +1,554 @@ +#pragma once + +#include "mlx/allocator.h" +#include "mlx/array.h" +#include "mlx/backend/common/utils.h" + +namespace mlx::core { + +namespace { + +enum BinaryOpType { + ScalarScalar, + ScalarVector, + VectorScalar, + VectorVector, + General, +}; + +BinaryOpType get_binary_op_type(const array& a, const array& b) { + BinaryOpType bopt; + if (a.data_size() == 1 && b.data_size() == 1) { + bopt = ScalarScalar; + } else if (a.data_size() == 1 && b.flags().contiguous) { + bopt = ScalarVector; + } else if (b.data_size() == 1 && a.flags().contiguous) { + bopt = VectorScalar; + } else if ( + a.flags().row_contiguous && b.flags().row_contiguous || + a.flags().col_contiguous && b.flags().col_contiguous) { + bopt = VectorVector; + } else { + bopt = General; + } + return bopt; +} + +void set_binary_op_output_data( + const array& a, + const array& b, + array& out, + BinaryOpType bopt) { + switch (bopt) { + case ScalarScalar: + out.set_data( + allocator::malloc_or_wait(out.itemsize()), 1, a.strides(), a.flags()); + break; + case ScalarVector: + out.set_data( + allocator::malloc_or_wait(b.data_size() * out.itemsize()), + b.data_size(), + b.strides(), + b.flags()); + break; + case VectorScalar: + case VectorVector: + out.set_data( + allocator::malloc_or_wait(a.data_size() * out.itemsize()), + a.data_size(), + a.strides(), + a.flags()); + break; + case General: + out.set_data(allocator::malloc_or_wait(out.nbytes())); + break; + } +} + +struct UseDefaultBinaryOp { + template + void operator()(const T* a, const T* b, U* dst, int size) { + // Should we throw? This should normally never be called. + assert(false); + } +}; + +template +struct DefaultVectorScalar { + Op op; + + DefaultVectorScalar(Op op_) : op(op_) {} + + void operator()(const T* a, const T* b, U* dst, int size) { + T scalar = *b; + while (size-- > 0) { + *dst = op(*a, scalar); + dst++; + a++; + } + } +}; + +template +struct DefaultScalarVector { + Op op; + + DefaultScalarVector(Op op_) : op(op_) {} + + void operator()(const T* a, const T* b, U* dst, int size) { + T scalar = *a; + while (size-- > 0) { + *dst = op(scalar, *b); + dst++; + b++; + } + } +}; + +template +struct DefaultVectorVector { + Op op; + + DefaultVectorVector(Op op_) : op(op_) {} + + void operator()(const T* a, const T* b, U* dst, int size) { + while (size-- > 0) { + *dst = op(*a, *b); + dst++; + a++; + b++; + } + } +}; + +template +void binary_op_dims1(const array& a, const array& b, array& out, Op op) { + const T* a_ptr = a.data(); + const T* b_ptr = b.data(); + U* dst = out.data(); + size_t a_idx = 0; + size_t b_idx = 0; + for (size_t i = 0; i < out.size(); ++i) { + dst[i] = op(a_ptr[a_idx], b_ptr[b_idx]); + a_idx += a.strides()[0]; + b_idx += b.strides()[0]; + } +} + +template +void binary_op_dims1( + const array& a, + const array& b, + array& out, + Op op, + int stride) { + const T* a_ptr = a.data(); + const T* b_ptr = b.data(); + U* dst = out.data(); + size_t a_idx = 0; + size_t b_idx = 0; + for (size_t i = 0; i < a.shape()[0]; i++) { + op(a_ptr + a_idx, b_ptr + b_idx, dst, stride); + a_idx += a.strides()[0]; + b_idx += b.strides()[0]; + dst += stride; + } +} + +template +void binary_op_dims2(const array& a, const array& b, array& out, Op op) { + const T* a_ptr = a.data(); + const T* b_ptr = b.data(); + U* dst = out.data(); + size_t a_idx = 0; + size_t b_idx = 0; + size_t out_idx = 0; + for (size_t i = 0; i < a.shape()[0]; ++i) { + for (size_t j = 0; j < a.shape()[1]; ++j) { + dst[out_idx++] = op(a_ptr[a_idx], b_ptr[b_idx]); + a_idx += a.strides()[1]; + b_idx += b.strides()[1]; + } + a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1]; + b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1]; + } +} + +template +void binary_op_dims2( + const array& a, + const array& b, + array& out, + Op op, + int stride) { + const T* a_ptr = a.data(); + const T* b_ptr = b.data(); + U* dst = out.data(); + size_t a_idx = 0; + size_t b_idx = 0; + for (size_t i = 0; i < a.shape()[0]; ++i) { + for (size_t j = 0; j < a.shape()[1]; ++j) { + op(a_ptr + a_idx, b_ptr + b_idx, dst, stride); + a_idx += a.strides()[1]; + b_idx += b.strides()[1]; + dst += stride; + } + a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1]; + b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1]; + } +} + +template +void binary_op_dims3(const array& a, const array& b, array& out, Op op) { + const T* a_ptr = a.data(); + const T* b_ptr = b.data(); + U* dst = out.data(); + size_t a_idx = 0; + size_t b_idx = 0; + size_t out_idx = 0; + for (size_t i = 0; i < a.shape()[0]; ++i) { + for (size_t j = 0; j < a.shape()[1]; ++j) { + for (size_t k = 0; k < a.shape()[2]; ++k) { + dst[out_idx++] = op(a_ptr[a_idx], b_ptr[b_idx]); + a_idx += a.strides()[2]; + b_idx += b.strides()[2]; + } + a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2]; + b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2]; + } + a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1]; + b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1]; + } +} + +template +void binary_op_dims4(const array& a, const array& b, array& out, Op op) { + const T* a_ptr = a.data(); + const T* b_ptr = b.data(); + U* dst = out.data(); + size_t a_idx = 0; + size_t b_idx = 0; + size_t out_idx = 0; + for (size_t i = 0; i < a.shape()[0]; ++i) { + for (size_t j = 0; j < a.shape()[1]; ++j) { + for (size_t k = 0; k < a.shape()[2]; ++k) { + for (size_t ii = 0; ii < a.shape()[3]; ++ii) { + dst[out_idx++] = op(a_ptr[a_idx], b_ptr[b_idx]); + a_idx += a.strides()[3]; + b_idx += b.strides()[3]; + } + a_idx += a.strides()[2] - a.strides()[3] * a.shape()[3]; + b_idx += b.strides()[2] - b.strides()[3] * b.shape()[3]; + } + a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2]; + b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2]; + } + a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1]; + b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1]; + } +} + +template +void binary_op_dispatch_dims( + const array& a, + const array& b, + array& out, + Op op) { + switch (out.ndim()) { + case 1: + binary_op_dims1(a, b, out, op); + return; + case 2: + binary_op_dims2(a, b, out, op); + return; + case 3: + binary_op_dims3(a, b, out, op); + return; + case 4: + binary_op_dims4(a, b, out, op); + return; + } + + const T* a_ptr = a.data(); + const T* b_ptr = b.data(); + U* dst = out.data(); + for (size_t i = 0; i < out.size(); i++) { + int a_idx = elem_to_loc(i, a.shape(), a.strides()); + int b_idx = elem_to_loc(i, b.shape(), b.strides()); + dst[i] = op(a_ptr[a_idx], b_ptr[b_idx]); + } +} + +template +void binary_op_dispatch_dims( + const array& a, + const array& b, + array& out, + Op op, + int dim, + int stride) { + // Number of dimensions to loop over for vectorized ops + switch (dim) { + case 1: + binary_op_dims1(a, b, out, op, stride); + return; + case 2: + binary_op_dims2(a, b, out, op, stride); + return; + } + + const T* a_ptr = a.data(); + const T* b_ptr = b.data(); + U* dst = out.data(); + for (size_t i = 0; i < out.size(); i += stride) { + int a_idx = elem_to_loc(i, a.shape(), a.strides()); + int b_idx = elem_to_loc(i, b.shape(), b.strides()); + op(a_ptr + a_idx, b_ptr + b_idx, dst, stride); + dst += stride; + } +} + +template < + typename T, + typename U, + typename Op, + typename OpSV, + typename OpVS, + typename OpVV> +void binary_op( + const array& a, + const array& b, + array& out, + Op op, + OpSV opsv, + OpVS opvs, + OpVV opvv) { + auto bopt = get_binary_op_type(a, b); + set_binary_op_output_data(a, b, out, bopt); + + // The full computation is scalar scalar so call the base op once + if (bopt == ScalarScalar) { + *(out.data()) = op(*a.data(), *b.data()); + return; + } + + // The full computation is scalar vector so delegate to the op + if (bopt == ScalarVector) { + opsv(a.data(), b.data(), out.data(), b.data_size()); + return; + } + + // The full computation is vector scalar so delegate to the op + if (bopt == VectorScalar) { + opvs(a.data(), b.data(), out.data(), a.data_size()); + return; + } + + // The full computation is vector vector so delegate to the op + if (bopt == VectorVector) { + opvv(a.data(), b.data(), out.data(), out.size()); + return; + } + + // General computation so let's try to optimize + + // Get the left-most dim such that the array is row contiguous after + auto& strides = out.strides(); + auto leftmost_rc_dim = [&strides](const array& arr) { + int d = arr.ndim() - 1; + for (; d >= 0 && arr.strides()[d] == strides[d]; d--) { + } + return d + 1; + }; + auto a_rc_dim = leftmost_rc_dim(a); + auto b_rc_dim = leftmost_rc_dim(b); + + // Get the left-most dim such that the array is a broadcasted "scalar" after + auto leftmost_s_dim = [](const array& arr) { + int d = arr.ndim() - 1; + for (; d >= 0 && arr.strides()[d] == 0; d--) { + } + return d + 1; + }; + auto a_s_dim = leftmost_s_dim(a); + auto b_s_dim = leftmost_s_dim(b); + + auto ndim = out.ndim(); + + // Case 1: LxM and FxM where L and F are broadcastable and M is row contiguous + int dim = ndim; + if (int d = std::max(a_rc_dim, b_rc_dim); d < ndim) { + bopt = VectorVector; + dim = d; + // Case 2: LxM and Fx1 where L and F are broadcastable and M is row + // contiguous + } else if (int d = std::max(a_rc_dim, b_s_dim); d < ndim) { + bopt = VectorScalar; + dim = d; + // Case 3: Lx1 and FxM where L and F are broadcastable and M is row + // contiguous + } else if (int d = std::max(a_s_dim, b_rc_dim); d < ndim) { + bopt = ScalarVector; + dim = d; + } + + // Can be sure dim > 0 since otherwise we would have used one of the fully + // contiguous methods above. Except for the case that the flags do not + // correspond to the underlying contiguity. + size_t stride; + if (dim == 0 || strides[dim - 1] < 16) { + stride = 1; + bopt = General; + dim = ndim; + } else { + stride = strides[dim - 1]; + } + + switch (bopt) { + case VectorVector: + binary_op_dispatch_dims(a, b, out, opvv, dim, stride); + break; + case VectorScalar: + binary_op_dispatch_dims(a, b, out, opvs, dim, stride); + break; + case ScalarVector: + binary_op_dispatch_dims(a, b, out, opsv, dim, stride); + break; + default: + binary_op_dispatch_dims(a, b, out, op); + break; + } +} + +template +void binary_op( + const array& a, + const array& b, + array& out, + Op op, + OpSV opsv, + OpVS opvs, + OpVV opvv) { + // TODO: The following mess of constexpr evaluations can probably be achieved + // with template specializations and overloading. Would it be simpler? + + if (std::is_same::value) { + if (std::is_same::value) { + if (std::is_same::value) { + // All ops are UseDefaultBinaryOp (why oh why would someone call that?) + binary_op( + a, + b, + out, + op, + DefaultScalarVector(op), + DefaultVectorScalar(op), + DefaultVectorVector(op)); + } else { + // opsv and opvs were UseDefaultBinaryOp + binary_op( + a, + b, + out, + op, + DefaultScalarVector(op), + DefaultVectorScalar(op), + opvv); + } + } else if (std::is_same::value) { + // opsv and opvv were UseDefaultBinaryOp + binary_op( + a, + b, + out, + op, + DefaultScalarVector(op), + opvs, + DefaultVectorVector(op)); + } else { + // opsv was UseDefaultBinaryOp + binary_op( + a, b, out, op, DefaultScalarVector(op), opvs, opvv); + } + } else if (std::is_same::value) { + if (std::is_same::value) { + // opvs and opvv were UseDefaultBinaryOp + binary_op( + a, + b, + out, + op, + opsv, + DefaultVectorScalar(op), + DefaultVectorVector(op)); + } else { + // opvs was UseDefaultBinaryOp + binary_op( + a, b, out, op, opsv, DefaultVectorScalar(op), opvv); + } + } else if (std::is_same::value) { + // opvv was UseDefaultBinaryOp + binary_op( + a, b, out, op, opsv, opvs, DefaultVectorVector(op)); + } else { + // All ops provided + binary_op(a, b, out, op, opsv, opvs, opvv); + } +} + +template +void binary_op(const array& a, const array& b, array& out, Op op) { + DefaultScalarVector opsv(op); + DefaultVectorScalar opvs(op); + DefaultVectorVector opvv(op); + binary_op(a, b, out, op, opsv, opvs, opvv); +} + +template +void binary(const array& a, const array& b, array& out, Ops... ops) { + switch (out.dtype()) { + case bool_: + binary_op(a, b, out, ops...); + break; + case uint8: + binary_op(a, b, out, ops...); + break; + case uint16: + binary_op(a, b, out, ops...); + break; + case uint32: + binary_op(a, b, out, ops...); + break; + case uint64: + binary_op(a, b, out, ops...); + break; + case int8: + binary_op(a, b, out, ops...); + break; + case int16: + binary_op(a, b, out, ops...); + break; + case int32: + binary_op(a, b, out, ops...); + break; + case int64: + binary_op(a, b, out, ops...); + break; + case float16: + binary_op(a, b, out, ops...); + break; + case float32: + binary_op(a, b, out, ops...); + break; + case bfloat16: + binary_op(a, b, out, ops...); + break; + case complex64: + binary_op(a, b, out, ops...); + break; + } +} + +} // namespace + +} // namespace mlx::core diff --git a/mlx/backend/common/fft.cpp b/mlx/backend/common/fft.cpp new file mode 100644 index 000000000..385d9cb71 --- /dev/null +++ b/mlx/backend/common/fft.cpp @@ -0,0 +1,85 @@ +#include + +#include "mlx/3rdparty/pocketfft.h" +#include "mlx/allocator.h" +#include "mlx/primitives.h" + +namespace mlx::core { + +void FFT::eval(const std::vector& inputs, array& out) { + auto& in = inputs[0]; + std::vector strides_in( + in.strides().begin(), in.strides().end()); + for (auto& s : strides_in) { + s *= in.itemsize(); + } + std::vector strides_out( + out.strides().begin(), out.strides().end()); + for (auto& s : strides_out) { + s *= out.itemsize(); + } + + out.set_data(allocator::malloc_or_wait(out.nbytes())); + + std::vector shape; + if (out.dtype() == float32) { + shape.insert(shape.end(), out.shape().begin(), out.shape().end()); + } else { + shape.insert(shape.end(), in.shape().begin(), in.shape().end()); + } + + float scale = 1.0f; + if (inverse_) { + size_t nelem = std::accumulate( + axes_.begin(), axes_.end(), 1, [&shape](auto x, auto y) { + return x * shape[y]; + }); + scale /= nelem; + } + if (in.dtype() == complex64 && out.dtype() == complex64) { + auto in_ptr = + reinterpret_cast*>(in.data()); + auto out_ptr = + reinterpret_cast*>(out.data()); + pocketfft::c2c( + shape, + strides_in, + strides_out, + axes_, + !inverse_, + in_ptr, + out_ptr, + scale); + } else if (in.dtype() == float32 && out.dtype() == complex64) { + auto in_ptr = in.data(); + auto out_ptr = + reinterpret_cast*>(out.data()); + pocketfft::r2c( + shape, + strides_in, + strides_out, + axes_, + !inverse_, + in_ptr, + out_ptr, + scale); + } else if (in.dtype() == complex64 && out.dtype() == float32) { + auto in_ptr = + reinterpret_cast*>(in.data()); + auto out_ptr = out.data(); + pocketfft::c2r( + shape, + strides_in, + strides_out, + axes_, + !inverse_, + in_ptr, + out_ptr, + scale); + } else { + throw std::runtime_error( + "[FFT] Received unexpected input and output type combination."); + } +} + +} // namespace mlx::core diff --git a/mlx/backend/common/softmax.cpp b/mlx/backend/common/softmax.cpp new file mode 100644 index 000000000..dabf81143 --- /dev/null +++ b/mlx/backend/common/softmax.cpp @@ -0,0 +1,98 @@ +#include +#include + +#include "mlx/backend/common/copy.h" +#include "mlx/primitives.h" + +namespace mlx::core { + +namespace { + +template +void softmax(const array& in, array& out) { + const T* in_ptr = in.data(); + T* out_ptr = out.data(); + int N = in.shape().back(); + int M = in.data_size() / N; + const T* current_in_ptr; + T* current_out_ptr; + + for (int i = 0; i < M; i++, in_ptr += N, out_ptr += N) { + // Find the maximum + current_in_ptr = in_ptr; + T maximum = *current_in_ptr; + for (int j = 0; j < N; j++, current_in_ptr++) { + maximum = (maximum < *current_in_ptr) ? *current_in_ptr : maximum; + } + + // Compute the normalizer and the exponentials + T normalizer = 0; + current_out_ptr = out_ptr; + current_in_ptr = in_ptr; + for (int j = 0; j < N; j++, current_out_ptr++, current_in_ptr++) { + T expv = std::exp(*current_in_ptr - maximum); + normalizer += expv; + *current_out_ptr = expv; + } + normalizer = 1 / normalizer; + + // Normalize + current_out_ptr = out_ptr; + for (int j = 0; j < N; j++, current_out_ptr++) { + *current_out_ptr *= normalizer; + } + } +} + +} // namespace + +void Softmax::eval(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + + // Make sure that the last dimension is contiguous + auto check_input = [](array x) { + if (x.strides().back() == 1) { + return x; + } else { + array x_copy(x.shape(), x.dtype(), nullptr, {}); + copy(x, x_copy, CopyType::General); + return x_copy; + } + }; + array in = check_input(std::move(inputs[0])); + out.set_data( + allocator::malloc_or_wait(in.data_size() * in.itemsize()), + in.data_size(), + in.strides(), + in.flags()); + + switch (in.dtype()) { + case bool_: + case uint8: + case uint16: + case uint32: + case uint64: + case int8: + case int16: + case int32: + case int64: + throw std::invalid_argument( + "Softmax is defined only for floating point types"); + break; + case float32: + softmax(in, out); + break; + case float16: + softmax(in, out); + break; + case bfloat16: + softmax(in, out); + break; + case complex64: + throw std::invalid_argument( + "[Softmax] Not yet implemented for complex64"); + break; + } +} + +} // namespace mlx::core diff --git a/mlx/backend/metal/allocator.cpp b/mlx/backend/metal/allocator.cpp new file mode 100644 index 000000000..40ec85a13 --- /dev/null +++ b/mlx/backend/metal/allocator.cpp @@ -0,0 +1,200 @@ +#include "mlx/backend/metal/allocator.h" +#include "mlx/backend/metal/metal.h" + +#include +#include +#include + +namespace mlx::core { + +namespace allocator { + +Allocator& allocator() { + return metal::allocator(); +} + +void* Buffer::raw_ptr() { + return static_cast(ptr_)->contents(); +} + +} // namespace allocator + +namespace metal { + +namespace { + +BufferCache::BufferCache(MTL::Device* device) + : device_(device), + head_(nullptr), + tail_(nullptr), + pool_size_(0), + gc_limit_(0.95 * device_->recommendedMaxWorkingSetSize()) {} + +BufferCache::~BufferCache() { + clear(); +} + +void BufferCache::clear() { + std::lock_guard lk(cache_mutex_); + for (auto& [size, holder] : buffer_pool_) { + if (holder->buf) + holder->buf->release(); + delete holder; + } + buffer_pool_.clear(); + pool_size_ = 0; + head_ = nullptr; + tail_ = nullptr; +} + +MTL::Buffer* BufferCache::reuse_from_cache(size_t size) { + std::lock_guard lk(cache_mutex_); + + // Find the closest buffer in pool + MTL::Buffer* pbuf = nullptr; + auto it = buffer_pool_.lower_bound(size); + + // Make sure we use > 50% of the available memory + while (!pbuf && it != buffer_pool_.end() && it->first < 2 * size) { + // Collect from the cache + pbuf = it->second->buf; + // Remove from cache + remove_from_list(it->second); + delete it->second; + it = buffer_pool_.erase(it); + } + + if (pbuf) { + pool_size_ -= pbuf->length(); + } + + return pbuf; +} + +void BufferCache::recycle_to_cache(MTL::Buffer* buf) { + std::lock_guard lk(cache_mutex_); + + // Add to cache + if (buf) { + BufferHolder* bh = new BufferHolder(buf); + add_at_head(bh); + pool_size_ += buf->length(); + buffer_pool_.insert({buf->length(), bh}); + } +} + +size_t BufferCache::release_cached_buffers(size_t min_bytes_to_free) { + min_bytes_to_free += device_->currentAllocatedSize() - gc_limit_; + + if (min_bytes_to_free >= 0.9 * pool_size_) { + size_t old_pool_size = pool_size_; + clear(); + return old_pool_size; + } else { + std::lock_guard lk(cache_mutex_); + size_t total_bytes_freed = 0; + + while (tail_ && (total_bytes_freed < min_bytes_to_free)) { + if (tail_->buf) { + total_bytes_freed += tail_->buf->length(); + tail_->buf->release(); + tail_->buf = nullptr; + } + remove_from_list(tail_); + } + + pool_size_ -= total_bytes_freed; + return total_bytes_freed; + } +} + +void BufferCache::add_at_head(BufferCache::BufferHolder* to_add) { + if (!to_add) + return; + + if (!head_) { + head_ = to_add; + tail_ = to_add; + } else { + head_->prev = to_add; + to_add->next = head_; + head_ = to_add; + } +} + +void BufferCache::remove_from_list(BufferCache::BufferHolder* to_remove) { + if (!to_remove) + return; + + // If in the middle + if (to_remove->prev && to_remove->next) { + to_remove->prev->next = to_remove->next; + to_remove->next->prev = to_remove->prev; + } else if (to_remove->prev && to_remove == tail_) { // If tail + tail_ = to_remove->prev; + tail_->next = nullptr; + } else if (to_remove == head_ && to_remove->next) { // If head + head_ = to_remove->next; + head_->prev = nullptr; + } else if (to_remove == head_ && to_remove == tail_) { // If only element + head_ = nullptr; + tail_ = nullptr; + } + + to_remove->prev = nullptr; + to_remove->next = nullptr; +} + +} // namespace + +MetalAllocator::MetalAllocator() + : device_(device(mlx::core::Device::gpu).mtl_device()), + buffer_cache_(device_), + peak_allocated_size_(0), + block_limit_(1.5 * device_->recommendedMaxWorkingSetSize()) {} + +Buffer MetalAllocator::malloc(size_t size) { + // Align up memory + if (size > vm_page_size) { + size = vm_page_size * ((size + vm_page_size - 1) / vm_page_size); + } + + MTL::Buffer* buf = buffer_cache_.reuse_from_cache(size); + + // Prepare to allocate new memory as needed + if (!buf) { + // If we are under very high memoory pressure, we don't allocate further + if (device_->currentAllocatedSize() >= block_limit_) { + return Buffer{nullptr}; + } + + // If we are still under memory pressure, try cleaning cache + if (buffer_cache_.can_garbage_collect()) { + buffer_cache_.release_cached_buffers(size); + } + + // Allocate new buffer if needed + size_t res_opt = MTL::ResourceStorageModeShared; + res_opt |= MTL::ResourceHazardTrackingModeTracked; + buf = device_->newBuffer(size, res_opt); + } + + peak_allocated_size_ = + std::max(peak_allocated_size_, device_->currentAllocatedSize()); + + return Buffer{static_cast(buf)}; +} + +void MetalAllocator::free(Buffer buffer) { + auto buf = static_cast(buffer.ptr()); + buffer_cache_.recycle_to_cache(buf); +} + +MetalAllocator& allocator() { + static MetalAllocator allocator_; + return allocator_; +} + +} // namespace metal + +} // namespace mlx::core diff --git a/mlx/backend/metal/allocator.h b/mlx/backend/metal/allocator.h new file mode 100644 index 000000000..92ede1757 --- /dev/null +++ b/mlx/backend/metal/allocator.h @@ -0,0 +1,76 @@ +#pragma once + +#include +#include +#include + +#include "mlx/allocator.h" +#include "mlx/backend/metal/device.h" + +namespace mlx::core::metal { + +using allocator::Buffer; + +namespace { + +class BufferCache { + public: + BufferCache(MTL::Device* device); + ~BufferCache(); + void clear(); + + MTL::Buffer* reuse_from_cache(size_t size); + void recycle_to_cache(MTL::Buffer* buf); + size_t release_cached_buffers(size_t min_bytes_to_free); + + bool can_garbage_collect() { + return pool_size_ > 0 && device_->currentAllocatedSize() > gc_limit_; + } + + private: + struct BufferHolder { + public: + BufferHolder(MTL::Buffer* buf_) : buf(buf_), prev(nullptr), next(nullptr) {} + + BufferHolder* prev; + BufferHolder* next; + MTL::Buffer* buf; + }; + + void add_at_head(BufferHolder* to_add); + void remove_from_list(BufferHolder* to_remove); + + MTL::Device* device_; + std::mutex cache_mutex_; + + std::multimap buffer_pool_; + BufferHolder* head_; + BufferHolder* tail_; + size_t pool_size_; + size_t gc_limit_; +}; + +} // namespace + +class MetalAllocator : public allocator::Allocator { + /** Allocator for Metal GPUs. */ + public: + virtual Buffer malloc(size_t size) override; + virtual void free(Buffer buffer) override; + + private: + MTL::Device* device_; + MetalAllocator(); + friend MetalAllocator& allocator(); + + // Caching allocator + BufferCache buffer_cache_; + + // Allocation stats + size_t peak_allocated_size_; + size_t block_limit_; +}; + +MetalAllocator& allocator(); + +} // namespace mlx::core::metal diff --git a/mlx/backend/metal/conv.cpp b/mlx/backend/metal/conv.cpp new file mode 100644 index 000000000..e13297174 --- /dev/null +++ b/mlx/backend/metal/conv.cpp @@ -0,0 +1,555 @@ +#include +#include +#include +#include +#include + +#include "mlx/backend/metal/copy.h" +#include "mlx/backend/metal/device.h" +#include "mlx/backend/metal/kernels/conv_params.h" +#include "mlx/backend/metal/kernels/defines.h" +#include "mlx/backend/metal/matmul.h" +#include "mlx/backend/metal/utils.h" +#include "mlx/primitives.h" +#include "mlx/utils.h" + +namespace mlx::core { + +namespace { + +void explicit_gemm_conv_1D_gpu( + const Stream& s, + metal::Device& d, + const array& in, + const array& wt, + array out, + const MLXConvParams<1>& conv_params) { + // Pad input + std::vector padded_shape = { + conv_params.N, conv_params.iS[0] + 2 * conv_params.pad[0], conv_params.C}; + array in_padded(padded_shape, in.dtype(), nullptr, {}); + + // Fill with zeros + copy_gpu(array(0, in.dtype()), in_padded, CopyType::Scalar, s); + + // Pick input slice from padded + size_t data_offset = conv_params.pad[0] * in_padded.strides()[1]; + array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {}); + in_padded_slice.copy_shared_buffer( + in_padded, + in_padded.strides(), + in_padded.flags(), + in_padded_slice.size(), + data_offset); + + // Copy input values into the slice + copy_gpu_inplace(in, in_padded_slice, CopyType::GeneralGeneral, s); + + // Make strided view + std::vector strided_shape = { + conv_params.N, conv_params.oS[0], conv_params.wS[0], conv_params.C}; + + std::vector strided_strides = { + in_padded.strides()[0], + in_padded.strides()[1] * conv_params.str[0], + in_padded.strides()[1], + in_padded.strides()[2]}; + auto flags = in_padded.flags(); + + array in_strided_view(strided_shape, in_padded.dtype(), nullptr, {}); + in_strided_view.copy_shared_buffer( + in_padded, strided_strides, flags, in_strided_view.size(), 0); + + // Materialize strided view + std::vector strided_reshape = { + conv_params.N * conv_params.oS[0], conv_params.wS[0] * conv_params.C}; + array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {}); + copy_gpu(in_strided_view, in_strided, CopyType::General, s); + + // Peform gemm + std::vector copies = {in_padded, in_strided}; + mlx_matmul( + s, + d, + /*a = */ in_strided, + /*b = */ wt, + /*c = */ out, + /*M = */ strided_reshape[0], + /*N = */ conv_params.O, + /*K = */ strided_reshape[1], + /*batch_size_out = */ 1, + /*a_cols = */ strided_reshape[1], + /*b_cols = */ strided_reshape[1], + /*a_transposed = */ false, + /*b_transposed = */ true, + /*copies = */ copies); +} + +void conv_1D_gpu( + const Stream& s, + metal::Device& d, + const array& in, + const array& wt, + array out, + const std::vector& padding, + const std::vector& wt_strides, + const std::vector& wt_dilation) { + // Make conv params + MLXConvParams<1> conv_params{ + /* const int N = */ in.shape(0), + /* const int C = */ in.shape(2), + /* const int O = */ wt.shape(0), + /* const int iS[NDIM] = */ {in.shape(1)}, + /* const int wS[NDIM] = */ {wt.shape(1)}, + /* const int oS[NDIM] = */ {out.shape(1)}, + /* const int str[NDIM] = */ {wt_strides[0]}, + /* const int pad[NDIM] = */ {padding[0]}, + /* const int dil[NDIM] = */ {wt_dilation[0]}, + /* const size_t in_strides[NDIM + 2] = */ + {in.strides()[0], in.strides()[1], in.strides()[2]}, + /* const size_t wt_strides[NDIM + 2] = */ + {wt.strides()[0], wt.strides()[1], wt.strides()[2]}, + /* const size_t out_strides[NDIM + 2] = */ + {out.strides()[0], out.strides()[1], out.strides()[2]}, + }; + + // Direct to explicit gemm conv + if (wt_dilation[0] == 1) { + explicit_gemm_conv_1D_gpu(s, d, in, wt, out, conv_params); + } + + // Direct to fallback conv + else { + throw std::invalid_argument("[conv_1D_gpu] Dilation needs to be 1."); + } +} + +void slow_conv_2D_gpu( + const Stream& s, + metal::Device& d, + const array& in, + const array& wt, + array out, + const MLXConvParams<2>& conv_params) { + int bm = 16, bn = 8; + int tm = 4, tn = 4; + + std::ostringstream kname; + kname << "naive_conv_2d_" << type_to_name(out) << "_bm" << bm << "_bn" << bn + << "_tm" << tm << "_tn" << tn; + + // Encode and dispatch kernel + auto compute_encoder = d.get_command_encoder(s.index); + auto kernel = d.get_kernel(kname.str()); + compute_encoder->setComputePipelineState(kernel); + + size_t n_pixels = conv_params.oS[0] * conv_params.oS[1]; + + size_t grid_dim_x = (n_pixels + (tm * bm) - 1) / (tm * bm); + size_t grid_dim_y = (conv_params.O + (tn * bn) - 1) / (tn * bn); + size_t grid_dim_z = conv_params.N; + + MTL::Size group_dims = MTL::Size(bm, bn, 1); + MTL::Size grid_dims = MTL::Size(grid_dim_x, grid_dim_y, grid_dim_z); + + set_array_buffer(compute_encoder, in, 0); + set_array_buffer(compute_encoder, wt, 1); + set_array_buffer(compute_encoder, out, 2); + + compute_encoder->setBytes(&conv_params, sizeof(MLXConvParams<2>), 3); + compute_encoder->dispatchThreadgroups(grid_dims, group_dims); +} + +void implicit_gemm_conv_2D_gpu( + const Stream& s, + metal::Device& d, + const array& in, + const array& wt, + array out, + const MLXConvParams<2>& conv_params) { + int bm = 32, bn = 32, bk = 16; + int wm = 2, wn = 2; + + std::ostringstream kname; + kname << "implicit_gemm_conv_2d_" << type_to_name(out) << "_bm" << bm << "_bn" + << bn << "_bk" << bk << "_wm" << wm << "_wn" << wn; + + // Encode and dispatch kernel + auto compute_encoder = d.get_command_encoder(s.index); + auto kernel = d.get_kernel(kname.str()); + compute_encoder->setComputePipelineState(kernel); + + int implicit_M = conv_params.N * conv_params.oS[0] * conv_params.oS[1]; + int implicit_N = conv_params.O; + int implicit_K = conv_params.wS[0] * conv_params.wS[1] * conv_params.C; + + size_t grid_dim_x = (implicit_N + bn - 1) / bn; + size_t grid_dim_y = (implicit_M + bm - 1) / bm; + + MTL::Size group_dims = MTL::Size(32, wn, wm); + MTL::Size grid_dims = MTL::Size(grid_dim_x, grid_dim_y, 1); + + set_array_buffer(compute_encoder, in, 0); + set_array_buffer(compute_encoder, wt, 1); + set_array_buffer(compute_encoder, out, 2); + + compute_encoder->setBytes(&conv_params, sizeof(MLXConvParams<2>), 3); + compute_encoder->dispatchThreadgroups(grid_dims, group_dims); +} + +void explicit_gemm_conv_2D_gpu( + const Stream& s, + metal::Device& d, + const array& in, + const array& wt, + array out, + const MLXConvParams<2>& conv_params) { + // Pad input + std::vector padded_shape = { + conv_params.N, + conv_params.iS[0] + 2 * conv_params.pad[0], + conv_params.iS[1] + 2 * conv_params.pad[1], + conv_params.C}; + array in_padded(padded_shape, in.dtype(), nullptr, {}); + + // Fill with zeros + copy_gpu(array(0, in.dtype()), in_padded, CopyType::Scalar, s); + + // Pick input slice from padded + size_t data_offset = conv_params.pad[0] * in_padded.strides()[1] + + conv_params.pad[1] * in_padded.strides()[2]; + array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {}); + in_padded_slice.copy_shared_buffer( + in_padded, + in_padded.strides(), + in_padded.flags(), + in_padded_slice.size(), + data_offset); + + // Copy input values into the slice + copy_gpu_inplace(in, in_padded_slice, CopyType::GeneralGeneral, s); + + // Make strided view + std::vector strided_shape = { + conv_params.N, + conv_params.oS[0], + conv_params.oS[1], + conv_params.wS[0], + conv_params.wS[1], + conv_params.C}; + + std::vector strided_strides = { + in_padded.strides()[0], + in_padded.strides()[1] * conv_params.str[0], + in_padded.strides()[2] * conv_params.str[1], + in_padded.strides()[1], + in_padded.strides()[2], + in_padded.strides()[3]}; + auto flags = in_padded.flags(); + + array in_strided_view(strided_shape, in_padded.dtype(), nullptr, {}); + in_strided_view.copy_shared_buffer( + in_padded, strided_strides, flags, in_strided_view.size(), 0); + + // Materialize strided view + std::vector strided_reshape = { + conv_params.N * conv_params.oS[0] * conv_params.oS[1], + conv_params.wS[0] * conv_params.wS[1] * conv_params.C}; + array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {}); + copy_gpu(in_strided_view, in_strided, CopyType::General, s); + + // Peform gemm + std::vector copies = {in_padded, in_strided}; + mlx_matmul( + s, + d, + /*a = */ in_strided, + /*b = */ wt, + /*c = */ out, + /*M = */ strided_reshape[0], + /*N = */ conv_params.O, + /*K = */ strided_reshape[1], + /*batch_size_out = */ 1, + /*a_cols = */ strided_reshape[1], + /*b_cols = */ strided_reshape[1], + /*a_transposed = */ false, + /*b_transposed = */ true, + /*copies = */ copies); +} + +void winograd_conv_2D_gpu( + const Stream& s, + metal::Device& d, + const array& in, + const array& wt, + array out, + const MLXConvParams<2>& conv_params, + std::vector& copies_w) { + std::vector padded_shape = { + conv_params.N, + conv_params.iS[0] + 2 * conv_params.pad[0], + conv_params.iS[1] + 2 * conv_params.pad[1], + conv_params.C}; + + padded_shape[1] = 6 * ((padded_shape[1] - 2 + 5) / 6) + 2; + padded_shape[2] = 6 * ((padded_shape[2] - 2 + 5) / 6) + 2; + + array in_padded(padded_shape, in.dtype(), nullptr, {}); + + // Fill with zeros + array zero_arr = array(0, in.dtype()); + copy_gpu(zero_arr, in_padded, CopyType::Scalar, s); + + // Pick input slice from padded + size_t data_offset = conv_params.pad[0] * in_padded.strides()[1] + + conv_params.pad[1] * in_padded.strides()[2]; + array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {}); + in_padded_slice.copy_shared_buffer( + in_padded, + in_padded.strides(), + in_padded.flags(), + in_padded_slice.size(), + data_offset); + + // Copy input values into the slice + copy_gpu_inplace(in, in_padded_slice, CopyType::GeneralGeneral, s); + + copies_w.push_back(in_padded_slice); + copies_w.push_back(in_padded); + copies_w.push_back(zero_arr); + + MLXConvParams<2> conv_params_updated{ + /* const int N = */ in_padded.shape(0), + /* const int C = */ in_padded.shape(3), + /* const int O = */ wt.shape(0), + /* const int iS[NDIM] = */ {in_padded.shape(1), in_padded.shape(2)}, + /* const int wS[NDIM] = */ {wt.shape(1), wt.shape(2)}, + /* const int oS[NDIM] = */ {out.shape(1), out.shape(2)}, + /* const int str[NDIM] = */ {1, 1}, + /* const int pad[NDIM] = */ {0, 0}, + /* const int dil[NDIM] = */ {1, 1}, + /* const size_t in_strides[NDIM + 2] = */ + {in_padded.strides()[0], + in_padded.strides()[1], + in_padded.strides()[2], + in_padded.strides()[3]}, + /* const size_t wt_strides[NDIM + 2] = */ + {wt.strides()[0], wt.strides()[1], wt.strides()[2], wt.strides()[3]}, + /* const size_t out_strides[NDIM + 2] = */ + {out.strides()[0], out.strides()[1], out.strides()[2], out.strides()[3]}, + }; + + int O_c = conv_params.O; + int C_c = conv_params.C; + + int N_tiles_n = conv_params.N; + int N_tiles_h = (conv_params.oS[0] + 5) / 6; + int N_tiles_w = (conv_params.oS[1] + 5) / 6; + int N_tiles = N_tiles_n * N_tiles_h * N_tiles_w; + + // Do filter transform + std::vector filt_wg_shape = {8 * 8, conv_params.C, conv_params.O}; + array filt_wg(filt_wg_shape, wt.dtype(), nullptr, {}); + filt_wg.set_data(allocator::malloc_or_wait(filt_wg.nbytes())); + copies_w.push_back(filt_wg); + { + int bc = 32; + int bo = 4; + std::ostringstream kname; + kname << "winograd_conv_2d_weight_transform_" << type_to_name(out) << "_bc" + << bc; + auto compute_encoder = d.get_command_encoder(s.index); + auto kernel = d.get_kernel(kname.str()); + compute_encoder->setComputePipelineState(kernel); + + set_array_buffer(compute_encoder, wt, 0); + set_array_buffer(compute_encoder, filt_wg, 1); + + compute_encoder->setBytes(&C_c, sizeof(int), 2); + compute_encoder->setBytes(&O_c, sizeof(int), 3); + + MTL::Size group_dims = MTL::Size(32, bo, 1); + MTL::Size grid_dims = MTL::Size(O_c / bo, 1, 1); + + compute_encoder->dispatchThreadgroups(grid_dims, group_dims); + } + + // Do input transform + std::vector inp_wg_shape = {8 * 8, N_tiles, conv_params.C}; + array inp_wg(inp_wg_shape, in.dtype(), nullptr, {}); + inp_wg.set_data(allocator::malloc_or_wait(inp_wg.nbytes())); + copies_w.push_back(inp_wg); + { + int bc = 32; + int wm = 2; + int wn = 2; + std::ostringstream kname; + kname << "winograd_conv_2d_input_transform_" << type_to_name(out) << "_bc" + << bc; + auto compute_encoder = d.get_command_encoder(s.index); + auto kernel = d.get_kernel(kname.str()); + compute_encoder->setComputePipelineState(kernel); + + set_array_buffer(compute_encoder, in_padded, 0); + set_array_buffer(compute_encoder, inp_wg, 1); + + compute_encoder->setBytes( + &conv_params_updated, sizeof(MLXConvParams<2>), 2); + + MTL::Size group_dims = MTL::Size(32, wn, wm); + MTL::Size grid_dims = MTL::Size(N_tiles_w, N_tiles_h, N_tiles_n); + + compute_encoder->dispatchThreadgroups(grid_dims, group_dims); + } + + // Do batched gemm + std::vector out_wg_shape = {8 * 8, N_tiles, conv_params.O}; + array out_wg(out_wg_shape, in.dtype(), nullptr, {}); + out_wg.set_data(allocator::malloc_or_wait(out_wg.nbytes())); + copies_w.push_back(out_wg); + { + std::vector empty_copies; + mlx_matmul( + s, + d, + /*a = */ inp_wg, + /*b = */ filt_wg, + /*c = */ out_wg, + /*M = */ N_tiles, + /*N = */ conv_params.O, + /*K = */ conv_params.C, + /*batch_size_out = */ 8 * 8, + /*a_cols = */ conv_params.C, + /*b_cols = */ conv_params.O, + /*a_transposed = */ false, + /*b_transposed = */ false, + /*copies = */ empty_copies); + } + + // Do output transform + { + int bc = 32; + int wm = 2; + int wn = 2; + std::ostringstream kname; + kname << "winograd_conv_2d_output_transform_" << type_to_name(out) << "_bo" + << bc; + auto compute_encoder = d.get_command_encoder(s.index); + auto kernel = d.get_kernel(kname.str()); + compute_encoder->setComputePipelineState(kernel); + + set_array_buffer(compute_encoder, out_wg, 0); + set_array_buffer(compute_encoder, out, 1); + + compute_encoder->setBytes( + &conv_params_updated, sizeof(MLXConvParams<2>), 2); + + MTL::Size group_dims = MTL::Size(32, wn, wm); + MTL::Size grid_dims = MTL::Size(N_tiles_w, N_tiles_h, N_tiles_n); + + compute_encoder->dispatchThreadgroups(grid_dims, group_dims); + } +} + +void conv_2D_gpu( + const Stream& s, + metal::Device& d, + const array& in, + const array& wt, + array out, + const std::vector& padding, + const std::vector& wt_strides, + const std::vector& wt_dilation, + std::vector& copies) { + // Make conv params + MLXConvParams<2> conv_params{ + /* const int N = */ in.shape(0), + /* const int C = */ in.shape(3), + /* const int O = */ wt.shape(0), + /* const int iS[NDIM] = */ {in.shape(1), in.shape(2)}, + /* const int wS[NDIM] = */ {wt.shape(1), wt.shape(2)}, + /* const int oS[NDIM] = */ {out.shape(1), out.shape(2)}, + /* const int str[NDIM] = */ {wt_strides[0], wt_strides[1]}, + /* const int pad[NDIM] = */ {padding[0], padding[1]}, + /* const int dil[NDIM] = */ {wt_dilation[0], wt_dilation[1]}, + /* const size_t in_strides[NDIM + 2] = */ + {in.strides()[0], in.strides()[1], in.strides()[2], in.strides()[3]}, + /* const size_t wt_strides[NDIM + 2] = */ + {wt.strides()[0], wt.strides()[1], wt.strides()[2], wt.strides()[3]}, + /* const size_t out_strides[NDIM + 2] = */ + {out.strides()[0], out.strides()[1], out.strides()[2], out.strides()[3]}, + }; + + // Direct to winograd conv + if (conv_params.C % 32 == 0 && conv_params.O % 32 == 0 && + conv_params.C >= 64 && conv_params.O >= 64 && conv_params.wS[0] == 3 && + conv_params.wS[1] == 3 && conv_params.str[0] == 1 && + conv_params.str[1] == 1 && conv_params.dil[0] == 1 && + conv_params.dil[1] == 1) { + winograd_conv_2D_gpu(s, d, in, wt, out, conv_params, copies); + } + + // Direct to implicit gemm conv + else if (conv_params.C % 32 == 0 && conv_params.O % 32 == 0) { + implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params); + } + + // Direct to explicit gemm conv + else if (wt_dilation[0] == 1 && wt_dilation[1] == 1) { + explicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params); + } + + // Direct to fallback conv + else { + slow_conv_2D_gpu(s, d, in, wt, out, conv_params); + } +} + +} // namespace + +void Convolution::eval_gpu(const std::vector& inputs, array& out) { + out.set_data(allocator::malloc_or_wait(out.nbytes())); + auto& s = stream(); + auto& d = metal::device(s.device); + + // Ensure contiguity + std::vector copies; + auto in = inputs[0]; + auto wt = inputs[1]; + if (!in.flags().row_contiguous) { + array arr_copy(in.shape(), in.dtype(), nullptr, {}); + copy_gpu(in, arr_copy, CopyType::General, s); + copies.push_back(arr_copy); + in = arr_copy; + } + if (!wt.flags().row_contiguous) { + array arr_copy(wt.shape(), wt.dtype(), nullptr, {}); + copy_gpu(wt, arr_copy, CopyType::General, s); + copies.push_back(arr_copy); + wt = arr_copy; + } + + // 2D conv + if (out.ndim() == 4) { + conv_2D_gpu( + s, d, in, wt, out, padding_, kernel_strides_, kernel_dilation_, copies); + } + // 1D conv + else if (out.ndim() == 3) { + conv_1D_gpu(s, d, in, wt, out, padding_, kernel_strides_, kernel_dilation_); + } + // Throw error + else { + throw std::invalid_argument( + "[Convolution::eval_gpu] Only supports 1D or 2D convolutions."); + } + + // Clear copies + if (copies.size() > 0) { + auto command_buffer = d.get_command_buffer(s.index); + command_buffer->addCompletedHandler( + [copies](MTL::CommandBuffer*) mutable { copies.clear(); }); + } +} + +} // namespace mlx::core diff --git a/mlx/backend/metal/copy.h b/mlx/backend/metal/copy.h new file mode 100644 index 000000000..02d294a84 --- /dev/null +++ b/mlx/backend/metal/copy.h @@ -0,0 +1,16 @@ +#pragma once + +#include "mlx/backend/common/copy.h" +#include "mlx/stream.h" + +namespace mlx::core { + +void copy_gpu(const array& src, array& out, CopyType ctype, const Stream& s); +void copy_gpu(const array& src, array& out, CopyType ctype); +void copy_gpu_inplace( + const array& src, + array& out, + CopyType ctype, + const Stream& s); + +} // namespace mlx::core diff --git a/mlx/backend/metal/kernels/arange.metal b/mlx/backend/metal/kernels/arange.metal new file mode 100644 index 000000000..f3afc24f6 --- /dev/null +++ b/mlx/backend/metal/kernels/arange.metal @@ -0,0 +1,30 @@ +#include "mlx/backend/metal/kernels/bf16.h" + +template +[[kernel]] void arange( + constant const T& start, + constant const T& step, + device T* out, + uint index [[thread_position_in_grid]]) { + out[index] = start + index * step; +} + +#define instantiate_arange(tname, type) \ + template [[host_name("arange" #tname)]] \ + [[kernel]] void arange( \ + constant const type& start, \ + constant const type& step, \ + device type* out, \ + uint index [[thread_position_in_grid]]); + +instantiate_arange(uint8, uint8_t) +instantiate_arange(uint16, uint16_t) +instantiate_arange(uint32, uint32_t) +instantiate_arange(uint64, uint64_t) +instantiate_arange(int8, int8_t) +instantiate_arange(int16, int16_t) +instantiate_arange(int32, int32_t) +instantiate_arange(int64, int64_t) +instantiate_arange(float16, half) +instantiate_arange(float32, float) +instantiate_arange(bfloat16, bfloat16_t) \ No newline at end of file diff --git a/mlx/backend/metal/kernels/arg_reduce.metal b/mlx/backend/metal/kernels/arg_reduce.metal new file mode 100644 index 000000000..8cd13afca --- /dev/null +++ b/mlx/backend/metal/kernels/arg_reduce.metal @@ -0,0 +1,208 @@ +#include +#include + +#include "mlx/backend/metal/kernels/utils.h" + +using namespace metal; + +template +struct IndexValPair { + uint32_t index; + U val; + + IndexValPair(uint32_t _index, U _val) : index(_index), val(_val) {} +}; + +template +struct ArgMin { + static constexpr constant U init = Limits::max; + + IndexValPair reduce(IndexValPair best, IndexValPair current) { + if (best.val > current.val || (best.val == current.val && best.index > current.index)) { + return current; + } else { + return best; + } + } + + template + IndexValPair reduce_many(IndexValPair best, thread U* vals, uint32_t offset) { + for (int i=0; i +struct ArgMax { + static constexpr constant U init = Limits::min; + + IndexValPair reduce(IndexValPair best, IndexValPair current) { + if (best.val < current.val || (best.val == current.val && best.index > current.index)) { + return current; + } else { + return best; + } + } + + template + IndexValPair reduce_many(IndexValPair best, thread U* vals, uint32_t offset) { + for (int i=0; i best.val) { + best.val = vals[i]; + best.index = offset+i; + } + } + return best; + } +}; + +bool simd_shuffle_down(bool data, uint16_t delta) { + return simd_shuffle_down(static_cast(data), delta); +} + +uint64_t simd_shuffle_down(uint64_t data, uint16_t delta) { + return as_type(simd_shuffle_down(as_type(data), delta)); +} + +int64_t simd_shuffle_down(int64_t data, uint16_t delta) { + return as_type(simd_shuffle_down(as_type(data), delta)); +} + +template +IndexValPair simd_shuffle_down(IndexValPair data, uint16_t delta) { + return IndexValPair( + simd_shuffle_down(data.index, delta), + simd_shuffle_down(data.val, delta) + ); +} + + +template +[[kernel]] void arg_reduce_general( + const device T *in [[buffer(0)]], + device uint32_t *out [[buffer(1)]], + const device int *shape [[buffer(2)]], + const device size_t *in_strides [[buffer(3)]], + const device size_t *out_strides [[buffer(4)]], + const device size_t& ndim [[buffer(5)]], + const device size_t& axis_stride [[buffer(6)]], + const device size_t& axis_size [[buffer(7)]], + threadgroup IndexValPair *local_data [[threadgroup(0)]], + uint gid [[thread_position_in_grid]], + uint lid [[thread_position_in_threadgroup]], + uint lsize [[threads_per_threadgroup]], + uint simd_size [[threads_per_simdgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]]) { + + // Shapes and strides *do not* contain the reduction axis. The reduction size + // and stride are provided in axis_stride and axis_size. + // + // Note: in shape == out shape with this convention. + // + // The sketch of the kernel is as follows. + // 1. Launch prod(shape) * thread_group_size threads. + // 2. Loop ceildiv(axis_size / lsize) times + // 3. Read input values + // 4. Reduce among them and go to 3 + // 4. Reduce in each simd_group + // 6. Write in the thread local memory + // 6. Reduce them accross thread group + // 7. Write the output without need for atomic + Op op; + + // Compute the input/output index. There is one beginning and one output for + // the whole threadgroup. + auto in_idx = elem_to_loc(gid / lsize, shape, in_strides, ndim); + auto out_idx = elem_to_loc(gid / lsize, shape, out_strides, ndim); + + IndexValPair best(0, Op::init); + + // Loop over the reduction axis in lsize*N_READS buckets + for (uint r=0; r < ceildiv(axis_size, N_READS*lsize); r++) { + // Read the current value + uint32_t current_index = r*lsize*N_READS + lid*N_READS; + uint32_t offset = current_index; + const device T * current_in = in + in_idx + current_index * axis_stride; + T vals[N_READS]; + for (int i=0; i(best, vals, offset); + } + // At this point we have reduced the axis into thread group best values so we + // need to reduce across the thread group. + + // First per simd reduction. + for (uint offset=simd_size/2; offset>0; offset/=2) { + IndexValPair neighbor = simd_shuffle_down(best, offset); + best = op.reduce(best, neighbor); + } + + // Write to the threadgroup memory + if (simd_lane_id == 0) { + local_data[simd_group_id] = best; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + if (simd_group_id != 0) { + return; + } + + // Read the appropriate value from local data and perform one simd reduction + uint simd_groups = ceildiv(lsize, simd_size); + if (simd_lane_id < simd_groups) { + best = local_data[simd_lane_id]; + } + for (uint offset=simd_size/2; offset>0; offset/=2) { + IndexValPair neighbor = simd_shuffle_down(best, offset); + best = op.reduce(best, neighbor); + } + + // Finally write the output + if (lid == 0) { + out[out_idx] = best.index; + } +} + +#define instantiate_arg_reduce_helper(name, itype, op) \ + template [[host_name(name)]] \ + [[kernel]] void arg_reduce_general, 4>( \ + const device itype *in [[buffer(0)]], \ + device uint32_t * out [[buffer(1)]], \ + const device int *shape [[buffer(2)]], \ + const device size_t *in_strides [[buffer(3)]], \ + const device size_t *out_strides [[buffer(4)]], \ + const device size_t& ndim [[buffer(5)]], \ + const device size_t& axis_stride [[buffer(6)]], \ + const device size_t& axis_size [[buffer(7)]], \ + threadgroup IndexValPair *local_data [[threadgroup(0)]], \ + uint gid [[thread_position_in_grid]], \ + uint lid [[thread_position_in_threadgroup]], \ + uint lsize [[threads_per_threadgroup]], \ + uint simd_size [[threads_per_simdgroup]], \ + uint simd_lane_id [[thread_index_in_simdgroup]], \ + uint simd_group_id [[simdgroup_index_in_threadgroup]]); + +#define instantiate_arg_reduce(name, itype) \ + instantiate_arg_reduce_helper("argmin_" #name , itype, ArgMin) \ + instantiate_arg_reduce_helper("argmax_" #name , itype, ArgMax) + +instantiate_arg_reduce(bool_, bool) +instantiate_arg_reduce(uint8, uint8_t) +instantiate_arg_reduce(uint16, uint16_t) +instantiate_arg_reduce(uint32, uint32_t) +instantiate_arg_reduce(uint64, uint64_t) +instantiate_arg_reduce(int8, int8_t) +instantiate_arg_reduce(int16, int16_t) +instantiate_arg_reduce(int32, int32_t) +instantiate_arg_reduce(int64, int64_t) +instantiate_arg_reduce(float16, half) +instantiate_arg_reduce(float32, float) +instantiate_arg_reduce(bfloat16, bfloat16_t) \ No newline at end of file diff --git a/mlx/backend/metal/kernels/bf16_math.h b/mlx/backend/metal/kernels/bf16_math.h new file mode 100644 index 000000000..824ecc9ef --- /dev/null +++ b/mlx/backend/metal/kernels/bf16_math.h @@ -0,0 +1,392 @@ +#pragma once + +#include "mlx/backend/metal/kernels/bf16.h" + +/////////////////////////////////////////////////////////////////////////////// +// Metal math for bfloat16 +/////////////////////////////////////////////////////////////////////////////// + +/* + +Following the Metal Shading Language Specification (Metal 3.1) + +"bfloat is an extended itypeing point type that only allows implicit conversion + to a type of greater itypeing point rank. While bfloat can be implicitly + converted to itype, it cannot be implicitly converted to half, and neither + itype nor half can be implicitly converted to bfloat." + +Further, as far as I can tell, the stdlib math/simd functions are not defined +for bfloat and calling with an argument of type bfloat will result in that +argument getting implicitly converted to itype which then returns an output +that is (likely) a itype which cannot be implicitly converted into a bfloat + +This leads to situations where +bfloat a = 5.0bf; +bfloat b = metal::abs(a); // this will throw an error since abs return itype +bfloat c = static_cast(metal::abs(a)); // this is fine + +For the moment, I will be adding overloaded instantiations of the math +functions to accordingly automatically handle the casting + +*/ + +#define instantiate_metal_math_funcs(itype, otype, ctype, mfast) \ + \ + METAL_FUNC otype abs(itype x) { \ + return static_cast(__metal_fabs(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype acos(itype x) { \ + return static_cast(__metal_acos(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype acosh(itype x) { \ + return static_cast(__metal_acosh(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype asin(itype x) { \ + return static_cast(__metal_asin(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype asinh(itype x) { \ + return static_cast(__metal_asinh(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype atan(itype y_over_x) { \ + return static_cast( \ + __metal_atan(static_cast(y_over_x), mfast)); \ + } \ + METAL_FUNC otype atan2(itype y, itype x) { \ + return static_cast( \ + __metal_atan2(static_cast(y), static_cast(x), mfast)); \ + } \ + METAL_FUNC otype atanh(itype x) { \ + return static_cast(__metal_atanh(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype ceil(itype x) { \ + return static_cast(__metal_ceil(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype cos(itype x) { \ + return static_cast(__metal_cos(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype cosh(itype x) { \ + return static_cast(__metal_cosh(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype cospi(itype x) { \ + return static_cast(__metal_cospi(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype divide(itype x, itype y) { \ + return static_cast( \ + __metal_divide(static_cast(x), static_cast(y), mfast)); \ + } \ + METAL_FUNC otype exp(itype x) { \ + return static_cast(__metal_exp(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype exp10(itype x) { \ + return static_cast(__metal_exp10(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype exp2(itype x) { \ + return static_cast(__metal_exp2(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype fabs(itype x) { \ + return static_cast(__metal_fabs(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype fdim(itype x, itype y) { \ + ctype t = static_cast(x - y); \ + return static_cast(select(t, ctype(0), t < ctype(0) || x == y)); \ + } \ + METAL_FUNC otype floor(itype x) { \ + return static_cast(__metal_floor(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype fma(itype x, itype y, itype z) { \ + return static_cast(__metal_fma( \ + static_cast(x), static_cast(y), static_cast(z))); \ + } \ + METAL_FUNC otype fmax(itype x, itype y) { \ + return static_cast( \ + __metal_fmax(static_cast(x), static_cast(y), mfast)); \ + } \ + METAL_FUNC otype fmax3(itype x, itype y, itype z) { \ + return static_cast(__metal_fmax3( \ + static_cast(x), \ + static_cast(y), \ + static_cast(z), \ + mfast)); \ + } \ + METAL_FUNC otype fmedian3(itype x, itype y, itype z) { \ + return static_cast(__metal_fmedian3( \ + static_cast(x), \ + static_cast(y), \ + static_cast(z), \ + mfast)); \ + } \ + METAL_FUNC otype fmin(itype x, itype y) { \ + return static_cast( \ + __metal_fmin(static_cast(x), static_cast(y), mfast)); \ + } \ + METAL_FUNC otype fmin3(itype x, itype y, itype z) { \ + return static_cast(__metal_fmin3( \ + static_cast(x), \ + static_cast(y), \ + static_cast(z), \ + mfast)); \ + } \ + METAL_FUNC otype fmod(itype x, itype y) { \ + return static_cast( \ + __metal_fmod(static_cast(x), static_cast(y), mfast)); \ + } \ + METAL_FUNC otype fract(itype x) { \ + return static_cast(__metal_fract(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype frexp(itype x, thread int& exp) { \ + return static_cast(__metal_frexp(static_cast(x), &exp)); \ + } \ + METAL_FUNC otype ldexp(itype x, int k) { \ + return static_cast(__metal_ldexp(static_cast(x), k, mfast)); \ + } \ + METAL_FUNC otype log(itype x) { \ + return static_cast(__metal_log(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype log10(itype x) { \ + return static_cast(__metal_log10(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype log2(itype x) { \ + return static_cast(__metal_log2(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype max(itype x, itype y) { \ + return static_cast( \ + __metal_fmax(static_cast(x), static_cast(y), mfast)); \ + } \ + METAL_FUNC otype max3(itype x, itype y, itype z) { \ + return static_cast(__metal_fmax3( \ + static_cast(x), \ + static_cast(y), \ + static_cast(z), \ + mfast)); \ + } \ + METAL_FUNC otype median3(itype x, itype y, itype z) { \ + return static_cast(__metal_fmedian3( \ + static_cast(x), \ + static_cast(y), \ + static_cast(z), \ + mfast)); \ + } \ + METAL_FUNC otype min(itype x, itype y) { \ + return static_cast( \ + __metal_fmin(static_cast(x), static_cast(y), mfast)); \ + } \ + METAL_FUNC otype min3(itype x, itype y, itype z) { \ + return static_cast(__metal_fmin3( \ + static_cast(x), \ + static_cast(y), \ + static_cast(z), \ + mfast)); \ + } \ + METAL_FUNC otype nextafter(itype x, itype y) { \ + return static_cast( \ + __metal_nextafter(static_cast(x), static_cast(y))); \ + } \ + METAL_FUNC otype pow(itype x, itype y) { \ + return static_cast( \ + __metal_pow(static_cast(x), static_cast(y), mfast)); \ + } \ + METAL_FUNC otype powr(itype x, itype y) { \ + return static_cast( \ + __metal_powr(static_cast(x), static_cast(y), mfast)); \ + } \ + METAL_FUNC otype rint(itype x) { \ + return static_cast(__metal_rint(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype round(itype x) { \ + return static_cast(__metal_round(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype rsqrt(itype x) { \ + return static_cast(__metal_rsqrt(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype sin(itype x) { \ + return static_cast(__metal_sin(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype sinh(itype x) { \ + return static_cast(__metal_sinh(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype sinpi(itype x) { \ + return static_cast(__metal_sinpi(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype sqrt(itype x) { \ + return static_cast(__metal_sqrt(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype tan(itype x) { \ + return static_cast(__metal_tan(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype tanh(itype x) { \ + return static_cast(__metal_tanh(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype tanpi(itype x) { \ + return static_cast(__metal_tanpi(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype trunc(itype x) { \ + return static_cast(__metal_trunc(static_cast(x), mfast)); \ + } + +namespace metal { + +instantiate_metal_math_funcs( + bfloat16_t, + bfloat16_t, + float, + __METAL_MAYBE_FAST_MATH__); + +namespace fast { + +instantiate_metal_math_funcs( + bfloat16_t, + bfloat16_t, + float, + __METAL_FAST_MATH__); + +} // namespace fast + +namespace precise { + +instantiate_metal_math_funcs( + bfloat16_t, + bfloat16_t, + float, + __METAL_PRECISE_MATH__); + +} // namespace precise + +} // namespace metal + +/////////////////////////////////////////////////////////////////////////////// +// Metal simd for bfloat16 +/////////////////////////////////////////////////////////////////////////////// + +#define instantiate_metal_simd_comm_funcs( \ + itype, otype, ctype, itype_to_ctype, ctype_to_otype) \ + \ + METAL_FUNC otype simd_broadcast(itype data, ushort broadcast_lane_id) { \ + return ctype_to_otype( \ + __metal_simd_broadcast(itype_to_ctype(data), broadcast_lane_id)); \ + } \ + \ + METAL_FUNC otype simd_shuffle(itype data, ushort simd_lane_id) { \ + return ctype_to_otype( \ + __metal_simd_shuffle(itype_to_ctype(data), simd_lane_id)); \ + } \ + \ + METAL_FUNC otype simd_shuffle_and_fill_down( \ + itype data, itype filling_data, ushort delta, ushort modulo) { \ + return ctype_to_otype(__metal_simd_shuffle_and_fill_down( \ + itype_to_ctype(data), itype_to_ctype(filling_data), delta, modulo)); \ + } \ + \ + METAL_FUNC otype simd_shuffle_and_fill_down( \ + itype data, itype filling_data, ushort delta) { \ + return ctype_to_otype(__metal_simd_shuffle_and_fill_down( \ + itype_to_ctype(data), \ + itype_to_ctype(filling_data), \ + delta, \ + __metal_get_simdgroup_size(ushort()))); \ + } \ + \ + METAL_FUNC otype simd_shuffle_and_fill_up( \ + itype data, itype filling_data, ushort delta, ushort modulo) { \ + return ctype_to_otype(__metal_simd_shuffle_and_fill_up( \ + itype_to_ctype(data), itype_to_ctype(filling_data), delta, modulo)); \ + } \ + \ + METAL_FUNC otype simd_shuffle_and_fill_up( \ + itype data, itype filling_data, ushort delta) { \ + return ctype_to_otype(__metal_simd_shuffle_and_fill_up( \ + itype_to_ctype(data), \ + itype_to_ctype(filling_data), \ + delta, \ + __metal_get_simdgroup_size(ushort()))); \ + } \ + \ + METAL_FUNC otype simd_shuffle_down(itype data, ushort delta) { \ + return ctype_to_otype( \ + __metal_simd_shuffle_down(itype_to_ctype(data), delta)); \ + } \ + \ + METAL_FUNC otype simd_shuffle_rotate_down(itype data, ushort delta) { \ + return ctype_to_otype( \ + __metal_simd_shuffle_rotate_down(itype_to_ctype(data), delta)); \ + } \ + \ + METAL_FUNC otype simd_shuffle_rotate_up(itype data, ushort delta) { \ + return ctype_to_otype( \ + __metal_simd_shuffle_rotate_up(itype_to_ctype(data), delta)); \ + } \ + \ + METAL_FUNC otype simd_shuffle_up(itype data, ushort delta) { \ + return ctype_to_otype( \ + __metal_simd_shuffle_up(itype_to_ctype(data), delta)); \ + } \ + \ + METAL_FUNC otype simd_shuffle_xor(itype data, ushort mask) { \ + return ctype_to_otype( \ + __metal_simd_shuffle_xor(itype_to_ctype(data), mask)); \ + } + +#define instantiate_metal_simd_reduction_funcs(itype, otype, ctype) \ + \ + METAL_FUNC otype simd_max(itype data) { \ + return static_cast(__metal_simd_max(static_cast(data))); \ + } \ + \ + METAL_FUNC otype simd_min(itype data) { \ + return static_cast(__metal_simd_min(static_cast(data))); \ + } \ + \ + METAL_FUNC otype simd_prefix_exclusive_product(itype data) { \ + return static_cast( \ + __metal_simd_prefix_exclusive_product(static_cast(data))); \ + } \ + \ + METAL_FUNC otype simd_prefix_exclusive_sum(itype data) { \ + return static_cast( \ + __metal_simd_prefix_exclusive_sum(static_cast(data))); \ + } \ + \ + METAL_FUNC otype simd_prefix_inclusive_product(itype data) { \ + return static_cast( \ + __metal_simd_prefix_inclusive_product(static_cast(data))); \ + } \ + \ + METAL_FUNC otype simd_prefix_inclusive_sum(itype data) { \ + return static_cast( \ + __metal_simd_prefix_inclusive_sum(static_cast(data))); \ + } \ + \ + METAL_FUNC otype simd_product(itype data) { \ + return static_cast(__metal_simd_product(static_cast(data))); \ + } \ + \ + METAL_FUNC otype simd_sum(itype data) { \ + return static_cast(__metal_simd_sum(static_cast(data))); \ + } \ + \ + METAL_FUNC otype simd_xor(itype data) { \ + return static_cast(__metal_simd_xor(static_cast(data))); \ + } + +#if defined(__HAVE_BFLOAT__) + +#define bfloat16_to_uint16(x) as_type(x) +#define uint16_to_bfloat16(x) as_type(x) + +#else + +#define bfloat16_to_uint16(x) x.bits_ +#define uint16_to_bfloat16(x) _MLX_BFloat16(x, _MLX_BFloat16::bits_to_bfloat()) + +#endif + +namespace metal { + +instantiate_metal_simd_comm_funcs( + bfloat16_t, + bfloat16_t, + uint16_t, + bfloat16_to_uint16, + uint16_to_bfloat16); +instantiate_metal_simd_reduction_funcs(bfloat16_t, bfloat16_t, float); + +} // namespace metal \ No newline at end of file diff --git a/mlx/backend/metal/kernels/conv.metal b/mlx/backend/metal/kernels/conv.metal new file mode 100644 index 000000000..65329d19f --- /dev/null +++ b/mlx/backend/metal/kernels/conv.metal @@ -0,0 +1,553 @@ +#include + +#include "mlx/backend/metal/kernels/conv_params.h" +#include "mlx/backend/metal/kernels/bf16.h" + +#include "mlx/backend/metal/kernels/gemm/conv.h" + +using namespace metal; + +/////////////////////////////////////////////////////////////////////////////// +/// Slow and naive kernels +/////////////////////////////////////////////////////////////////////////////// + +template +[[kernel]] void naive_conv_2d( + const device T* in [[buffer(0)]], + const device T* wt [[buffer(1)]], + device T* out [[buffer(2)]], + const constant MLXConvParams<2>& params [[buffer(3)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + + (void)simd_gid; + (void)simd_lid; + + out += tid.z * params.out_strides[0]; + in += tid.z * params.in_strides[0]; + + int out_o = tid.y * BN * TN + lid.y * TN; + int out_hw = tid.x * BM * TM + lid.x * TM; + + int out_h[TM]; + int out_w[TN]; + + for(int m = 0; m < TM; ++m) { + int mm = (out_hw + m); + out_h[m] = mm / params.oS[1]; + out_w[m] = mm % params.oS[1]; + } + + + T in_local[TM]; + T wt_local[TN]; + T out_local[TM * TN] = {T(0)}; + + for(int h = 0; h < params.wS[0]; ++h) { + for(int w = 0; w < params.wS[1]; ++w) { + for(int c = 0; c < params.C; ++c) { + + // Local in + for(int m = 0; m < TM; m++) { + int i = out_h[m] * params.str[0] - params.pad[0] + h * params.dil[0]; + int j = out_w[m] * params.str[1] - params.pad[1] + w * params.dil[1]; + + bool valid = i >= 0 && i < params.iS[0] && j >= 0 && j < params.iS[1]; + in_local[m] = valid ? in[i * params.in_strides[1] + j * params.in_strides[2] + c] : T(0); + } + + // Load weight + for (int n = 0; n < TN; ++n) { + int o = out_o + n; + wt_local[n] = o < params.O ? wt[o * params.wt_strides[0] + + h * params.wt_strides[1] + + w * params.wt_strides[2] + c] : T(0); + } + + // Accumulate + for(int m = 0; m < TM; ++m) { + for(int n = 0; n < TN; ++n) { + out_local[m * TN + n] += in_local[m] * wt_local[n]; + } + } + + } + } + } + + for(int m = 0; m < TM; ++m) { + for(int n = 0; n < TN; ++n) { + if(out_h[m] < params.oS[0] && out_w[m] < params.oS[1] && (out_o + n) < params.O) + out[out_h[m] * params.out_strides[1] + + out_w[m] * params.out_strides[2] + out_o + n] = out_local[m * TN + n]; + } + } + +} + +// Instantiations + +#define instantiate_naive_conv_2d(name, itype, bm, bn, tm, tn) \ + template [[host_name("naive_conv_2d_" #name "_bm" #bm "_bn" #bn "_tm" #tm "_tn" #tn)]] \ + [[kernel]] void naive_conv_2d( \ + const device itype* in [[buffer(0)]], \ + const device itype* wt [[buffer(1)]], \ + device itype* out [[buffer(2)]], \ + const constant MLXConvParams<2>& params [[buffer(3)]], \ + uint3 tid [[threadgroup_position_in_grid]], \ + uint3 lid [[thread_position_in_threadgroup]], \ + uint simd_gid [[simdgroup_index_in_threadgroup]], \ + uint simd_lid [[thread_index_in_simdgroup]]); + +#define instantiate_naive_conv_2d_blocks(name, itype) \ + instantiate_naive_conv_2d(name, itype, 16, 8, 4, 4) \ + instantiate_naive_conv_2d(name, itype, 16, 8, 2, 4) + +instantiate_naive_conv_2d_blocks(float32, float); +instantiate_naive_conv_2d_blocks(float16, half); +instantiate_naive_conv_2d_blocks(bfloat16, bfloat16_t); + +/////////////////////////////////////////////////////////////////////////////// +/// Implicit gemm kernels +/////////////////////////////////////////////////////////////////////////////// + +template +[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void implicit_gemm_conv_2d( + const device T* in [[buffer(0)]], + const device T* wt [[buffer(1)]], + device T* out [[buffer(2)]], + const constant MLXConvParams<2>& params [[buffer(3)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + + using gemm_kernel = Conv2DImplicitGEMMKernel; + + threadgroup T tgp_memory[gemm_kernel::tgp_mem_size]; + + gemm_kernel::run( + in, wt, out, + params, tgp_memory, + tid, lid, simd_gid, simd_lid + ); + +} + +#define instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn) \ + template [[host_name("implicit_gemm_conv_2d_" #name "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn)]] \ + [[kernel]] void implicit_gemm_conv_2d( \ + const device itype* in [[buffer(0)]], \ + const device itype* wt [[buffer(1)]], \ + device itype* out [[buffer(2)]], \ + const constant MLXConvParams<2>& params [[buffer(3)]], \ + uint3 tid [[threadgroup_position_in_grid]], \ + uint3 lid [[thread_position_in_threadgroup]], \ + uint simd_gid [[simdgroup_index_in_threadgroup]], \ + uint simd_lid [[thread_index_in_simdgroup]]); + +#define instantiate_implicit_2d_blocks(name, itype) \ + instantiate_implicit_conv_2d(name, itype, 32, 32, 32, 2, 2) \ + instantiate_implicit_conv_2d(name, itype, 32, 32, 16, 2, 2) \ + instantiate_implicit_conv_2d(name, itype, 64, 64, 16, 2, 2) + +instantiate_implicit_2d_blocks(float32, float); +instantiate_implicit_2d_blocks(float16, half); +instantiate_implicit_2d_blocks(bfloat16, bfloat16_t); + +/////////////////////////////////////////////////////////////////////////////// +/// Winograd kernels +/////////////////////////////////////////////////////////////////////////////// + +template +struct WinogradTransforms { + +}; + +template <> +struct WinogradTransforms<6, 3, 8> { + MLX_MTL_CONST int OUT_TILE_SIZE = 6; + MLX_MTL_CONST int FILTER_SIZE = 3; + MLX_MTL_CONST int IN_TILE_SIZE = OUT_TILE_SIZE + FILTER_SIZE - 1; + MLX_MTL_CONST int SIMD_MATRIX_SIZE = 8; + MLX_MTL_CONST float in_transform[SIMD_MATRIX_SIZE][SIMD_MATRIX_SIZE] = { + { 1.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f}, + { 0.00f, 1.00f, -1.00f, 0.50f, -0.50f, 2.00f, -2.00f, -1.00f}, + {-5.25f, 1.00f, 1.00f, 0.25f, 0.25f, 4.00f, 4.00f, 0.00f}, + { 0.00f, -4.25f, 4.25f, -2.50f, 2.50f, -2.50f, 2.50f, 5.25f}, + { 5.25f, -4.25f, -4.25f, -1.25f, -1.25f, -5.00f, -5.00f, 0.00f}, + { 0.00f, 1.00f, -1.00f, 2.00f, -2.00f, 0.50f, -0.50f, -5.25f}, + {-1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 0.00f}, + { 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 1.00f}, + }; + + MLX_MTL_CONST float out_transform[SIMD_MATRIX_SIZE][SIMD_MATRIX_SIZE] = { + { 1.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f}, + { 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f}, + { 1.00f, -1.00f, 1.00f, -1.00f, 1.00f, -1.00f}, + { 1.00f, 2.00f, 4.00f, 8.00f, 16.00f, 32.00f}, + { 1.00f, -2.00f, 4.00f, -8.00f, 16.00f, -32.00f}, + { 1.00f, 0.50f, 0.25f, 0.125f, 0.0625f, 0.03125f}, + { 1.00f, -0.50f, 0.25f, -0.125f, 0.0625f, -0.03125f}, + { 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 1.00f}, + }; + + MLX_MTL_CONST float wt_transform[SIMD_MATRIX_SIZE][SIMD_MATRIX_SIZE] = { + { 1.00, 0.00, 0.00}, + { -2.0/9.00, -2.0/9.00, -2.0/9.00}, + { -2.0/9.00, 2.0/9.00, -2.0/9.00}, + { 1.0/90.0, 1.0/45.0, 2.0/45.0}, + { 1.0/90.0, -1.0/45.0, 2.0/45.0}, + { 32.0/45.0, 16.0/45.0, 8.0/45.0}, + { 32.0/45.0, -16.0/45.0, 8.0/45.0}, + { 0.00, 0.00, 1.00}, + }; +}; + +constant constexpr const float WinogradTransforms<6, 3, 8>::wt_transform[8][8]; +constant constexpr const float WinogradTransforms<6, 3, 8>::in_transform[8][8]; +constant constexpr const float WinogradTransforms<6, 3, 8>::out_transform[8][8]; + +template +[[kernel, max_total_threads_per_threadgroup(BO * 32)]] void winograd_conv_2d_weight_transform( + const device T* wt_in [[buffer(0)]], + device T* wt_out [[buffer(1)]], + const constant int& C [[buffer(2)]], + const constant int& O [[buffer(3)]], + uint tid [[threadgroup_position_in_grid]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]]) { + + using WGT = WinogradTransforms; + + // Get lane position in simdgroup + const short qid = simd_lane_id / 4; + const short sm = (qid & 4) + (simd_lane_id / 2) % 4; + const short sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2; + + // Initialize G matrix + simdgroup_matrix G; + G.thread_elements()[0] = WGT::wt_transform[sm][sn]; + G.thread_elements()[1] = WGT::wt_transform[sm][sn + 1]; + + // Initialize Gt matrix + simdgroup_matrix Gt; + Gt.thread_elements()[0] = WGT::wt_transform[sn][sm]; + Gt.thread_elements()[1] = WGT::wt_transform[sn + 1][sm]; + + // Move to the correct output filter + size_t ko = BO * tid + simd_group_id; + wt_in += ko * R * R * C; + + // wt_out is stored transposed (A x A x C x O) + short ohw_0 = sm * 8 + sn; + short ohw_1 = sm * 8 + sn + 1; + device T* wt_out_0 = wt_out + ohw_0 * C * O + ko; + device T* wt_out_1 = wt_out + ohw_1 * C * O + ko; + + // Prepare shared memory + threadgroup T Ws[BO][R][R][BC]; + + // Loop over C + for(int bc = 0; bc < C; bc += BC) { + threadgroup_barrier(mem_flags::mem_threadgroup); + // Read into shared memory + for(int kh = 0; kh < R; ++kh) { + for(int kw = 0; kw < R; ++kw) { + for(int kc = simd_lane_id; kc < BC; kc += 32) { + Ws[simd_group_id][kh][kw][kc] = wt_in[kh * R * C + kw * C + kc]; + } + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + // Do transform and store the result + for(int c = 0; c < BC; ++c) { + simdgroup_matrix g; + g.thread_elements()[0] = sm < R && sn < R ? Ws[simd_group_id][sm][sn][c] : T(0); + g.thread_elements()[1] = sm < R && sn + 1 < R ? Ws[simd_group_id][sm][sn + 1][c] : T(0); + + simdgroup_matrix g_out = (G * g) * Gt; + wt_out_0[c * O] = g_out.thread_elements()[0]; + wt_out_1[c * O] = g_out.thread_elements()[1]; + } + + wt_in += BC; + wt_out_0 += BC * O; + wt_out_1 += BC * O; + } + +} + +#define instantiate_winograd_conv_2d_weight_transform_base(name, itype, bc) \ + template [[host_name("winograd_conv_2d_weight_transform_" #name "_bc" #bc)]]\ + [[kernel]] void winograd_conv_2d_weight_transform(\ + const device itype* wt_in [[buffer(0)]],\ + device itype* wt_out [[buffer(1)]],\ + const constant int& C [[buffer(2)]],\ + const constant int& O [[buffer(3)]],\ + uint tid [[threadgroup_position_in_grid]],\ + uint simd_group_id [[simdgroup_index_in_threadgroup]],\ + uint simd_lane_id [[thread_index_in_simdgroup]]); + +template +[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void winograd_conv_2d_input_transform( + const device T* inp_in [[buffer(0)]], + device T* inp_out [[buffer(1)]], + const constant MLXConvParams<2>& params [[buffer(2)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint3 tgp_per_grid [[threadgroups_per_grid]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]]) { + + (void)lid; + + using WGT = WinogradTransforms; + constexpr int A = WGT::IN_TILE_SIZE; + constexpr int N_SIMD_GROUPS = WM * WN; + + // Get lane position in simdgroup + const short qid = simd_lane_id / 4; + const short sm = (qid & 4) + (simd_lane_id / 2) % 4; + const short sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2; + + // Initialize B matrix + simdgroup_matrix B; + B.thread_elements()[0] = WGT::in_transform[sm][sn]; + B.thread_elements()[1] = WGT::in_transform[sm][sn + 1]; + + // Initialize Bt matrix + simdgroup_matrix Bt; + Bt.thread_elements()[0] = WGT::in_transform[sn][sm]; + Bt.thread_elements()[1] = WGT::in_transform[sn + 1][sm]; + + // Resolve input tile + constexpr int TH = (A / WM); + constexpr int TW = (A / WN); + int kh = TH * (simd_group_id / WN); + int kw = TW * (simd_group_id % WN); + int bh = M * tid.y + kh; + int bw = M * tid.x + kw; + + // Move to the correct input tile + inp_in += tid.z * params.in_strides[0] + + bh * params.in_strides[1] + + bw * params.in_strides[2]; + + // Pre compute strides + int jump_in[TH][TW]; + + for(int h = 0; h < TH; h++) { + for(int w = 0; w < TW; w++) { + jump_in[h][w] = h * params.in_strides[1] + w * params.in_strides[2]; + } + } + + // inp_out is stored interleaved (A x A x tiles x C) + size_t N_TILES = tgp_per_grid.x * tgp_per_grid.y * tgp_per_grid.z; + size_t tile_id = tid.z * tgp_per_grid.x * tgp_per_grid.y + tid.y * tgp_per_grid.x + tid.x; + size_t ohw_0 = sm * 8 + sn; + size_t ohw_1 = sm * 8 + sn + 1; + device T* inp_out_0 = inp_out + ohw_0 * N_TILES * params.C + tile_id * params.C; + device T* inp_out_1 = inp_out + ohw_1 * N_TILES * params.C + tile_id * params.C; + + // Prepare shared memory + threadgroup T Is[A][A][BC]; + + // Loop over C + for(int bc = 0; bc < params.C; bc += BC) { + threadgroup_barrier(mem_flags::mem_threadgroup); + // Read into shared memory + for(int h = 0; h < TH; h++) { + for(int w = 0; w < TW; w++) { + const device T* in_ptr = inp_in + jump_in[h][w]; + for(int c = simd_lane_id; c < BC; c += 32) { + Is[kh + h][kw + w][c] = in_ptr[c]; + } + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + // Do transform and store the result + for(int c = simd_group_id; c < BC; c += N_SIMD_GROUPS) { + simdgroup_matrix I; + I.thread_elements()[0] = Is[sm][sn][c]; + I.thread_elements()[1] = Is[sm][sn + 1][c]; + + simdgroup_matrix I_out = (Bt * I) * B; + inp_out_0[c] = I_out.thread_elements()[0]; + inp_out_1[c] = I_out.thread_elements()[1]; + } + + inp_in += BC; + inp_out_0 += BC; + inp_out_1 += BC; + } + +} + +#define instantiate_winograd_conv_2d_input_transform(name, itype, bc) \ + template [[host_name("winograd_conv_2d_input_transform_" #name "_bc" #bc)]]\ + [[kernel]] void winograd_conv_2d_input_transform(\ + const device itype* inp_in [[buffer(0)]],\ + device itype* inp_out [[buffer(1)]],\ + const constant MLXConvParams<2>& params [[buffer(2)]],\ + uint3 tid [[threadgroup_position_in_grid]],\ + uint3 lid [[thread_position_in_threadgroup]],\ + uint3 tgp_per_grid [[threadgroups_per_grid]],\ + uint simd_group_id [[simdgroup_index_in_threadgroup]],\ + uint simd_lane_id [[thread_index_in_simdgroup]]); + +template +[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void winograd_conv_2d_output_transform( + const device T* out_in [[buffer(0)]], + device T* out_out [[buffer(1)]], + const constant MLXConvParams<2>& params [[buffer(2)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint3 tgp_per_grid [[threadgroups_per_grid]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]]) { + + (void)lid; + + using WGT = WinogradTransforms; + constexpr int N_SIMD_GROUPS = WM * WN; + + // Get lane position in simdgroup + const short qid = simd_lane_id / 4; + const short sm = (qid & 4) + (simd_lane_id / 2) % 4; + const short sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2; + + // Initialize A matrix + simdgroup_matrix B; + B.thread_elements()[0] = WGT::out_transform[sm][sn]; + B.thread_elements()[1] = WGT::out_transform[sm][sn + 1]; + + // Initialize At matrix + simdgroup_matrix Bt; + Bt.thread_elements()[0] = WGT::out_transform[sn][sm]; + Bt.thread_elements()[1] = WGT::out_transform[sn + 1][sm]; + + // Out_in comes in shape (A x A x tiles x O) + // We do transform and then write out to out_out in shape (N, H, W, O) + + // Resolve output tile + constexpr int TH = (M / WM); + constexpr int TW = (M / WN); + int kh = TH * (simd_group_id / WN); + int kw = TW * (simd_group_id % WN); + int bh = M * tid.y + kh; + int bw = M * tid.x + kw; + + // Move to the correct input tile + out_out += tid.z * params.out_strides[0] + + bh * params.out_strides[1] + + bw * params.out_strides[2]; + + // Pre compute strides + int jump_in[TH][TW]; + + for(int h = 0; h < TH; h++) { + for(int w = 0; w < TW; w++) { + bool valid = ((bh + h) < params.oS[0]) && ((bw + w) < params.oS[1]); + jump_in[h][w] = valid ? h * params.out_strides[1] + w * params.out_strides[2] : -1; + } + } + + // out_in is stored interleaved (A x A x tiles x O) + size_t N_TILES = tgp_per_grid.x * tgp_per_grid.y * tgp_per_grid.z; + size_t tile_id = tid.z * tgp_per_grid.x * tgp_per_grid.y + tid.y * tgp_per_grid.x + tid.x; + size_t ohw_0 = sm * 8 + sn; + size_t ohw_1 = sm * 8 + sn + 1; + const device T* out_in_0 = out_in + ohw_0 * N_TILES * params.O + tile_id * params.O; + const device T* out_in_1 = out_in + ohw_1 * N_TILES * params.O + tile_id * params.O; + + // Prepare shared memory + threadgroup T Os[M][M][BO]; + + // Loop over O + for(int bo = 0; bo < params.O; bo += BO) { + + threadgroup_barrier(mem_flags::mem_threadgroup); + // Do transform and store the result + for(int c = simd_group_id; c < BO; c += N_SIMD_GROUPS) { + simdgroup_matrix O_mat; + O_mat.thread_elements()[0] = out_in_0[c]; + O_mat.thread_elements()[1] = out_in_1[c]; + + simdgroup_matrix O_out = (Bt * (O_mat * B)); + if((sm < M) && (sn < M)) { + Os[sm][sn][c] = O_out.thread_elements()[0]; + } + if((sm < M) && ((sn + 1) < M)) { + Os[sm][sn + 1][c] = O_out.thread_elements()[1]; + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + // Read out from shared memory + for(int h = 0; h < TH; h++) { + for(int w = 0; w < TW; w++) { + if(jump_in[h][w] >= 0) { + device T* out_ptr = out_out + jump_in[h][w]; + for(int c = simd_lane_id; c < BO; c += 32) { + out_ptr[c] = Os[kh + h][kw + w][c]; + } + } + } + } + + out_out += BO; + out_in_0 += BO; + out_in_1 += BO; + } + +} + +#define instantiate_winograd_conv_2d_output_transform(name, itype, bo) \ + template [[host_name("winograd_conv_2d_output_transform_" #name "_bo" #bo)]]\ + [[kernel]] void winograd_conv_2d_output_transform(\ + const device itype* out_in [[buffer(0)]],\ + device itype* out_out [[buffer(1)]],\ + const constant MLXConvParams<2>& params [[buffer(2)]],\ + uint3 tid [[threadgroup_position_in_grid]],\ + uint3 lid [[thread_position_in_threadgroup]],\ + uint3 tgp_per_grid [[threadgroups_per_grid]],\ + uint simd_group_id [[simdgroup_index_in_threadgroup]],\ + uint simd_lane_id [[thread_index_in_simdgroup]]); + +#define instantiate_winograd_conv_2d(name, itype) \ + instantiate_winograd_conv_2d_weight_transform_base(name, itype, 32) \ + instantiate_winograd_conv_2d_input_transform(name, itype, 32) \ + instantiate_winograd_conv_2d_output_transform(name, itype, 32) + +instantiate_winograd_conv_2d(float32, float); +instantiate_winograd_conv_2d(float16, half); \ No newline at end of file diff --git a/mlx/backend/metal/kernels/copy.metal b/mlx/backend/metal/kernels/copy.metal new file mode 100644 index 000000000..07ae22baa --- /dev/null +++ b/mlx/backend/metal/kernels/copy.metal @@ -0,0 +1,269 @@ +#include "mlx/backend/metal/kernels/bf16.h" +#include "mlx/backend/metal/kernels/utils.h" + +template +[[kernel]] void copy_s( + device const T* src, + device U* dst, + uint index [[thread_position_in_grid]]) { + dst[index] = static_cast(src[0]); +} + +template +[[kernel]] void copy_v( + device const T* src, + device U* dst, + uint index [[thread_position_in_grid]]) { + dst[index] = static_cast(src[index]); +} + +template +[[kernel]] void copy_g_nd1( + device const T* src, + device U* dst, + constant const size_t& src_stride, + uint index [[thread_position_in_grid]]) { + auto src_idx = elem_to_loc_1(index, src_stride); + dst[index] = static_cast(src[src_idx]); +} + +template +[[kernel]] void copy_g_nd2( + device const T* src, + device U* dst, + constant const size_t src_strides[2], + uint2 index [[thread_position_in_grid]], + uint2 grid_dim [[threads_per_grid]]) { + auto src_idx = elem_to_loc_2(index, src_strides); + size_t dst_idx = index.x + (size_t)grid_dim.x * index.y; + dst[dst_idx] = static_cast(src[src_idx]); +} + +template +[[kernel]] void copy_g_nd3( + device const T* src, + device U* dst, + constant const size_t src_strides[3], + uint3 index [[thread_position_in_grid]], + uint3 grid_dim [[threads_per_grid]]) { + auto src_idx = elem_to_loc_3(index, src_strides); + size_t dst_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z); + dst[dst_idx] = static_cast(src[src_idx]); +} + +template +[[kernel]] void copy_g_nd( + device const T* src, + device U* dst, + constant const int src_shape[DIM], + constant const size_t src_strides[DIM], + uint3 index [[thread_position_in_grid]], + uint3 grid_dim [[threads_per_grid]]) { + auto src_idx = elem_to_loc_nd(index, src_shape, src_strides); + size_t dst_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z); + dst[dst_idx] = static_cast(src[src_idx]); +} + +template +[[kernel]] void copy_g( + device const T* src, + device U* dst, + constant const int* src_shape, + constant const size_t* src_strides, + constant const int& ndim, + uint3 index [[thread_position_in_grid]], + uint3 grid_dim [[threads_per_grid]]) { + auto src_idx = elem_to_loc(index, src_shape, src_strides, ndim); + size_t dst_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z); + dst[dst_idx] = static_cast(src[src_idx]); +} + +template +[[kernel]] void copy_gg_nd1( + device const T* src, + device U* dst, + constant const size_t& src_stride, + constant const size_t& dst_stride, + uint index [[thread_position_in_grid]]) { + auto src_idx = elem_to_loc_1(index, src_stride); + auto dst_idx = elem_to_loc_1(index, dst_stride); + dst[dst_idx] = static_cast(src[src_idx]); +} + +template +[[kernel]] void copy_gg_nd2( + device const T* src, + device U* dst, + constant const size_t src_strides[2], + constant const size_t dst_strides[2], + uint2 index [[thread_position_in_grid]]) { + auto src_idx = elem_to_loc_2(index, src_strides); + auto dst_idx = elem_to_loc_2(index, dst_strides); + dst[dst_idx] = static_cast(src[src_idx]); +} + +template +[[kernel]] void copy_gg_nd3( + device const T* src, + device U* dst, + constant const size_t src_strides[3], + constant const size_t dst_strides[3], + uint3 index [[thread_position_in_grid]]) { + auto src_idx = elem_to_loc_3(index, src_strides); + auto dst_idx = elem_to_loc_3(index, dst_strides); + dst[dst_idx] = static_cast(src[src_idx]); +} + +template +[[kernel]] void copy_gg_nd( + device const T* src, + device U* dst, + constant const int src_shape[DIM], + constant const size_t src_strides[DIM], + constant const size_t dst_strides[DIM], + uint3 index [[thread_position_in_grid]]) { + auto src_idx = elem_to_loc_nd(index, src_shape, src_strides); + auto dst_idx = elem_to_loc_nd(index, src_shape, dst_strides); + dst[dst_idx] = static_cast(src[src_idx]); +} + +template +[[kernel]] void copy_gg( + device const T* src, + device U* dst, + constant const int* src_shape, + constant const size_t* src_strides, + constant const size_t* dst_strides, + constant const int& ndim, + uint3 index [[thread_position_in_grid]]) { + auto src_idx = elem_to_loc(index, src_shape, src_strides, ndim); + auto dst_idx = elem_to_loc(index, src_shape, dst_strides, ndim); + dst[dst_idx] = static_cast(src[src_idx]); +} + +#define instantiate_copy(name, itype, otype, ctype) \ + template [[host_name(name)]] \ + [[kernel]] void copy_##ctype( \ + device const itype* src, \ + device otype* dst, \ + uint index [[thread_position_in_grid]]); + +#define instantiate_copy_g_dim(name, itype, otype, dims) \ + template [[host_name(name "_" #dims)]] \ + [[kernel]] void copy_g_nd( \ + device const itype* src, \ + device otype* dst, \ + constant const int src_shape[dims], \ + constant const size_t src_strides[dims], \ + uint3 index [[thread_position_in_grid]], \ + uint3 grid_dim [[threads_per_grid]]); \ + template [[host_name("g" name "_" #dims)]] \ + [[kernel]] void copy_gg_nd( \ + device const itype* src, \ + device otype* dst, \ + constant const int src_shape[dims], \ + constant const size_t src_strides[dims], \ + constant const size_t dst_strides[dims], \ + uint3 index [[thread_position_in_grid]]); + + +#define instantiate_copy_g_nd(name, itype, otype) \ + template [[host_name(name "_1")]] \ + [[kernel]] void copy_g_nd1( \ + device const itype* src, \ + device otype* dst, \ + constant const size_t& src_stride, \ + uint index [[thread_position_in_grid]]); \ + template [[host_name(name "_2")]] \ + [[kernel]] void copy_g_nd2( \ + device const itype* src, \ + device otype* dst, \ + constant const size_t src_strides[2], \ + uint2 index [[thread_position_in_grid]], \ + uint2 grid_dim [[threads_per_grid]]); \ + template [[host_name(name "_3")]] \ + [[kernel]] void copy_g_nd3( \ + device const itype* src, \ + device otype* dst, \ + constant const size_t src_strides[3], \ + uint3 index [[thread_position_in_grid]], \ + uint3 grid_dim [[threads_per_grid]]); \ + template [[host_name("g" name "_1")]] \ + [[kernel]] void copy_gg_nd1( \ + device const itype* src, \ + device otype* dst, \ + constant const size_t& src_stride, \ + constant const size_t& dst_stride, \ + uint index [[thread_position_in_grid]]); \ + template [[host_name("g" name "_2")]] \ + [[kernel]] void copy_gg_nd2( \ + device const itype* src, \ + device otype* dst, \ + constant const size_t src_strides[2], \ + constant const size_t dst_strides[2], \ + uint2 index [[thread_position_in_grid]]); \ + template [[host_name("g" name "_3")]] \ + [[kernel]] void copy_gg_nd3( \ + device const itype* src, \ + device otype* dst, \ + constant const size_t src_strides[3], \ + constant const size_t dst_strides[3], \ + uint3 index [[thread_position_in_grid]]); \ + instantiate_copy_g_dim(name, itype, otype, 4) \ + instantiate_copy_g_dim(name, itype, otype, 5) + + +#define instantiate_copy_g(name, itype, otype) \ + template [[host_name(name)]] \ + [[kernel]] void copy_g( \ + device const itype* src, \ + device otype* dst, \ + constant const int* src_shape, \ + constant const size_t* src_strides, \ + constant const int& ndim, \ + uint3 index [[thread_position_in_grid]], \ + uint3 grid_dim [[threads_per_grid]]); \ + template [[host_name("g" name)]] \ + [[kernel]] void copy_gg( \ + device const itype* src, \ + device otype* dst, \ + constant const int* src_shape, \ + constant const size_t* src_strides, \ + constant const size_t* dst_strides, \ + constant const int& ndim, \ + uint3 index [[thread_position_in_grid]]); + +#define instantiate_copy_all(tname, itype, otype) \ + instantiate_copy("scopy" #tname, itype, otype, s) \ + instantiate_copy("vcopy" #tname, itype, otype, v) \ + instantiate_copy_g("gcopy" #tname, itype, otype) \ + instantiate_copy_g_nd("gcopy" #tname, itype, otype) + +#define instantiate_copy_itype(itname, itype) \ + instantiate_copy_all(itname ##bool_, itype, bool) \ + instantiate_copy_all(itname ##uint8, itype, uint8_t) \ + instantiate_copy_all(itname ##uint16, itype, uint16_t) \ + instantiate_copy_all(itname ##uint32, itype, uint32_t) \ + instantiate_copy_all(itname ##uint64, itype, uint64_t) \ + instantiate_copy_all(itname ##int8, itype, int8_t) \ + instantiate_copy_all(itname ##int16, itype, int16_t) \ + instantiate_copy_all(itname ##int32, itype, int32_t) \ + instantiate_copy_all(itname ##int64, itype, int64_t) \ + instantiate_copy_all(itname ##float16, itype, half) \ + instantiate_copy_all(itname ##float32, itype, float) \ + instantiate_copy_all(itname ##bfloat16, itype, bfloat16_t) \ + instantiate_copy_all(itname ##complex64, itype, complex64_t) + +instantiate_copy_itype(bool_, bool) +instantiate_copy_itype(uint8, uint8_t) +instantiate_copy_itype(uint16, uint16_t) +instantiate_copy_itype(uint32, uint32_t) +instantiate_copy_itype(uint64, uint64_t) +instantiate_copy_itype(int8, int8_t) +instantiate_copy_itype(int16, int16_t) +instantiate_copy_itype(int32, int32_t) +instantiate_copy_itype(int64, int64_t) +instantiate_copy_itype(float16, half) +instantiate_copy_itype(float32, float) +instantiate_copy_itype(bfloat16, bfloat16_t) +instantiate_copy_itype(complex64, complex64_t) diff --git a/mlx/backend/metal/kernels/erf.h b/mlx/backend/metal/kernels/erf.h new file mode 100644 index 000000000..de9cc58c9 --- /dev/null +++ b/mlx/backend/metal/kernels/erf.h @@ -0,0 +1,68 @@ +#pragma once + +#include + +/* + * Approximation to the error function. + * Based on code from: + * https://stackoverflow.com/questions/35148198/efficient-faithfully-rounded-implementation-of-error-function-erff#answer-35148199 + */ +float erf(float a) { + float r, s, t, u; + t = metal::abs(a); + s = a * a; + if (t > 0.927734375f) { + // maximum error 0.99527 ulp + r = metal::fma( + -1.72853470e-5f, t, 3.83197126e-4f); // -0x1.220000p-16,0x1.91cfb2p-12 + u = metal::fma( + -3.88396438e-3f, t, 2.42546219e-2f); // -0x1.fd1438p-9, 0x1.8d6342p-6 + r = metal::fma(r, s, u); + r = metal::fma(r, t, -1.06777877e-1f); // -0x1.b55cb8p-4 + r = metal::fma(r, t, -6.34846687e-1f); // -0x1.450aa0p-1 + r = metal::fma(r, t, -1.28717512e-1f); // -0x1.079d0cp-3 + r = metal::fma(r, t, -t); + // TODO, replace with expm1 when implemented + r = 1.0f - metal::exp(r); + r = metal::copysign(r, a); + } else { + // maximum error 0.98929 ulp + r = -5.96761703e-4f; // -0x1.38e000p-11 + r = metal::fma(r, s, 4.99119423e-3f); // 0x1.471a58p-8 + r = metal::fma(r, s, -2.67681349e-2f); // -0x1.b691b2p-6 + r = metal::fma(r, s, 1.12819925e-1f); // 0x1.ce1c44p-4 + r = metal::fma(r, s, -3.76125336e-1f); // -0x1.812700p-2 + r = metal::fma(r, s, 1.28379166e-1f); // 0x1.06eba8p-3 + r = metal::fma(r, a, a); + } + return r; +} + +float erfinv(float a) { + auto t = metal::fma(a, 0.0f - a, 1.0f); + t = metal::log(t); + float p; + if (metal::abs(t) > 6.125f) { // maximum ulp error = 2.35793 + p = 3.03697567e-10f; // 0x1.4deb44p-32 + p = metal::fma(p, t, 2.93243101e-8f); // 0x1.f7c9aep-26 + p = metal::fma(p, t, 1.22150334e-6f); // 0x1.47e512p-20 + p = metal::fma(p, t, 2.84108955e-5f); // 0x1.dca7dep-16 + p = metal::fma(p, t, 3.93552968e-4f); // 0x1.9cab92p-12 + p = metal::fma(p, t, 3.02698812e-3f); // 0x1.8cc0dep-9 + p = metal::fma(p, t, 4.83185798e-3f); // 0x1.3ca920p-8 + p = metal::fma(p, t, -2.64646143e-1f); // -0x1.0eff66p-2 + p = metal::fma(p, t, 8.40016484e-1f); // 0x1.ae16a4p-1 + } else { // maximum ulp error = 2.35002 + p = 5.43877832e-9f; // 0x1.75c000p-28 + p = metal::fma(p, t, 1.43285448e-7f); // 0x1.33b402p-23 + p = metal::fma(p, t, 1.22774793e-6f); // 0x1.499232p-20 + p = metal::fma(p, t, 1.12963626e-7f); // 0x1.e52cd2p-24 + p = metal::fma(p, t, -5.61530760e-5f); // -0x1.d70bd0p-15 + p = metal::fma(p, t, -1.47697632e-4f); // -0x1.35be90p-13 + p = metal::fma(p, t, 2.31468678e-3f); // 0x1.2f6400p-9 + p = metal::fma(p, t, 1.15392581e-2f); // 0x1.7a1e50p-7 + p = metal::fma(p, t, -2.32015476e-1f); // -0x1.db2aeep-3 + p = metal::fma(p, t, 8.86226892e-1f); // 0x1.c5bf88p-1 + } + return a * p; +} \ No newline at end of file diff --git a/mlx/backend/metal/kernels/gemm.metal b/mlx/backend/metal/kernels/gemm.metal new file mode 100644 index 000000000..9f51e6ff8 --- /dev/null +++ b/mlx/backend/metal/kernels/gemm.metal @@ -0,0 +1,91 @@ +#include "mlx/backend/metal/kernels/bf16.h" +#include "mlx/backend/metal/kernels/gemm/gemm.h" + +using namespace metal; + +/////////////////////////////////////////////////////////////////////////////// +// GEMM kernels +/////////////////////////////////////////////////////////////////////////////// + +template +[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void gemm( + const device T *A [[buffer(0)]], + const device T *B [[buffer(1)]], + device T *C [[buffer(2)]], + const constant int &M [[buffer(3)]], + const constant int &N [[buffer(4)]], + const constant int &K [[buffer(5)]], + const constant int &batch_stride_a [[buffer(6)]], + const constant int &batch_stride_b [[buffer(7)]], + const constant int &batch_stride_c [[buffer(8)]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]) { + + using gemm_kernel = GEMMKernel; + + threadgroup T tgp_memory[gemm_kernel::tgp_mem_size]; + + gemm_kernel::run( + A, B, C, + M, N, K, + batch_stride_a, batch_stride_b, batch_stride_c, + tgp_memory, + simd_lane_id, simd_group_id, tid, lid + ); +} + +/////////////////////////////////////////////////////////////////////////////// +// GEMM kernel initializations +/////////////////////////////////////////////////////////////////////////////// + +#define instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \ + template [[host_name("gemm_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_MN_" #aname "_K_" #kname)]] \ + [[kernel]] void gemm( \ + const device itype *A [[buffer(0)]], \ + const device itype *B [[buffer(1)]], \ + device itype *C [[buffer(2)]], \ + const constant int &M [[buffer(3)]], \ + const constant int &N [[buffer(4)]], \ + const constant int &K [[buffer(5)]], \ + const constant int &batch_stride_a [[buffer(6)]], \ + const constant int &batch_stride_b [[buffer(7)]], \ + const constant int &batch_stride_c [[buffer(8)]], \ + uint simd_lane_id [[thread_index_in_simdgroup]], \ + uint simd_group_id [[simdgroup_index_in_threadgroup]], \ + uint3 tid [[threadgroup_position_in_grid]], \ + uint3 lid [[thread_position_in_threadgroup]]); + +#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \ + instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \ + instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \ + instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false) + +#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn) + +#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \ + instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \ + instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) \ + instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \ + instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 16, 2, 2) + +instantiate_gemm_shapes_helper(float16, half, float16, half); +instantiate_gemm_shapes_helper(float32, float, float32, float); +instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t); + +// TODO: Accumulation in different type \ No newline at end of file diff --git a/mlx/backend/metal/kernels/random.metal b/mlx/backend/metal/kernels/random.metal new file mode 100644 index 000000000..d8e4c4208 --- /dev/null +++ b/mlx/backend/metal/kernels/random.metal @@ -0,0 +1,99 @@ +#include "mlx/backend/metal/kernels/utils.h" + +static constexpr constant uint32_t rotations[2][4] = { + {13, 15, 26, 6}, + {17, 29, 16, 24} +}; + +union rbits { + uint2 val; + uchar4 bytes[2]; +}; + +rbits threefry2x32_hash(const thread uint2& key, uint2 count) { + + uint4 ks = {key.x, key.y, key.x ^ key.y ^ 0x1BD11BDA}; + + rbits v; + v.val.x = count.x + ks[0]; + v.val.y = count.y + ks[1]; + + for (int i = 0; i < 5; ++i) { + for (auto r : rotations[i % 2]) { + v.val.x += v.val.y; + v.val.y = (v.val.y << r) | (v.val.y >> (32 - r)); + v.val.y ^= v.val.x; + } + v.val.x += ks[(i + 1) % 3]; + v.val.y += ks[(i + 2) % 3] + i + 1; + } + + return v; +} + +[[kernel]] void rbitsc( + device const uint32_t* keys, + device char* out, + device const bool& odd, + device const uint& bytes_per_key, + uint2 grid_dim [[threads_per_grid]], + uint2 index [[thread_position_in_grid]]) { + auto kidx = 2 * index.x; + auto key = uint2(keys[kidx], keys[kidx + 1]); + auto half_size = grid_dim.y - odd; + out += index.x * bytes_per_key; + bool drop_last = odd && (index.y == half_size); + auto count = uint2(index.y, drop_last ? 0 : index.y + grid_dim.y); + auto bits = threefry2x32_hash(key, count); + for (int i = 0; i < 4; ++i) { + out[4 * count.x + i] = bits.bytes[0][i]; + } + if (!drop_last) { + if ((index.y + 1) == half_size && (bytes_per_key % 4) > 0) { + int edge_bytes = (bytes_per_key % 4); + for (int i = 0; i < edge_bytes; ++i) { + out[4 * count.y + i] = bits.bytes[1][i]; + } + } else { + for (int i = 0; i < 4; ++i) { + out[4 * count.y + i] = bits.bytes[1][i]; + } + } + } +} + +[[kernel]] void rbits( + device const uint32_t* keys, + device char* out, + device const bool& odd, + device const uint& bytes_per_key, + device const int& ndim, + device const int* key_shape, + device const size_t* key_strides, + uint2 grid_dim [[threads_per_grid]], + uint2 index [[thread_position_in_grid]]) { + auto kidx = 2 * index.x; + auto k1_elem = elem_to_loc(kidx, key_shape, key_strides, ndim); + auto k2_elem = elem_to_loc(kidx + 1, key_shape, key_strides, ndim); + auto key = uint2(keys[k1_elem], keys[k2_elem]); + auto half_size = grid_dim.y - odd; + out += index.x * bytes_per_key; + bool drop_last = odd && (index.y == half_size); + auto count = uint2(index.y, drop_last ? 0 : index.y + grid_dim.y); + auto bits = threefry2x32_hash(key, count); + for (int i = 0; i < 4; ++i) { + out[4 * count.x + i] = bits.bytes[0][i]; + } + if (!drop_last) { + if ((index.y + 1) == half_size && (bytes_per_key % 4) > 0) { + int edge_bytes = (bytes_per_key % 4); + for (int i = 0; i < edge_bytes; ++i) { + out[4 * count.y + i] = bits.bytes[1][i]; + } + } else { + for (int i = 0; i < 4; ++i) { + out[4 * count.y + i] = bits.bytes[1][i]; + } + } + } +} diff --git a/mlx/backend/metal/kernels/reduce.metal b/mlx/backend/metal/kernels/reduce.metal new file mode 100644 index 000000000..f69514906 --- /dev/null +++ b/mlx/backend/metal/kernels/reduce.metal @@ -0,0 +1,536 @@ +#include +#include + +#include "mlx/backend/metal/kernels/defines.h" +#include "mlx/backend/metal/kernels/reduce.h" +#include "mlx/backend/metal/kernels/utils.h" + +using namespace metal; + +static constant uint8_t simd_size = 32; + +template +[[kernel]] void init_reduce( + device T *out [[buffer(0)]], + uint tid [[thread_position_in_grid]]) { + out[tid] = Op::init; +} + +#define instantiate_init_reduce(name, otype, op) \ + template [[host_name("i" #name)]] \ + [[kernel]] void init_reduce( \ + device otype *out [[buffer(1)]], \ + uint tid [[thread_position_in_grid]]); + + +/////////////////////////////////////////////////////////////////////////////// +// All reduce +/////////////////////////////////////////////////////////////////////////////// + +template +[[kernel]] void all_reduce( + const device T *in [[buffer(0)]], + device mlx_atomic *out [[buffer(1)]], + const device size_t& in_size [[buffer(2)]], + uint gid [[thread_position_in_grid]], + uint lid [[thread_position_in_threadgroup]], + uint grid_size [[threads_per_grid]], + uint simd_per_group [[simdgroups_per_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]]) { + // NB: this kernel assumes threads_per_threadgroup is at most + // 1024. This way with a simd_size of 32, we are guaranteed to + // complete the reduction in two steps of simd-level reductions. + + Op op; + threadgroup U local_vals[simd_size]; + + U total_val = Op::init; + + in += gid * N_READS; + + int r = 0; + for(; r < (int)ceildiv(in_size, grid_size * N_READS) - 1; r++) { + U vals[N_READS] = {op.init}; + + for(int i = 0; i < N_READS; i++) { + vals[i] = static_cast(in[i]); + } + for(int i = 0; i < N_READS; i++) { + total_val = op(vals[i], total_val); + } + + in += grid_size * N_READS; + } + + // Sepate case for the last set as we close the reduction size + size_t curr_idx = (gid + r * (size_t)grid_size) * N_READS; + if (curr_idx < in_size) { + int max_reads = in_size - curr_idx; + T vals[N_READS]; + + for(int i = 0, idx = 0; i < N_READS; i++, idx++) { + idx = idx < max_reads ? idx : max_reads - 1; + vals[i] = in[idx]; + } + for(int i = 0; i < N_READS; i++) { + U val = i < max_reads ? vals[i] : Op::init; + total_val = op(static_cast(val), total_val); + } + } + + // Reduction within simd group + total_val = op.simd_reduce(total_val); + if (simd_lane_id == 0) { + local_vals[simd_group_id] = total_val; + } + + // Reduction within thread group + threadgroup_barrier(mem_flags::mem_threadgroup); + total_val = lid < simd_per_group ? local_vals[lid] : op.init; + total_val = op.simd_reduce(total_val); + + // Reduction across threadgroups + if (lid == 0) { + op.atomic_update(out, total_val); + } +} + +#define instantiate_all_reduce(name, itype, otype, op) \ + template [[host_name("all_reduce_" #name)]] \ + [[kernel]] void all_reduce( \ + const device itype *in [[buffer(0)]], \ + device mlx_atomic *out [[buffer(1)]], \ + const device size_t& in_size [[buffer(2)]], \ + uint gid [[thread_position_in_grid]], \ + uint lid [[thread_position_in_threadgroup]], \ + uint grid_size [[threads_per_grid]], \ + uint simd_per_group [[simdgroups_per_threadgroup]], \ + uint simd_lane_id [[thread_index_in_simdgroup]], \ + uint simd_group_id [[simdgroup_index_in_threadgroup]]); + + +/////////////////////////////////////////////////////////////////////////////// +// General reduce +/////////////////////////////////////////////////////////////////////////////// + +template +[[kernel]] void general_reduce( + const device T *in [[buffer(0)]], + device mlx_atomic *out [[buffer(1)]], + const device int *in_shape [[buffer(2)]], + const device size_t *in_strides [[buffer(3)]], + const device size_t *out_strides [[buffer(4)]], + const device size_t& ndim [[buffer(5)]], + uint gid [[thread_position_in_grid]]) { + Op op; + auto in_idx = elem_to_loc(gid, in_shape, in_strides, ndim); + auto out_idx = elem_to_loc(gid, in_shape, out_strides, ndim); + op.atomic_update(out, static_cast(in[in_idx]), out_idx); +} + +template +[[kernel]] void general_reduce( + const device T *in [[buffer(0)]], + device mlx_atomic *out [[buffer(1)]], + const device int *in_shape [[buffer(2)]], + const device size_t *in_strides [[buffer(3)]], + const device size_t *out_strides [[buffer(4)]], + uint gid [[thread_position_in_grid]]) { + Op op; + auto in_idx = elem_to_loc_nd(gid, in_shape, in_strides); + auto out_idx = elem_to_loc_nd(gid, in_shape, out_strides); + op.atomic_update(out, static_cast(in[in_idx]), out_idx); +} + +#define instantiate_general_reduce_helper(name, itype, otype, op) \ + template [[host_name("general_reduce_" #name)]] \ + [[kernel]] void general_reduce( \ + const device itype *in [[buffer(0)]], \ + device mlx_atomic *out [[buffer(1)]], \ + const device int *in_shape [[buffer(2)]], \ + const device size_t *in_strides [[buffer(3)]], \ + const device size_t *out_strides [[buffer(4)]], \ + const device size_t& ndim [[buffer(5)]], \ + uint gid [[thread_position_in_grid]]); + +#define instantiate_general_reduce_helper_nd(name, itype, otype, op, n) \ + template [[host_name("general_reduce_" #name "_dim_" #n)]] \ + [[kernel]] void general_reduce( \ + const device itype *in [[buffer(0)]], \ + device mlx_atomic *out [[buffer(1)]], \ + const device int *in_shape [[buffer(2)]], \ + const device size_t *in_strides [[buffer(3)]], \ + const device size_t *out_strides [[buffer(4)]], \ + uint gid [[thread_position_in_grid]]); + +#define instantiate_general_reduce(name, itype, otype, op) \ + instantiate_general_reduce_helper(name, itype, otype, op) \ + instantiate_general_reduce_helper_nd(name, itype, otype, op, 1) \ + instantiate_general_reduce_helper_nd(name, itype, otype, op, 2) \ + instantiate_general_reduce_helper_nd(name, itype, otype, op, 3) \ + instantiate_general_reduce_helper_nd(name, itype, otype, op, 4) + + +/////////////////////////////////////////////////////////////////////////////// +// Row atomics +/////////////////////////////////////////////////////////////////////////////// + +template +[[kernel]] void row_reduce( + const device T *in [[buffer(0)]], + device U *out [[buffer(1)]], + const device size_t& reduction_size [[buffer(2)]], + uint lid [[thread_position_in_threadgroup]], + uint lsize [[threads_per_threadgroup]], + uint tid [[threadgroup_position_in_grid]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_per_group [[simdgroups_per_threadgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]]) { + + Op op; + + // Each threadgroup handles 1 reduction + in += tid * reduction_size + lid * N_READS; + + // The reduction is accumulated here + U total_val = Op::init; + threadgroup U local_vals[simd_size]; + + // Loop over the reduction size within thread group + int r = 0; + for (; r < (int)ceildiv(reduction_size, N_READS*lsize) - 1; r++) { + T vals[N_READS]; + for(int i = 0; i < N_READS; i++) { + vals[i] = in[i]; + } + for(int i = 0; i < N_READS; i++) { + total_val = op(static_cast(vals[i]), total_val); + } + + in += lsize * N_READS; + } + + // Sepate case for the last set as we close the reduction size + size_t reduction_index = (lid + (size_t)lsize * r) * N_READS; + if(reduction_index < reduction_size) { + int max_reads = reduction_size - reduction_index; + + T vals[N_READS]; + for(int i = 0; i < N_READS; i++) { + int idx = min(i, max_reads - 1); + vals[i] = static_cast(in[idx]); + } + for(int i = 0; i < N_READS; i++) { + T val = i < max_reads ? vals[i] : Op::init; + total_val = op(static_cast(val), total_val); + } + } + + total_val = op.simd_reduce(total_val); + + // Prepare next level + if (simd_lane_id == 0) { + local_vals[simd_group_id] = total_val; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Reduction within thread group + // Only needed if multiple simd groups + if(reduction_size > simd_size) { + total_val = lid < simd_per_group ? local_vals[lid] : op.init; + total_val = op.simd_reduce(total_val); + } + // Update output + if (lid == 0) { + out[tid] = total_val; + } +} + +#define instantiate_row_reduce(name, itype, otype, op) \ + template [[host_name("row_reduce_" #name)]] \ + [[kernel]] void row_reduce( \ + const device itype *in [[buffer(0)]], \ + device otype *out [[buffer(1)]], \ + const device size_t& reduction_size [[buffer(2)]], \ + uint lid [[thread_position_in_threadgroup]], \ + uint lsize [[threads_per_threadgroup]], \ + uint tid [[threadgroup_position_in_grid]], \ + uint simd_lane_id [[thread_index_in_simdgroup]], \ + uint simd_per_group [[simdgroups_per_threadgroup]], \ + uint simd_group_id [[simdgroup_index_in_threadgroup]]); + + +/////////////////////////////////////////////////////////////////////////////// +// Column reduce +/////////////////////////////////////////////////////////////////////////////// + +template +inline void _contiguous_strided_reduce( + const device T *in, + device mlx_atomic *out, + threadgroup U *local_data, + uint in_idx, + uint out_idx, + uint reduction_size, + uint reduction_stride, + uint2 tid, + uint2 lid, + uint2 lsize) { + + Op op; + T local_vals[N_READS]; + + uint base_offset = (tid.y * lsize.y + lid.y) * N_READS; + + for(uint r = 0; r < N_READS; r++) { + uint offset = base_offset + r; + offset = offset < reduction_size ? offset : reduction_size - 1; + local_vals[r] = in[in_idx + offset * reduction_stride]; + } + + U total_val = Op::init; + for(uint r = 0; r < N_READS && (base_offset + r) < reduction_size; r++) { + total_val = op(static_cast(total_val), local_vals[r]); + } + local_data[lsize.y * lid.x + lid.y] = total_val; + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if(lid.y == 0) { + U val = op.init; + + for(uint i = 0; i < lsize.y; i++) { + val = op(val, local_data[lsize.y * lid.x + i]); + } + + op.atomic_update(out, val, out_idx); + } +} + +template +[[kernel]] void col_reduce( + const device T *in [[buffer(0)]], + device mlx_atomic *out [[buffer(1)]], + const constant size_t& reduction_size [[buffer(2)]], + const constant size_t& reduction_stride [[buffer(3)]], + const constant size_t& out_size [[buffer(4)]], + threadgroup U *local_data [[threadgroup(0)]], + uint2 tid [[threadgroup_position_in_grid]], + uint2 lid [[thread_position_in_threadgroup]], + uint2 lsize [[threads_per_threadgroup]]) { + auto out_idx = tid.x * lsize.x + lid.x; + + if(out_idx < out_size) { + _contiguous_strided_reduce( + in, + out, + local_data, + out_idx, + out_idx, + reduction_size, + reduction_stride, + tid, + lid, + lsize); + } +} + +#define instantiate_col_reduce(name, itype, otype, op) \ + template [[host_name("col_reduce_" #name)]] \ + [[kernel]] void col_reduce( \ + const device itype *in [[buffer(0)]], \ + device mlx_atomic *out [[buffer(1)]], \ + const constant size_t& reduction_size [[buffer(2)]], \ + const constant size_t& reduction_stride [[buffer(3)]], \ + const constant size_t& out_size [[buffer(4)]], \ + threadgroup otype *local_data [[threadgroup(0)]], \ + uint2 tid [[threadgroup_position_in_grid]], \ + uint2 lid [[thread_position_in_threadgroup]], \ + uint2 lsize [[threads_per_threadgroup]]); + +template +[[kernel]] void contiguous_strided_reduce( + const device T *in [[buffer(0)]], + device mlx_atomic *out [[buffer(1)]], + const constant size_t& reduction_size [[buffer(2)]], + const constant size_t& reduction_stride [[buffer(3)]], + const constant size_t& out_size [[buffer(4)]], + const device int* in_shape [[buffer(5)]], + const device size_t* in_strides [[buffer(6)]], + threadgroup U *local_data [[threadgroup(0)]], + uint2 tid [[threadgroup_position_in_grid]], + uint2 lid [[thread_position_in_threadgroup]], + uint2 lsize [[threads_per_threadgroup]]) { + + auto out_idx = tid.x * lsize.x + lid.x; + auto in_idx = elem_to_loc_nd(out_idx, in_shape, in_strides); + + if(out_idx < out_size) { + _contiguous_strided_reduce( + in, + out, + local_data, + in_idx, + out_idx, + reduction_size, + reduction_stride, + tid, + lid, + lsize); + } +} + +template +[[kernel]] void contiguous_strided_reduce( + const device T *in [[buffer(0)]], + device mlx_atomic *out [[buffer(1)]], + const constant size_t& reduction_size [[buffer(2)]], + const constant size_t& reduction_stride [[buffer(3)]], + const constant size_t& out_size [[buffer(4)]], + const device int* in_shape [[buffer(5)]], + const device size_t* in_strides [[buffer(6)]], + const device size_t& in_dim [[buffer(7)]], + threadgroup U *local_data [[threadgroup(0)]], + uint2 tid [[threadgroup_position_in_grid]], + uint2 lid [[thread_position_in_threadgroup]], + uint2 lsize [[threads_per_threadgroup]]) { + + auto out_idx = tid.x * lsize.x + lid.x; + auto in_idx = elem_to_loc(out_idx, in_shape, in_strides, in_dim); + + if(out_idx < out_size) { + _contiguous_strided_reduce( + in, + out, + local_data, + in_idx, + out_idx, + reduction_size, + reduction_stride, + tid, + lid, + lsize); + } +} + +#define instantiate_contiguous_strided_helper(name, itype, otype, op) \ + template [[host_name("contiguous_strided_reduce_" #name)]] \ + [[kernel]] void contiguous_strided_reduce( \ + const device itype *in [[buffer(0)]], \ + device mlx_atomic *out [[buffer(1)]], \ + const constant size_t& reduction_size [[buffer(2)]], \ + const constant size_t& reduction_stride [[buffer(3)]], \ + const constant size_t& out_size [[buffer(4)]], \ + const device int* in_shape [[buffer(5)]], \ + const device size_t* in_strides [[buffer(6)]], \ + const device size_t& in_dim [[buffer(7)]], \ + threadgroup otype *local_data [[threadgroup(0)]], \ + uint2 tid [[threadgroup_position_in_grid]], \ + uint2 lid [[thread_position_in_threadgroup]], \ + uint2 lsize [[threads_per_threadgroup]]); + +#define instantiate_contiguous_strided_helper_nd(name, itype, otype, op, n) \ + template [[host_name("contiguous_strided_reduce_" #name "_dim_" #n)]] \ + [[kernel]] void contiguous_strided_reduce( \ + const device itype *in [[buffer(0)]], \ + device mlx_atomic *out [[buffer(1)]], \ + const constant size_t& reduction_size [[buffer(2)]], \ + const constant size_t& reduction_stride [[buffer(3)]], \ + const constant size_t& out_size [[buffer(4)]], \ + const device int* in_shape [[buffer(5)]], \ + const device size_t* in_strides [[buffer(6)]], \ + threadgroup otype *local_data [[threadgroup(0)]], \ + uint2 tid [[threadgroup_position_in_grid]], \ + uint2 lid [[thread_position_in_threadgroup]], \ + uint2 lsize [[threads_per_threadgroup]]); + +#define instantiate_contiguous_strided(name, itype, otype, op) \ + instantiate_contiguous_strided_helper(name, itype, otype, op) \ + instantiate_contiguous_strided_helper_nd(name, itype, otype, op, 1) \ + instantiate_contiguous_strided_helper_nd(name, itype, otype, op, 2) \ + instantiate_contiguous_strided_helper_nd(name, itype, otype, op, 3) \ + instantiate_contiguous_strided_helper_nd(name, itype, otype, op, 4) + + +/////////////////////////////////////////////////////////////////////////////// +// Instantiations +/////////////////////////////////////////////////////////////////////////////// + +#define instantiate_reduce(name, itype, otype, op) \ + instantiate_all_reduce(name, itype, otype, op) \ + instantiate_row_reduce(name, itype, otype, op) \ + instantiate_col_reduce(name, itype, otype, op) \ + instantiate_contiguous_strided(name, itype, otype, op) \ + instantiate_general_reduce(name, itype, otype, op) + +#define instantiate_same_reduce(name, tname, type, op) \ + instantiate_init_reduce(name ##tname, type, op) \ + instantiate_reduce(name ##tname, type, type, op) + +#define instantiate_reduce_from_types_helper(name, tname, itype, otype, op) \ + instantiate_reduce(name ##tname, itype, otype, op) + +#define instantiate_reduce_from_types(name, otype, op) \ + instantiate_reduce_from_types_helper(name, bool_, bool, otype, op) \ + instantiate_reduce_from_types_helper(name, uint8, uint8_t, otype, op) \ + instantiate_reduce_from_types_helper(name, uint16, uint16_t, otype, op) \ + instantiate_reduce_from_types_helper(name, uint32, uint32_t, otype, op) \ + instantiate_reduce_from_types_helper(name, int8, int8_t, otype, op) \ + instantiate_reduce_from_types_helper(name, int16, int16_t, otype, op) \ + instantiate_reduce_from_types_helper(name, int32, int32_t, otype, op) \ + instantiate_reduce_from_types_helper(name, int64, int64_t, otype, op) \ + instantiate_reduce_from_types_helper(name, float16, half, otype, op) \ + instantiate_reduce_from_types_helper(name, float32, float, otype, op) \ + instantiate_reduce_from_types_helper(name, bfloat16, bfloat16_t, otype, op) + +// special case bool with larger output type +instantiate_reduce(sumbool_, bool, uint32_t, Sum) +instantiate_same_reduce(sum, uint8, uint8_t, Sum) +instantiate_same_reduce(sum, uint16, uint16_t, Sum) +instantiate_same_reduce(sum, uint32, uint32_t, Sum) +instantiate_same_reduce(sum, int8, int8_t, Sum) +instantiate_same_reduce(sum, int16, int16_t, Sum) +instantiate_same_reduce(sum, int32, int32_t, Sum) +instantiate_same_reduce(sum, float16, half, Sum) +instantiate_same_reduce(sum, float32, float, Sum) + +instantiate_same_reduce(prod, uint8, uint8_t, Prod) +instantiate_same_reduce(prod, uint16, uint16_t, Prod) +instantiate_same_reduce(prod, uint32, uint32_t, Prod) +instantiate_same_reduce(prod, int8, int8_t, Prod) +instantiate_same_reduce(prod, int16, int16_t, Prod) +instantiate_same_reduce(prod, int32, int32_t, Prod) +instantiate_same_reduce(prod, float16, half, Prod) +instantiate_same_reduce(prod, float32, float, Prod) + +instantiate_same_reduce(sum, bfloat16, bfloat16_t, Sum) +instantiate_same_reduce(prod, bfloat16, bfloat16_t, Prod) + +instantiate_init_reduce(andbool_, bool, And) +instantiate_reduce_from_types(and, bool, And) + +instantiate_init_reduce(orbool_, bool, Or) +instantiate_reduce_from_types(or, bool, Or) + +// Compiler segfaulted with the names "min" or "max" ... +instantiate_same_reduce(min_, uint8, uint8_t, Min) +instantiate_same_reduce(min_, uint16, uint16_t, Min) +instantiate_same_reduce(min_, uint32, uint32_t, Min) +instantiate_same_reduce(min_, int8, int8_t, Min) +instantiate_same_reduce(min_, int16, int16_t, Min) +instantiate_same_reduce(min_, int32, int32_t, Min) +instantiate_same_reduce(min_, float16, half, Min) +instantiate_same_reduce(min_, float32, float, Min) + +instantiate_same_reduce(max_, uint8, uint8_t, Max) +instantiate_same_reduce(max_, uint16, uint16_t, Max) +instantiate_same_reduce(max_, uint32, uint32_t, Max) +instantiate_same_reduce(max_, int8, int8_t, Max) +instantiate_same_reduce(max_, int16, int16_t, Max) +instantiate_same_reduce(max_, int32, int32_t, Max) +instantiate_same_reduce(max_, float16, half, Max) +instantiate_same_reduce(max_, float32, float, Max) + +instantiate_same_reduce(min_, bfloat16, bfloat16_t, Min) +instantiate_same_reduce(max_, bfloat16, bfloat16_t, Max) \ No newline at end of file diff --git a/mlx/backend/metal/kernels/unary.metal b/mlx/backend/metal/kernels/unary.metal new file mode 100644 index 000000000..213e440d4 --- /dev/null +++ b/mlx/backend/metal/kernels/unary.metal @@ -0,0 +1,284 @@ +#include +#include + +#include "mlx/backend/metal/kernels/utils.h" +#include "mlx/backend/metal/kernels/erf.h" +#include "mlx/backend/metal/kernels/bf16.h" + +struct Abs { + template T operator()(T x) { return metal::abs(x); }; + template <> uint8_t operator()(uint8_t x) { return x; }; + template <> uint16_t operator()(uint16_t x) { return x; }; + template <> uint32_t operator()(uint32_t x) { return x; }; + template <> uint64_t operator()(uint64_t x) { return x; }; + template <> bool operator()(bool x) { return x; }; + template <> complex64_t operator()(complex64_t x) { + return {metal::precise::sqrt(x.real * x.real + x.imag * x.imag), 0}; + }; +}; + +struct ArcCos { + template T operator()(T x) { return metal::precise::acos(x); }; +}; + +struct ArcCosh { + template T operator()(T x) { return metal::precise::acosh(x); }; +}; + +struct ArcSin { + template T operator()(T x) { return metal::precise::asin(x); }; +}; + +struct ArcSinh { + template T operator()(T x) { return metal::precise::asinh(x); }; +}; + +struct ArcTan { + template T operator()(T x) { return metal::precise::atan(x); }; +}; + +struct ArcTanh { + template T operator()(T x) { return metal::precise::atanh(x); }; +}; + +struct Cos { + template T operator()(T x) { return metal::precise::cos(x); }; + + template <> + complex64_t operator()(complex64_t x) { + return { + metal::precise::cos(x.real) * metal::precise::cosh(x.imag), + -metal::precise::sin(x.real) * metal::precise::sinh(x.imag) + }; + }; +}; + +struct Cosh { + template T operator()(T x) { return metal::precise::cosh(x); }; + + template <> + complex64_t operator()(complex64_t x) { + return { + metal::precise::cosh(x.real) * metal::precise::cos(x.imag), + metal::precise::sinh(x.real) * metal::precise::sin(x.imag) + }; + }; +}; + +struct Erf { + template T operator()(T x) { return static_cast(erf(static_cast(x))); }; +}; + +struct ErfInv { + template T operator()(T x) { return static_cast(erfinv(static_cast(x))); }; +}; + +struct Exp { + template T operator()(T x) { return metal::precise::exp(x); }; + template <> complex64_t operator()(complex64_t x) { + auto m = metal::precise::exp(x.real); + return {m * metal::precise::cos(x.imag), m * metal::precise::sin(x.imag)}; + } +}; + +struct Log { + template T operator()(T x) { return metal::precise::log(x); }; +}; + +struct Log2 { + template T operator()(T x) { return metal::precise::log2(x); }; +}; + +struct Log10 { + template T operator()(T x) { return metal::precise::log10(x); }; +}; + +struct Log1p { + template T operator()(T x) { return log1p(x); }; +}; + +struct LogicalNot { + template T operator()(T x) { return !x; }; +}; + +struct Negative { + template T operator()(T x) { return -x; }; +}; + +struct Sigmoid { + template + T operator()(T x) { + auto y = 1 / (1 + metal::exp(-metal::abs(x))); + return (x < 0) ? 1 - y : y; + } +}; + +struct Sign { + template T operator()(T x) { return (x > T(0)) - (x < T(0)); }; + template <> uint32_t operator()(uint32_t x) { return x != 0; }; +}; + +struct Sin { + template T operator()(T x) { return metal::precise::sin(x); }; + + template <> + complex64_t operator()(complex64_t x) { + return { + metal::precise::sin(x.real) * metal::precise::cosh(x.imag), + metal::precise::cos(x.real) * metal::precise::sinh(x.imag) + }; + }; +}; + +struct Sinh { + template T operator()(T x) { return metal::precise::sinh(x); }; + + template <> + complex64_t operator()(complex64_t x) { + return { + metal::precise::sinh(x.real) * metal::precise::cos(x.imag), + metal::precise::cosh(x.real) * metal::precise::sin(x.imag) + }; + }; +}; + +struct Square { + template T operator()(T x) { return x * x; }; +}; + +struct Sqrt { + template T operator()(T x) { return metal::precise::sqrt(x); }; +}; + +struct Rsqrt { + template T operator()(T x) { return metal::precise::rsqrt(x); }; +}; + +struct Tan { + template T operator()(T x) { return metal::precise::tan(x); }; + + template <> + complex64_t operator()(complex64_t x) { + float tan_a = metal::precise::tan(x.real); + float tanh_b = metal::precise::tanh(x.imag); + float t1 = tan_a * tanh_b; + float denom = 1. + t1 * t1; + return { + (tan_a - tanh_b * t1) / denom, + (tanh_b + tan_a * t1) / denom + }; + }; +}; + +struct Tanh { + template T operator()(T x) { return metal::precise::tanh(x); }; + + template <> + complex64_t operator()(complex64_t x) { + float tanh_a = metal::precise::tanh(x.real); + float tan_b = metal::precise::tan(x.imag); + float t1 = tanh_a * tan_b; + float denom = 1. + t1 * t1; + return { + (tanh_a + tan_b * t1) / denom, + (tan_b - tanh_a * t1) / denom + }; + }; +}; + +template +[[kernel]] void unary_op_v( + device const T* in, + device T* out, + uint index [[thread_position_in_grid]]) { + out[index] = Op()(in[index]); +} + +template +[[kernel]] void unary_op_g( + device const T* in, + device T* out, + device const int* in_shape, + device const size_t* in_strides, + device const int& ndim, + uint index [[thread_position_in_grid]]) { + auto idx = elem_to_loc(index, in_shape, in_strides, ndim); + out[index] = Op()(in[idx]); +} + +#define instantiate_unary_v(name, type, op) \ + template [[host_name(name)]] \ + [[kernel]] void unary_op_v( \ + device const type* in, \ + device type* out, \ + uint index [[thread_position_in_grid]]); + +#define instantiate_unary_g(name, type, op) \ + template [[host_name(name)]] \ + [[kernel]] void unary_op_g( \ + device const type* in, \ + device type* out, \ + device const int* in_shape, \ + device const size_t* in_strides, \ + device const int& ndim, \ + uint index [[thread_position_in_grid]]); + +#define instantiate_unary_all(name, tname, type, op) \ + instantiate_unary_v("v" #name #tname, type, op) \ + instantiate_unary_g("g" #name #tname, type, op) + +#define instantiate_unary_float(name, op) \ + instantiate_unary_all(name, float16, half, op) \ + instantiate_unary_all(name, float32, float, op) \ + instantiate_unary_all(name, bfloat16, bfloat16_t, op) \ + +#define instantiate_unary_types(name, op) \ + instantiate_unary_all(name, bool_, bool, op) \ + instantiate_unary_all(name, uint8, uint8_t, op) \ + instantiate_unary_all(name, uint16, uint16_t, op) \ + instantiate_unary_all(name, uint32, uint32_t, op) \ + instantiate_unary_all(name, uint64, uint64_t, op) \ + instantiate_unary_all(name, int8, int8_t, op) \ + instantiate_unary_all(name, int16, int16_t, op) \ + instantiate_unary_all(name, int32, int32_t, op) \ + instantiate_unary_all(name, int64, int64_t, op) \ + instantiate_unary_float(name, op) + +instantiate_unary_types(abs, Abs) +instantiate_unary_float(arccos, ArcCos) +instantiate_unary_float(arccosh, ArcCosh) +instantiate_unary_float(arcsin, ArcSin) +instantiate_unary_float(arcsinh, ArcSinh) +instantiate_unary_float(arctan, ArcTan) +instantiate_unary_float(arctanh, ArcTanh) +instantiate_unary_float(cos, Cos) +instantiate_unary_float(cosh, Cosh) +instantiate_unary_float(exp, Exp) +instantiate_unary_float(log, Log) +instantiate_unary_float(log2, Log2) +instantiate_unary_float(log10, Log10) +instantiate_unary_float(log1p, Log1p) +instantiate_unary_types(neg, Negative) +instantiate_unary_float(sigmoid, Sigmoid) +instantiate_unary_float(erf, Erf) +instantiate_unary_float(erfinv, ErfInv) +instantiate_unary_types(sign, Sign) +instantiate_unary_float(sin, Sin) +instantiate_unary_float(sinh, Sinh) +instantiate_unary_types(square, Square) +instantiate_unary_float(sqrt, Sqrt) +instantiate_unary_float(rsqrt, Rsqrt) +instantiate_unary_float(tan, Tan) +instantiate_unary_float(tanh, Tanh) + +instantiate_unary_all(abs, complex64, complex64_t, Abs) +instantiate_unary_all(cos, complex64, complex64_t, Cos) +instantiate_unary_all(cosh, complex64, complex64_t, Cosh) +instantiate_unary_all(exp, complex64, complex64_t, Exp) +instantiate_unary_all(neg, complex64, complex64_t, Negative) +instantiate_unary_all(sin, complex64, complex64_t, Sin) +instantiate_unary_all(sinh, complex64, complex64_t, Sinh) +instantiate_unary_all(tan, complex64, complex64_t, Tan) +instantiate_unary_all(tanh, complex64, complex64_t, Tanh) + +instantiate_unary_all(lnot, bool_, bool, LogicalNot) diff --git a/mlx/backend/metal/reduce.cpp b/mlx/backend/metal/reduce.cpp new file mode 100644 index 000000000..84a368c53 --- /dev/null +++ b/mlx/backend/metal/reduce.cpp @@ -0,0 +1,369 @@ +#include +#include +#include + +#include "mlx/backend/common/reduce.h" +#include "mlx/backend/metal/device.h" +#include "mlx/backend/metal/kernels/defines.h" +#include "mlx/backend/metal/utils.h" +#include "mlx/primitives.h" +#include "mlx/utils.h" + +namespace mlx::core { + +////////////////////////////////////////////////////////////////////// +// Case wise reduce dispatch +////////////////////////////////////////////////////////////////////// + +namespace { + +// All Reduce +void all_reduce_dispatch( + const array& in, + array& out, + const std::string& op_name, + MTL::ComputeCommandEncoder* compute_encoder, + metal::Device& d) { + // Get kernel and encode buffers + size_t in_size = in.size(); + auto kernel = d.get_kernel("all_reduce_" + op_name + type_to_name(in)); + + compute_encoder->setComputePipelineState(kernel); + set_array_buffer(compute_encoder, in, 0); + set_array_buffer(compute_encoder, out, 1); + compute_encoder->setBytes(&in_size, sizeof(size_t), 2); + + // Set grid dimensions + + // We make sure each thread has enough to do by making it read in + // atleast n_reads inputs + int n_reads = REDUCE_N_READS; + + // mod_in_size gives us the groups of n_reads needed to go over the entire + // input + uint mod_in_size = (in_size + n_reads - 1) / n_reads; + + NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); + thread_group_size = + mod_in_size > thread_group_size ? thread_group_size : mod_in_size; + + // If the number of thread groups needed exceeds 1024, we reuse threads groups + uint n_thread_groups = + (mod_in_size + thread_group_size - 1) / thread_group_size; + n_thread_groups = std::min(n_thread_groups, 1024u); + uint nthreads = n_thread_groups * thread_group_size; + + MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); + MTL::Size grid_dims = MTL::Size(nthreads, 1, 1); + + compute_encoder->dispatchThreads(grid_dims, group_dims); +} + +void row_reduce_dispatch( + const array& in, + array& out, + const std::string& op_name, + const std::vector& axes_, + MTL::ComputeCommandEncoder* compute_encoder, + metal::Device& d) { + auto kernel = d.get_kernel("row_reduce_" + op_name + type_to_name(in)); + + int n_reads = REDUCE_N_READS; + size_t reduction_size = in.size() / out.size(); + + compute_encoder->setComputePipelineState(kernel); + set_array_buffer(compute_encoder, in, 0); + set_array_buffer(compute_encoder, out, 1); + compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2); + + // Each thread group is responsible for 1 output + NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); + thread_group_size = + std::min((reduction_size + n_reads - 1) / n_reads, thread_group_size); + + // Align thread group size with simd_size + uint simd_size = kernel->threadExecutionWidth(); + thread_group_size = + (thread_group_size + simd_size - 1) / simd_size * simd_size; + assert(thread_group_size <= kernel->maxTotalThreadsPerThreadgroup()); + + // Launch enough thread groups for each output + size_t n_threads = out.size() * thread_group_size; + MTL::Size grid_dims = MTL::Size(n_threads, 1, 1); + MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); + + compute_encoder->dispatchThreads(grid_dims, group_dims); +} + +void col_reduce_dispatch( + const array& in, + array& out, + const std::string& op_name, + const std::vector& axes_, + MTL::ComputeCommandEncoder* compute_encoder, + metal::Device& d) { + std::ostringstream kernel_name; + + bool encode_in_shape = false; + bool encode_ndim = false; + + // If the slowest moving axis can be merged into the reductions, + // we call the column reduce kernel + // In this case, a linear index in the output corresponds to the + // linear index in the input where the reduction starts + if (axes_[axes_.size() - 1] == (axes_.size() - 1)) { + kernel_name << "col_reduce_" << op_name << type_to_name(in); + } + // Otherwise, while all the reduction axes can be merged, the mapping between + // indices in the output and input require resolving using shapes and strides + else { + kernel_name << "contiguous_strided_reduce_" << op_name << type_to_name(in); + encode_in_shape = true; + + // We check for a viable template with the required number of dimensions + // we only care about encoding non-reduced shapes and strides in the input + size_t non_reducing_dims = in.ndim() - axes_.size(); + if (non_reducing_dims >= 1 && + non_reducing_dims <= MAX_REDUCE_SPECIALIZED_DIMS) { + kernel_name << "_dim_" << non_reducing_dims; + } else { + encode_ndim = true; + } + } + + auto kernel = d.get_kernel(kernel_name.str()); + size_t in_size = in.size(); + size_t out_size = out.size(); + + compute_encoder->setComputePipelineState(kernel); + set_array_buffer(compute_encoder, in, 0); + set_array_buffer(compute_encoder, out, 1); + + // Calculate the number of inputs to reduce and the stride b/w them + size_t reduction_size = 1; + size_t in_ndim = in.ndim(); + size_t reduction_stride = in_size; + + for (int i : axes_) { + reduction_size *= in.shape(i); + reduction_stride = std::min(reduction_stride, in.strides()[i]); + } + + compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2); + compute_encoder->setBytes(&reduction_stride, sizeof(size_t), 3); + compute_encoder->setBytes(&out_size, sizeof(size_t), 4); + if (encode_in_shape) { + // Obtain the non-reducing shape and strides of the input to encode + std::vector inp_shape_mod; + std::vector inp_strides_mod; + + for (size_t i = 0, j = 0; i < in.ndim(); i++) { + if (j < axes_.size() && axes_[j] == i) { + j++; + } else { + inp_shape_mod.push_back(in.shape(i)); + inp_strides_mod.push_back(in.strides()[i]); + } + } + + size_t ndim = inp_shape_mod.size(); + + compute_encoder->setBytes(inp_shape_mod.data(), ndim * sizeof(int), 5); + compute_encoder->setBytes(inp_strides_mod.data(), ndim * sizeof(size_t), 6); + + if (encode_ndim) { + compute_encoder->setBytes(&ndim, sizeof(size_t), 7); + } + } + + // Select block dimensions + + // Each thread reads 16 inputs to give it more work + uint n_inputs_per_thread = REDUCE_N_READS; + uint n_threads_per_output = + (reduction_size + n_inputs_per_thread - 1) / n_inputs_per_thread; + + // We spread outputs over the x dimension and inputs over the y dimension + // Threads with the same lid.x in a given threadgroup work on the same + // output and each thread in the y dimension accumlates for that output + uint threadgroup_dim_x = std::min(out_size, 128ul); + uint threadgroup_dim_y = + kernel->maxTotalThreadsPerThreadgroup() / threadgroup_dim_x; + threadgroup_dim_y = std::min(n_threads_per_output, threadgroup_dim_y); + + uint n_threadgroups_x = + (out_size + threadgroup_dim_x - 1) / threadgroup_dim_x; + + uint n_threadgroups_y = + (n_threads_per_output + threadgroup_dim_y - 1) / threadgroup_dim_y; + + // Launch enough thread groups for each output + MTL::Size grid_dims = MTL::Size(n_threadgroups_x, n_threadgroups_y, 1); + MTL::Size group_dims = MTL::Size(threadgroup_dim_x, threadgroup_dim_y, 1); + + // We set shared memory to be exploited here for reductions within a + // threadgroup - each thread must be able to update its accumulated output + // Note: Each threadgroup should have 32kB of data in threadgroup memory + // and threadgroup_dim_x * threadgroup_dim_y <= 1024 by design + // This should be fine for floats, but we might need to revisit + // if we ever come to doubles. In that case, we should also cut + // down the number of threads we launch in a threadgroup + compute_encoder->setThreadgroupMemoryLength( + threadgroup_dim_x * threadgroup_dim_y * out.itemsize(), 0); + + compute_encoder->dispatchThreadgroups(grid_dims, group_dims); +} + +void general_reduce_dispatch( + const array& in, + array& out, + const std::string& op_name, + const std::vector& axes_, + MTL::ComputeCommandEncoder* compute_encoder, + metal::Device& d) { + bool encode_ndim = true; + std::ostringstream kernel_name; + kernel_name << "general_reduce_" << op_name << type_to_name(in); + + // Check for specialzed kernels for input ndim + if (in.ndim() >= 1 && in.ndim() <= MAX_REDUCE_SPECIALIZED_DIMS) { + kernel_name << "_dim_" << in.ndim(); + encode_ndim = false; + } + auto kernel = d.get_kernel(kernel_name.str()); + size_t in_size = in.size(); + size_t ndim = in.ndim(); + + // We set the reducing strides to 0 to induce collisions for the reduction + std::vector out_strides(ndim); + size_t stride = 1; + for (int i = ndim - 1, j = axes_.size() - 1; i >= 0; --i) { + if (j >= 0 && axes_[j] == i) { + out_strides[i] = 0; + --j; + } else { + out_strides[i] = stride; + stride *= in.shape(i); + } + } + + compute_encoder->setComputePipelineState(kernel); + set_array_buffer(compute_encoder, in, 0); + set_array_buffer(compute_encoder, out, 1); + compute_encoder->setBytes(in.shape().data(), ndim * sizeof(int), 2); + compute_encoder->setBytes(in.strides().data(), ndim * sizeof(size_t), 3); + compute_encoder->setBytes(out_strides.data(), ndim * sizeof(size_t), 4); + if (encode_ndim) { + compute_encoder->setBytes(&ndim, sizeof(size_t), 5); + } + + NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); + if (thread_group_size > in_size) { + thread_group_size = in_size; + } + size_t nthreads = in_size; + + MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); + MTL::Size grid_dims = MTL::Size(nthreads, 1, 1); + compute_encoder->dispatchThreads(grid_dims, group_dims); +} + +} // namespace + +////////////////////////////////////////////////////////////////////// +// Main reduce dispatch +////////////////////////////////////////////////////////////////////// + +void Reduce::eval_gpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + auto& in = inputs[0]; + + // TODO: Allow specific row and column reductions with types disabled + // due to atomics ? + if (size_of(in.dtype()) == 8) { + std::ostringstream msg; + msg << "[Reduce::eval_gpu] Does not support " << in.dtype(); + throw std::runtime_error(msg.str()); + } + + // Make sure no identity reductions trickle down here + assert(!axes_.empty()); + + // Continue with reduction operation + out.set_data(allocator::malloc_or_wait(out.nbytes())); + std::string op_name; + switch (reduce_type_) { + case Reduce::And: + op_name = "and"; + break; + case Reduce::Or: + op_name = "or"; + break; + case Reduce::Sum: + op_name = "sum"; + break; + case Reduce::Prod: + op_name = out.dtype() == bool_ ? "and" : "prod"; + break; + case Reduce::Min: + op_name = out.dtype() == bool_ ? "and" : "min_"; + break; + case Reduce::Max: + op_name = out.dtype() == bool_ ? "or" : "max_"; + break; + } + + // Initialize output + auto& s = stream(); + auto& d = metal::device(s.device); + auto compute_encoder = d.get_command_encoder(s.index); + { + auto kernel = d.get_kernel("i" + op_name + type_to_name(out)); + size_t nthreads = out.size(); + MTL::Size grid_dims = MTL::Size(nthreads, 1, 1); + NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); + if (thread_group_size > nthreads) { + thread_group_size = nthreads; + } + MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); + compute_encoder->setComputePipelineState(kernel); + set_array_buffer(compute_encoder, out, 0); + compute_encoder->dispatchThreads(grid_dims, group_dims); + } + + // Reduce + { + // Check for contiguous data + if (in.size() == in.data_size() && + (in.flags().row_contiguous || in.flags().col_contiguous)) { + // Go to all reduce if reducing over all axes + if (axes_.size() == in.ndim()) { + all_reduce_dispatch(in, out, op_name, compute_encoder, d); + return; + } + // Use specialized kernels if the input is row contiguous and + // the reducing axes can be merged into one + else if ( + in.flags().row_contiguous && in.strides().back() == 1 && + (axes_.back() - axes_.front()) == axes_.size() - 1) { + // If the fastest moving axis is being reduced, go to row reduce + if (axes_[0] == (in.ndim() - axes_.size())) { + row_reduce_dispatch(in, out, op_name, axes_, compute_encoder, d); + return; + } + // Otherwise go to to generalized strided reduce + // Note: bool isn't support here yet due to the use of atomics + // once that is updated, this should be the else condition of this + // branch + else if (in.dtype() != bool_) { + col_reduce_dispatch(in, out, op_name, axes_, compute_encoder, d); + return; + } + } + } + // Fall back to the general case + general_reduce_dispatch(in, out, op_name, axes_, compute_encoder, d); + } +} + +} // namespace mlx::core diff --git a/mlx/backend/no_metal/allocator.cpp b/mlx/backend/no_metal/allocator.cpp new file mode 100644 index 000000000..718f3509b --- /dev/null +++ b/mlx/backend/no_metal/allocator.cpp @@ -0,0 +1,15 @@ + +#include "mlx/allocator.h" + +namespace mlx::core::allocator { + +Allocator& allocator() { + static CommonAllocator allocator_; + return allocator_; +} + +void* Buffer::raw_ptr() { + return ptr_; +} + +} // namespace mlx::core::allocator diff --git a/mlx/fft.h b/mlx/fft.h new file mode 100644 index 000000000..62a34f4e9 --- /dev/null +++ b/mlx/fft.h @@ -0,0 +1,149 @@ +#pragma once + +#include + +#include "array.h" +#include "device.h" +#include "stream.h" + +namespace mlx::core::fft { + +using StreamOrDevice = std::variant; + +/** Compute the n-dimensional Fourier Transform. */ +array fftn( + const array& a, + const std::vector& n, + const std::vector& axes, + StreamOrDevice s = {}); +array fftn(const array& a, const std::vector& axes, StreamOrDevice s = {}); +array fftn(const array& a, StreamOrDevice s = {}); + +/** Compute the n-dimensional inverse Fourier Transform. */ +array ifftn( + const array& a, + const std::vector& n, + const std::vector& axes, + StreamOrDevice s = {}); +array ifftn( + const array& a, + const std::vector& axes, + StreamOrDevice s = {}); +array ifftn(const array& a, StreamOrDevice s = {}); + +/** Compute the one-dimensional Fourier Transform. */ +inline array fft(const array& a, int n, int axis, StreamOrDevice s = {}) { + return fftn(a, {n}, {axis}, s); +} +inline array fft(const array& a, int axis = -1, StreamOrDevice s = {}) { + return fftn(a, {axis}, s); +} + +/** Compute the one-dimensional inverse Fourier Transform. */ +inline array ifft(const array& a, int n, int axis, StreamOrDevice s = {}) { + return ifftn(a, {n}, {axis}, s); +} +inline array ifft(const array& a, int axis = -1, StreamOrDevice s = {}) { + return ifftn(a, {axis}, s); +} + +/** Compute the two-dimensional Fourier Transform. */ +inline array fft2( + const array& a, + const std::vector& n, + const std::vector& axes, + StreamOrDevice s = {}) { + return fftn(a, n, axes, s); +} +inline array fft2( + const array& a, + const std::vector& axes = {-2, -1}, + StreamOrDevice s = {}) { + return fftn(a, axes, s); +} + +/** Compute the two-dimensional inverse Fourier Transform. */ +inline array ifft2( + const array& a, + const std::vector& n, + const std::vector& axes, + StreamOrDevice s = {}) { + return ifftn(a, n, axes, s); +} +inline array ifft2( + const array& a, + const std::vector& axes = {-2, -1}, + StreamOrDevice s = {}) { + return ifftn(a, axes, s); +} + +/** Compute the n-dimensional Fourier Transform on a real input. */ +array rfftn( + const array& a, + const std::vector& n, + const std::vector& axes, + StreamOrDevice s = {}); +array rfftn( + const array& a, + const std::vector& axes, + StreamOrDevice s = {}); +array rfftn(const array& a, StreamOrDevice s = {}); + +/** Compute the n-dimensional inverse of `rfftn`. */ +array irfftn( + const array& a, + const std::vector& n, + const std::vector& axes, + StreamOrDevice s = {}); +array irfftn( + const array& a, + const std::vector& axes, + StreamOrDevice s = {}); +array irfftn(const array& a, StreamOrDevice s = {}); + +/** Compute the one-dimensional Fourier Transform on a real input. */ +inline array rfft(const array& a, int n, int axis, StreamOrDevice s = {}) { + return rfftn(a, {n}, {axis}, s); +} +inline array rfft(const array& a, int axis = -1, StreamOrDevice s = {}) { + return rfftn(a, {axis}, s); +} +/** Compute the one-dimensional inverse of `rfft`. */ +inline array irfft(const array& a, int n, int axis, StreamOrDevice s = {}) { + return irfftn(a, {n}, {axis}, s); +} +inline array irfft(const array& a, int axis = -1, StreamOrDevice s = {}) { + return irfftn(a, {axis}, s); +} + +/** Compute the two-dimensional Fourier Transform on a real input. */ +inline array rfft2( + const array& a, + const std::vector& n, + const std::vector& axes, + StreamOrDevice s = {}) { + return rfftn(a, n, axes, s); +} +inline array rfft2( + const array& a, + const std::vector& axes = {-2, -1}, + StreamOrDevice s = {}) { + return rfftn(a, axes, s); +} + +/** Compute the two-dimensional inverse of `rfft2`. */ +inline array irfft2( + const array& a, + const std::vector& n, + const std::vector& axes, + StreamOrDevice s = {}) { + return irfftn(a, n, axes, s); +} +inline array irfft2( + const array& a, + const std::vector& axes = {-2, -1}, + StreamOrDevice s = {}) { + return irfftn(a, axes, s); +} + +} // namespace mlx::core::fft diff --git a/mlx/load.cpp b/mlx/load.cpp new file mode 100644 index 000000000..1d3b571ed --- /dev/null +++ b/mlx/load.cpp @@ -0,0 +1,240 @@ +#include +#include +#include +#include +#include + +#include "mlx/load.h" +#include "mlx/ops.h" +#include "mlx/primitives.h" +#include "mlx/utils.h" + +// Adapted from +// https://github.com/angeloskath/supervised-lda/blob/master/include/ldaplusplus/NumpyFormat.hpp + +namespace mlx::core { + +namespace { + +static constexpr uint8_t MAGIC[] = { + 0x93, + 0x4e, + 0x55, + 0x4d, + 0x50, + 0x59, +}; + +inline bool is_big_endian_() { + union ByteOrder { + int32_t i; + uint8_t c[4]; + }; + ByteOrder b = {0x01234567}; + + return b.c[0] == 0x01; +} + +} // namespace + +/** Save array to out stream in .npy format */ +void save(std::shared_ptr out_stream, array a, bool retain_graph) { + //////////////////////////////////////////////////////// + // Check array + + a.eval(retain_graph); + + if (a.nbytes() == 0) { + throw std::invalid_argument("[save] cannot serialize an empty array"); + } + + if (!a.flags().contiguous) { + throw std::invalid_argument( + "[save] cannot serialize a non-contiguous array"); + } + + //////////////////////////////////////////////////////// + // Check file + if (!out_stream->good() || !out_stream->is_open()) { + throw std::runtime_error("[save] Failed to open " + out_stream->label()); + } + + //////////////////////////////////////////////////////// + // Prepare header + std::ostringstream magic_ver_len; + magic_ver_len.write(reinterpret_cast(MAGIC), 6); + + std::string fortran_order = a.flags().col_contiguous ? "True" : "False"; + std::ostringstream header; + header << "{'descr': '" << dtype_to_array_protocol(a.dtype()) << "'," + << " 'fortran_order': " << fortran_order << "," + << " 'shape': ("; + for (auto i : a.shape()) { + header << i << ", "; + } + header << ")}"; + + size_t header_len = static_cast(header.tellp()); + bool is_v1 = header_len + 15 < std::numeric_limits::max(); + + // Pad out magic + version + header_len + header + \n to be divisible by 16 + size_t padding = (6 + 2 + (2 + 2 * is_v1) + header_len + 1) % 16; + + header << std::string(padding, ' ') << '\n'; + + if (is_v1) { + magic_ver_len << (char)0x01 << (char)0x00; + + uint16_t v1_header_len = header.tellp(); + const char* len_bytes = reinterpret_cast(&v1_header_len); + + if (!is_big_endian_()) { + magic_ver_len.write(len_bytes, 2); + } else { + magic_ver_len.write(len_bytes + 1, 1); + magic_ver_len.write(len_bytes, 1); + } + } else { + magic_ver_len << (char)0x02 << (char)0x00; + + uint32_t v2_header_len = header.tellp(); + const char* len_bytes = reinterpret_cast(&v2_header_len); + + if (!is_big_endian_()) { + magic_ver_len.write(len_bytes, 4); + } else { + magic_ver_len.write(len_bytes + 3, 1); + magic_ver_len.write(len_bytes + 2, 1); + magic_ver_len.write(len_bytes + 1, 1); + magic_ver_len.write(len_bytes, 1); + } + } + //////////////////////////////////////////////////////// + // Serialize array + + out_stream->write(magic_ver_len.str().c_str(), magic_ver_len.str().length()); + out_stream->write(header.str().c_str(), header.str().length()); + out_stream->write(a.data(), a.nbytes()); + + return; +} + +/** Save array to file in .npy format */ +void save(const std::string& file_, array a, bool retain_graph) { + // Open and check file + std::string file = file_; + + // Add .npy to file name if it is not there + if (file.length() < 4 || file.substr(file.length() - 4, 4) != ".npy") + file += ".npy"; + + // Serialize array + save(std::make_shared(file), a, retain_graph); +} + +/** Load array from reader in .npy format */ +array load(std::shared_ptr in_stream, StreamOrDevice s) { + //////////////////////////////////////////////////////// + // Open and check file + if (!in_stream->good() || !in_stream->is_open()) { + throw std::runtime_error("[load] Failed to open " + in_stream->label()); + } + + //////////////////////////////////////////////////////// + // Read header and prepare array details + + // Read and check magic + char read_magic_and_ver[8]; + in_stream->read(read_magic_and_ver, 8); + if (std::memcmp(read_magic_and_ver, MAGIC, 6) != 0) { + throw std::runtime_error("[load] Invalid header in " + in_stream->label()); + } + + // Read and check version + if (read_magic_and_ver[6] != 1 && read_magic_and_ver[6] != 2) { + throw std::runtime_error( + "[load] Unsupport npy format version in " + in_stream->label()); + } + + // Read header len and header + int header_len_size = read_magic_and_ver[6] == 1 ? 2 : 4; + size_t header_len; + + if (header_len_size == 2) { + uint16_t v1_header_len; + in_stream->read(reinterpret_cast(&v1_header_len), header_len_size); + header_len = v1_header_len; + } else { + uint32_t v2_header_len; + in_stream->read(reinterpret_cast(&v2_header_len), header_len_size); + header_len = v2_header_len; + } + + // Read the header + std::vector buffer(header_len + 1); + in_stream->read(&buffer[0], header_len); + buffer[header_len] = 0; + std::string header(&buffer[0]); + + // Read data type from header + std::string dtype_str = header.substr(11, 3); + bool read_is_big_endian = dtype_str[0] == '>'; + Dtype dtype = dtype_from_array_protocol(dtype_str); + + // Read contiguity order + bool col_contiguous = header[34] == 'T'; + + // Read array shape from header + std::vector shape; + + size_t st = header.find_last_of('(') + 1; + size_t ed = header.find_last_of(')'); + std::string shape_str = header.substr(st, ed - st); + + while (!shape_str.empty()) { + // Read current number and get position of comma + size_t pos; + int dim = std::stoi(shape_str, &pos); + shape.push_back(dim); + + // Skip the comma and space and read the next number + if (pos + 2 <= shape_str.length()) + shape_str = shape_str.substr(pos + 2); + else { + shape_str = shape_str.substr(pos); + if (!shape_str.empty() && shape_str != " " && shape_str != ",") { + throw std::runtime_error( + "[load] Unknown error while parsing header in " + + in_stream->label()); + } + shape_str = ""; + } + } + + //////////////////////////////////////////////////////// + // Build primitive + + size_t offset = 8 + header_len_size + header.length(); + bool swap_endianness = read_is_big_endian != is_big_endian_(); + + if (col_contiguous) { + std::reverse(shape.begin(), shape.end()); + } + auto loaded_array = array( + shape, + dtype, + std::make_unique(to_stream(s), in_stream, offset, swap_endianness), + std::vector{}); + if (col_contiguous) { + loaded_array = transpose(loaded_array, s); + } + + return loaded_array; +} + +/** Load array from file in .npy format */ +array load(const std::string& file, StreamOrDevice s) { + return load(std::make_shared(file), s); +} + +} // namespace mlx::core diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp new file mode 100644 index 000000000..8a2ade91e --- /dev/null +++ b/mlx/primitives.cpp @@ -0,0 +1,2265 @@ +#include +#include +#include +#include +#include +#include + +#include "mlx/backend/common/utils.h" +#include "mlx/fft.h" +#include "mlx/ops.h" +#include "mlx/primitives.h" +#include "mlx/utils.h" + +namespace mlx::core { + +namespace { + +std::tuple vmap_binary_op( + const std::vector& inputs, + const std::vector& axes, + const Stream& stream) { + assert(inputs.size() == 2); + assert(axes.size() == 2); + + auto a = inputs[0]; + auto b = inputs[1]; + int ndim = std::max(a.ndim() + (axes[0] == -1), b.ndim() + (axes[1] == -1)); + + auto expand_dims = [stream, ndim](auto in) { + auto shape = in.shape(); + shape.insert(shape.begin(), ndim - shape.size(), 1); + return reshape(in, shape, stream); + }; + + int to_ax = (ndim - a.ndim()) + axes[0]; + int from_ax = (ndim - b.ndim()) + axes[1]; + a = expand_dims(a); + b = expand_dims(b); + + if (from_ax != to_ax) { + std::vector tdims(b.ndim()); + std::iota(tdims.begin(), tdims.end(), 0); + tdims.erase(tdims.begin() + from_ax); + tdims.insert(tdims.begin() + to_ax, from_ax); + b = transpose(b, tdims, stream); + } + return {a, b, to_ax}; +} + +} // namespace + +array Primitive::jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) { + throw std::invalid_argument("Primitive's jvp not implemented."); +}; + +std::vector Primitive::vjp( + const std::vector& primals, + const array& cotan, + const std::vector& argnums) { + throw std::invalid_argument("Primitive's vjp not implemented."); +}; + +std::pair Primitive::vmap( + const std::vector& inputs, + const std::vector& axes) { + throw std::invalid_argument("Primitive's vmap not implemented."); +}; + +std::vector Abs::vjp( + const std::vector& primals, + const array& cotan, + const std::vector& argnums) { + return {jvp(primals, {cotan}, argnums)}; +} + +array Abs::jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) { + assert(primals.size() == 1); + assert(argnums.size() == 1); + return multiply(tangents[0], sign(primals[0], stream()), stream()); +} + +std::pair Abs::vmap( + const std::vector& inputs, + const std::vector& axes) { + assert(inputs.size() == 1); + assert(axes.size() == 1); + return {abs(inputs[0], stream()), axes[0]}; +} + +array Add::jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) { + return tangents.size() > 1 ? add(tangents[0], tangents[1], stream()) + : tangents[0]; +} + +std::vector Add::vjp( + const std::vector& primals, + const array& cotan, + const std::vector& argnums) { + if (argnums.size() == 1) { + return {cotan}; + } else { + return {cotan, cotan}; + } +} + +std::pair Add::vmap( + const std::vector& inputs, + const std::vector& axes) { + auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream()); + return {add(a, b, stream()), to_ax}; +} + +bool Arange::is_equivalent(const Primitive& other) const { + const Arange& a_other = static_cast(other); + return ( + start_ == a_other.start_ && stop_ == a_other.stop_ && + step_ == a_other.step_); +} + +std::vector ArcCos::vjp( + const std::vector& primals, + const array& cotan, + const std::vector& argnums) { + return {jvp(primals, {cotan}, argnums)}; +} + +array ArcCos::jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) { + assert(primals.size() == 1); + assert(argnums.size() == 1); + array one = array(1., primals[0].dtype()); + array t = subtract(one, square(primals[0], stream()), stream()); + array denom = negative(rsqrt(t, stream()), stream()); + return multiply(tangents[0], denom, stream()); +} + +std::pair ArcCos::vmap( + const std::vector& inputs, + const std::vector& axes) { + assert(inputs.size() == 1); + assert(axes.size() == 1); + return {arccos(inputs[0], stream()), axes[0]}; +} + +std::vector ArcCosh::vjp( + const std::vector& primals, + const array& cotan, + const std::vector& argnums) { + return {jvp(primals, {cotan}, argnums)}; +} + +array ArcCosh::jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) { + assert(primals.size() == 1); + assert(argnums.size() == 1); + array one = array(1., primals[0].dtype()); + array t = subtract(square(primals[0], stream()), one, stream()); + return multiply(tangents[0], rsqrt(t, stream()), stream()); +} + +std::pair ArcCosh::vmap( + const std::vector& inputs, + const std::vector& axes) { + assert(inputs.size() == 1); + assert(axes.size() == 1); + return {arccosh(inputs[0], stream()), axes[0]}; +} + +std::vector ArcSin::vjp( + const std::vector& primals, + const array& cotan, + const std::vector& argnums) { + return {jvp(primals, {cotan}, argnums)}; +} + +array ArcSin::jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) { + assert(primals.size() == 1); + assert(argnums.size() == 1); + array one = array(1., primals[0].dtype()); + array t = subtract(one, square(primals[0], stream()), stream()); + return multiply(tangents[0], rsqrt(t, stream()), stream()); +} + +std::pair ArcSin::vmap( + const std::vector& inputs, + const std::vector& axes) { + assert(inputs.size() == 1); + assert(axes.size() == 1); + return {arcsin(inputs[0], stream()), axes[0]}; +} + +std::vector ArcSinh::vjp( + const std::vector& primals, + const array& cotan, + const std::vector& argnums) { + return {jvp(primals, {cotan}, argnums)}; +} + +array ArcSinh::jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) { + assert(primals.size() == 1); + assert(argnums.size() == 1); + array one = array(1., primals[0].dtype()); + array t = add(square(primals[0], stream()), one, stream()); + return multiply(tangents[0], rsqrt(t, stream()), stream()); +} + +std::pair ArcSinh::vmap( + const std::vector& inputs, + const std::vector& axes) { + assert(inputs.size() == 1); + assert(axes.size() == 1); + return {arcsinh(inputs[0], stream()), axes[0]}; +} + +std::vector ArcTan::vjp( + const std::vector& primals, + const array& cotan, + const std::vector& argnums) { + return {jvp(primals, {cotan}, argnums)}; +} + +array ArcTan::jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) { + assert(primals.size() == 1); + assert(argnums.size() == 1); + array one = array(1., primals[0].dtype()); + array t = add(one, square(primals[0], stream()), stream()); + return divide(tangents[0], t, stream()); +} + +std::pair ArcTan::vmap( + const std::vector& inputs, + const std::vector& axes) { + assert(inputs.size() == 1); + assert(axes.size() == 1); + return {arctan(inputs[0], stream()), axes[0]}; +} + +std::vector ArcTanh::vjp( + const std::vector& primals, + const array& cotan, + const std::vector& argnums) { + return {jvp(primals, {cotan}, argnums)}; +} + +array ArcTanh::jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) { + assert(primals.size() == 1); + assert(argnums.size() == 1); + array one = array(1., primals[0].dtype()); + array t = subtract(one, square(primals[0], stream()), stream()); + return divide(tangents[0], t, stream()); +} + +std::pair ArcTanh::vmap( + const std::vector& inputs, + const std::vector& axes) { + assert(inputs.size() == 1); + assert(axes.size() == 1); + return {arctanh(inputs[0], stream()), axes[0]}; +} + +std::pair ArgPartition::vmap( + const std::vector& inputs, + const std::vector& axes) { + assert(inputs.size() == 1); + assert(axes.size() == 1); + + return { + argpartition(inputs[0], axis_ + (axes[0] <= axis_), stream()), axes[0]}; +} + +bool ArgPartition::is_equivalent(const Primitive& other) const { + const ArgPartition& r_other = static_cast(other); + return axis_ == r_other.axis_ && kth_ == r_other.kth_; +} + +bool ArgReduce::is_equivalent(const Primitive& other) const { + const ArgReduce& r_other = static_cast(other); + return reduce_type_ == r_other.reduce_type_ && axis_ == r_other.axis_; +} + +std::pair ArgSort::vmap( + const std::vector& inputs, + const std::vector& axes) { + assert(inputs.size() == 1); + assert(axes.size() == 1); + + return {argsort(inputs[0], axis_ + (axes[0] <= axis_), stream()), axes[0]}; +} + +bool ArgSort::is_equivalent(const Primitive& other) const { + const ArgSort& r_other = static_cast(other); + return axis_ == r_other.axis_; +} + +std::vector AsType::vjp( + const std::vector& primals, + const array& cotan, + const std::vector& argnums) { + if (cotan.dtype() != dtype_) { + throw std::invalid_argument( + "[astype] Type of cotangent does not much primal output type."); + } + return {astype(cotan, primals[0].dtype(), stream())}; +} + +array AsType::jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) { + return astype(tangents[0], dtype_, stream()); +} + +std::pair AsType::vmap( + const std::vector& inputs, + const std::vector& axes) { + return {astype(inputs[0], dtype_, stream()), axes[0]}; +} + +bool AsType::is_equivalent(const Primitive& other) const { + const AsType& a_other = static_cast(other); + return dtype_ == a_other.dtype_; +} + +std::vector AsStrided::vjp( + const std::vector& primals, + const array& cotan, + const std::vector& argnums) { + assert(argnums.size() == 1); + + // Extract the sizes and cast them to ints + int grad_size = primals[0].size(); + int cotan_size = cotan.size(); + + // Make a flat container to hold the gradients + auto grad = zeros_like(primals[0], stream()); + grad = reshape(grad, {grad_size}, stream()); + + // Create the indices that map output to input + auto idx = arange(grad_size, stream()); + idx = as_strided(idx, shape_, strides_, offset_, stream()); + idx = reshape(idx, {cotan_size}, stream()); + + // Reshape the cotangent for use with scatter + auto flat_cotan = reshape(cotan, {cotan_size, 1}, stream()); + + // Finally accumulate the gradients and reshape them to look like the input + grad = scatter_add(grad, idx, flat_cotan, 0, stream()); + grad = reshape(grad, primals[0].shape(), stream()); + + return {grad}; +} + +array AsStrided::jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) { + assert(primals.size() == 1); + + return as_strided(tangents[0], shape_, strides_, offset_, stream()); +} + +bool AsStrided::is_equivalent(const Primitive& other) const { + const AsStrided& a_other = static_cast(other); + return shape_ == a_other.shape_ && strides_ == a_other.strides_ && + offset_ == a_other.offset_; +} + +std::vector Broadcast::vjp( + const std::vector& primals, + const array& cotan, + const std::vector& argnums) { + assert(argnums.size() == 1); + + // Reduce cotan to the shape of the primal + auto& shape = primals[0].shape(); + int diff = cotan.ndim() - shape.size(); + std::vector reduce_axes; + for (int i = 0; i < cotan.ndim(); ++i) { + if (i < diff) { + reduce_axes.push_back(i); + } else if (shape[i - diff] != cotan.shape(i)) { + reduce_axes.push_back(i); + } + } + return {reshape(sum(cotan, reduce_axes, true, stream()), shape, stream())}; +} + +array Broadcast::jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) { + assert(argnums.size() == 1); + return broadcast_to(tangents[0], shape_, stream()); +} + +std::pair Broadcast::vmap( + const std::vector& inputs, + const std::vector& axes) { + assert(inputs.size() == 1); + assert(axes.size() == 1); + auto ax = axes[0]; + auto in_shape = inputs[0].shape(); + int diff = shape_.size() - inputs[0].ndim() + 1; + assert(diff >= 0); + in_shape.insert(in_shape.begin(), diff, 1); + ax += diff; + shape_.insert(shape_.begin() + ax, in_shape[ax]); + auto in = reshape(inputs[0], in_shape, stream()); + return {broadcast_to(in, shape_, stream()), ax}; +} + +bool Broadcast::is_equivalent(const Primitive& other) const { + const Broadcast& b_other = static_cast(other); + return shape_ == b_other.shape_; +} + +std::vector Concatenate::vjp( + const std::vector& primals, + const array& cotan, + const std::vector& argnums) { + std::vector start(cotan.ndim(), 0); + std::vector stop = cotan.shape(); + + std::vector sizes; + sizes.push_back(0); + for (auto& p : primals) { + sizes.push_back(p.shape(axis_)); + } + std::partial_sum(sizes.cbegin(), sizes.cend(), sizes.begin()); + + std::vector grads; + for (auto i : argnums) { + start[axis_] = sizes[i]; + stop[axis_] = sizes[i + 1]; + grads.push_back(slice(cotan, start, stop, stream())); + } + return grads; +} + +array Concatenate::jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) { + std::vector argidx(argnums.size()); + std::iota(argidx.begin(), argidx.end(), 0); + std::sort(argidx.begin(), argidx.end(), [&argnums](int a, int b) { + return argnums[a] < argnums[b]; + }); + + std::vector vals; + for (int i = 0, j = 0; i < primals.size(); ++i) { + if (j < argnums.size() && argnums[argidx[j]] == i) { + vals.push_back(tangents[argidx[j++]]); + } else { + vals.push_back(zeros_like(primals[i], stream())); + } + } + return concatenate(vals, axis_, stream()); +} + +std::pair Concatenate::vmap( + const std::vector& inputs, + const std::vector& axes) { + throw std::runtime_error("Concatenate vmap is NYI."); +} + +bool Concatenate::is_equivalent(const Primitive& other) const { + const Concatenate& c_other = static_cast(other); + return axis_ == c_other.axis_; +} + +std::vector Convolution::vjp( + const std::vector& primals, + const array& cotan, + const std::vector& argnums) { + assert(primals.size() == 2); + std::vector grads; + + // Collect info + auto& in = primals[0]; + auto& wt = primals[1]; + + int N = in.shape(0); + int O = wt.shape(0); + + // Resolve Padded input shapes and strides + std::vector padding_starts(in.ndim(), 0); + std::vector padding_ends = in.shape(); + std::vector in_padded_shape = in.shape(); + + // padded shape + for (int i = 1; i < in.ndim() - 1; i++) { + in_padded_shape[i] += 2 * padding_[i - 1]; + padding_ends[i] += padding_[i - 1]; + padding_starts[i] += padding_[i - 1]; + } + + // padded strides (contiguous) + std::vector in_padded_strides(in.ndim(), 1); + for (int i = in.ndim() - 2; i >= 0; --i) { + in_padded_strides[i] = in_padded_strides[i + 1] * in_padded_shape[i + 1]; + } + + // Resolve strided patches + + // patches are shaped as + // (batch_dim, out_spatial_dims, weight_spatial_dims, in_channels) + std::vector patches_shape{ + cotan.shape().begin(), cotan.shape().end() - 1}; + patches_shape.insert( + patches_shape.end(), wt.shape().begin() + 1, wt.shape().end()); + + // Resolve patch strides + int n_spatial_dim = in.ndim() - 2; + std::vector patches_strides(patches_shape.size(), 1); + patches_strides[0] = in_padded_strides[0]; + for (int i = 1; i < n_spatial_dim + 1; i++) { + patches_strides[i] = in_padded_strides[i] * kernel_strides_[i - 1]; + } + for (int i = 1; i < in.ndim(); i++) { + patches_strides[n_spatial_dim + i] = in_padded_strides[i]; + } + + // Reshape cotan and weights for gemm + auto cotan_reshaped = reshape(cotan, {-1, O}, stream()); + auto weight_reshaped = reshape(wt, {O, -1}, stream()); + + for (int a : argnums) { + // Grads for input + if (a == 0) { + // Gemm with cotan to get patches + auto grad_patches = matmul(cotan_reshaped, weight_reshaped, stream()); + + // Prepare base grad array to accumulate on + int in_padded_size = in_padded_strides[0] * in_padded_shape[0]; + auto grad = zeros( + { + in_padded_size, + }, + in.dtype(), + stream()); + + // Create index map + int patches_size = grad_patches.size(); + auto idx = arange(in_padded_size, stream()); + idx = as_strided(idx, patches_shape, patches_strides, 0, stream()); + idx = reshape(idx, {patches_size}, stream()); + + // Flatten patches and scatter + auto flat_patches = reshape(grad_patches, {patches_size, 1}, stream()); + grad = scatter_add(grad, idx, flat_patches, 0, stream()); + + // Reshape and slice away padding + grad = reshape(grad, in_padded_shape, stream()); + grad = slice(grad, padding_starts, padding_ends, stream()); + + grads.push_back(grad); + } + // Grads for weight + else if (a == 1) { + // Make patches from in + std::vector padded_axes(in.ndim() - 2, 0); + std::iota(padded_axes.begin(), padded_axes.end(), 1); + auto in_padded = pad( + in, padded_axes, padding_, padding_, array(0, in.dtype()), stream()); + auto in_patches = + as_strided(in_padded, patches_shape, patches_strides, 0, stream()); + in_patches = reshape(in_patches, {cotan_reshaped.shape(0), -1}, stream()); + + auto grad = matmul( + transpose(cotan_reshaped, {1, 0}, stream()), in_patches, stream()); + grad = reshape(grad, wt.shape(), stream()); + grads.push_back(grad); + } + } + + return grads; +} + +bool Convolution::is_equivalent(const Primitive& other) const { + const Convolution& c_other = static_cast(other); + return padding_ == c_other.padding_ && + kernel_strides_ == c_other.kernel_strides_ && + kernel_dilation_ == c_other.kernel_dilation_ && + input_dilation_ == c_other.input_dilation_; +} + +std::vector Copy::vjp( + const std::vector& primals, + const array& cotan, + const std::vector& argnums) { + assert(primals.size() == 1); + assert(argnums.size() == 1); + return {cotan}; +} + +array Copy::jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) { + assert(primals.size() == 1); + assert(argnums.size() == 1); + return tangents[0]; +} + +std::pair Copy::vmap( + const std::vector& inputs, + const std::vector& axes) { + assert(inputs.size() == 1); + assert(axes.size() == 1); + return {copy(inputs[0], stream()), axes[0]}; +} + +std::vector Cos::vjp( + const std::vector& primals, + const array& cotan, + const std::vector& argnums) { + return {jvp(primals, {cotan}, argnums)}; +} + +array Cos::jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) { + assert(primals.size() == 1); + assert(argnums.size() == 1); + return multiply( + tangents[0], negative(sin(primals[0], stream()), stream()), stream()); +} + +std::pair Cos::vmap( + const std::vector& inputs, + const std::vector& axes) { + assert(inputs.size() == 1); + assert(axes.size() == 1); + return {cos(inputs[0], stream()), axes[0]}; +} + +std::vector Cosh::vjp( + const std::vector& primals, + const array& cotan, + const std::vector& argnums) { + return {jvp(primals, {cotan}, argnums)}; +} + +array Cosh::jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) { + assert(primals.size() == 1); + assert(argnums.size() == 1); + return multiply(tangents[0], sinh(primals[0], stream()), stream()); +} + +std::pair Cosh::vmap( + const std::vector& inputs, + const std::vector& axes) { + assert(inputs.size() == 1); + assert(axes.size() == 1); + return {cosh(inputs[0], stream()), axes[0]}; +} + +std::vector Divide::vjp( + const std::vector& primals, + const array& cotan, + const std::vector& argnums) { + std::vector vjps; + for (auto arg : argnums) { + if (arg == 0) { + vjps.push_back(divide(cotan, primals[1], stream())); + } else { + vjps.push_back(negative( + divide( + multiply(cotan, primals[0], stream()), + square(primals[1], stream()), + stream()), + stream())); + } + } + return vjps; +} + +array Divide::jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) { + auto jvp_fun = [&](int i) { + int arg = argnums[i]; + if (arg == 0) { + return divide(tangents[i], primals[1], stream()); + } else { + return negative( + divide( + multiply(tangents[i], primals[0], stream()), + square(primals[1], stream()), + stream()), + stream()); + } + }; + auto out = jvp_fun(0); + if (argnums.size() > 1) { + out = add(out, jvp_fun(1), stream()); + } + return out; +} + +std::pair Divide::vmap( + const std::vector& inputs, + const std::vector& axes) { + auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream()); + return {divide(a, b, stream()), to_ax}; +} + +std::pair Equal::vmap( + const std::vector& inputs, + const std::vector& axes) { + auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream()); + return {equal(a, b, stream()), axes[0]}; +} + +std::vector Equal::vjp( + const std::vector& primals, + const array& cotan, + const std::vector& argnums) { + std::vector vjps; + for (auto arg : argnums) { + vjps.push_back(zeros_like(primals[arg], stream())); + } + return vjps; +} + +array Equal::jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) { + auto shape = broadcast_shapes(primals[0].shape(), primals[1].shape()); + return zeros(shape, bool_, stream()); +} + +std::vector Erf::vjp( + const std::vector& primals, + const array& cotan, + const std::vector& argnums) { + return {jvp(primals, {cotan}, argnums)}; +} + +array Erf::jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) { + assert(primals.size() == 1); + assert(argnums.size() == 1); + auto dtype = primals[0].dtype(); + auto scale = multiply(array(M_2_SQRTPI, dtype), tangents[0], stream()); + return multiply( + scale, + exp(negative(square(primals[0], stream()), stream()), stream()), + stream()); +} + +std::pair Erf::vmap( + const std::vector& inputs, + const std::vector& axes) { + assert(inputs.size() == 1); + assert(axes.size() == 1); + return {erf(inputs[0], stream()), axes[0]}; +} + +std::vector ErfInv::vjp( + const std::vector& primals, + const array& cotan, + const std::vector& argnums) { + return {jvp(primals, {cotan}, argnums)}; +} + +array ErfInv::jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) { + assert(primals.size() == 1); + assert(argnums.size() == 1); + auto dtype = primals[0].dtype(); + auto scale = multiply(array(1.0 / M_2_SQRTPI, dtype), tangents[0], stream()); + return multiply( + scale, + exp(square(erfinv(primals[0], stream()), stream()), stream()), + stream()); +} + +std::pair ErfInv::vmap( + const std::vector& inputs, + const std::vector& axes) { + assert(inputs.size() == 1); + assert(axes.size() == 1); + return {erfinv(inputs[0], stream()), axes[0]}; +} + +std::vector Exp::vjp( + const std::vector& primals, + const array& cotan, + const std::vector& argnums) { + return {jvp(primals, {cotan}, argnums)}; +} + +array Exp::jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) { + assert(primals.size() == 1); + assert(argnums.size() == 1); + return multiply(tangents[0], exp(primals[0], stream()), stream()); +} + +std::pair Exp::vmap( + const std::vector& inputs, + const std::vector& axes) { + assert(inputs.size() == 1); + assert(axes.size() == 1); + return {exp(inputs[0], stream()), axes[0]}; +} + +bool FFT::is_equivalent(const Primitive& other) const { + const FFT& r_other = static_cast(other); + return axes_ == r_other.axes_ && inverse_ == r_other.inverse_ && + real_ == r_other.real_; +} + +std::pair FFT::vmap( + const std::vector& inputs, + const std::vector& axes) { + auto& in = inputs[0]; + int ax = axes[0]; + auto fft_axes = axes_; + auto out_shape = in.shape(); + for (auto& fft_ax : fft_axes) { + if (fft_ax >= ax) { + fft_ax++; + } + if (real_) { + auto n = out_shape[fft_ax]; + out_shape[fft_ax] = inverse_ ? 2 * (n - 1) : n / 2 + 1; + } + } + return { + array( + out_shape, + real_ && inverse_ ? float32 : complex64, + std::make_unique(stream(), fft_axes, inverse_, real_), + {in}), + ax}; +} + +std::vector FFT::vjp( + const std::vector& primals, + const array& cotan, + const std::vector& argnums) { + assert(primals.size() == 1); + assert(argnums.size() == 1); + auto& in = primals[0]; + std::vector axes(axes_.begin(), axes_.end()); + if (real_ && inverse_) { + auto out = fft::fftn(cotan, axes, stream()); + auto start = std::vector(out.ndim(), 0); + auto stop = in.shape(); + out = slice(out, start, stop, stream()); + auto mask_shape = out.shape(); + mask_shape[axes_.back()] -= 2; + auto mask = full(mask_shape, 2.0f, stream()); + auto pad_shape = out.shape(); + pad_shape[axes_.back()] = 1; + auto pad = full(pad_shape, 1.0f, stream()); + mask = concatenate({pad, mask, pad}, axes_.back(), stream()); + return {multiply(mask, out, stream())}; + } else if (real_) { + std::vector n; + for (auto ax : axes_) { + n.push_back(in.shape()[ax]); + } + return {astype(fft::fftn(cotan, n, axes, stream()), in.dtype(), stream())}; + } else if (inverse_) { + return {fft::ifftn(cotan, axes, stream())}; + } else { + return {fft::fftn(cotan, axes, stream())}; + } +} + +array FFT::jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) { + assert(primals.size() == 1); + assert(argnums.size() == 1); + auto& tan = tangents[0]; + if (real_ & inverse_) { + return fft::irfftn(tan, stream()); + } else if (real_) { + return fft::rfftn(tan, stream()); + } else if (inverse_) { + return fft::ifftn(tan, stream()); + } else { + return fft::fftn(tan, stream()); + } +} + +std::vector Full::vjp( + const std::vector& primals, + const array& cotan, + const std::vector& argnums) { + assert(primals.size() == 1); + assert(argnums.size() == 1); + return {multiply(cotan, primals[0], stream())}; +} + +array Full::jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) { + assert(primals.size() == 1); + assert(argnums.size() == 1); + return tangents[0]; +} + +std::pair Full::vmap( + const std::vector& inputs, + const std::vector& axes) { + assert(inputs.size() == 1); + assert(axes.size() == 1); + auto& in = inputs[0]; + auto out = + array(in.shape(), in.dtype(), std::make_unique(stream()), {in}); + return {out, axes[0]}; +} + +std::pair Gather::vmap( + const std::vector& inputs, + const std::vector& axes) { + throw std::runtime_error("Gather vmap is NYI, please change slices instead"); +} + +std::vector Gather::vjp( + const std::vector& primals, + const array& cotan, + const std::vector& argnums) { + if (argnums.size() > 1 || argnums[0] != 0) { + throw std::invalid_argument( + "[gather] Cannot calculate VJP with respect to indices."); + } + auto src = zeros_like(primals[0], stream()); + std::vector inds(primals.begin() + 1, primals.end()); + return {scatter_add(src, inds, cotan, axes_, stream())}; +} + +array Gather::jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) { + if (argnums.size() > 1 || argnums[0] != 0) { + throw std::invalid_argument( + "[gather] Cannot calculate JVP with respect to indices."); + } + std::vector inds(primals.begin() + 1, primals.end()); + return gather(tangents[0], inds, axes_, slice_sizes_, stream()); +} + +bool Gather::is_equivalent(const Primitive& other) const { + const Gather& g_other = static_cast(other); + return axes_ == g_other.axes_ && slice_sizes_ == g_other.slice_sizes_; +} + +std::pair Greater::vmap( + const std::vector& inputs, + const std::vector& axes) { + auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream()); + return {greater(a, b, stream()), axes[0]}; +} + +std::vector Greater::vjp( + const std::vector& primals, + const array& cotan, + const std::vector& argnums) { + std::vector vjps; + for (auto arg : argnums) { + vjps.push_back(zeros_like(primals[arg], stream())); + } + return vjps; +} + +array Greater::jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) { + auto shape = broadcast_shapes(primals[0].shape(), primals[1].shape()); + return zeros(shape, bool_, stream()); +} + +std::pair GreaterEqual::vmap( + const std::vector& inputs, + const std::vector& axes) { + auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream()); + return {greater_equal(a, b, stream()), axes[0]}; +} + +std::vector GreaterEqual::vjp( + const std::vector& primals, + const array& cotan, + const std::vector& argnums) { + std::vector vjps; + for (auto arg : argnums) { + vjps.push_back(zeros_like(primals[arg], stream())); + } + return vjps; +} + +array GreaterEqual::jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) { + auto shape = broadcast_shapes(primals[0].shape(), primals[1].shape()); + return zeros(shape, bool_, stream()); +} + +std::pair Less::vmap( + const std::vector& inputs, + const std::vector& axes) { + auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream()); + return {less(a, b, stream()), axes[0]}; +} + +std::vector Less::vjp( + const std::vector& primals, + const array& cotan, + const std::vector& argnums) { + std::vector vjps; + for (auto arg : argnums) { + vjps.push_back(zeros_like(primals[arg], stream())); + } + return vjps; +} + +array Less::jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) { + auto shape = broadcast_shapes(primals[0].shape(), primals[1].shape()); + return zeros(shape, bool_, stream()); +} + +std::pair LessEqual::vmap( + const std::vector& inputs, + const std::vector& axes) { + auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream()); + return {less_equal(a, b, stream()), axes[0]}; +} + +std::vector LessEqual::vjp( + const std::vector& primals, + const array& cotan, + const std::vector& argnums) { + std::vector vjps; + for (auto arg : argnums) { + vjps.push_back(zeros_like(primals[arg], stream())); + } + return vjps; +} + +array LessEqual::jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) { + auto shape = broadcast_shapes(primals[0].shape(), primals[1].shape()); + return zeros(shape, bool_, stream()); +} + +std::vector Log::vjp( + const std::vector& primals, + const array& cotan, + const std::vector& argnums) { + return {jvp(primals, {cotan}, argnums)}; +} + +array Log::jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) { + assert(primals.size() == 1); + assert(argnums.size() == 1); + auto out = divide(tangents[0], primals[0], stream()); + if (base_ != Base::e) { + auto scale = 1 / std::log(base_ == Base::ten ? 10.0f : 2.0f); + out = multiply(array(scale, out.dtype()), out, stream()); + } + return out; +} + +std::pair Log::vmap( + const std::vector& inputs, + const std::vector& axes) { + assert(inputs.size() == 1); + assert(axes.size() == 1); + auto& in = inputs[0]; + return { + array( + in.shape(), in.dtype(), std::make_unique(stream(), base_), {in}), + axes[0]}; +} + +std::vector Log1p::vjp( + const std::vector& primals, + const array& cotan, + const std::vector& argnums) { + return {jvp(primals, {cotan}, argnums)}; +} + +array Log1p::jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) { + assert(primals.size() == 1); + assert(argnums.size() == 1); + auto dtype = primals[0].dtype(); + return divide( + tangents[0], add(array(1.0f, dtype), primals[0], stream()), stream()); +} + +std::pair Log1p::vmap( + const std::vector& inputs, + const std::vector& axes) { + assert(inputs.size() == 1); + assert(axes.size() == 1); + return {log1p(inputs[0], stream()), axes[0]}; +} + +std::vector LogicalNot::vjp( + const std::vector& primals, + const array& cotan, + const std::vector& argnums) { + return {jvp(primals, {cotan}, argnums)}; +} + +array LogicalNot::jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) { + assert(primals.size() == 1); + assert(argnums.size() == 1); + return zeros_like(tangents[0], stream()); +} + +std::pair LogicalNot::vmap( + const std::vector& inputs, + const std::vector& axes) { + assert(inputs.size() == 1); + assert(axes.size() == 1); + return {logical_not(inputs[0], stream()), axes[0]}; +} + +std::vector LogAddExp::vjp( + const std::vector& primals, + const array& cotan, + const std::vector& argnums) { + auto a = primals[0]; + auto b = primals[1]; + auto s = sigmoid(subtract(a, b, stream()), stream()); + std::vector vjps; + for (auto arg : argnums) { + vjps.push_back(multiply( + cotan, + arg == 0 ? s : subtract(array(1.0f, s.dtype()), s, stream()), + stream())); + } + return vjps; +} + +array LogAddExp::jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) { + auto a = primals[0]; + auto b = primals[1]; + auto s = sigmoid(subtract(a, b, stream()), stream()); + auto jvp_fun = [&](int i) { + int arg = argnums[i]; + return multiply( + tangents[i], + arg == 0 ? s : subtract(array(1.0f, s.dtype()), s, stream()), + stream()); + }; + auto out = jvp_fun(0); + if (argnums.size() > 1) { + out = add(out, jvp_fun(1), stream()); + } + return out; +} + +std::pair LogAddExp::vmap( + const std::vector& inputs, + const std::vector& axes) { + auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream()); + return {logaddexp(a, b, stream()), to_ax}; +} + +std::vector Matmul::vjp( + const std::vector& primals, + const array& cotan, + const std::vector& argnums) { + std::vector vjps; + std::vector reorder(cotan.ndim()); + std::iota(reorder.begin(), reorder.end(), 0); + std::iter_swap(reorder.end() - 1, reorder.end() - 2); + for (auto arg : argnums) { + if (arg == 0) { + // M X N * (K X N).T -> M X K + vjps.push_back( + matmul(cotan, transpose(primals[1], reorder, stream()), stream())); + } else { + // (M X K).T * M X N -> K X N + vjps.push_back( + matmul(transpose(primals[0], reorder, stream()), cotan, stream())); + } + } + return vjps; +} + +std::pair Matmul::vmap( + const std::vector& inputs, + const std::vector& axes) { + return {array(1.0), 0}; +} + +std::vector Maximum::vjp( + const std::vector& primals, + const array& cotan, + const std::vector& argnums) { + auto& a = primals[0]; + auto& b = primals[1]; + std::vector vjps; + for (auto arg : argnums) { + auto mask = + (arg == 0) ? greater(a, b, stream()) : less_equal(a, b, stream()); + vjps.push_back(multiply(cotan, mask, stream())); + } + return vjps; +} + +array Maximum::jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) { + auto& a = primals[0]; + auto& b = primals[1]; + auto jvp_fun = [&](int i) { + int arg = argnums[i]; + auto mask = + (arg == 0) ? greater(a, b, stream()) : less_equal(a, b, stream()); + return multiply(tangents[i], mask, stream()); + }; + auto out = jvp_fun(0); + if (argnums.size() > 1) { + out = add(out, jvp_fun(1), stream()); + } + return out; +} + +std::pair Maximum::vmap( + const std::vector& inputs, + const std::vector& axes) { + auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream()); + return {maximum(a, b, stream()), to_ax}; +} + +std::vector Minimum::vjp( + const std::vector& primals, + const array& cotan, + const std::vector& argnums) { + auto& a = primals[0]; + auto& b = primals[1]; + std::vector vjps; + for (auto arg : argnums) { + auto mask = + (arg == 0) ? less(a, b, stream()) : greater_equal(a, b, stream()); + vjps.push_back(multiply(cotan, mask, stream())); + } + return vjps; +} + +array Minimum::jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) { + auto& a = primals[0]; + auto& b = primals[1]; + auto jvp_fun = [&](int i) { + int arg = argnums[i]; + auto mask = + (arg == 0) ? less(a, b, stream()) : greater_equal(a, b, stream()); + return multiply(tangents[i], mask, stream()); + }; + auto out = jvp_fun(0); + if (argnums.size() > 1) { + out = add(out, jvp_fun(1), stream()); + } + return out; +} + +std::pair Minimum::vmap( + const std::vector& inputs, + const std::vector& axes) { + auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream()); + return {minimum(a, b, stream()), to_ax}; +} + +array Multiply::jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) { + auto arg = argnums[0]; + auto jvp = multiply(tangents[0], primals[1 - arg], stream()); + if (argnums.size() > 1) { + arg = argnums[1]; + jvp = add(jvp, multiply(tangents[1], primals[1 - arg], stream()), stream()); + } + return jvp; +} + +std::vector Multiply::vjp( + const std::vector& primals, + const array& cotan, + const std::vector& argnums) { + std::vector vjps; + for (auto arg : argnums) { + vjps.push_back(multiply(primals[1 - arg], cotan, stream())); + } + return vjps; +} + +std::pair Multiply::vmap( + const std::vector& inputs, + const std::vector& axes) { + auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream()); + return {multiply(a, b, stream()), to_ax}; +} + +std::vector Negative::vjp( + const std::vector& primals, + const array& cotan, + const std::vector& argnums) { + return {jvp(primals, {cotan}, argnums)}; +} + +array Negative::jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) { + assert(primals.size() == 1); + assert(argnums.size() == 1); + return negative(tangents[0], stream()); +} + +std::pair Negative::vmap( + const std::vector& inputs, + const std::vector& axes) { + assert(inputs.size() == 1); + assert(axes.size() == 1); + return {negative(inputs[0], stream()), axes[0]}; +} + +std::pair NotEqual::vmap( + const std::vector& inputs, + const std::vector& axes) { + auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream()); + return {not_equal(a, b, stream()), axes[0]}; +} + +std::vector NotEqual::vjp( + const std::vector& primals, + const array& cotan, + const std::vector& argnums) { + std::vector vjps; + for (auto arg : argnums) { + vjps.push_back(zeros_like(primals[arg], stream())); + } + return vjps; +} + +array NotEqual::jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) { + auto shape = broadcast_shapes(primals[0].shape(), primals[1].shape()); + return zeros(shape, bool_, stream()); +} + +std::vector Pad::vjp( + const std::vector& primals, + const array& cotan, + const std::vector& argnums) { + assert(argnums.size() == 1 && argnums[0] == 0); + + std::vector start(cotan.ndim(), 0); + std::vector stop = cotan.shape(); + + for (auto i : axes_) { + start[i] = low_pad_size_[i]; + stop[i] -= high_pad_size_[i]; + } + + auto out = slice(cotan, start, stop, stream()); + + return {out}; +} + +array Pad::jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) { + assert(argnums.size() == 1 && argnums[0] == 0); + + return pad( + tangents[0], + axes_, + low_pad_size_, + high_pad_size_, + array(0, tangents[0].dtype()), + stream()); +} + +std::pair Pad::vmap( + const std::vector& inputs, + const std::vector& axes) { + throw std::runtime_error("Pad vmap is NYI."); +} + +bool Pad::is_equivalent(const Primitive& other) const { + const Pad& p_other = static_cast(other); + return ( + p_other.axes_ == axes_ && p_other.low_pad_size_ == low_pad_size_ && + p_other.high_pad_size_ == high_pad_size_); +} + +std::vector Partition::vjp( + const std::vector& primals, + const array& cotan, + const std::vector& argnums) { + return {jvp(primals, {cotan}, argnums)}; +} + +array Partition::jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) { + assert(primals.size() == 1); + assert(tangents.size() == 1); + auto sort_idx = argpartition(primals[0], kth_, axis_, stream()); + auto out = take_along_axis(tangents[0], sort_idx, axis_, stream()); + return out; +} + +std::pair Partition::vmap( + const std::vector& inputs, + const std::vector& axes) { + assert(inputs.size() == 1); + assert(axes.size() == 1); + + return {partition(inputs[0], axis_ + (axes[0] <= axis_), stream()), axes[0]}; +} + +bool Partition::is_equivalent(const Primitive& other) const { + const Partition& r_other = static_cast(other); + return axis_ == r_other.axis_ && kth_ == r_other.kth_; +} + +std::vector Power::vjp( + const std::vector& primals, + const array& cotan, + const std::vector& argnums) { + std::vector vjps; + for (auto arg : argnums) { + if (arg == 0) { + vjps.push_back(multiply( + power( + primals[0], + subtract(primals[1], array(1, primals[0].dtype()), stream()), + stream()), + primals[1], + stream())); + } else { + vjps.push_back(multiply( + log(primals[0], stream()), + power(primals[0], primals[1], stream()), + stream())); + } + vjps.back() = multiply(cotan, vjps.back(), stream()); + } + return vjps; +} + +array Power::jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) { + auto jvp = vjp(primals, tangents[0], {argnums[0]})[0]; + if (argnums.size() > 1) { + jvp = add(jvp, vjp(primals, tangents[1], {argnums[1]})[0], stream()); + } + return jvp; +} + +std::pair Power::vmap( + const std::vector& inputs, + const std::vector& axes) { + auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream()); + return {power(a, b, stream()), to_ax}; +} + +std::pair RandomBits::vmap( + const std::vector& inputs, + const std::vector& axes) { + assert(inputs.size() == 1); + assert(axes.size() == 1); + + // The last dimension of the key is always a key pair + auto key = inputs[0]; + auto kax = axes[0]; + if (kax == key.ndim() - 1) { + std::vector reorder(key.ndim()); + std::iota(reorder.begin(), reorder.end(), 0); + std::swap(reorder[kax], reorder[kax - 1]); + key = transpose(key, reorder, stream()); + kax--; + } + + auto shape = shape_; + shape.insert(shape.begin() + kax, key.shape()[kax]); + + auto get_dtype = [width = width_]() { + switch (width) { + case 1: + return uint8; + case 2: + return uint16; + default: + return uint32; + } + }; + + auto out = array( + shape, + get_dtype(), + std::make_unique(stream(), shape, width_), + {key}); + return {out, kax}; +} + +bool RandomBits::is_equivalent(const Primitive& other) const { + const RandomBits& r_other = static_cast(other); + return shape_ == r_other.shape_; +} + +std::pair Reshape::vmap( + const std::vector& inputs, + const std::vector& axes) { + // Transpose the input so that the vmap dim is first. + auto& in = inputs[0]; + auto ax = axes[0]; + std::vector reorder(in.ndim()); + std::iota(reorder.begin(), reorder.end(), 0); + reorder.erase(reorder.begin() + ax); + reorder.insert(reorder.begin(), ax); + // Insert the vmap dim into the shape at the beginning. + auto out = transpose(in, reorder, stream()); + shape_.insert(shape_.begin(), in.shape()[ax]); + // Reshape the transposed input to the new shape. + return {reshape(out, shape_, stream()), 0}; +} + +std::vector Reshape::vjp( + const std::vector& primals, + const array& cotan, + const std::vector& argnums) { + assert(primals.size() == 1); + assert(argnums.size() == 1); + assert(argnums[0] == 0); + return {reshape(cotan, primals[0].shape(), stream())}; +} + +array Reshape::jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) { + assert(primals.size() == 1); + assert(argnums.size() == 1); + assert(argnums[0] == 0); + return reshape(tangents[0], shape_, stream()); +} + +bool Reshape::is_equivalent(const Primitive& other) const { + const Reshape& r_other = static_cast(other); + return shape_ == r_other.shape_; +} + +std::vector Reduce::vjp( + const std::vector& primals, + const array& cotan, + const std::vector& argnums) { + auto in = primals[0]; + + std::vector shape = in.shape(); + for (auto ax : axes_) { + shape[ax] = 1; + } + + if (reduce_type_ == Reduce::Sum) { + return { + broadcast_to(reshape(cotan, shape, stream()), in.shape(), stream())}; + } else if (reduce_type_ == Reduce::Prod) { + auto s = stream(); + auto prod_grad_single_axis = + [&s](const array& x, const array& cotan, int axis) { + auto p1 = cumprod(x, axis, /*reverse=*/false, /*inclusive=*/false, s); + auto p2 = cumprod(x, axis, /*reverse=*/true, /*inclusive=*/false, s); + auto exclusive_prod = multiply(p1, p2, s); + return multiply(exclusive_prod, cotan, s); + }; + + // To compute a numerically stable gradient for prod we need an exclusive + // product of all elements in axes_ . To achieve that we move axes_ to the + // last dim and perform two exclusive cumprods. Afterwards we move + // everything back to the original axes. + if (axes_.size() > 1) { + std::vector transpose_to; + std::vector transpose_back; + std::vector shape_flat; + { + // Find the transpose needed to move axes_ to the back and the shape + // except the reduced over axes. + int j = 0; + for (int i = 0; i < in.ndim(); i++) { + if (j < axes_.size() && axes_[j] == i) { + j++; + } else { + transpose_to.push_back(i); + shape_flat.push_back(in.shape(i)); + } + } + for (auto ax : axes_) { + transpose_to.push_back(ax); + } + shape_flat.push_back(-1); + transpose_back.resize(transpose_to.size()); + for (int i = 0; i < transpose_to.size(); i++) { + transpose_back[transpose_to[i]] = i; + } + } + + // Move axes to the back + auto x = transpose(in, transpose_to, s); + // Keep the shape in order to reshape back to the original + auto shape_to = x.shape(); + + // Flatten and compute the gradient + x = reshape(x, shape_flat, stream()); + auto grad = prod_grad_single_axis(x, reshape(cotan, shape_flat, s), -1); + + // Reshape and transpose to the original shape + grad = reshape(grad, shape_to, s); + grad = transpose(grad, transpose_back, s); + + return {grad}; + } else { + return {prod_grad_single_axis(in, reshape(cotan, shape, s), axes_[0])}; + } + + } else if (reduce_type_ == Reduce::Min || reduce_type_ == Reduce::Max) { + array (*op)(const array&, const std::vector&, bool, StreamOrDevice); + + if (reduce_type_ == Reduce::Min) { + op = min; + } else { + op = max; + } + + auto out = op(in, axes_, true, stream()); + auto mask = equal(in, out, stream()); + auto normalizer = sum(mask, axes_, true, stream()); + auto cotan_reshape = reshape(cotan, shape, stream()); + cotan_reshape = divide(cotan_reshape, normalizer, stream()); + return {multiply(cotan_reshape, mask, stream())}; + } + + else { + throw std::runtime_error("Reduce type VJP not yet implemented."); + } +} + +std::pair Reduce::vmap( + const std::vector& inputs, + const std::vector& axes) { + // TODO implement + return {array(1.0f), axes[0]}; +} + +bool Reduce::is_equivalent(const Primitive& other) const { + const Reduce& r_other = static_cast(other); + return reduce_type_ == r_other.reduce_type_ && axes_ == r_other.axes_; +} + +std::pair Scan::vmap( + const std::vector& inputs, + const std::vector& axes) { + auto& in = inputs[0]; + auto axis = axes[0]; + + auto out_dtype = + (in.dtype() == bool_ && reduce_type_ == Scan::Sum) ? int32 : in.dtype(); + return { + array( + in.shape(), + out_dtype, + std::make_unique( + stream(), + reduce_type_, + axis_ + (axis <= axis_), + reverse_, + inclusive_), + {in}), + axis}; +} + +std::vector Scan::vjp( + const std::vector& primals, + const array& cotan, + const std::vector& argnums) { + assert(primals.size() == 1); + assert(argnums[0] == 0); + + if (reduce_type_ == Scan::Sum) { + return {cumsum(cotan, axis_, !reverse_, inclusive_, stream())}; + } else if (reduce_type_ == Scan::Prod) { + // TODO: Make it numerically stable when we introduce where() + auto prod = cumprod(primals[0], axis_, reverse_, inclusive_, stream()); + auto partial_grads = multiply(prod, cotan, stream()); + auto accum_grads = + cumsum(partial_grads, axis_, !reverse_, inclusive_, stream()); + return {divide(accum_grads, primals[0], stream())}; + } else { + // Can probably be implemented by equals and then cummax to make the mask + throw std::runtime_error("VJP is not implemented for cumulative min/max"); + } +} + +array Scan::jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) { + assert(tangents.size() == 1); + assert(argnums[0] == 0); + + if (reduce_type_ == Scan::Sum) { + return cumsum(tangents[0], axis_, reverse_, inclusive_, stream()); + } else { + throw std::runtime_error( + "JVP is not implemented for cumulative prod/min/max"); + } +} + +bool Scan::is_equivalent(const Primitive& other) const { + const Scan& s_other = static_cast(other); + return ( + reduce_type_ == s_other.reduce_type_ && axis_ == s_other.axis_ && + reverse_ == s_other.reverse_ && inclusive_ == s_other.inclusive_); +} + +bool Scatter::is_equivalent(const Primitive& other) const { + const Scatter& s_other = static_cast(other); + return reduce_type_ == s_other.reduce_type_ && axes_ == s_other.axes_; +} + +std::vector Sigmoid::vjp( + const std::vector& primals, + const array& cotan, + const std::vector& argnums) { + return {jvp(primals, {cotan}, argnums)}; +} + +array Sigmoid::jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) { + assert(primals.size() == 1); + assert(argnums.size() == 1); + auto s = sigmoid(primals[0], stream()); + auto sprime = + multiply(s, subtract(array(1.0f, s.dtype()), s, stream()), stream()); + return multiply(tangents[0], sprime, stream()); +} + +std::pair Sigmoid::vmap( + const std::vector& inputs, + const std::vector& axes) { + assert(inputs.size() == 1); + assert(axes.size() == 1); + return {sigmoid(inputs[0], stream()), axes[0]}; +} + +std::vector Sign::vjp( + const std::vector& primals, + const array& cotan, + const std::vector& argnums) { + return {jvp(primals, {cotan}, argnums)}; +} + +array Sign::jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) { + assert(primals.size() == 1); + assert(argnums.size() == 1); + return zeros(primals[0].shape(), primals[0].dtype(), stream()); +} + +std::pair Sign::vmap( + const std::vector& inputs, + const std::vector& axes) { + assert(inputs.size() == 1); + assert(axes.size() == 1); + return {sign(inputs[0], stream()), axes[0]}; +} + +std::vector Sin::vjp( + const std::vector& primals, + const array& cotan, + const std::vector& argnums) { + return {jvp(primals, {cotan}, argnums)}; +} + +array Sin::jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) { + assert(primals.size() == 1); + assert(argnums.size() == 1); + return multiply(tangents[0], cos(primals[0], stream()), stream()); +} + +std::pair Sin::vmap( + const std::vector& inputs, + const std::vector& axes) { + assert(inputs.size() == 1); + assert(axes.size() == 1); + return {sin(inputs[0], stream()), axes[0]}; +} + +std::vector Sinh::vjp( + const std::vector& primals, + const array& cotan, + const std::vector& argnums) { + return {jvp(primals, {cotan}, argnums)}; +} + +array Sinh::jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) { + assert(primals.size() == 1); + assert(argnums.size() == 1); + return multiply(tangents[0], cosh(primals[0], stream()), stream()); +} + +std::pair Sinh::vmap( + const std::vector& inputs, + const std::vector& axes) { + assert(inputs.size() == 1); + assert(axes.size() == 1); + return {sinh(inputs[0], stream()), axes[0]}; +} + +std::pair Slice::vmap( + const std::vector& inputs, + const std::vector& axes) { + // TODO implement + return {array(1.0f), axes[0]}; +} + +std::vector Slice::vjp( + const std::vector& primals, + const array& cotan, + const std::vector& argnums) { + // Check inputs + assert(primals.size() == 1); + + std::vector inds; + std::vector ind_axes; + std::vector single_inds; + std::vector single_ind_axes; + for (int i = 0; i < start_indices_.size(); ++i) { + auto start = start_indices_[i]; + auto end = end_indices_[i]; + auto stride = strides_[i]; + if (start == 0 && stride == 1) { + continue; + } + if (stride == 1) { + single_inds.push_back(array(start)); + single_ind_axes.push_back(i); + } else { + inds.push_back(arange(start, end, stride, stream())); + ind_axes.push_back(i); + } + } + + // Transpose and reshape cotan + auto cotan_ = cotan; + if (!ind_axes.empty()) { + std::vector cotan_shape; + for (auto ax : ind_axes) { + cotan_shape.push_back(cotan.shape(ax)); + } + std::vector cotan_axes(ind_axes); + for (int j = 0, i = 0; i < cotan.ndim(); ++i) { + if (j < ind_axes.size() && ind_axes[j] == i) { + cotan_shape.push_back(1); + j++; + } else { + cotan_shape.push_back(cotan.shape(i)); + cotan_axes.push_back(i); + } + } + cotan_ = + reshape(transpose(cotan_, cotan_axes, stream()), cotan_shape, stream()); + } + + // Make indices broadcastable + std::vector inds_shape(inds.size(), 1); + for (int i = 0; i < inds.size(); ++i) { + inds_shape[i] = inds[i].size(); + inds[i] = reshape(inds[i], inds_shape, stream()); + inds_shape[i] = 1; + } + + // Concatenate all the indices and axes + inds.insert(inds.end(), single_inds.begin(), single_inds.end()); + ind_axes.insert( + ind_axes.end(), single_ind_axes.begin(), single_ind_axes.end()); + + return {scatter_add( + zeros_like(primals[0], stream()), inds, cotan_, ind_axes, stream())}; +} + +array Slice::jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) { + // Check inputs + assert(primals.size() == 1); + return slice(tangents[0], start_indices_, end_indices_, strides_, stream()); +} + +bool Slice::is_equivalent(const Primitive& other) const { + const Slice& s_other = static_cast(other); + return ( + start_indices_ == s_other.start_indices_ && + end_indices_ == s_other.end_indices_ && strides_ == s_other.strides_); +} + +std::pair Softmax::vmap( + const std::vector& inputs, + const std::vector& axes) { + assert(inputs.size() == 1); + assert(axes.size() == 1); + + std::vector softmax_axes; + + // We are vectorizing over an axis other than the last one so keep the + // softmax axis unchanged + if (axes[0] < inputs[0].ndim() - 1) { + softmax_axes.push_back(-1); + } else { + softmax_axes.push_back(-2); + } + return {softmax(inputs[0], softmax_axes, stream()), axes[0]}; +} + +std::vector Softmax::vjp( + const std::vector& primals, + const array& cotan, + const std::vector& argnums) { + return {jvp(primals, {cotan}, argnums)}; +} + +array Softmax::jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) { + assert(primals.size() == 1); + assert(tangents.size() == 1); + auto s = softmax(primals[0], std::vector{-1}, stream()); + auto sv = multiply(s, tangents[0], stream()); + return subtract( + sv, multiply(s, sum(sv, std::vector{-1}, true, stream()), stream())); +} + +std::pair Sort::vmap( + const std::vector& inputs, + const std::vector& axes) { + assert(inputs.size() == 1); + assert(axes.size() == 1); + + return {sort(inputs[0], axis_ + (axes[0] <= axis_), stream()), axes[0]}; +} + +std::vector Sort::vjp( + const std::vector& primals, + const array& cotan, + const std::vector& argnums) { + return {jvp(primals, {cotan}, argnums)}; +} + +array Sort::jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) { + assert(primals.size() == 1); + assert(tangents.size() == 1); + auto sort_idx = argsort(primals[0], axis_, stream()); + auto out = take_along_axis(tangents[0], sort_idx, axis_, stream()); + return out; +} + +bool Sort::is_equivalent(const Primitive& other) const { + const Sort& r_other = static_cast(other); + return axis_ == r_other.axis_; +} + +std::vector Square::vjp( + const std::vector& primals, + const array& cotan, + const std::vector& argnums) { + return {jvp(primals, {cotan}, argnums)}; +} + +array Square::jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) { + assert(primals.size() == 1); + assert(tangents.size() == 1); + return multiply( + primals[0], + multiply(array(2, primals[0].dtype()), tangents[0], stream()), + stream()); +} + +std::pair Square::vmap( + const std::vector& inputs, + const std::vector& axes) { + assert(inputs.size() == 1); + assert(axes.size() == 1); + return {square(inputs[0], stream()), axes[0]}; +} + +std::vector Sqrt::vjp( + const std::vector& primals, + const array& cotan, + const std::vector& argnums) { + return {jvp(primals, {cotan}, argnums)}; +} + +array Sqrt::jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) { + assert(primals.size() == 1); + assert(tangents.size() == 1); + auto dtype = primals[0].dtype(); + if (recip_) { + auto one_over_x_root_x = + divide(rsqrt(primals[0], stream()), primals[0], stream()); + return multiply( + multiply(array(-0.5, dtype), tangents[0], stream()), + one_over_x_root_x, + stream()); + } + return divide( + multiply(array(0.5, dtype), tangents[0], stream()), + sqrt(primals[0], stream()), + stream()); +} + +std::pair Sqrt::vmap( + const std::vector& inputs, + const std::vector& axes) { + assert(inputs.size() == 1); + assert(axes.size() == 1); + if (recip_) + return {rsqrt(inputs[0], stream()), axes[0]}; + + return {sqrt(inputs[0], stream()), axes[0]}; +} + +bool Sqrt::is_equivalent(const Primitive& other) const { + const Sqrt& s_other = static_cast(other); + return recip_ == s_other.recip_; +} + +std::pair StopGradient::vmap( + const std::vector& inputs, + const std::vector& axes) { + return {inputs[0], axes[0]}; +}; + +std::vector Subtract::vjp( + const std::vector& primals, + const array& cotan, + const std::vector& argnums) { + std::vector vjps; + for (auto arg : argnums) { + auto vjp = cotan; + if (arg == 1) { + vjp = negative(vjp, stream()); + } + vjps.push_back(vjp); + } + return vjps; +} + +array Subtract::jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) { + auto jvp_fun = [&](int i) { + int arg = argnums[i]; + return arg == 1 ? negative(tangents[i], stream()) : tangents[i]; + }; + auto out = jvp_fun(0); + if (argnums.size() > 1) { + out = add(out, jvp_fun(1), stream()); + } + return out; +} + +std::pair Subtract::vmap( + const std::vector& inputs, + const std::vector& axes) { + auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream()); + return {subtract(a, b, stream()), to_ax}; +} + +std::vector Tan::vjp( + const std::vector& primals, + const array& cotan, + const std::vector& argnums) { + return {jvp(primals, {cotan}, argnums)}; +} + +array Tan::jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) { + assert(primals.size() == 1); + assert(argnums.size() == 1); + array cos_sq = square(cos(primals[0], stream()), stream()); + return divide(tangents[0], cos_sq, stream()); +} + +std::pair Tan::vmap( + const std::vector& inputs, + const std::vector& axes) { + assert(inputs.size() == 1); + assert(axes.size() == 1); + return {tan(inputs[0], stream()), axes[0]}; +} + +std::vector Tanh::vjp( + const std::vector& primals, + const array& cotan, + const std::vector& argnums) { + return {jvp(primals, {cotan}, argnums)}; +} + +array Tanh::jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) { + assert(primals.size() == 1); + assert(argnums.size() == 1); + array cosh_sq = square(cosh(primals[0], stream()), stream()); + return divide(tangents[0], cosh_sq, stream()); +} + +std::pair Tanh::vmap( + const std::vector& inputs, + const std::vector& axes) { + assert(inputs.size() == 1); + assert(axes.size() == 1); + return {tanh(inputs[0], stream()), axes[0]}; +} + +std::vector Transpose::vjp( + const std::vector& primals, + const array& cotan, + const std::vector& argnums) { + assert(primals.size() == 1); + assert(argnums.size() == 1); + std::vector iaxes(axes_.size()); + for (int i = 0; i < axes_.size(); ++i) { + iaxes[axes_[i]] = i; + } + return {transpose(cotan, iaxes, stream())}; +} + +array Transpose::jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) { + assert(primals.size() == 1); + assert(tangents.size() == 1); + return transpose(tangents[0], axes_, stream()); +} + +std::pair Transpose::vmap( + const std::vector& inputs, + const std::vector& axes) { + assert(inputs.size() == 1); + assert(axes.size() == 1); + auto vdim = axes[0]; + for (auto& dim : axes_) { + if (dim >= vdim) { + dim++; + } + } + axes_.insert(axes_.begin() + vdim, vdim); + return {transpose(inputs[0], axes_, stream()), vdim}; +} + +bool Transpose::is_equivalent(const Primitive& other) const { + const Transpose& t_other = static_cast(other); + return axes_ == t_other.axes_; +} + +} // namespace mlx::core diff --git a/mlx/transforms.cpp b/mlx/transforms.cpp new file mode 100644 index 000000000..89e0c42a9 --- /dev/null +++ b/mlx/transforms.cpp @@ -0,0 +1,778 @@ +#include +#include +#include +#include +#include +#include +#include +#include + +#include "mlx/backend/metal/metal.h" +#include "mlx/ops.h" +#include "mlx/primitives.h" +#include "mlx/scheduler.h" +#include "mlx/transforms.h" +#include "mlx/transforms_impl.h" +#include "mlx/utils.h" + +namespace mlx::core { + +void simplify(const std::vector& outputs) { + std::function recurse; + std::queue tape; + std::unordered_set cache; + std::unordered_map>> + parents_map; + + // Helpers to identify identical scalars + std::map, array> scalars; + auto is_scalar = [](const array& a) { + return a.is_evaled() && a.ndim() == 0; + }; + auto get_scalar_rep = [](const array& a) { + uint64_t v = 0; + int dtype; + switch (a.dtype().size) { + case 1: + v = *a.data(); + break; + case 4: + v = *a.data(); + break; + case 8: + v = *a.data(); + break; + } + return std::make_pair(v, a.dtype().val); + }; + + // DFS the graph to log the parents + recurse = [&](const array& a) { + auto id = a.id(); + if (cache.find(id) != cache.end()) { + return; + } + for (int i = 0; i < a.inputs().size(); i++) { + auto& in = a.inputs()[i]; + parents_map[in.id()].push_back({a, i}); + recurse(in); + } + cache.insert(id); + tape.push(a); + if (is_scalar(a)) { + scalars.insert({get_scalar_rep(a), a}); + } + }; + for (auto& a : outputs) { + recurse(a); + } + + // Helper that fuses two arrays in the graph by setting the parents of the + // source to point to the destination + auto fuse = [&](array& dst, array& src) { + auto src_parents = parents_map.find(src.id()); + if (src_parents == parents_map.end()) { + return; + } + + auto& pairs = parents_map[dst.id()]; + for (auto& parent : src_parents->second) { + parent.first.editable_inputs()[parent.second] = dst; + pairs.push_back(parent); + } + }; + + // Walk the graph + cache.clear(); + + // Depth-1 array equivalence check. + auto array_equivalent = [](const array& a, const array& b) { + if (!a.has_primitive() || !b.has_primitive()) { + return false; + } + const auto& pa = a.primitive(); + const auto& pb = b.primitive(); + if (typeid(pa) != typeid(pb)) { + return false; + } + + if (a.inputs().size() != b.inputs().size()) { + return false; + } + + for (int i = 0; i < a.inputs().size(); i++) { + if (a.inputs()[i].id() != b.inputs()[i].id()) { + return false; + } + } + + return pa.is_equivalent(pb); + }; + + while (!tape.empty()) { + auto arr = std::move(tape.front()); + tape.pop(); + + if (cache.find(arr.id()) != cache.end()) { + continue; + } + + // Check if we can fuse scalars + if (is_scalar(arr)) { + auto scalar = scalars.find(get_scalar_rep(arr)); + if (scalar->second.id() != arr.id()) { + fuse(scalar->second, arr); + arr = scalar->second; + } + } + + // Check if we can fuse the parents of this array + auto parents = parents_map.find(arr.id()); + if (parents != parents_map.end()) { + std::vector mask(parents->second.size(), false); + auto N = parents->second.size(); + for (int i = 0; i < N; i++) { + if (mask[i]) { + continue; + } + for (int j = i + 1; j < N; j++) { + if (mask[j]) { + continue; + } + auto& src = parents->second[j].first; + auto& dst = parents->second[i].first; + if (src.id() != dst.id() && array_equivalent(src, dst)) { + cache.insert(src.id()); + fuse(dst, src); + mask[j] = true; + } + } + } + } + } +} + +void eval(const std::vector& outputs, bool retain_graph /* = false */) { + if (!retain_graph) { + for (auto& out : outputs) { + if (out.has_primitive() && out.is_tracer()) { + throw std::invalid_argument( + "[eval] Illegal to eval an array during " + "function transform without graph retention."); + } + } + } + std::function recurse; + std::queue tape; + std::unordered_set cache; + std::unordered_map> deps; + + recurse = [&](const array& a) { + auto id = a.id(); + if (cache.find(id) != cache.end()) { + return; + } + for (auto in : a.inputs()) { + recurse(in); + // If one of the inputs is being computed on a different + // stream, we need to manage the dependency. + if (!in.is_evaled()) { + if (a.primitive().stream() != in.primitive().stream()) { + deps.insert({in.id(), std::shared_future{}}); + } + } + } + cache.insert(id); + if (!a.is_evaled() || (!retain_graph && a.has_primitive())) { + if (!a.has_primitive()) { + throw std::invalid_argument( + "[eval] Attempting to eval an array without a primitive."); + } + tape.push(a); + } + }; + + for (auto& arr : outputs) { + if (!arr.is_evaled() || (!retain_graph && arr.has_primitive())) { + recurse(arr); + // Insert a dependency for every output to synchronize + // with at the end. + if (!arr.is_evaled()) { + deps.insert({arr.id(), std::shared_future{}}); + } + } + } + + while (!tape.empty()) { + auto arr = std::move(tape.front()); + tape.pop(); + if (arr.is_evaled()) { + if (!retain_graph && arr.has_primitive()) { + arr.detach(); + } + continue; + } + + auto stream = arr.primitive().stream(); + std::vector> arr_deps; + for (auto& in : arr.inputs()) { + if (auto it = deps.find(in.id()); it != deps.end()) { + arr_deps.push_back(it->second); + } + } + std::shared_ptr> p{nullptr}; + if (auto it = deps.find(arr.id()); it != deps.end()) { + p = std::make_unique>(); + it->second = p->get_future().share(); + } + + if (arr.primitive().device() == Device::gpu) { + if (!metal::is_available()) { + throw std::runtime_error("Metal GPU is not available."); + } + scheduler::enqueue( + stream, + metal::make_task( + arr, std::move(arr_deps), std::move(p), retain_graph)); + } else { + auto task = [retain_graph, + arr, + stream, + arr_deps = std::move(arr_deps), + p = std::move(p)]() mutable { + for (auto& d : arr_deps) { + d.wait(); + } + scheduler::notify_new_task(stream); + arr.primitive().eval_cpu(arr.inputs(), arr); + if (!retain_graph) { + arr.detach(); + } + if (p) { + p->set_value(); + } + scheduler::notify_task_completion(stream); + }; + scheduler::enqueue(stream, std::move(task)); + } + } + for (auto& arr : outputs) { + if (auto it = deps.find(arr.id()); it != deps.end()) { + it->second.wait(); + } + } +} + +std::pair, std::vector> vjp( + const std::function(const std::vector&)>& fun, + const std::vector& primals, + const std::vector& cotans) { + // Make tracers from given primals + std::vector primals_; + for (auto& p : primals) { + auto s = p.has_primitive() ? p.primitive().stream() + : default_stream(default_device()); + primals_.push_back(copy(p, s)); // Does not do a deep copy + primals_.back().set_tracer(true); + } + + // Pass tracer primals through the function + // Any variables that depend on the primals are marked as tracers + auto outputs = fun(primals_); + + // Map outputs to passed cotans while ignoring the outputs + // that have stop_gradient called on them + int cotan_index = 0; + std::vector> output_cotan_pairs; + for (int i = 0; i < outputs.size(); ++i) { + auto& out = outputs[i]; + if (out.has_primitive()) { + if (auto& p = out.primitive(); typeid(p) == typeid(StopGradient)) { + continue; + } + } + if (cotan_index >= cotans.size()) { + throw std::invalid_argument( + "[vjp] Number of outputs with gradient does not match number of cotangents."); + } + if (out.shape() != cotans[cotan_index].shape()) { + throw std::invalid_argument( + "[vjp] Output shape does not match shape of cotangent."); + } + output_cotan_pairs.emplace_back(i, cotan_index++); + } + + // Topologically sort the compute graph, record outputs + // in the tape if a gradient is needed. + std::unordered_set cache; + std::unordered_set calc_grad; + for (auto& primal : primals_) { + primal.set_tracer(false); + calc_grad.insert(primal.id()); + cache.insert(primal.id()); + } + + std::vector tape; + + std::function recurse; + recurse = [&](auto& a) { + auto id = a.id(); + a.set_tracer(false); + + // Check if visited and add to cache if not + if (auto inserted = cache.insert(id); !inserted.second) { + return; + } + + for (auto& input : a.editable_inputs()) { + recurse(input); + } + + // Stop grad + if (a.has_primitive() && typeid(a.primitive()) == typeid(StopGradient)) { + return; + } + + // Calculate gradient if any inputs require gradient + for (auto& input : a.inputs()) { + if (calc_grad.find(input.id()) != calc_grad.end()) { + tape.push_back(a); + calc_grad.insert(id); + break; + } + } + }; + + for (auto& out : outputs) { + recurse(out); + } + + // Run the tape backwards, computing vector-jacobian + // products for each primitive + std::unordered_map cotan_map; + for (auto [out_idx, cotan_idx] : output_cotan_pairs) { + cotan_map.insert({outputs[out_idx].id(), cotans[cotan_idx]}); + } + for (auto it = tape.rbegin(); it != tape.rend(); ++it) { + auto& a = *it; + + // Get the arguments whose gradients are needed + std::vector argnums; + for (int i = 0; i < a.inputs().size(); ++i) { + if (calc_grad.find(a.inputs()[i].id()) != calc_grad.end()) { + argnums.push_back(i); + } + } + + auto cotan_it = cotan_map.find(a.id()); + if (cotan_it == cotan_map.end()) { + continue; + } + + auto cotan = cotan_map.extract(cotan_it).mapped(); + auto vjps = a.primitive().vjp(a.inputs(), cotan, argnums); + auto s = a.primitive().stream(); + // Accumulate the vector-jacobian products for each input + for (int i = 0; i < argnums.size(); ++i) { + auto in_id = a.inputs()[argnums[i]].id(); + if (auto cotan_it = cotan_map.find(in_id); cotan_it != cotan_map.end()) { + cotan_it->second = add(cotan_it->second, vjps[i], s); + } else { + cotan_map.insert({in_id, vjps[i]}); + } + } + } + + std::vector vjps; + for (auto& primal : primals_) { + if (auto cotan_it = cotan_map.find(primal.id()); + cotan_it != cotan_map.end()) { + vjps.push_back(cotan_it->second); + } else { + auto s = primal.has_primitive() ? primal.primitive().stream() + : default_stream(default_device()); + vjps.push_back(zeros_like(primal, s)); + } + } + return {outputs, vjps}; +} + +std::pair vjp( + const std::function& fun, + const array& primal, + const array& cotan) { + auto vec_fun = [fun](const std::vector& inputs) { + return std::vector{fun(inputs[0])}; + }; + auto [outputs, vjps] = vjp(vec_fun, {primal}, {cotan}); + return {outputs[0], vjps[0]}; +} + +std::pair, std::vector> jvp( + const std::function(const std::vector&)>& fun, + const std::vector& primals, + const std::vector& tangents) { + if (primals.size() != tangents.size()) { + throw std::invalid_argument( + "[jvp] Number of inputs does not match number of tangents."); + } + for (int i = 0; i < primals.size(); ++i) { + if (primals[i].shape() != tangents[i].shape()) { + throw std::invalid_argument( + "[jvp] Input shape does not match shape of tangent."); + } + } + + std::vector primals_; + for (auto& p : primals) { + auto s = p.has_primitive() ? p.primitive().stream() + : default_stream(default_device()); + primals_.push_back(copy(p, s)); // Does not do a deep copy + primals_.back().set_tracer(true); + } + auto outputs = fun(primals_); + + // Topologically sort the compute graph, record outputs + // in the tape if a gradient is needed. + std::unordered_set cache; + std::unordered_set calc_grad; + for (auto& primal : primals_) { + primal.set_tracer(false); + calc_grad.insert(primal.id()); + cache.insert(primal.id()); + } + + std::vector tape; + + std::function recurse; + recurse = [&](auto& a) { + auto id = a.id(); + a.set_tracer(false); + + // Check if visited and add to cache if not + if (auto inserted = cache.insert(id); !inserted.second) { + return; + } + + for (auto& input : a.editable_inputs()) { + recurse(input); + } + + // Stop grad + if (a.has_primitive() && typeid(a.primitive()) == typeid(StopGradient)) { + return; + } + + // Calculate gradient if any inputs require gradient + for (auto& input : a.inputs()) { + if (calc_grad.find(input.id()) != calc_grad.end()) { + tape.push_back(a); + calc_grad.insert(id); + break; + } + } + }; + + for (auto& out : outputs) { + recurse(out); + } + std::unordered_map tan_map; + for (int i = 0; i < primals_.size(); ++i) { + tan_map.insert({primals_[i].id(), tangents[i]}); + } + + for (auto& a : tape) { + // Get the arguments used in the jvp + std::vector argnums; + std::vector tangents; + for (int i = 0; i < a.inputs().size(); ++i) { + if (auto it = tan_map.find(a.inputs()[i].id()); it != tan_map.end()) { + argnums.push_back(i); + tangents.push_back(it->second); + } + } + + auto jvp = a.primitive().jvp(a.inputs(), tangents, argnums); + tan_map.insert({a.id(), jvp}); + } + + std::vector jvps; + for (auto& out : outputs) { + if (auto it = tan_map.find(out.id()); it != tan_map.end()) { + jvps.push_back(it->second); + } else { + auto s = out.has_primitive() ? out.primitive().stream() + : default_stream(default_device()); + jvps.push_back(zeros_like(out, s)); + } + } + return {outputs, jvps}; +} + +std::pair jvp( + const std::function& fun, + const array& primal, + const array& tangent) { + auto vec_fun = [fun](const std::vector& inputs) { + return std::vector{fun(inputs[0])}; + }; + auto [outputs, jvps] = jvp(vec_fun, {primal}, {tangent}); + return {outputs[0], jvps[0]}; +} + +ValueAndGradFn value_and_grad( + const std::function(const std::vector&)>& fun, + const std::vector& argnums) { + if (argnums.empty()) { + throw std::invalid_argument("[grad] Must specify at least one argument."); + } + return [fun, argnums](const std::vector& inputs) { + std::set args; + for (auto& arg : argnums) { + args.insert(arg < 0 ? arg + inputs.size() : arg); + } + if (args.size() != argnums.size()) { + throw std::invalid_argument( + "[grad] Repeat argument number not allowed in grad."); + } + if (*args.begin() < 0 || *args.rbegin() >= inputs.size()) { + std::ostringstream msg; + msg << "[grad] Invalid argument number for function with " + << inputs.size() << " inputs."; + throw std::invalid_argument(msg.str()); + } + + auto gfun = [&fun, &inputs, &args](const std::vector& ginputs) { + std::vector inputs_(inputs); + auto argit = args.begin(); + for (int i = 0; i < ginputs.size(); ++i) { + inputs_[*argit] = ginputs[i]; + ++argit; + } + auto outputs = fun(inputs_); + for (int i = 1; i < outputs.size(); i++) { + auto& out = outputs[i]; + auto s = out.has_primitive() ? out.primitive().stream() + : default_stream(default_device()); + outputs[i] = stop_gradient(out, s); + } + return outputs; + }; + + std::vector ginputs; + for (auto arg : args) { + ginputs.push_back(inputs[arg]); + } + // Set the incoming gradient as int32 so that it will be promoted to the + // appropriate floating point type op(int, floatXX) -> floatXX for most ops + auto [outputs, grads] = vjp(gfun, ginputs, {array(1)}); + return std::make_pair(outputs, grads); + }; +} + +namespace detail { + +std::pair, std::vector> vmap_trace( + const std::function(const std::vector&)>& fun, + const std::vector& inputs, + const std::vector& in_axes) { + if (in_axes.size() != inputs.size()) { + throw std::invalid_argument( + "[vmap] The number of in axes must match the number of inputs."); + } + + // Run the function on placeholder inputs + // to get the original graph + std::vector s_inputs; + for (int i = 0; i < inputs.size(); ++i) { + if (in_axes[i] != -1) { + if (inputs[i].ndim() == 0) { + throw std::invalid_argument( + "[vmap] Cannot vmap an input with zero dimensions."); + } + if (in_axes[i] > inputs[i].ndim()) { + std::ostringstream msg; + msg << "[vmap] Axis " << in_axes[i] << " invalid for input with " + << inputs[i].ndim() << " dimensions."; + throw std::invalid_argument(msg.str()); + } + + std::vector shape = inputs[i].shape(); + shape.erase(shape.begin() + in_axes[i]); + array in(shape, inputs[i].dtype(), nullptr, {}); + s_inputs.push_back(in); + s_inputs.back().set_tracer(true); + } else { + s_inputs.push_back(inputs[i]); + } + } + return {s_inputs, fun(s_inputs)}; +} + +std::vector vmap_replace( + const std::vector& inputs, + const std::vector& s_inputs, + const std::vector& s_outputs, + const std::vector& in_axes, + const std::vector& out_axes) { + if (out_axes.size() != s_outputs.size()) { + throw std::invalid_argument( + "[vmap] The number of out axes must match the number of outputs."); + } + + std::unordered_map> tmap; + std::unordered_set needs_vmap; + for (int i = 0; i < s_inputs.size(); ++i) { + if (in_axes[i] != -1) { + tmap.insert({s_inputs[i].id(), {inputs[i], in_axes[i]}}); + needs_vmap.insert(s_inputs[i].id()); + } + } + + // Topologically sort the graph + std::unordered_set cache; + for (int i = 0; i < s_inputs.size(); ++i) { + auto in = s_inputs[i]; + if (in_axes[i] != -1) { + in.set_tracer(false); + } + cache.insert(in.id()); + } + std::vector tape; + + std::function recurse; + + recurse = [&](const array& a) { + // Stop at inputs to the vmap function + auto id = a.id(); + if (cache.find(id) != cache.end()) { + return; + } + for (auto& input : a.inputs()) { + recurse(input); + } + cache.insert(id); + for (auto& input : a.inputs()) { + if (needs_vmap.find(input.id()) != needs_vmap.end()) { + needs_vmap.insert(id); + tape.push_back(a); + tape.back().set_tracer(false); + break; + } + } + }; + + for (auto& out : s_outputs) { + recurse(out); + } + + // Transform each primitive in the graph with + // its vmap implementation + for (auto& a : tape) { + std::vector v_inputs; + std::vector v_axes; + for (auto& in : a.inputs()) { + auto map_it = tmap.find(in.id()); + if (map_it != tmap.end()) { + v_inputs.push_back(map_it->second.first); + v_axes.push_back(map_it->second.second); + } else { + v_inputs.push_back(in); + v_axes.push_back(-1); + } + } + auto out_and_axis = a.primitive().vmap(v_inputs, v_axes); + tmap.insert({a.id(), out_and_axis}); + } + + // Populate the outputs and make sure all the output axes are + // in the right place + std::vector outputs; + for (int i = 0; i < s_outputs.size(); ++i) { + auto map_it = tmap.find(s_outputs[i].id()); + if (map_it != tmap.end()) { + auto& [out, vdim] = map_it->second; + if (vdim != out_axes[i]) { + if (out_axes[i] >= out.ndim()) { + std::ostringstream msg; + msg << "[vmap] Axis " << out_axes[i] << " invalid for output with " + << out.ndim() << " dimensions."; + throw std::invalid_argument(msg.str()); + } + std::vector reorder(out.ndim()); + std::iota(reorder.begin(), reorder.end(), 0); + reorder.erase(reorder.begin() + vdim); + reorder.insert(reorder.begin() + out_axes[i], vdim); + out = transpose(out, reorder); + } + outputs.push_back(out); + } else { + outputs.push_back(s_outputs[i]); + } + } + return outputs; +} + +} // namespace detail + +std::function(const std::vector&)> vmap( + const std::function(const std::vector&)>& fun, + const std::vector& in_axes /* = {} */, + const std::vector& out_axes /* = {} */) { + auto infer_axes = [](auto axes) { + return !axes.empty() && + std::all_of(axes.begin(), axes.end(), [](int ax) { return ax < 0; }); + }; + if (infer_axes(in_axes) != infer_axes(out_axes)) { + throw std::invalid_argument( + "[vmap] Input (or output) axes must be " + "specified if output (or input) axes are."); + } + auto vfun = [fun, in_axes = in_axes, out_axes = out_axes]( + const std::vector& inputs) mutable { + if (in_axes.size() == 0) { + in_axes.resize(inputs.size(), 0); + } + + auto [trace_inputs, trace_outputs] = + detail::vmap_trace(fun, inputs, in_axes); + + if (out_axes.size() == 0) { + out_axes.resize(trace_outputs.size(), 0); + } + + return detail::vmap_replace( + inputs, trace_inputs, trace_outputs, in_axes, out_axes); + }; + + return vfun; +} + +std::function vmap( + const std::function& fun, + int in_axis_a /* = 0 */, + int in_axis_b /* = 0 */, + int out_axis /* = 0 */) { + auto vfun = vmap( + [in_axis_a, in_axis_b, out_axis, fun](const std::vector& inputs) { + return std::vector{fun(inputs[0], inputs[1])}; + }, + {in_axis_a, in_axis_b}, + {out_axis}); + return [vfun](const array& a, const array& b) { return vfun({a, b})[0]; }; +} + +std::function vmap( + const std::function& fun, + int in_axis /* = 0 */, + int out_axis /* = 0 */) { + auto vfun = vmap( + [in_axis, out_axis, fun](const std::vector& inputs) { + return std::vector{fun(inputs[0])}; + }, + {in_axis}, + {out_axis}); + return [vfun](const array& a) { return vfun({a})[0]; }; +} + +} // namespace mlx::core diff --git a/mlx/transforms.h b/mlx/transforms.h new file mode 100644 index 000000000..88296521e --- /dev/null +++ b/mlx/transforms.h @@ -0,0 +1,185 @@ +#pragma once + +#include "array.h" + +namespace mlx::core { + +/** Fuse equivalent arrays to avoid duplicate execution. */ +void simplify(const std::vector& outputs); + +template +void simplify(Arrays... outputs) { + simplify(std::vector{std::forward(outputs)...}); +} + +void eval(const std::vector& outputs, bool retain_graph = false); + +template +void eval(Arrays... outputs) { + eval(std::vector{std::forward(outputs)...}, false); +} + +/** + * Computes the output and vector-Jacobian product (VJP) of a function. + * + * Computes the vector-Jacobian product of the vector of cotangents with the + * Jacobian of the function evaluated at the primals. Returns a pair of + * vectors of output arrays and VJP arrays. + **/ +std::pair, std::vector> vjp( + const std::function(const std::vector&)>& fun, + const std::vector& primals, + const std::vector& cotangents); + +/** + * Computes the output and vector-Jacobian product (VJP) of a unary function. + */ +std::pair vjp( + const std::function& fun, + const array& primal, + const array& cotangent); + +/** + * Computes the output and Jacobian-vector product (JVP) of a function. + * + * Computes the Jacobian-vector product of the Jacobian of the function + * evaluated at the primals with the vector of tangents. Returns a pair of + * vectors of output arrays and JVP arrays. + **/ +std::pair, std::vector> jvp( + const std::function(const std::vector&)>& fun, + const std::vector& primals, + const std::vector& tangents); + +/** + * Computes the output and Jacobian-vector product (JVP) of a unary function. + */ +std::pair jvp( + const std::function& fun, + const array& primal, + const array& tangent); + +// Return type of general value_and_grad: a function which takes an input +// vector of arrays and returns a pair of vectors of arrays one for the +// values and one for the gradients wrt the first value. +using ValueAndGradFn = + std::function, std::vector>( + const std::vector&)>; +using SimpleValueAndGradFn = std::function>( + const std::vector&)>; + +/** + * Returns a function which computes the value and gradient of the input + * function with respect to a vector of input arrays. + **/ +ValueAndGradFn value_and_grad( + const std::function(const std::vector&)>& fun, + const std::vector& argnums); + +/** + * Returns a function which computes the value and gradient of the input + * function with repsect to a single input array. + **/ +ValueAndGradFn inline value_and_grad( + const std::function(const std::vector&)>& fun, + int argnum = 0) { + return value_and_grad(fun, std::vector{argnum}); +} + +/** + * Returns a function which computes the value and gradient of the unary + * input function. + **/ +std::function(const array&)> inline value_and_grad( + const std::function& fun) { + return [fun](auto inputs) { return vjp(fun, inputs, array(1.0f)); }; +} + +SimpleValueAndGradFn inline value_and_grad( + const std::function&)>& fun, + const std::vector& argnums) { + return [fun, argnums](auto inputs) { + auto result = value_and_grad( + [fun](auto inputs) { return std::vector{fun(inputs)}; }, + argnums)(inputs); + + return std::make_pair(result.first[0], result.second); + }; +} + +SimpleValueAndGradFn inline value_and_grad( + const std::function&)>& fun, + int argnum = 0) { + return value_and_grad(fun, std::vector{argnum}); +} + +/** + * Returns a function which computes the gradient of the input function with + * respect to a vector of input arrays. + * + * The function being differentiated takes a vector of arrays and returns an + * array. The vector of `argnums` specifies which the arguments to compute + * the gradient with respect to. At least one argument must be specified. + **/ +std::function(const std::vector&)> inline grad( + const std::function&)>& fun, + const std::vector& argnums) { + auto fn = value_and_grad(fun, argnums); + return [fn](const std::vector& inputs) { return fn(inputs).second; }; +} + +/** + * Returns a function which computes the gradient of the input function with + * repsect to a single input array. + * + * The function being differentiated takes a vector of arrays and returns an + * array. The optional `argnum` index specifies which the argument to compute + * the gradient with respect to and defaults to 0. + **/ +std::function(const std::vector&)> inline grad( + const std::function&)>& fun, + int argnum = 0) { + return grad(fun, std::vector{argnum}); +} + +/** + * Returns a function which computes the gradient of the unary input function. + **/ +std::function inline grad( + const std::function& fun) { + auto fn = value_and_grad(fun); + return [fn](const array& input) { return fn(input).second; }; +} + +/** + * Automatically vectorize a unary function over the requested axes. + */ +std::function vmap( + const std::function& fun, + int in_axis = 0, + int out_axis = 0); + +/** + * Automatically vectorize a binary function over the requested axes. + */ +std::function vmap( + const std::function& fun, + int in_axis_a = 0, + int in_axis_b = 0, + int out_axis = 0); + +/** + * Automatically vectorize a function over the requested axes. + * + * The input function to `vmap` takes as an argument a vector of arrays and + * returns a vector of arrays. Optionally specify the axes to vectorize over + * with `in_axes` and `out_axes`, otherwise a default of 0 is used. + * Returns a vectorized function with the same signature as the input + * function. + */ +std::function(const std::vector&)> vmap( + const std::function(const std::vector&)>& fun, + const std::vector& in_axes = {}, + const std::vector& out_axes = {}); + +} // namespace mlx::core diff --git a/mlx/types/bf16.h b/mlx/types/bf16.h new file mode 100644 index 000000000..102aea91e --- /dev/null +++ b/mlx/types/bf16.h @@ -0,0 +1,185 @@ +#pragma once + +#include +#include +#include +#include + +#define __MLX_BFLOAT_NAN__ 0x7FC0 + +namespace mlx::core { + +namespace { +union float_bits_bf16 { + float f; + uint32_t u; +}; +} // namespace + +struct _MLX_BFloat16 { + uint16_t bits_; + + // Default constructor + _MLX_BFloat16() = default; + + // Default copy constructor + _MLX_BFloat16(_MLX_BFloat16 const&) = default; + + // Appease std::vector for being special + _MLX_BFloat16& operator=(std::vector::reference x) { + bits_ = x; + return *this; + } + + _MLX_BFloat16& operator=(const float& x) { + return (*this = _MLX_BFloat16(x)); + } + + // From float32 + _MLX_BFloat16(const float& x) { + if (std::isnan(x)) { + bits_ = __MLX_BFLOAT_NAN__; + } else { + // Union + float_bits_bf16 in; + + // Take bits + in.f = x; + + // Round to nearest even + in.u += (in.u >> 16 & 1) + uint32_t(0x7FFF); + + // Take upper 16 bits + bits_ = in.u >> 16; + } + } + + // To float32 + operator float() const { + // Union + float_bits_bf16 out; + + // Upper 16 bits are the data and lower 16 bits are 0s + out.u = ((uint32_t)bits_) << 16; + + return out.f; + } +}; + +#define bfloat_binop_base(__op__, __operator__, otype, atype, btype, ctype) \ + inline otype __operator__(atype lhs, btype rhs) { \ + return static_cast(lhs) __op__ static_cast(rhs); \ + } + +#define bfloat_binop_helper(__op__, __operator__, otype, itype, ctype) \ + inline otype __operator__(_MLX_BFloat16 lhs, itype rhs) { \ + return static_cast(lhs) __op__ static_cast(rhs); \ + } \ + inline otype __operator__(itype lhs, _MLX_BFloat16 rhs) { \ + return static_cast(lhs) __op__ static_cast(rhs); \ + } + +// Operators +#define bfloat_binop(_op_, _operator_) \ + bfloat_binop_base( \ + _op_, _operator_, _MLX_BFloat16, _MLX_BFloat16, _MLX_BFloat16, float); \ + bfloat_binop_helper(_op_, _operator_, float, float, float); \ + bfloat_binop_helper(_op_, _operator_, double, double, double); \ + bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, bool, float); \ + bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int32_t, float); \ + bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint32_t, float); \ + bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int64_t, float); \ + bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint64_t, float); + +bfloat_binop(+, operator+); +bfloat_binop(-, operator-); +bfloat_binop(*, operator*); +bfloat_binop(/, operator/); + +#undef bfloat_binop + +// Comparison ops +#define bfloat_compop(__op__, __operator__) \ + bfloat_binop_base( \ + __op__, __operator__, bool, _MLX_BFloat16, _MLX_BFloat16, float); \ + bfloat_binop_helper(__op__, __operator__, bool, float, float); \ + bfloat_binop_helper(__op__, __operator__, bool, double, double); \ + bfloat_binop_helper(__op__, __operator__, bool, int32_t, float); \ + bfloat_binop_helper(__op__, __operator__, bool, uint32_t, float); \ + bfloat_binop_helper(__op__, __operator__, bool, int64_t, float); \ + bfloat_binop_helper(__op__, __operator__, bool, uint64_t, float); + +bfloat_compop(>, operator>); +bfloat_compop(<, operator<); +bfloat_compop(>=, operator>=); +bfloat_compop(<=, operator<=); +bfloat_compop(==, operator==); +bfloat_compop(!=, operator!=); + +#undef bfloat_compop + +// Negative +inline _MLX_BFloat16 operator-(_MLX_BFloat16 lhs) { + return -static_cast(lhs); +} + +// Inplace ops +#define bfloat_inplace_op(__op__, __operator__) \ + inline _MLX_BFloat16& __operator__(_MLX_BFloat16& lhs, const float& rhs) { \ + lhs = lhs __op__ rhs; \ + return lhs; \ + } \ + inline float& __operator__(float& lhs, _MLX_BFloat16 rhs) { \ + lhs = lhs __op__ rhs; \ + return lhs; \ + } + +bfloat_inplace_op(+, operator+=); +bfloat_inplace_op(-, operator-=); +bfloat_inplace_op(*, operator*=); +bfloat_inplace_op(/, operator/=); + +#undef bfloat_inplace_op + +// Bitwise ops + +#define bfloat_bitop(__op__, __operator__) \ + inline _MLX_BFloat16 __operator__(_MLX_BFloat16 lhs, _MLX_BFloat16 rhs) { \ + _MLX_BFloat16 out; \ + out.bits_ = lhs.bits_ __op__ rhs.bits_; \ + return out; \ + } \ + inline _MLX_BFloat16 __operator__(_MLX_BFloat16 lhs, uint16_t rhs) { \ + _MLX_BFloat16 out; \ + out.bits_ = lhs.bits_ __op__ rhs; \ + return out; \ + } \ + inline _MLX_BFloat16 __operator__(uint16_t lhs, _MLX_BFloat16 rhs) { \ + _MLX_BFloat16 out; \ + out.bits_ = lhs __op__ rhs.bits_; \ + return out; \ + } + +bfloat_bitop(|, operator|); +bfloat_bitop(&, operator&); +bfloat_bitop(^, operator^); + +#undef bfloat_bitop + +#define bfloat_inplace_bitop(__op__, __operator__) \ + inline _MLX_BFloat16& __operator__(_MLX_BFloat16& lhs, _MLX_BFloat16 rhs) { \ + lhs.bits_ = lhs.bits_ __op__ rhs.bits_; \ + return lhs; \ + } \ + inline _MLX_BFloat16& __operator__(_MLX_BFloat16& lhs, uint16_t rhs) { \ + lhs.bits_ = lhs.bits_ __op__ rhs; \ + return lhs; \ + } + +bfloat_inplace_bitop(|, operator|=); +bfloat_inplace_bitop(&, operator&=); +bfloat_inplace_bitop(^, operator^=); + +#undef bfloat_inplace_bitop + +} // namespace mlx::core diff --git a/python/mlx/nn/__init__.py b/python/mlx/nn/__init__.py new file mode 100644 index 000000000..4991254f8 --- /dev/null +++ b/python/mlx/nn/__init__.py @@ -0,0 +1,3 @@ +from mlx.nn.layers import * +from mlx.nn import losses +from mlx.nn.utils import value_and_grad diff --git a/python/mlx/nn/layers/__init__.py b/python/mlx/nn/layers/__init__.py new file mode 100644 index 000000000..0436441af --- /dev/null +++ b/python/mlx/nn/layers/__init__.py @@ -0,0 +1,23 @@ +from mlx.nn.layers.base import Module +from mlx.nn.layers.activations import ( + GELU, + ReLU, + SiLU, + gelu, + gelu_approx, + gelu_fast_approx, + relu, + silu, +) +from mlx.nn.layers.containers import Sequential +from mlx.nn.layers.convolution import Conv1d, Conv2d +from mlx.nn.layers.dropout import Dropout +from mlx.nn.layers.embedding import Embedding +from mlx.nn.layers.linear import Linear +from mlx.nn.layers.normalization import GroupNorm, LayerNorm, RMSNorm +from mlx.nn.layers.positional_encoding import RoPE, SinusoidalPositionalEncoding +from mlx.nn.layers.transformer import ( + MultiHeadAttention, + TransformerEncoder, + TransformerEncoderLayer, +) diff --git a/python/mlx/nn/layers/activations.py b/python/mlx/nn/layers/activations.py new file mode 100644 index 000000000..ef1cde035 --- /dev/null +++ b/python/mlx/nn/layers/activations.py @@ -0,0 +1,129 @@ +import math + +import mlx.core as mx +from mlx.nn.layers.base import Module + + +def _make_activation_module(f): + def decorator(klass): + klass.__doc__ = f.__doc__ + klass.__call__ = lambda self, x: f(x) + return klass + + return decorator + + +def relu(x): + """Applies the Rectified Linear Unit. + + Simply ``mx.maximum(x, 0)``. + """ + return mx.maximum(x, 0) + + +def silu(x): + r"""Applies the Sigmoid Linear Unit. + + Applies :math:`x \sigma(x)` element wise, where :math:`\sigma(\cdot)` is + the logistic sigmoid. + """ + return x * mx.sigmoid(x) + + +def gelu(x): + """Applies the Gaussian Error Linear Units function. + + .. math:: + \\textrm{GELU}(x) = x * \Phi(x) + + where :math:`\Phi(x)` is the Gaussian CDF. + + See also :func:`gelu_approx` and :func:`gelu_fast_approx` for faster + approximations. + """ + return x * (1 + mx.erf(x / math.sqrt(2))) / 2 + + +def gelu_approx(x): + r"""An approximation to Gaussian Error Linear Unit. + + See :func:`gelu` for the exact computation. + + This function approximates ``gelu`` with a maximum absolute error :math:`< + 0.0003` in the range :math:`[-6, 6]` using the following + + .. math:: + + x = x \sigma\left(1.60033 x \left(1 + 0.0433603 x^2\right)\right) + + where :math:`\sigma(\cdot)` is the logistic sigmoid. + """ + return x * mx.sigmoid(1.60033 * x * (1 + 0.0433603 * x.square())) + + +def gelu_fast_approx(x): + r"""A fast approximation to Gaussian Error Linear Unit. + + See :func:`gelu` for the exact computation. + + This function approximates ``gelu`` with a maximum absolute error :math:`< + 0.015` in the range :math:`[-6, 6]` using the following + + .. math:: + + x = x \sigma\left(1.773 x\right) + + where :math:`\sigma(\cdot)` is the logistic sigmoid. + """ + return x * mx.sigmoid(1.773 * x) + + +@_make_activation_module(relu) +class ReLU(Module): + pass + + +@_make_activation_module(silu) +class SiLU(Module): + pass + + +class GELU(Module): + r"""Applies the Gaussian Error Linear Units. + + .. math:: + \textrm{GELU}(x) = x * \Phi(x) + + where :math:`\Phi(x)` is the Gaussian CDF. + + However, if ``approx`` is set to 'precise' or 'fast' it applies + + .. math:: + \textrm{GELUApprox}(x) &= x * \sigma\left(1.60033 * x \left(1 + 0.0433603 * x^2\right)\right) \\ + \textrm{GELUFast}(x) &= x * \sigma\left(1.773 * x\right) + + respectively. + + See :func:`gelu`, :func:`gelu_approx` and :func:`gelu_fast_approx` for the + functional equivalents and information regarding error bounds. + + Args: + approx ('none' | 'precise' | 'fast'): Which approximation to gelu to use if any. + """ + + def __init__(self, approx="none"): + super().__init__() + + if approx == "none": + self._act = gelu + elif approx == "precise": + self._act = gelu_approx + elif approx == "fast": + self._act = gelu_fast_approx + else: + raise ValueError( + f"The approximation should be in ['none', 'precise', 'fast'] but '{approx}' was given" + ) + + def __call__(self, x): + return self._act(x) diff --git a/python/mlx/nn/losses.py b/python/mlx/nn/losses.py new file mode 100644 index 000000000..800242c70 --- /dev/null +++ b/python/mlx/nn/losses.py @@ -0,0 +1,6 @@ +import mlx.core as mx + + +def cross_entropy(logits: mx.array, targets: mx.array, axis: int = -1): + score = mx.take_along_axis(logits, targets[..., None], axis).squeeze(-1) + return mx.logsumexp(logits, axis=axis) - score diff --git a/python/mlx/utils.py b/python/mlx/utils.py new file mode 100644 index 000000000..47842d83d --- /dev/null +++ b/python/mlx/utils.py @@ -0,0 +1,136 @@ +def tree_map(fn, tree, *rest): + """Applies ``fn`` to the leaves of the python tree ``tree`` and + returns a new collection with the results. + + If ``rest`` is provided, every item is assumed to be a superset of ``tree`` + and the corresponding leaves are provided as extra positional arguments to + ``fn``. In that respect, :meth:`tree_map` is closer to :func:`itertools.starmap` + than to :func:`map`. + + .. code-block:: python + + import mlx.nn as nn + from mlx.utils import tree_map + + model = nn.Linear(10, 10) + print(model.parameters().keys()) + # dict_keys(['weight', 'bias']) + + # square the parameters + model.update(tree_map(lambda x: x*x, model.parameters())) + + Args: + fn (Callable): The function that processes the leaves of the tree + tree (Any): The main python tree that will be iterated upon + rest (Tuple[Any]): Extra trees to be iterated together with tree + + Returns: + A python tree with the new values returned by ``fn``. + """ + if isinstance(tree, list): + return [ + tree_map(fn, child, *(r[i] for r in rest)) for i, child in enumerate(tree) + ] + elif isinstance(tree, tuple): + return tuple( + tree_map(fn, child, *(r[i] for r in rest)) for i, child in enumerate(tree) + ) + elif isinstance(tree, dict): + return { + k: tree_map(fn, child, *(r[k] for r in rest)) for k, child in tree.items() + } + else: + return fn(tree, *rest) + + +def tree_flatten(tree, prefix="", is_leaf=None): + """Flattens a python tree to a list of key, value tuples. + + The keys are using the dot notation to define trees of arbitrary depth and + complexity. + + .. code-block:: python + + from mlx.utils import tree_flatten + + print(tree_flatten([[[0]]])) + # [("0.0.0", 0)] + + print(tree_flatten([[[0]]], ".hello")) + # [("hello.0.0.0", 0)] + + .. note:: + Dictionaries should have keys that are valid python identifiers. + + Args: + tree (Any): The python tree to be flattened. + prefix (str): A prefix to use for the keys. The first character is + always discarded. + is_leaf (Callable): An optional callable that returns True if the + passed object is considered a leaf or False otherwise. + + Returns: + List[Tuple[str, Any]]: The flat representation of the python tree. + """ + flat_tree = [] + + if is_leaf is None or not is_leaf(tree): + if isinstance(tree, (list, tuple)): + for i, t in enumerate(tree): + flat_tree.extend(tree_flatten(t, f"{prefix}.{i}", is_leaf)) + return flat_tree + if isinstance(tree, dict): + for k, t in tree.items(): + flat_tree.extend(tree_flatten(t, f"{prefix}.{k}", is_leaf)) + return flat_tree + + return [(prefix[1:], tree)] + + +def tree_unflatten(tree): + """Recreate a python tree from its flat representation. + + .. code-block:: python + + from mlx.utils import tree_unflatten + + d = tree_unflatten([("hello.world", 42)]) + print(d) + # {"hello": {"world": 42}} + + Args: + tree (List[Tuple[str, Any]]): The flat representation of a python tree. + For instance as returned by :meth:`tree_flatten`. + + Returns: + A python tree. + """ + if len(tree) == 1 and tree[0][0] == "": + return tree[0][1] + + try: + int(tree[0][0].split(".", maxsplit=1)[0]) + is_list = True + except ValueError: + is_list = False + + # collect children + children = {} + for key, value in tree: + current_idx, *next_idx = key.split(".", maxsplit=1) + next_idx = "" if not next_idx else next_idx[0] + if current_idx not in children: + children[current_idx] = [] + children[current_idx].append((next_idx, value)) + + # recursively map them to the original container + if is_list: + keys = sorted((int(idx), idx) for idx in children.keys()) + l = [] + for i, k in keys: + while i > len(l): + l.append({}) + l.append(tree_unflatten(children[k])) + return l + else: + return {k: tree_unflatten(v) for k, v in children.items()} diff --git a/python/src/array.cpp b/python/src/array.cpp new file mode 100644 index 000000000..5e34bfdbd --- /dev/null +++ b/python/src/array.cpp @@ -0,0 +1,1071 @@ +#include +#include +#include + +#include + +#include "python/src/indexing.h" +#include "python/src/utils.h" + +#include "mlx/ops.h" +#include "mlx/transforms.h" +#include "mlx/utils.h" + +namespace py = pybind11; +using namespace py::literals; + +enum PyScalarT { + pybool = 0, + pyint = 1, + pyfloat = 2, + pycomplex = 3, +}; + +template +py::list to_list(array& a, size_t index, int dim) { + py::list pl; + auto stride = a.strides()[dim]; + for (int i = 0; i < a.shape(dim); ++i) { + if (dim == a.ndim() - 1) { + pl.append((a.data()[index])); + } else { + pl.append(to_list(a, index, dim + 1)); + } + index += stride; + } + return pl; +} + +auto to_scalar(array& a) { + bool retain_graph = a.is_tracer(); + switch (a.dtype()) { + case bool_: + return py::cast(a.item(retain_graph)); + case uint8: + return py::cast(a.item(retain_graph)); + case uint16: + return py::cast(a.item(retain_graph)); + case uint32: + return py::cast(a.item(retain_graph)); + case uint64: + return py::cast(a.item(retain_graph)); + case int8: + return py::cast(a.item(retain_graph)); + case int16: + return py::cast(a.item(retain_graph)); + case int32: + return py::cast(a.item(retain_graph)); + case int64: + return py::cast(a.item(retain_graph)); + case float16: + return py::cast(static_cast(a.item(retain_graph))); + case float32: + return py::cast(a.item(retain_graph)); + case bfloat16: + return py::cast(static_cast(a.item(retain_graph))); + case complex64: + return py::cast(a.item>(retain_graph)); + } +} + +py::object tolist(array& a) { + if (a.ndim() == 0) { + return to_scalar(a); + } + a.eval(a.is_tracer()); + py::object pl; + switch (a.dtype()) { + case bool_: + return to_list(a, 0, 0); + case uint8: + return to_list(a, 0, 0); + case uint16: + return to_list(a, 0, 0); + case uint32: + return to_list(a, 0, 0); + case uint64: + return to_list(a, 0, 0); + case int8: + return to_list(a, 0, 0); + case int16: + return to_list(a, 0, 0); + case int32: + return to_list(a, 0, 0); + case int64: + return to_list(a, 0, 0); + case float16: + return to_list(a, 0, 0); + case float32: + return to_list(a, 0, 0); + case bfloat16: + return to_list(a, 0, 0); + case complex64: + return to_list>(a, 0, 0); + } +} + +template +void fill_vector(T list, std::vector& vals) { + for (auto l : list) { + if (py::isinstance(l)) { + fill_vector(l.template cast(), vals); + } else if (py::isinstance(*list.begin())) { + fill_vector(l.template cast(), vals); + } else { + vals.push_back(l.template cast()); + } + } +} + +template +PyScalarT validate_shape(T list, const std::vector& shape, int idx) { + if (idx >= shape.size()) { + throw std::invalid_argument("Initialization encountered extra dimension."); + } + auto s = shape[idx]; + if (py::len(list) != s) { + throw std::invalid_argument( + "Initialization encountered non-uniform length."); + } + + if (s == 0) { + return pyfloat; + } + + PyScalarT type = pybool; + for (auto l : list) { + PyScalarT t; + if (py::isinstance(l)) { + t = validate_shape(l.template cast(), shape, idx + 1); + } else if (py::isinstance(*list.begin())) { + t = validate_shape(l.template cast(), shape, idx + 1); + } else if (py::isinstance(l)) { + t = pybool; + } else if (py::isinstance(l)) { + t = pyint; + } else if (py::isinstance(l)) { + t = pyfloat; + } else if (PyComplex_Check(l.ptr())) { + t = pycomplex; + } else { + std::ostringstream msg; + msg << "Invalid type in array initialization" << l.get_type() << "."; + throw std::invalid_argument(msg.str()); + } + type = std::max(type, t); + } + return type; +} + +template +void get_shape(T list, std::vector& shape) { + shape.push_back(py::len(list)); + if (shape.back() > 0) { + auto& l = *list.begin(); + if (py::isinstance(l)) { + return get_shape(l.template cast(), shape); + } else if (py::isinstance(l)) { + return get_shape(l.template cast(), shape); + } + } +} + +template +array array_from_list(T pl, std::optional dtype) { + // Compute the shape + std::vector shape; + get_shape(pl, shape); + + // Validate the shape and type + auto type = validate_shape(pl, shape, 0); + + size_t size = 1; + for (auto s : shape) { + size *= s; + } + + // Make the array + switch (type) { + case pybool: { + std::vector vals; + fill_vector(pl, vals); + return array(vals.begin(), shape, dtype.value_or(bool_)); + } + case pyint: { + std::vector vals; + fill_vector(pl, vals); + return array(vals.begin(), shape, dtype.value_or(int32)); + } + case pyfloat: { + std::vector vals; + fill_vector(pl, vals); + return array(vals.begin(), shape, dtype.value_or(float32)); + } + case pycomplex: { + std::vector> vals; + fill_vector(pl, vals); + return array( + reinterpret_cast(vals.data()), + shape, + dtype.value_or(complex64)); + } + } +} + +/////////////////////////////////////////////////////////////////////////////// +// MLX -> Numpy +/////////////////////////////////////////////////////////////////////////////// + +size_t elem_to_loc( + int elem, + const std::vector& shape, + const std::vector& strides) { + size_t loc = 0; + for (int i = shape.size() - 1; i >= 0; --i) { + auto q_and_r = ldiv(elem, shape[i]); + loc += q_and_r.rem * strides[i]; + elem = q_and_r.quot; + } + return loc; +} + +struct PyArrayPayload { + array a; +}; + +template +py::array_t mlx_array_to_np_t(const array& src) { + // Let py::capsule hold onto a copy of the array which holds a shared ptr to + // the data + const py::capsule freeWhenDone(new PyArrayPayload({src}), [](void* payload) { + delete reinterpret_cast(payload); + }); + // Collect strides + std::vector strides{src.strides().begin(), src.strides().end()}; + for (int i = 0; i < src.ndim(); i++) { + strides[i] *= src.itemsize(); + } + // Pack the capsule with the array + py::array_t out(src.shape(), strides, src.data(), freeWhenDone); + // Mark array as read-only + py::detail::array_proxy(out.ptr())->flags &= + ~py::detail::npy_api::NPY_ARRAY_WRITEABLE_; + // Return array + return py::array_t(src.shape(), strides, src.data(), out); +} + +template +py::array mlx_array_to_np_t(const array& src, const py::dtype& dt) { + // Let py::capsule hold onto a copy of the array which holds a shared ptr to + // the data + const py::capsule freeWhenDone(new PyArrayPayload({src}), [](void* payload) { + delete reinterpret_cast(payload); + }); + // Collect strides + std::vector strides{src.strides().begin(), src.strides().end()}; + for (int i = 0; i < src.ndim(); i++) { + strides[i] *= src.itemsize(); + } + // Pack the capsule with the array + py::array out(dt, src.shape(), strides, src.data(), freeWhenDone); + // Mark array as read-only + py::detail::array_proxy(out.ptr())->flags &= + ~py::detail::npy_api::NPY_ARRAY_WRITEABLE_; + // Return array + return py::array(dt, src.shape(), strides, src.data(), out); +} + +py::array mlx_array_to_np(const array& src) { + // Eval if not already evaled + if (!src.is_evaled()) { + eval({src}, src.is_tracer()); + } + + switch (src.dtype()) { + case bool_: + return mlx_array_to_np_t(src); + case uint8: + return mlx_array_to_np_t(src); + case uint16: + return mlx_array_to_np_t(src); + case uint32: + return mlx_array_to_np_t(src); + case uint64: + return mlx_array_to_np_t(src); + case int8: + return mlx_array_to_np_t(src); + case int16: + return mlx_array_to_np_t(src); + case int32: + return mlx_array_to_np_t(src); + case int64: + return mlx_array_to_np_t(src); + case float16: + return mlx_array_to_np_t(src, py::dtype("float16")); + case float32: + return mlx_array_to_np_t(src); + case bfloat16: { + auto a = astype(src, float32); + eval({a}, src.is_tracer()); + return mlx_array_to_np_t(a); + } + case complex64: + return mlx_array_to_np_t(src, py::dtype("complex64")); + } +} + +/////////////////////////////////////////////////////////////////////////////// +// Numpy -> MLX +/////////////////////////////////////////////////////////////////////////////// + +template +array np_array_to_mlx_contiguous( + py::array_t np_array, + const std::vector& shape, + Dtype dtype) { + // Make a copy of the numpy buffer + // Get buffer ptr pass to array constructor + py::buffer_info buf = np_array.request(); + const T* data_ptr = static_cast(buf.ptr); + return array(data_ptr, shape, dtype); + + // Note: Leaving the following memoryless copy from np to mx commented + // out for the time being since it is unsafe given that the incoming + // numpy array may change the underlying data + + // // Share underlying numpy buffer + // // Copy to increase ref count + // auto deleter = [np_array](void*) {}; + // void* data_ptr = np_array.mutable_data(); + // // Use buffer from numpy + // return array(data_ptr, deleter, shape, dtype); +} + +template <> +array np_array_to_mlx_contiguous( + py::array_t, py::array::c_style | py::array::forcecast> + np_array, + const std::vector& shape, + Dtype dtype) { + // Get buffer ptr pass to array constructor + py::buffer_info buf = np_array.request(); + auto data_ptr = static_cast*>(buf.ptr); + return array(reinterpret_cast(data_ptr), shape, dtype); +} + +array np_array_to_mlx(py::array np_array, std::optional dtype) { + // Compute the shape and size + std::vector shape; + for (int i = 0; i < np_array.ndim(); i++) { + shape.push_back(np_array.shape(i)); + } + + // Get dtype + auto type = np_array.dtype(); + + // Copy data and make array + if (type.is(py::dtype::of())) { + return np_array_to_mlx_contiguous( + np_array, shape, dtype.value_or(int32)); + } else if (type.is(py::dtype::of())) { + return np_array_to_mlx_contiguous( + np_array, shape, dtype.value_or(uint32)); + } else if (type.is(py::dtype::of())) { + return np_array_to_mlx_contiguous( + np_array, shape, dtype.value_or(bool_)); + } else if (type.is(py::dtype::of())) { + return np_array_to_mlx_contiguous( + np_array, shape, dtype.value_or(float32)); + } else if (type.is(py::dtype::of())) { + return np_array_to_mlx_contiguous( + np_array, shape, dtype.value_or(float32)); + } else if (type.is(py::dtype("float16"))) { + return np_array_to_mlx_contiguous( + np_array, shape, dtype.value_or(float16)); + } else if (type.is(py::dtype::of())) { + return np_array_to_mlx_contiguous( + np_array, shape, dtype.value_or(uint8)); + } else if (type.is(py::dtype::of())) { + return np_array_to_mlx_contiguous( + np_array, shape, dtype.value_or(uint16)); + } else if (type.is(py::dtype::of())) { + return np_array_to_mlx_contiguous( + np_array, shape, dtype.value_or(uint64)); + } else if (type.is(py::dtype::of())) { + return np_array_to_mlx_contiguous( + np_array, shape, dtype.value_or(int8)); + } else if (type.is(py::dtype::of())) { + return np_array_to_mlx_contiguous( + np_array, shape, dtype.value_or(int16)); + } else if (type.is(py::dtype::of())) { + return np_array_to_mlx_contiguous( + np_array, shape, dtype.value_or(int64)); + } else if (type.is(py::dtype::of>())) { + return np_array_to_mlx_contiguous>( + np_array, shape, dtype.value_or(complex64)); + } else if (type.is(py::dtype::of>())) { + return np_array_to_mlx_contiguous>( + np_array, shape, dtype.value_or(complex64)); + } else { + std::ostringstream msg; + msg << "Cannot convert numpy array of type " << type << " to mlx array."; + throw std::invalid_argument(msg.str()); + } +} + +/////////////////////////////////////////////////////////////////////////////// +// Module +/////////////////////////////////////////////////////////////////////////////// + +void init_array(py::module_& m) { + // Types + py::class_( + m, + "Dtype", + R"pbdoc( + An object to hold the type of a :class:`array`. + + See the :ref:`list of types ` for more details + on available data types. + )pbdoc") + .def_readonly( + "size", &Dtype::size, R"pbdoc(Size of the type in bytes.)pbdoc") + .def( + "__repr__", + [](const Dtype& t) { + std::ostringstream os; + os << t; + return os.str(); + }) + .def("__eq__", [](const Dtype& t1, const Dtype& t2) { return t1 == t2; }); + m.attr("bool_") = py::cast(bool_); + m.attr("uint8") = py::cast(uint8); + m.attr("uint16") = py::cast(uint16); + m.attr("uint32") = py::cast(uint32); + m.attr("uint64") = py::cast(uint64); + m.attr("int8") = py::cast(int8); + m.attr("int16") = py::cast(int16); + m.attr("int32") = py::cast(int32); + m.attr("int64") = py::cast(int64); + m.attr("float16") = py::cast(float16); + m.attr("float32") = py::cast(float32); + m.attr("bfloat16") = py::cast(bfloat16); + m.attr("complex64") = py::cast(complex64); + + py::class_(m, "array", R"pbdoc(An N-dimensional array object.)pbdoc") + .def( + py::init([](ScalarOrArray v, std::optional t) { + auto arr = to_array(v, t); + return astype(arr, t.value_or(arr.dtype())); + }), + "val"_a, + "dtype"_a = std::nullopt) + .def( + py::init([](std::variant pl, + std::optional dtype) { + if (auto pv = std::get_if(&pl); pv) { + return array_from_list(*pv, dtype); + } else { + auto v = std::get(pl); + return array_from_list(v, dtype); + } + }), + "vals"_a, + "dtype"_a = std::nullopt) + .def( + py::init([](py::array np_array, std::optional dtype) { + return np_array_to_mlx(np_array, dtype); + }), + "vals"_a, + "dtype"_a = std::nullopt) + .def( + py::init([](py::buffer np_buffer, std::optional dtype) { + return np_array_to_mlx(np_buffer, dtype); + }), + "vals"_a, + "dtype"_a = std::nullopt) + .def_property_readonly( + "size", &array::size, R"pbdoc(Number of elments in the array.)pbdoc") + .def_property_readonly( + "ndim", &array::ndim, R"pbdoc(The array's dimension.)pbdoc") + // TODO, this makes a deep copy of the shape + // implement alternatives to use reference + // https://pybind11.readthedocs.io/en/stable/advanced/cast/stl.html + .def_property_readonly( + "shape", + [](const array& a) { return a.shape(); }, + R"pbdoc( + The shape of the array as a Python list. + + Returns: + list(int): A list containing the sizes of each dimension. + )pbdoc") + .def_property_readonly( + "dtype", + &array::dtype, + R"pbdoc( + The array's :class:`Dtype`. + )pbdoc") + .def( + "item", + &to_scalar, + R"pbdoc( + Access the value of a scalar array. + + Returns: + Standard Python scalar. + )pbdoc") + .def( + "tolist", + &tolist, + R"pbdoc( + Convert the array to a Python :class:`list`. + + Returns: + list: The Python list. + + If the array is a scalar then a standard Python scalar is returned. + + If the array has more than one dimension then the result is a nested + list of lists. + + The value type of the list correpsonding to the last dimension is either + ``bool``, ``int`` or ``float`` depending on the ``dtype`` of the array. + )pbdoc") + .def("__array__", &mlx_array_to_np) + .def( + "astype", + &astype, + "dtype"_a, + "stream"_a = none, + R"pbdoc( + Cast the array to a specified type. + + Args: + dtype (Dtype): Type to which the array is cast. + stream (Stream): Stream (or device) for the operation. + + Returns: + array: The array with type ``dtype``. + )pbdoc") + .def("__getitem__", mlx_get_item) + .def("__setitem__", mlx_set_item) + .def( + "__len__", + [](const array& a) { + if (a.ndim() == 0) { + throw py::type_error("len() 0-dimensional array."); + } + return a.shape(0); + }) + .def( + "__iter__", + [](const array& a) { return py::make_iterator(a); }, + py::keep_alive<0, 1>()) + .def( + "__add__", + [](const array& a, const ScalarOrArray v) { + return add(a, to_array(v, a.dtype())); + }, + "other"_a) + .def( + "__radd__", + [](const array& a, const ScalarOrArray v) { + return add(a, to_array(v, a.dtype())); + }, + "other"_a) + .def( + "__sub__", + [](const array& a, const ScalarOrArray v) { + return subtract(a, to_array(v, a.dtype())); + }, + "other"_a) + .def( + "__rsub__", + [](const array& a, const ScalarOrArray v) { + return subtract(to_array(v, a.dtype()), a); + }, + "other"_a) + .def( + "__mul__", + [](const array& a, const ScalarOrArray v) { + return multiply(a, to_array(v, a.dtype())); + }, + "other"_a) + .def( + "__rmul__", + [](const array& a, const ScalarOrArray v) { + return multiply(a, to_array(v, a.dtype())); + }, + "other"_a) + .def( + "__truediv__", + [](const array& a, const ScalarOrArray v) { + return divide(a, to_array(v, float32)); + }, + "other"_a) + .def( + "__div__", + [](const array& a, const ScalarOrArray v) { + return divide(a, to_array(v, float32)); + }, + "other"_a) + .def( + "__rtruediv__", + [](const array& a, const ScalarOrArray v) { + return divide(to_array(v, float32), a); + }, + "other"_a) + .def( + "__rdiv__", + [](const array& a, const ScalarOrArray v) { + return divide(to_array(v, float32), a); + }, + "other"_a) + .def( + "__eq__", + [](const array& a, const ScalarOrArray v) { + return equal(a, to_array(v, a.dtype())); + }, + "other"_a) + .def( + "__lt__", + [](const array& a, const ScalarOrArray v) { + return less(a, to_array(v, a.dtype())); + }, + "other"_a) + .def( + "__le__", + [](const array& a, const ScalarOrArray v) { + return less_equal(a, to_array(v, a.dtype())); + }, + "other"_a) + .def( + "__gt__", + [](const array& a, const ScalarOrArray v) { + return greater(a, to_array(v, a.dtype())); + }, + "other"_a) + .def( + "__ge__", + [](const array& a, const ScalarOrArray v) { + return greater_equal(a, to_array(v, a.dtype())); + }, + "other"_a) + .def( + "__ne__", + [](const array& a, const ScalarOrArray v) { + return not_equal(a, to_array(v, a.dtype())); + }, + "other"_a) + .def("__neg__", [](const array& a) { return -a; }) + .def("__bool__", [](array& a) { return py::bool_(to_scalar(a)); }) + .def( + "__repr__", + [](array& a) { + if (!a.is_evaled()) { + a.eval(a.is_tracer()); + } + std::ostringstream os; + os << a; + return os.str(); + }) + .def( + "__matmul__", [](array& a, array& other) { return matmul(a, other); }) + .def( + "__pow__", + [](const array& a, const ScalarOrArray v) { + return power(a, to_array(v, a.dtype())); + }, + "other"_a) + .def( + "reshape", + [](const array& a, py::args shape, StreamOrDevice s) { + if (shape.size() == 1) { + py::object arg = shape[0]; + if (!py::isinstance(arg)) { + return reshape(a, py::cast>(arg), s); + } + } + return reshape(a, py::cast>(shape), s); + }, + py::kw_only(), + "stream"_a = none, + R"pbdoc( + Equivalent to :func:`reshape` but the shape can be passed either as a + tuple or as separate arguments. + + See :func:`reshape` for full documentation. + )pbdoc") + .def( + "squeeze", + [](const array& a, const IntOrVec& v, const StreamOrDevice& s) { + if (std::holds_alternative(v)) { + return squeeze(a, s); + } else if (auto pv = std::get_if(&v); pv) { + return squeeze(a, *pv, s); + } else { + return squeeze(a, std::get>(v), s); + } + }, + "axis"_a = none, + py::kw_only(), + "stream"_a = none, + R"pbdoc( + See :func:`squeeze`. + )pbdoc") + .def( + "abs", + &mlx::core::abs, + py::kw_only(), + "stream"_a = none, + "See :func:`abs`.") + .def( + "square", + &square, + py::kw_only(), + "stream"_a = none, + "See :func:`square`.") + .def( + "sqrt", + &mlx::core::sqrt, + py::kw_only(), + "stream"_a = none, + "See :func:`sqrt`.") + .def( + "rsqrt", + &rsqrt, + py::kw_only(), + "stream"_a = none, + "See :func:`rsqrt`.") + .def( + "reciprocal", + &reciprocal, + py::kw_only(), + "stream"_a = none, + "See :func:`reciprocal`.") + .def( + "exp", + &mlx::core::exp, + py::kw_only(), + "stream"_a = none, + "See :func:`exp`.") + .def( + "log", + &mlx::core::log, + py::kw_only(), + "stream"_a = none, + "See :func:`log`.") + .def( + "log2", + &mlx::core::log2, + py::kw_only(), + "stream"_a = none, + "See :func:`log2`.") + .def( + "log10", + &mlx::core::log10, + py::kw_only(), + "stream"_a = none, + "See :func:`log10`.") + .def( + "sin", + &mlx::core::sin, + py::kw_only(), + "stream"_a = none, + "See :func:`sin`.") + .def( + "cos", + &mlx::core::cos, + py::kw_only(), + "stream"_a = none, + "See :func:`cos`.") + .def( + "log1p", + &mlx::core::log1p, + py::kw_only(), + "stream"_a = none, + "See :func:`log1p`.") + .def( + "all", + [](const array& a, + const IntOrVec& axis, + bool keepdims, + StreamOrDevice s) { + return all(a, get_reduce_axes(axis, a.ndim()), keepdims, s); + }, + "axis"_a = none, + "keepdims"_a = false, + py::kw_only(), + "stream"_a = none, + "See :func:`all`.") + .def( + "any", + [](const array& a, + const IntOrVec& axis, + bool keepdims, + StreamOrDevice s) { + return any(a, get_reduce_axes(axis, a.ndim()), keepdims, s); + }, + "axis"_a = none, + "keepdims"_a = false, + py::kw_only(), + "stream"_a = none, + "See :func:`any`.") + .def( + "transpose", + [](const array& a, py::args axes, StreamOrDevice s) { + if (axes.size() > 0) { + if (axes.size() == 1) { + py::object arg = axes[0]; + if (!py::isinstance(arg)) { + return transpose(a, py::cast>(arg), s); + } + } + return transpose(a, py::cast>(axes), s); + } else { + return transpose(a, s); + } + }, + py::kw_only(), + "stream"_a = none, + R"pbdoc( + Equivalent to :func:`transpose` but the axes can be passed either as + a tuple or as separate arguments. + + See :func:`transpose` for full documentation. + )pbdoc") + .def_property_readonly( + "T", + [](const array& a) { return transpose(a); }, + "Equivalent to calling ``self.transpose()`` with no arguments.") + .def( + "sum", + [](const array& a, + const IntOrVec& axis, + bool keepdims, + StreamOrDevice s) { + return sum(a, get_reduce_axes(axis, a.ndim()), keepdims, s); + }, + "axis"_a = none, + "keepdims"_a = false, + py::kw_only(), + "stream"_a = none, + "See :func:`sum`.") + .def( + "prod", + [](const array& a, + const IntOrVec& axis, + bool keepdims, + StreamOrDevice s) { + return prod(a, get_reduce_axes(axis, a.ndim()), keepdims, s); + }, + "axis"_a = none, + "keepdims"_a = false, + py::kw_only(), + "stream"_a = none, + "See :func:`prod`.") + .def( + "min", + [](const array& a, + const IntOrVec& axis, + bool keepdims, + StreamOrDevice s) { + return min(a, get_reduce_axes(axis, a.ndim()), keepdims, s); + }, + "axis"_a = none, + "keepdims"_a = false, + py::kw_only(), + "stream"_a = none, + "See :func:`min`.") + .def( + "max", + [](const array& a, + const IntOrVec& axis, + bool keepdims, + StreamOrDevice s) { + return max(a, get_reduce_axes(axis, a.ndim()), keepdims, s); + }, + "axis"_a = none, + "keepdims"_a = false, + py::kw_only(), + "stream"_a = none, + "See :func:`max`.") + .def( + "logsumexp", + [](const array& a, + const IntOrVec& axis, + bool keepdims, + StreamOrDevice s) { + return logsumexp(a, get_reduce_axes(axis, a.ndim()), keepdims, s); + }, + "axis"_a = none, + "keepdims"_a = false, + py::kw_only(), + "stream"_a = none, + "See :func:`logsumexp`.") + .def( + "mean", + [](const array& a, + const IntOrVec& axis, + bool keepdims, + StreamOrDevice s) { + return mean(a, get_reduce_axes(axis, a.ndim()), keepdims, s); + }, + "axis"_a = none, + "keepdims"_a = false, + py::kw_only(), + "stream"_a = none, + "See :func:`mean`.") + .def( + "var", + [](const array& a, + const IntOrVec& axis, + bool keepdims, + int ddof, + StreamOrDevice s) { + return var(a, get_reduce_axes(axis, a.ndim()), keepdims, ddof, s); + }, + "axis"_a = none, + "keepdims"_a = false, + "ddof"_a = 0, + py::kw_only(), + "stream"_a = none, + "See :func:`var`.") + .def( + "split", + [](const array& a, + const std::variant>& indices_or_sections, + int axis, + StreamOrDevice s) { + if (auto pv = std::get_if(&indices_or_sections); pv) { + return split(a, *pv, axis, s); + } else { + return split( + a, std::get>(indices_or_sections), axis, s); + } + }, + "indices_or_sections"_a, + "axis"_a = 0, + py::kw_only(), + "stream"_a = none, + "See :func:`split`.") + .def( + "argmin", + [](const array& a, + std::optional axis, + bool keepdims, + StreamOrDevice s) { + if (axis) { + return argmin(a, *axis, keepdims, s); + } else { + return argmin(a, keepdims, s); + } + }, + "axis"_a = std::nullopt, + "keepdims"_a = false, + py::kw_only(), + "stream"_a = none, + "See :func:`argmin`.") + .def( + "argmax", + [](const array& a, + std::optional axis, + bool keepdims, + StreamOrDevice s) { + if (axis) { + return argmax(a, *axis, keepdims, s); + } else { + return argmax(a, keepdims, s); + } + }, + "axis"_a = none, + "keepdims"_a = false, + py::kw_only(), + "stream"_a = none, + "See :func:`argmax`.") + .def( + "cumsum", + [](const array& a, + std::optional axis, + bool reverse, + bool inclusive, + StreamOrDevice s) { + if (axis) { + return cumsum(a, *axis, reverse, inclusive, s); + } else { + // TODO: Implement that in the C++ API as well. See concatenate + // above. + return cumsum(reshape(a, {-1}, s), 0, reverse, inclusive, s); + } + }, + "axis"_a = none, + py::kw_only(), + "reverse"_a = false, + "inclusive"_a = true, + "stream"_a = none, + "See :func:`cumsum`.") + .def( + "cumprod", + [](const array& a, + std::optional axis, + bool reverse, + bool inclusive, + StreamOrDevice s) { + if (axis) { + return cumprod(a, *axis, reverse, inclusive, s); + } else { + // TODO: Implement that in the C++ API as well. See concatenate + // above. + return cumprod(reshape(a, {-1}, s), 0, reverse, inclusive, s); + } + }, + "axis"_a = none, + py::kw_only(), + "reverse"_a = false, + "inclusive"_a = true, + "stream"_a = none, + "See :func:`cumprod`.") + .def( + "cummax", + [](const array& a, + std::optional axis, + bool reverse, + bool inclusive, + StreamOrDevice s) { + if (axis) { + return cummax(a, *axis, reverse, inclusive, s); + } else { + // TODO: Implement that in the C++ API as well. See concatenate + // above. + return cummax(reshape(a, {-1}, s), 0, reverse, inclusive, s); + } + }, + "axis"_a = none, + py::kw_only(), + "reverse"_a = false, + "inclusive"_a = true, + "stream"_a = none, + "See :func:`cummax`.") + .def( + "cummin", + [](const array& a, + std::optional axis, + bool reverse, + bool inclusive, + StreamOrDevice s) { + if (axis) { + return cummin(a, *axis, reverse, inclusive, s); + } else { + // TODO: Implement that in the C++ API as well. See concatenate + // above. + return cummin(reshape(a, {-1}, s), 0, reverse, inclusive, s); + } + }, + "axis"_a = none, + py::kw_only(), + "reverse"_a = false, + "inclusive"_a = true, + "stream"_a = none, + "See :func:`cummin`."); +} diff --git a/python/src/device.cpp b/python/src/device.cpp new file mode 100644 index 000000000..45ae44aec --- /dev/null +++ b/python/src/device.cpp @@ -0,0 +1,42 @@ +#include + +#include + +#include "mlx/device.h" +#include "mlx/utils.h" + +namespace py = pybind11; +using namespace py::literals; +using namespace mlx::core; + +void init_device(py::module_& m) { + py::enum_(m, "DeviceType") + .value("cpu", Device::DeviceType::cpu) + .value("gpu", Device::DeviceType::gpu) + .export_values() + .def( + "__eq__", + [](const Device::DeviceType& d1, const Device& d2) { + return d1 == d2; + }, + py::prepend()); + + py::class_(m, "Device") + .def(py::init(), "type"_a, "index"_a = 0) + .def_readonly("type", &Device::type) + .def( + "__repr__", + [](const Device& d) { + std::ostringstream os; + os << d; + return os.str(); + }) + .def("__eq__", [](const Device& d1, const Device& d2) { + return d1 == d2; + }); + + py::implicitly_convertible(); + + m.def("default_device", &default_device); + m.def("set_default_device", &set_default_device, "device"_a); +} diff --git a/python/src/metal.cpp b/python/src/metal.cpp new file mode 100644 index 000000000..75e0437bf --- /dev/null +++ b/python/src/metal.cpp @@ -0,0 +1,12 @@ +#include + +#include "mlx/backend/metal/metal.h" + +namespace py = pybind11; + +using namespace mlx::core; + +void init_metal(py::module_& m) { + py::module_ metal = m.def_submodule("metal", "mlx.metal"); + metal.def("is_available", &metal::is_available); +} diff --git a/python/src/stream.cpp b/python/src/stream.cpp new file mode 100644 index 000000000..f7a73463f --- /dev/null +++ b/python/src/stream.cpp @@ -0,0 +1,32 @@ +#include + +#include + +#include "mlx/stream.h" +#include "mlx/utils.h" + +namespace py = pybind11; +using namespace py::literals; +using namespace mlx::core; + +void init_stream(py::module_& m) { + py::class_(m, "Stream") + .def(py::init(), "index"_a, "device"_a) + .def_readonly("device", &Stream::device) + .def( + "__repr__", + [](const Stream& s) { + std::ostringstream os; + os << s; + return os.str(); + }) + .def("__eq__", [](const Stream& s1, const Stream& s2) { + return s1 == s2; + }); + + py::implicitly_convertible(); + + m.def("default_stream", &default_stream, "device"_a); + m.def("set_default_stream", &set_default_stream, "stream"_a); + m.def("new_stream", &new_stream, "device"_a); +} diff --git a/python/tests/test_bf16.py b/python/tests/test_bf16.py new file mode 100644 index 000000000..e73937d79 --- /dev/null +++ b/python/tests/test_bf16.py @@ -0,0 +1,188 @@ +import unittest +from itertools import permutations + +import math +import mlx.core as mx +import numpy as np + +import mlx_tests + +try: + import torch + + has_torch = True +except ImportError as e: + has_torch = False + + +class TestBF16(mlx_tests.MLXTestCase): + def __test_ops( + self, + ref_op, # Function that outputs array_like + mlx_op, # Function that outputs array_like + np_args, # Numpy arguments + ref_transform=lambda x: x, + mlx_transform=lambda x: mx.array(x), + atol=1e-5, + ): + ref_args = map(ref_transform, np_args) + mlx_args = map(mlx_transform, np_args) + + r_ref = ref_op(*ref_args) + r_mlx = mlx_op(*mlx_args) + + self.assertTrue(np.allclose(r_mlx, r_ref, atol=atol)) + + def __default_test( + self, + op, + np_args, + simple_transform=lambda x: x, + atol_np=1e-3, + atol_torch=1e-5, + np_kwargs=dict(), + mlx_kwargs=dict(), + torch_kwargs=dict(), + torch_op=None, + ): + with self.subTest(reference="numpy"): + + def np_transform(x): + x_mx_bf16 = mx.array(x).astype(mx.bfloat16) + x_mx_fp32 = x_mx_bf16.astype(mx.float32) + return np.asarray(x_mx_fp32) + + def mlx_fn(*args): + out_bf16 = getattr(mx, op)(*args, **mlx_kwargs) + return np.asarray(out_bf16.astype(mx.float32)) + + def np_fn(*args): + out_fp32 = getattr(np, op)(*args, **np_kwargs) + return np_transform(out_fp32) + + ref_op = np_fn + mlx_op = mlx_fn + + ref_transform = lambda x: simple_transform(np_transform(x)) + mlx_transform = lambda x: simple_transform(mx.array(x).astype(mx.bfloat16)) + + self.__test_ops( + ref_op, + mlx_op, + np_args, + ref_transform=ref_transform, + mlx_transform=mlx_transform, + atol=atol_np, + ) + + if has_torch: + with self.subTest(reference="torch"): + torch_op = op if torch_op is None else torch_op + + def torch_fn(*args): + out_bf16 = getattr(torch, torch_op)(*args, **torch_kwargs) + return out_bf16.to(torch.float32).numpy() + + ref_op = torch_fn + ref_transform = lambda x: simple_transform( + torch.from_numpy(x).to(torch.bfloat16) + ) + self.__test_ops( + ref_op, + mlx_op, + np_args, + ref_transform=ref_transform, + mlx_transform=mlx_transform, + atol=atol_torch, + ) + + def test_unary_ops(self): + x = np.random.rand(18, 28, 38) + for op in ["abs", "exp", "log", "square", "sqrt"]: + with self.subTest(op=op): + np_args = (x.astype(np.float32),) + self.__default_test(op, np_args) + + def test_binary_ops(self): + x = np.random.rand(18, 28, 38) + y = np.random.rand(18, 28, 38) + for op in ["add", "subtract", "multiply", "divide", "maximum", "minimum"]: + with self.subTest(op=op): + np_args = ( + x.astype(np.float32), + y.astype(np.float32), + ) + self.__default_test(op, np_args, simple_transform=lambda x: x) + self.__default_test(op, np_args, simple_transform=lambda x: x[:1]) + self.__default_test(op, np_args, simple_transform=lambda x: x[:, :1]) + + def test_reduction_ops(self): + x = np.random.rand(18, 28, 38).astype(np.float32) + + for op in ("min", "max"): + with self.subTest(op=op): + + for axes in (0, 1, 2, (0, 1), (0, 2), (1, 2), (0, 1, 2)): + with self.subTest(axes=axes): + np_args = (x.astype(np.float32),) + self.__default_test( + op, + np_args, + np_kwargs={"axis": axes}, + mlx_kwargs={"axis": axes}, + torch_kwargs={"dim": axes}, + torch_op="a" + op, + ) + + def test_arg_reduction_ops(self): + data = np.random.rand(10, 12, 13).astype(np.float32) + x = mx.array(data).astype(mx.bfloat16) + data = np.asarray(x.astype(mx.float32)) + + for op in ["argmin", "argmax"]: + for axis in range(3): + for kd in [True, False]: + a = getattr(mx, op)(x, axis, kd) + b = getattr(np, op)(data, axis, keepdims=kd) + a = a.astype(mx.float32) + self.assertEqual(a.tolist(), b.tolist()) + + for op in ["argmin", "argmax"]: + a = getattr(mx, op)(x, keepdims=True) + b = getattr(np, op)(data, keepdims=True) + a = a.astype(mx.float32) + self.assertEqual(a.tolist(), b.tolist()) + a = getattr(mx, op)(x) + b = getattr(np, op)(data) + a = a.astype(mx.float32) + self.assertEqual(a.item(), b) + + def test_blas_ops(self): + if mx.default_device() != mx.gpu: + return + + def test_blas(shape_x, shape_y): + np.random.seed(42) + with self.subTest(shape_x=shape_x, shape_y=shape_y): + x = np.random.normal(0.0, 1.0 / shape_x[-1], size=shape_x) + y = np.random.normal(0.0, 1.0 / shape_x[-1], size=shape_y) + + np_args = ( + x.astype(np.float32), + y.astype(np.float32), + ) + op = "matmul" + + self.__default_test(op, np_args, atol_np=1e-3, atol_torch=1e-3) + + for shape_x, shape_y in [ + [(32, 32), (32, 32)], + [(23, 57), (57, 1)], + [(1, 3), (3, 128)], + [(8, 128, 768), (768, 16)], + ]: + test_blas(shape_x, shape_y) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/creations_tests.cpp b/tests/creations_tests.cpp new file mode 100644 index 000000000..8d03ff1f9 --- /dev/null +++ b/tests/creations_tests.cpp @@ -0,0 +1,224 @@ +#include "doctest/doctest.h" + +#include "mlx/mlx.h" + +using namespace mlx::core; + +TEST_CASE("test arange") { + // Check type is inferred correclty + { + auto x = arange(10); + CHECK_EQ(x.dtype(), int32); + + x = arange(10.0); + CHECK_EQ(x.dtype(), float32); + + x = arange(10, float32); + CHECK_EQ(x.dtype(), float32); + + x = arange(10.0, int32); + CHECK_EQ(x.dtype(), int32); + + x = arange(0, 10); + CHECK_EQ(x.dtype(), int32); + + x = arange(0.0, 10.0, int32); + CHECK_EQ(x.dtype(), int32); + + x = arange(0.0, 10.0); + CHECK_EQ(x.dtype(), float32); + + x = arange(0, 10, float32); + CHECK_EQ(x.dtype(), float32); + + x = arange(0, 10, 0.1, float32); + CHECK_EQ(x.dtype(), float32); + + x = arange(0.0, 10.0, 0.5, int32); + CHECK_EQ(x.dtype(), int32); + + x = arange(10.0, uint32); + CHECK_EQ(x.dtype(), uint32); + x = arange(0.0, 10.0, uint32); + CHECK_EQ(x.dtype(), uint32); + x = arange(0.0, 10.0, 0.5, uint32); + CHECK_EQ(x.dtype(), uint32); + + // arange unsupported for bool_ + CHECK_THROWS_AS(arange(10, bool_), std::invalid_argument); + } + + // Check correct sizes + { + auto x = arange(10); + CHECK_EQ(x.size(), 10); + + x = arange(0.0, 10.0, 0.5); + CHECK_EQ(x.size(), 20); + + x = arange(0.0, 10.0, 0.45); + CHECK_EQ(x.size(), 23); + + x = arange(0, 10, 10); + CHECK_EQ(x.size(), 1); + + x = arange(0, 10, 9); + CHECK_EQ(x.size(), 2); + + x = arange(0, 10, 100); + CHECK_EQ(x.size(), 1); + + x = arange(0, -10, 1); + CHECK_EQ(x.size(), 0); + + x = arange(0, -10, -1); + CHECK_EQ(x.size(), 10); + + x = arange(0, -10, -10); + CHECK_EQ(x.size(), 1); + } + + // Check values + { + auto x = arange(0, 3); + CHECK(array_equal(x, array({0, 1, 2})).item()); + + x = arange(0, 3, 2); + CHECK(array_equal(x, array({0, 2})).item()); + + x = arange(0, 3, 3); + CHECK(array_equal(x, array({0})).item()); + + x = arange(0, -3, 1); + CHECK(array_equal(x, array({})).item()); + + x = arange(0, 3, -1); + CHECK(array_equal(x, array({})).item()); + + x = arange(0, -3, -1); + CHECK(array_equal(x, array({0, -1, -2})).item()); + + x = arange(0.0, 5.0, 0.5, int32); + CHECK(array_equal(x, zeros({10})).item()); + + x = arange(0.0, 5.0, 1.5, int32); + CHECK(array_equal(x, array({0, 1, 2, 3})).item()); + } +} + +TEST_CASE("test astype") { + // Check type conversions + { + auto x = array(1); + auto y = astype(x, float32); + CHECK_EQ(y.dtype(), float32); + CHECK_EQ(y.item(), 1.0f); + + y = astype(x, int32); + CHECK_EQ(y.dtype(), int32); + CHECK_EQ(y.item(), 1); + + x = array(-3.0f); + y = astype(x, int32); + CHECK_EQ(y.dtype(), int32); + CHECK_EQ(y.item(), -3); + + y = astype(x, uint32); + CHECK_EQ(y.dtype(), uint32); + + // Use std::copy since the result is platform dependent + uint32_t v; + std::copy(x.data(), x.data() + 1, &v); + CHECK_EQ(y.item(), v); + } +} + +TEST_CASE("test full") { + // Check full works for different types + { + auto x = full({}, 0); + CHECK_EQ(x.dtype(), int32); + CHECK_EQ(x.item(), 0); + + x = full({}, 0.0); + CHECK_EQ(x.dtype(), float32); + CHECK_EQ(x.item(), 0); + + x = full({}, false); + CHECK_EQ(x.item(), false); + + x = full({}, 0, int32); + CHECK_EQ(x.item(), 0); + + x = full({}, 0, float32); + CHECK_EQ(x.item(), 0); + + x = full({1, 2}, 2, float32); + CHECK(array_equal(x, array({2.0, 2.0}, {1, 2})).item()); + + x = full({2, 1}, 2, float32); + CHECK(array_equal(x, array({2.0, 2.0}, {2, 1})).item()); + + x = full({2}, false); + CHECK_EQ(x.dtype(), bool_); + CHECK(array_equal(x, array({false, false})).item()); + + x = full({2}, 1.0, bool_); + CHECK_EQ(x.dtype(), bool_); + CHECK(array_equal(x, array({true, true})).item()); + + x = full({2}, 1.0, uint32); + CHECK_EQ(x.dtype(), uint32); + CHECK(array_equal(x, array({1, 1})).item()); + + CHECK_THROWS_AS(full({2}, array({})), std::invalid_argument); + } + + // Check broadcasting works + { + auto x = full({2, 2}, array({3, 4}, {2, 1})); + CHECK(array_equal(x, array({3, 3, 4, 4}, {2, 2})).item()); + x = full({2, 2}, array({3, 4}, {1, 2})); + CHECK(array_equal(x, array({3, 4, 3, 4}, {2, 2})).item()); + } + + // Check zeros and ones + { + auto x = zeros({2, 2}, float32); + CHECK_EQ(x.shape(), std::vector{2, 2}); + CHECK_EQ(x.ndim(), 2); + CHECK_EQ(x.dtype(), float32); + auto y = array({0.0, 0.0, 0.0, 0.0}, {2, 2}); + CHECK(array_equal(x, y).item()); + + x = ones({2, 2}, float32); + CHECK_EQ(x.shape(), std::vector{2, 2}); + CHECK_EQ(x.ndim(), 2); + CHECK_EQ(x.dtype(), float32); + y = array({1.0, 1.0, 1.0, 1.0}, {2, 2}); + CHECK(array_equal(x, y).item()); + + x = zeros({2, 2}, int32); + y = zeros_like(x); + CHECK_EQ(y.dtype(), int32); + CHECK(array_equal(x, y).item()); + + x = ones({2, 2}, int32); + y = ones_like(x); + CHECK_EQ(y.dtype(), int32); + CHECK(array_equal(x, y).item()); + } + + // Works for empty shape and empty array + { + array x = ones({}, int32); + CHECK_EQ(x.shape(), std::vector{}); + CHECK_EQ(x.item(), 1); + + x = full({0}, array({})); + CHECK_EQ(x.shape(), std::vector{0}); + CHECK_EQ(x.size(), 0); + + CHECK_THROWS_AS(full({}, array({})), std::invalid_argument); + } +} diff --git a/tests/fft_tests.cpp b/tests/fft_tests.cpp new file mode 100644 index 000000000..c6940751b --- /dev/null +++ b/tests/fft_tests.cpp @@ -0,0 +1,331 @@ +#include "doctest/doctest.h" + +#include "mlx/mlx.h" + +using namespace mlx::core; + +TEST_CASE("test fft basics") { + auto device = default_device(); + set_default_device(Device::cpu); + array x(1.0); + CHECK_THROWS(fft::fft(x)); + CHECK_THROWS(fft::ifft(x)); + + x = array({1.0}); + auto y = fft::fft(x); + CHECK_EQ(y.dtype(), complex64); + CHECK_EQ(y.size(), x.size()); + CHECK_EQ(y.item(), complex64_t{1.0f, 0.0f}); + + y = fft::ifft(x); + CHECK_EQ(y.dtype(), complex64); + CHECK_EQ(y.size(), x.size()); + CHECK_EQ(y.item(), complex64_t{1.0f, 0.0f}); + + x = array({complex64_t{1.0f, 1.0f}}, complex64); + y = fft::fft(x); + CHECK_EQ(y.size(), x.size()); + CHECK_EQ(y.item(), complex64_t{1.0f, 1.0f}); + + y = fft::ifft(x); + CHECK_EQ(y.dtype(), complex64); + CHECK_EQ(y.size(), x.size()); + CHECK_EQ(y.item(), complex64_t{1.0f, 1.0f}); + + { + x = array({0.0f, 1.0f, 2.0f, 3.0f}); + y = fft::fft(x); + std::initializer_list expected = { + {6.0, 0.0}, + {-2.0, 2.0}, + {-2.0, 0.0}, + {-2.0, -2.0}, + }; + CHECK_EQ(y.size(), x.size()); + CHECK(array_equal(y, array(expected)).item()); + + y = fft::ifft(x); + std::initializer_list expected_inv = { + {1.5, 0.0}, + {-0.5, -0.5}, + {-0.5, 0.0}, + {-0.5, 0.5}, + }; + CHECK(array_equal(y, array(expected_inv)).item()); + } + + { + std::initializer_list vals = { + {1.0f, 1.0f}, {2.0f, 1.0f}, {1.0f, 2.0f}, {2.0f, 2.0f}}; + x = array(vals); + y = fft::fft(x); + std::initializer_list expected = { + {6.0, 6.0}, + {-1.0, -1.0}, + {-2.0, 0.0}, + {1.0, -1.0}, + }; + CHECK_EQ(y.size(), x.size()); + CHECK(array_equal(y, array(expected)).item()); + CHECK(array_equal(fft::ifft(y), x).item()); + } + + // Specify axes + { + x = array({0.0f, 1.0f, 2.0f, 3.0f}, {2, 2}); + std::initializer_list expected_0 = { + {2.0, 0.0}, + {4.0, 0.0}, + {-2.0, 0.0}, + {-2.0, 0.0}, + }; + y = fft::fft(x, 0); + CHECK(array_equal(y, array(expected_0, {2, 2})).item()); + CHECK(array_equal(fft::ifft(y, 0), x).item()); + std::initializer_list expected_1 = { + {1.0, 0.0}, + {-1.0, 0.0}, + {5.0, 0.0}, + {-1.0, 0.0}, + }; + y = fft::fft(x, 1); + CHECK(array_equal(y, array(expected_1, {2, 2})).item()); + CHECK(array_equal(fft::ifft(y, 1), x).item()); + } + set_default_device(device); +} + +TEST_CASE("test real ffts") { + auto device = default_device(); + set_default_device(Device::cpu); + + auto x = array({1.0}); + auto y = fft::rfft(x); + CHECK_EQ(y.dtype(), complex64); + CHECK_EQ(y.size(), x.size()); + CHECK_EQ(y.item(), complex64_t{1.0f, 0.0f}); + + { + x = array({0.0f, 1.0f, 2.0f, 3.0f}); + y = fft::rfft(x); + std::initializer_list expected = { + {6.0, 0.0}, {-2.0, 2.0}, {-2.0, -0.0}}; + CHECK_EQ(y.size(), x.size() / 2 + 1); + CHECK(array_equal(y, array(expected)).item()); + } + + x = array(complex64_t{1, 1}); + CHECK_THROWS(fft::irfft(x)); + + x = array({complex64_t{0, 1}, complex64_t{1, 0}}); + y = fft::irfft(x); + CHECK_EQ(y.size(), 2); + CHECK_EQ(y.dtype(), float32); + CHECK(array_equal(y, array({0.5f, -0.5f})).item()); + + set_default_device(device); +} + +TEST_CASE("test fftn") { + auto device = default_device(); + set_default_device(Device::cpu); + + auto x = zeros({5, 5, 5}); + CHECK_THROWS_AS(fft::fftn(x, {}, {0, 3}), std::invalid_argument); + CHECK_THROWS_AS(fft::fftn(x, {}, {0, -4}), std::invalid_argument); + CHECK_THROWS_AS(fft::fftn(x, {}, {0, 0}), std::invalid_argument); + CHECK_THROWS_AS(fft::fftn(x, {5, 5, 5}, {0}), std::invalid_argument); + CHECK_THROWS_AS(fft::fftn(x, {0}, {}, {}), std::invalid_argument); + CHECK_THROWS_AS(fft::fftn(x, {1, -1}, {}, {}), std::invalid_argument); + + // Test 2D FFT + { + x = array({0.0f, 1.0f, 2.0f, 3.0f}, {2, 2}); + std::initializer_list expected = { + {6.0, 0.0}, + {-2.0, 0.0}, + {-4.0, 0.0}, + {0.0, 0.0}, + }; + auto y = fft::fft2(x); + CHECK(array_equal(y, array(expected, {2, 2})).item()); + CHECK(array_equal(fft::ifft2(y), x).item()); + } + + // Test 3D FFT + { + x = reshape(arange(8, float32), {2, 2, 2}); + std::initializer_list expected = { + {28.0, 0.0}, + {-4.0, 0.0}, + {-8.0, 0.0}, + {0.0, 0.0}, + {-16.0, 0.0}, + {0.0, 0.0}, + {0.0, 0.0}, + {0.0, 0.0}, + }; + auto y = fft::fftn(x); + CHECK(array_equal(y, array(expected, {2, 2, 2})).item()); + CHECK(array_equal(fft::ifftn(y), x).item()); + + x = reshape(arange(20, float32), {5, 4}); + y = fft::rfftn(x); + CHECK_EQ(y.shape(), std::vector{5, 3}); + y = fft::rfftn(x, {1, 0}); + CHECK_EQ(y.shape(), std::vector{3, 4}); + + x = reshape(arange(20, float32), {5, 4}); + y = fft::irfftn(x); + CHECK_EQ(y.shape(), std::vector{5, 6}); + y = fft::irfftn(x, {1, 0}); + CHECK_EQ(y.shape(), std::vector{8, 4}); + } + + // Check the types of real ffts + { + x = zeros({5, 5}, float32); + auto y = fft::rfft2(x); + CHECK_EQ(y.shape(), std::vector{5, 3}); + CHECK_EQ(y.dtype(), complex64); + + y = fft::rfftn(x); + CHECK_EQ(y.shape(), std::vector{5, 3}); + CHECK_EQ(y.dtype(), complex64); + + x = zeros({5, 5}, complex64); + y = fft::irfft2(x); + CHECK_EQ(y.shape(), std::vector{5, 8}); + CHECK_EQ(y.dtype(), float32); + + y = fft::irfftn(x); + CHECK_EQ(y.shape(), std::vector{5, 8}); + CHECK_EQ(y.dtype(), float32); + } + + set_default_device(device); +} + +TEST_CASE("test fft with provided shape") { + auto x = ones({5, 5}); + + auto y = fft::fft(x, 7, 0); + CHECK_EQ(y.shape(), std::vector{7, 5}); + + y = fft::fft(x, 3, 0); + CHECK_EQ(y.shape(), std::vector{3, 5}); + + y = fft::fft(x, 7, 1); + CHECK_EQ(y.shape(), std::vector{5, 7}); + + y = fft::fft(x, 3, 1); + CHECK_EQ(y.shape(), std::vector{5, 3}); + + y = fft::rfft(x, 7, 0); + CHECK_EQ(y.shape(), std::vector{4, 5}); + + y = fft::rfft(x, 3, 0); + CHECK_EQ(y.shape(), std::vector{2, 5}); + + y = fft::rfft(x, 3, 1); + CHECK_EQ(y.shape(), std::vector{5, 2}); +} + +TEST_CASE("test fft vmap") { + auto device = default_device(); + set_default_device(Device::cpu); + + auto fft_fn = [](array x) { return fft::fft(x); }; + auto x = reshape(arange(8), {2, 4}); + auto y = vmap(fft_fn)(x); + CHECK(array_equal(y, fft::fft(x)).item()); + + y = vmap(fft_fn, 1, 1)(x); + CHECK(array_equal(y, fft::fft(x, 0)).item()); + + auto rfft_fn = [](array x) { return fft::rfft(x); }; + + y = vmap(rfft_fn)(x); + CHECK(array_equal(y, fft::rfft(x)).item()); + + y = vmap(rfft_fn, 1, 1)(x); + CHECK(array_equal(y, fft::rfft(x, 0)).item()); + + set_default_device(device); +} + +TEST_CASE("test fft grads") { + auto device = default_device(); + set_default_device(Device::cpu); + + // Regular + auto fft_fn = [](array x) { return fft::fft(x); }; + auto cotangent = astype(arange(10), complex64); + auto vjp_out = vjp(fft_fn, zeros_like(cotangent), cotangent).second; + CHECK(array_equal(fft::fft(cotangent), vjp_out).item()); + + auto tangent = astype(arange(10), complex64); + auto jvp_out = jvp(fft_fn, zeros_like(tangent), tangent).second; + CHECK(array_equal(fft::fft(tangent), jvp_out).item()); + + // Inverse + auto ifft_fn = [](array x) { return fft::ifft(x); }; + vjp_out = vjp(ifft_fn, zeros_like(cotangent), cotangent).second; + CHECK(array_equal(fft::ifft(cotangent), vjp_out).item()); + + jvp_out = jvp(ifft_fn, zeros_like(tangent), tangent).second; + CHECK(array_equal(fft::ifft(tangent), jvp_out).item()); + + // Real + auto rfft_fn = [](array x) { return fft::rfft(x); }; + cotangent = astype(arange(6), complex64); + vjp_out = vjp(rfft_fn, zeros({10}), cotangent).second; + auto expected = astype(fft::fft(cotangent, 10, 0), float32); + CHECK(array_equal(expected, vjp_out).item()); + + tangent = astype(arange(10), float32); + jvp_out = jvp(rfft_fn, zeros_like(tangent), tangent).second; + CHECK(array_equal(fft::rfft(tangent), jvp_out).item()); + + // Inverse real + auto irfft_fn = [](array x) { return fft::irfft(x); }; + cotangent = astype(arange(10), float32); + vjp_out = vjp(irfft_fn, astype(zeros({6}), complex64), cotangent).second; + expected = fft::fft(cotangent, 10, 0); + auto o_splits = split(vjp_out, {1, 5}); + auto e_splits = split(expected, {1, 5, 6}); + CHECK_EQ(e_splits[0].item(), o_splits[0].item()); + CHECK(array_equal(2 * e_splits[1], o_splits[1]).item()); + CHECK_EQ(e_splits[2].item(), o_splits[2].item()); + + tangent = astype(arange(10), complex64); + jvp_out = jvp(irfft_fn, zeros_like(tangent), tangent).second; + CHECK(array_equal(fft::irfft(tangent), jvp_out).item()); + + // Check ND vjps run properly + vjp_out = vjp([](array x) { return fft::fftn(x); }, + astype(zeros({5, 5}), complex64), + astype(zeros({5, 5}), complex64)) + .second; + CHECK_EQ(vjp_out.shape(), std::vector{5, 5}); + + vjp_out = vjp([](array x) { return fft::ifftn(x); }, + astype(zeros({5, 5}), complex64), + astype(zeros({5, 5}), complex64)) + .second; + CHECK_EQ(vjp_out.shape(), std::vector{5, 5}); + + vjp_out = vjp([](array x) { return fft::rfftn(x); }, + zeros({5, 9}), + astype(zeros({5, 5}), complex64)) + .second; + CHECK_EQ(vjp_out.shape(), std::vector{5, 9}); + + vjp_out = vjp([](array x) { return fft::irfftn(x); }, + astype(zeros({5, 5}), complex64), + zeros({5, 8})) + .second; + CHECK_EQ(vjp_out.shape(), std::vector{5, 5}); + + set_default_device(device); +} diff --git a/tests/graph_optimize_tests.cpp b/tests/graph_optimize_tests.cpp new file mode 100644 index 000000000..669bb08ec --- /dev/null +++ b/tests/graph_optimize_tests.cpp @@ -0,0 +1,30 @@ +#include "doctest/doctest.h" + +#include "mlx/mlx.h" + +using namespace mlx::core; + +TEST_CASE("test simplify scalars") { + auto a = array({-1.0f, 2.0f}); + auto b = maximum(a, array(0.0f)); + auto c = maximum(-a, array(0.0f)); + auto d = b + c; + simplify({d}); + CHECK(b.inputs()[1].id() == c.inputs()[1].id()); +} + +TEST_CASE("test simplify") { + auto a = array({1.0f, 2.0f}); + auto b = exp(a) + exp(a); + simplify(b); + eval(b); + CHECK(b.inputs()[0].id() == b.inputs()[1].id()); +} + +TEST_CASE("test no simplify") { + auto a = array({1.0f, 2.0f}); + auto b = cos(a) + sin(a); + simplify(b); + eval(b); + CHECK(b.inputs()[0].id() != b.inputs()[1].id()); +}