From ea406d5e3346fa2d08410fc657016d9ee9d6f5fc Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Wed, 7 Feb 2024 06:04:34 -0800 Subject: [PATCH 01/42] CI change (#645) * CI update * Skip large binary test for now * Upgrade pip * Add proper env variable skipping * Update the CI * Fix workflow name * Set the low memory flag for the tests * Change build process * Add pip upgrade * Use a venv * Add a missing env activate * Add setuptools * Add twine upload back * Re-enable automatic release builds --- .circleci/config.yml | 206 +++++++++++---------------------------- MANIFEST.in | 2 +- python/tests/test_ops.py | 5 + 3 files changed, 62 insertions(+), 151 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 25fb71fb5..26aa21ce8 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -57,17 +57,18 @@ jobs: command: ./build/tests/tests mac_build_and_test: - machine: true - resource_class: ml-explore/m-builder + macos: + xcode: "15.2.0" + resource_class: macos.m1.large.gen1 steps: - checkout - run: name: Install dependencies command: | - eval "$(conda shell.bash hook)" - rm -r $CONDA_PREFIX/envs/runner-env - conda create -y -n runner-env python=3.9 - conda activate runner-env + brew install python@3.9 + python3.9 -m venv env + source env/bin/activate + pip install --upgrade pip pip install --upgrade cmake pip install --upgrade pybind11[global] pip install pybind11-stubgen @@ -78,191 +79,94 @@ jobs: - run: name: Install Python package command: | - eval "$(conda shell.bash hook)" - conda activate runner-env - CMAKE_BUILD_PARALLEL_LEVEL="" python setup.py build_ext --inplace - CMAKE_BUILD_PARALLEL_LEVEL="" python setup.py develop + source env/bin/activate + CMAKE_BUILD_PARALLEL_LEVEL="" pip install -e . -v - run: name: Generate package stubs command: | - eval "$(conda shell.bash hook)" - conda activate runner-env + source env/bin/activate python setup.py generate_stubs - run: name: Run Python tests command: | - eval "$(conda shell.bash hook)" - conda activate runner-env - DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu - DEVICE=gpu python -m xmlrunner discover -v python/tests -o test-results/gpu + source env/bin/activate + LOW_MEMORY=1 DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu + # TODO: Reenable when Circle CI can run gpu jobs + # DEVICE=gpu python3.9 -m xmlrunner discover -v python/tests -o test-results/gpu # TODO: Reenable when extension api becomes stable # - run: # name: Build example extension # command: | - # eval "$(conda shell.bash hook)" - # conda activate runner-env - # cd examples/extensions && python -m pip install . + # cd examples/extensions && python3.11 -m pip install . - store_test_results: path: test-results - run: name: Build CPP only command: | + source env/bin/activate mkdir -p build && cd build && cmake .. && make -j - run: name: Run CPP tests - command: METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 ./build/tests/tests + #command: METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 ./build/tests/tests + command: DEVICE=cpu ./build/tests/tests build_release: - machine: true - resource_class: ml-explore/m-builder parameters: python_version: type: string default: "3.9" - macos_version: + xcode_version: type: string - default: "14" + default: "15.2.0" + build_env: + type: string + default: "" + macos: + xcode: << parameters.xcode_version >> + resource_class: macos.m1.large.gen1 steps: - checkout - run: name: Install dependencies command: | - eval "$(conda shell.bash hook)" - rm -r $CONDA_PREFIX/envs/runner-env - conda create -y -n runner-env python=<< parameters.python_version >> - conda activate runner-env + brew install python@<< parameters.python_version >> + python<< parameters.python_version >> -m venv env + source env/bin/activate + pip install --upgrade pip pip install --upgrade cmake pip install --upgrade pybind11[global] + pip install --upgrade setuptools pip install pybind11-stubgen pip install numpy pip install twine - # TODO: Update build system to switch away from setup.py develop + pip install build - run: name: Install Python package command: | - eval "$(conda shell.bash hook)" - conda activate runner-env - DEVELOPER_DIR=$(developer_dir_macos_<< parameters.macos_version >>) \ - PYPI_RELEASE=1 \ + source env/bin/activate + DEV_RELEASE=1 \ CMAKE_BUILD_PARALLEL_LEVEL="" \ - python setup.py develop + pip install . -v - run: name: Generate package stubs command: | - eval "$(conda shell.bash hook)" - conda activate runner-env + source env/bin/activate python setup.py generate_stubs - run: - name: Publish Python package + name: Build Python package command: | - eval "$(conda shell.bash hook)" - conda activate runner-env - DEVELOPER_DIR=$(developer_dir_macos_<< parameters.macos_version >>) \ - PYPI_RELEASE=1 \ + source env/bin/activate + << parameters.build_env >> \ CMAKE_BUILD_PARALLEL_LEVEL="" \ - python setup.py bdist_wheel - twine upload dist/* --repository mlx - - store_artifacts: - path: dist/ - - build_dev_release: - machine: true - resource_class: ml-explore/m-builder - parameters: - python_version: - type: string - default: "3.9" - macos_version: - type: string - default: "14" - steps: - - checkout - - run: - name: Install dependencies - command: | - eval "$(conda shell.bash hook)" - rm -r $CONDA_PREFIX/envs/runner-env - conda create -y -n runner-env python=<< parameters.python_version >> - conda activate runner-env - pip install --upgrade cmake - pip install --upgrade pybind11[global] - pip install pybind11-stubgen - pip install numpy - pip install twine - - run: - name: Install Python package - command: | - eval "$(conda shell.bash hook)" - conda activate runner-env - DEVELOPER_DIR=$(developer_dir_macos_<< parameters.macos_version >>) \ - DEV_RELEASE=1 \ - CMAKE_BUILD_PARALLEL_LEVEL="" \ - python setup.py develop - - run: - name: Generate package stubs - command: | - eval "$(conda shell.bash hook)" - conda activate runner-env - python setup.py generate_stubs - - run: - name: Publish Python package - command: | - eval "$(conda shell.bash hook)" - conda activate runner-env - DEVELOPER_DIR=$(developer_dir_macos_<< parameters.macos_version >>) \ - DEV_RELEASE=1 \ - CMAKE_BUILD_PARALLEL_LEVEL="" \ - python setup.py bdist_wheel - twine upload dist/* --repository mlx - - store_artifacts: - path: dist/ - - build_package: - machine: true - resource_class: ml-explore/m-builder - parameters: - python_version: - type: string - default: "3.9" - macos_version: - type: string - default: "14" - steps: - - checkout - - run: - name: Install dependencies - command: | - eval "$(conda shell.bash hook)" - rm -r $CONDA_PREFIX/envs/runner-env - conda create -y -n runner-env python=<< parameters.python_version >> - conda activate runner-env - pip install --upgrade cmake - pip install --upgrade pybind11[global] - pip install pybind11-stubgen - pip install numpy - pip install twine - - run: - name: Install Python package - command: | - eval "$(conda shell.bash hook)" - conda activate runner-env - DEVELOPER_DIR=$(developer_dir_macos_<< parameters.macos_version >>) \ - CMAKE_BUILD_PARALLEL_LEVEL="" \ - python setup.py develop - - run: - name: Generate package stubs - command: | - eval "$(conda shell.bash hook)" - conda activate runner-env - python setup.py generate_stubs - - run: - name: Build package distribution - command: | - eval "$(conda shell.bash hook)" - conda activate runner-env - DEVELOPER_DIR=$(developer_dir_macos_<< parameters.macos_version >>) \ - CMAKE_BUILD_PARALLEL_LEVEL="" \ - python setup.py bdist_wheel + python -m build -w + - when: + condition: << parameters.build_env >> + steps: + - run: + name: Upload package + command: | + source env/bin/activate + twine upload dist/* - store_artifacts: path: dist/ @@ -273,8 +177,8 @@ workflows: - not: << pipeline.parameters.nightly_build >> - not: << pipeline.parameters.weekly_build >> jobs: - - linux_build_and_test - mac_build_and_test + - linux_build_and_test - build_release: filters: tags: @@ -284,20 +188,22 @@ workflows: matrix: parameters: python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"] - macos_version: ["13", "14"] + xcode_version: ["14.3.1", "15.2.0"] + build_env: ["PYPI_RELEASE=1"] nightly_build: when: << pipeline.parameters.nightly_build >> jobs: - - build_package: + - build_release: matrix: parameters: python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"] - macos_version: ["13", "14"] + xcode_version: ["14.3.1", "15.2.0"] weekly_build: when: << pipeline.parameters.weekly_build >> jobs: - - build_dev_release: + - build_release: matrix: parameters: python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"] - macos_version: ["13", "14"] + xcode_version: ["14.3.1", "15.2.0"] + build_env: ["DEV_RELEASE=1"] diff --git a/MANIFEST.in b/MANIFEST.in index d81234106..9faafee45 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,4 +1,4 @@ include CMakeLists.txt recursive-include mlx/ * include python/src/* -python/mlx/py.typed # support type hinting as in PEP-561 +include python/mlx/py.typed # support type hinting as in PEP-561 diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 6ac46779d..3d84f4b02 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -1,6 +1,7 @@ # Copyright © 2023-2024 Apple Inc. import math +import os import unittest from itertools import permutations @@ -1566,6 +1567,10 @@ class TestOps(mlx_tests.MLXTestCase): d_np = np.take(b_mx, np.arange(kth), axis=axis) self.assertTrue(np.all(d_np <= c_mx)) + @unittest.skipIf( + os.getenv("LOW_MEMORY", None) is not None, + "This test requires a lot of memory", + ) def test_large_binary(self): a = mx.ones([1000, 2147484], mx.int8) b = mx.ones([2147484], mx.int8) From ef73393a19bfc0f005f585656cbd3fa0ab865ec4 Mon Sep 17 00:00:00 2001 From: Aryan Gupta <97878444+guptaaryan16@users.noreply.github.com> Date: Wed, 7 Feb 2024 23:09:52 +0530 Subject: [PATCH 02/42] Feat: Add weights argument in BCE Loss and tests (#620) --- python/mlx/nn/losses.py | 11 +++++++++++ python/tests/test_losses.py | 8 ++++++++ 2 files changed, 19 insertions(+) diff --git a/python/mlx/nn/losses.py b/python/mlx/nn/losses.py index a466c10ed..ee33fde3e 100644 --- a/python/mlx/nn/losses.py +++ b/python/mlx/nn/losses.py @@ -117,6 +117,7 @@ def cross_entropy( def binary_cross_entropy( inputs: mx.array, targets: mx.array, + weights: mx.array = None, with_logits: bool = True, reduction: Reduction = "mean", ) -> mx.array: @@ -128,6 +129,7 @@ def binary_cross_entropy( ``inputs`` are unnormalized logits. Otherwise, ``inputs`` are probabilities. targets (array): The binary target values in {0, 1}. with_logits (bool, optional): Whether ``inputs`` are logits. Default: ``True``. + weights (array, optional): Optional weights for each target. Default: ``None``. reduction (str, optional): Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'mean'``. @@ -159,6 +161,15 @@ def binary_cross_entropy( else: loss = -(targets * mx.log(inputs) + (1 - targets) * mx.log(1 - inputs)) + # Apply weights if provided + if weights is not None: + if weights.shape != loss.shape: + raise ValueError( + f"Weights with shape {weights.shape} is not the same as " + f"output loss with shape {loss.shape}." + ) + loss *= weights + return _reduce(loss, reduction) diff --git a/python/tests/test_losses.py b/python/tests/test_losses.py index 2160b0a6e..3a430be21 100644 --- a/python/tests/test_losses.py +++ b/python/tests/test_losses.py @@ -92,6 +92,14 @@ class TestLosses(mlx_tests.MLXTestCase): expected_sum = mx.sum(expected_none) self.assertEqual(losses_sum, expected_sum) + # With weights, no label smoothing + weights = mx.array([1.0, 2.0, 1.0, 2.0]) + expected = mx.array([0.747215, 1.62186, 0.262365, 0.672944]) + loss = nn.losses.binary_cross_entropy( + logits, targets, weights=weights, reduction="none" + ) + self.assertTrue(mx.allclose(loss, expected)) + def _test_probs_as_inputs(): probs = mx.array([0.5, 0.6, 0.7, 0.8]) targets = mx.array([0, 0, 1, 1]) From 5fd11c347d4e30631592895ebb32e30ecaecb1e6 Mon Sep 17 00:00:00 2001 From: Noah Farr <69793313+noahfarr@users.noreply.github.com> Date: Wed, 7 Feb 2024 20:49:59 +0100 Subject: [PATCH 03/42] Add loc and scale to random.normal (#638) * Add loc and scale to random.normal * Add tests for loc and scale for random.normal * Run pre-commit hooks * Fix code review --- mlx/random.cpp | 13 +++++++++++-- mlx/random.h | 19 ++++++++++++++++++- python/src/random.cpp | 8 +++++++- python/tests/test_random.py | 14 ++++++++++++++ 4 files changed, 50 insertions(+), 4 deletions(-) diff --git a/mlx/random.cpp b/mlx/random.cpp index 63e39cdcc..5e0682d32 100644 --- a/mlx/random.cpp +++ b/mlx/random.cpp @@ -153,14 +153,23 @@ array uniform( array normal( const std::vector& shape, Dtype dtype, + const float loc /* = 0.0 */, + const float scale /* = 1.0 */, const std::optional& key /*= nullopt */, StreamOrDevice s /* = {} */) { auto stream = to_stream(s); auto low = array(std::nextafter(-1.0f, 0.0f), dtype); auto high = array(1.0f, dtype); auto samples = uniform(low, high, shape, dtype, key, stream); - return multiply( - array(std::sqrt(2.0), dtype), erfinv(samples, stream), stream); + samples = + multiply(array(std::sqrt(2.0), dtype), erfinv(samples, stream), stream); + if (scale != 1.0) { + samples = multiply(array(scale, dtype), samples, stream); + } + if (loc != 0.0) { + samples = add(array(loc, dtype), samples, stream); + } + return samples; } array randint( diff --git a/mlx/random.h b/mlx/random.h index ab75eb488..1397b32d7 100644 --- a/mlx/random.h +++ b/mlx/random.h @@ -95,13 +95,30 @@ inline array uniform( array normal( const std::vector& shape, Dtype dtype, + const float loc, + const float scale, const std::optional& key = std::nullopt, StreamOrDevice s = {}); inline array normal( const std::vector& shape, + const float loc, + const float scale, const std::optional& key = std::nullopt, StreamOrDevice s = {}) { - return normal(shape, float32, key, s); + return normal(shape, float32, loc, scale, key, s); +} +inline array normal( + const std::vector& shape, + const Dtype dtype, + const std::optional& key = std::nullopt, + StreamOrDevice s = {}) { + return normal(shape, dtype, 0.0, 1.0, key, s); +} +inline array normal( + const std::vector& shape, + const std::optional& key = std::nullopt, + StreamOrDevice s = {}) { + return normal(shape, float32, 0.0, 1.0, key, s); } /** Generate integer samples uniformly at random */ diff --git a/python/src/random.cpp b/python/src/random.cpp index 6e9f38d97..e9140e7d9 100644 --- a/python/src/random.cpp +++ b/python/src/random.cpp @@ -99,13 +99,17 @@ void init_random(py::module_& parent_module) { "normal", [](const std::vector& shape, std::optional type, + float loc, + float scale, const std::optional& key, StreamOrDevice s) { - return normal(shape, type.value_or(float32), key, s); + return normal(shape, type.value_or(float32), loc, scale, key, s); }, "shape"_a = std::vector{}, "dtype"_a = std::optional{float32}, + "loc"_a = 0.0, + "scale"_a = 1.0, "key"_a = none, "stream"_a = none, R"pbdoc( @@ -114,6 +118,8 @@ void init_random(py::module_& parent_module) { Args: shape (list(int), optional): Shape of the output. Default is ``()``. dtype (Dtype, optional): Type of the output. Default is ``float32``. + loc (float, optional): Mean of the distribution. Default is ``0.0``. + scale (float, optional): Standard deviation of the distribution. Default is ``1.0``. key (array, optional): A PRNG key. Default: None. Returns: diff --git a/python/tests/test_random.py b/python/tests/test_random.py index 0a06c3496..892db37df 100644 --- a/python/tests/test_random.py +++ b/python/tests/test_random.py @@ -80,6 +80,20 @@ class TestRandom(mlx_tests.MLXTestCase): a = mx.random.normal(dtype=t) self.assertEqual(a.dtype, t) + # Generate with a given mean and standard deviation + loc = 1.0 + scale = 2.0 + + a = mx.random.normal(shape=(3, 2), loc=loc, scale=scale, key=key) + b = scale * mx.random.normal(shape=(3, 2), key=key) + loc + self.assertTrue(mx.allclose(a, b)) + + a = mx.random.normal( + shape=(3, 2), loc=loc, scale=scale, dtype=mx.float16, key=key + ) + b = scale * mx.random.normal(shape=(3, 2), dtype=mx.float16, key=key) + loc + self.assertTrue(mx.allclose(a, b)) + self.assertEqual(mx.random.normal().dtype, mx.random.normal(dtype=None).dtype) def test_randint(self): From 28eac185710cf3a6ccd871d764c306849eaf902a Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Wed, 7 Feb 2024 13:15:59 -0800 Subject: [PATCH 04/42] Kernel generation (#614) Generate reusable element-wise kernels given a computation graph. --- mlx/array.h | 15 + mlx/backend/metal/CMakeLists.txt | 21 + mlx/backend/metal/compiled.cpp | 494 +++++++++++++++++- mlx/backend/metal/compiled_preamble.h | 9 + mlx/backend/metal/device.cpp | 5 + mlx/backend/metal/device.h | 2 + mlx/backend/metal/kernels/binary.h | 221 ++++++++ mlx/backend/metal/kernels/binary.metal | 174 +----- mlx/backend/metal/kernels/compiled_preamble.h | 4 + mlx/backend/metal/kernels/unary.h | 376 +++++++++++++ mlx/backend/metal/kernels/unary.metal | 221 +------- mlx/backend/metal/kernels/utils.h | 10 +- mlx/backend/metal/make_compiled_preamble.sh | 28 + mlx/backend/metal/utils.h | 32 +- mlx/compile.cpp | 2 +- mlx/graph_utils.cpp | 35 +- mlx/graph_utils.h | 6 + mlx/primitives.h | 46 +- tests/compile_tests.cpp | 60 +++ 19 files changed, 1302 insertions(+), 459 deletions(-) create mode 100644 mlx/backend/metal/compiled_preamble.h create mode 100644 mlx/backend/metal/kernels/binary.h create mode 100644 mlx/backend/metal/kernels/compiled_preamble.h create mode 100644 mlx/backend/metal/kernels/unary.h create mode 100644 mlx/backend/metal/make_compiled_preamble.sh diff --git a/mlx/array.h b/mlx/array.h index 2b849a7ae..5eefcf727 100644 --- a/mlx/array.h +++ b/mlx/array.h @@ -121,6 +121,9 @@ class array { template T item(); + template + T item() const; + struct ArrayIterator { using iterator_category = std::random_access_iterator_tag; using difference_type = size_t; @@ -454,6 +457,18 @@ T array::item() { return *data(); } +template +T array::item() const { + if (size() != 1) { + throw std::invalid_argument("item can only be called on arrays of size 1."); + } + if (!is_evaled()) { + throw std::invalid_argument( + "item() const can only be called on evaled arrays"); + } + return *data(); +} + template void array::init(It src) { set_data(allocator::malloc(size() * size_of(dtype()))); diff --git a/mlx/backend/metal/CMakeLists.txt b/mlx/backend/metal/CMakeLists.txt index fd1a47f01..93a25434f 100644 --- a/mlx/backend/metal/CMakeLists.txt +++ b/mlx/backend/metal/CMakeLists.txt @@ -1,3 +1,23 @@ +add_custom_command( + OUTPUT compiled_preamble.cpp + COMMAND /bin/bash + ${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.sh + ${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp + ${CMAKE_C_COMPILER} + ${CMAKE_SOURCE_DIR} + DEPENDS make_compiled_preamble.sh + kernels/compiled_preamble.h + kernels/unary.h + kernels/binary.h +) + +add_custom_target( + compiled_preamble + DEPENDS compiled_preamble.cpp +) + +add_dependencies(mlx compiled_preamble) + target_sources( mlx PRIVATE @@ -16,6 +36,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp ${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp ${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp + ${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp ) if (NOT MLX_METAL_PATH) diff --git a/mlx/backend/metal/compiled.cpp b/mlx/backend/metal/compiled.cpp index 4ce9b0c85..1f27a2493 100644 --- a/mlx/backend/metal/compiled.cpp +++ b/mlx/backend/metal/compiled.cpp @@ -1,44 +1,484 @@ // Copyright © 2023-2024 Apple Inc. +#include + +#include "mlx/backend/metal/compiled_preamble.h" #include "mlx/backend/metal/device.h" +#include "mlx/backend/metal/utils.h" +#include "mlx/graph_utils.h" #include "mlx/primitives.h" +#include "mlx/utils.h" namespace mlx::core { +inline bool is_static_cast(const Primitive& p) { + return ( + typeid(p) == typeid(Broadcast) || typeid(p) == typeid(Copy) || + typeid(p) == typeid(StopGradient) || typeid(p) == typeid(AsType)); +} + +inline auto get_type_string(Dtype d) { + switch (d) { + case float32: + return "float"; + case float16: + return "half"; + case bfloat16: + return "bfloat16_t"; + case bool_: + return "bool"; + case int8: + return "int8_t"; + case int16: + return "int16_t"; + case int32: + return "int32_t"; + case int64: + return "int64_t"; + case uint8: + return "uint8_t"; + case uint16: + return "uint16_t"; + case uint32: + return "uint32_t"; + case uint64: + return "uint64_t"; + default: { + std::ostringstream msg; + msg << "Unsupported compilation type " << d; + throw std::runtime_error(msg.str()); + } + } +} + +template +void print_float_constant(std::ostream& os, const array& x) { + auto old_precision = os.precision(); + os << std::setprecision(std::numeric_limits::digits10 + 1) + << x.item() << std::setprecision(old_precision); +} + +template +void print_int_constant(std::ostream& os, const array& x) { + os << x.item(); +} + +void print_constant(std::ostream& os, const array& x) { + switch (x.dtype()) { + case float32: + return print_float_constant(os, x); + case float16: + return print_float_constant(os, x); + case bfloat16: + return print_float_constant(os, x); + case int8: + return print_int_constant(os, x); + case int16: + return print_int_constant(os, x); + case int32: + return print_int_constant(os, x); + case int64: + return print_int_constant(os, x); + case uint8: + return print_int_constant(os, x); + case uint16: + return print_int_constant(os, x); + case uint32: + return print_int_constant(os, x); + case uint64: + return print_int_constant(os, x); + case bool_: + os << std::boolalpha << x.item(); + return; + default: + throw std::runtime_error("Unsupported constant type"); + } +} + +inline std::string build_lib_name( + const std::vector& inputs, + const std::vector& outputs, + const std::vector& tape, + const std::unordered_set& constant_ids) { + std::ostringstream os; + std::ostringstream constant_hasher; + + // The primitives describing the tape. For unary and binary primitives this + // must be enough to describe the full computation. + for (auto& a : tape) { + a.primitive().print(os); + } + os << "_"; + + for (auto& x : inputs) { + if (constant_ids.find(x.id()) != constant_ids.end()) { + os << "C"; + print_constant(constant_hasher, x); + } else { + os << ((x.size() == 1) ? "S" : "V"); + } + } + os << "_"; + for (auto& x : inputs) { + if (constant_ids.find(x.id()) != constant_ids.end()) { + continue; + } + os << kindof(x.dtype()) << x.itemsize(); + } + os << "_" << std::hash{}(constant_hasher.str()); + + return os.str(); +} + +inline void build_kernel( + std::ostream& os, + const std::string& kernel_name, + const std::vector& inputs, + const std::vector& outputs, + const std::vector& tape, + const std::unordered_set& constant_ids, + bool contiguous, + int ndim, + bool dynamic_dims) { + // All outputs should have the exact same shape and will be row contiguous + auto output_shape = outputs[0].shape(); + auto output_strides = outputs[0].strides(); + + // Constants are scalars that are captured by value and cannot change + auto is_constant = [&constant_ids](const array& x) { + return constant_ids.find(x.id()) != constant_ids.end(); + }; + + // For scalar we shouldn't do the indexing things, just read at 0 + auto is_scalar = [](const array& x) { return x.size() == 1; }; + + NodeNamer namer; + bool add_indices = false; + int cnt = 0; + + // Start the kernel + os << "[[host_name(\"" << kernel_name << "\")]]" << std::endl + << "[[kernel]] void " << kernel_name << "(" << std::endl; + + // Add the input arguments + for (auto& x : inputs) { + auto& xname = namer.get_name(x); + + // Skip constants from the input list + if (is_constant(x)) { + continue; + } + + // Scalars and contiguous need no strides + if (is_scalar(x) || contiguous) { + os << " device const " << get_type_string(x.dtype()) << "* " << xname + << " [[buffer(" << cnt++ << ")]]," << std::endl; + } else { + add_indices = true; + os << " device const " << get_type_string(x.dtype()) << "* " << xname + << " [[buffer(" << cnt++ << ")]]," << std::endl + << " constant const size_t* " << xname << "_strides [[buffer(" + << cnt++ << ")]]," << std::endl; + } + } + + // Add the output arguments + for (auto& x : outputs) { + os << " device " << get_type_string(x.dtype()) << "* " + << namer.get_name(x) << " [[buffer(" << cnt++ << ")]]," << std::endl; + } + // Add output strides and shape to extract the indices. + if (!contiguous) { + os << " constant const size_t* output_strides [[buffer(" << cnt++ + << ")]]," << std::endl + << " constant const int* output_shape [[buffer(" << cnt++ << ")]]," + << std::endl; + } + if (dynamic_dims) { + os << " constant const int& ndim [[buffer(" << cnt++ << ")]]," + << std::endl; + } + + // The thread index in the whole grid + os << " uint3 pos [[thread_position_in_grid]]," << std::endl + << " uint3 grid [[threads_per_grid]]) {" << std::endl + << " uint index = pos.x + grid.x * (pos.y + grid.y * pos.z);" + << std::endl; + + // Extract the indices per axis to individual uints if we have arrays that + // are broadcasted or transposed + if (add_indices) { + if (!dynamic_dims) { + if (ndim == 1) { + os << " uint index_0 = pos.x;" << std::endl; + } else if (ndim == 2) { + os << " uint index_0 = pos.y;" << std::endl + << " uint index_1 = pos.x;" << std::endl; + } else if (ndim == 3) { + os << " uint index_0 = pos.z;" << std::endl + << " uint index_1 = pos.y;" << std::endl + << " uint index_2 = pos.x;" << std::endl; + } else { + for (int i = 0; i < ndim - 2; i++) { + os << " uint index_" << i << " = (index / uint(output_strides[" << i + << "])) % output_shape[" << i << "];" << std::endl; + } + os << " uint index_" << ndim - 2 << " = pos.y;" << std::endl + << " uint index_" << ndim - 1 << " = pos.x;" << std::endl; + } + } + } + + // Read the inputs in tmps + for (auto& x : inputs) { + auto& xname = namer.get_name(x); + + if (is_constant(x)) { + os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "; + print_constant(os, x); + os << ";" << std::endl; + } else if (is_scalar(x)) { + os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = " + << xname << "[0];" << std::endl; + } else if (contiguous) { + os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = " + << xname << "[index];" << std::endl; + } else if (!dynamic_dims) { + os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = " + << xname << "["; + os << "index_0 * " << xname << "_strides[0]"; + for (int i = 1; i < ndim; i++) { + os << " + index_" << i << " * " << xname << "_strides[" << i << "]"; + } + os << "];" << std::endl; + } else { + os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = " + << xname << "[elem_to_loc(index, output_shape, " << xname + << "_strides, ndim)];" << std::endl; + } + } + + // Actually write the computation + for (auto& x : tape) { + os << " " << get_type_string(x.dtype()) << " tmp_" << namer.get_name(x) + << " = "; + if (is_static_cast(x.primitive())) { + os << "static_cast<" << get_type_string(x.dtype()) << ">(tmp_" + << namer.get_name(x.inputs()[0]) << ");" << std::endl; + } else { + x.primitive().print(os); + os << "()("; + for (int i = 0; i < x.inputs().size() - 1; i++) { + os << "tmp_" << namer.get_name(x.inputs()[i]) << ", "; + } + os << "tmp_" << namer.get_name(x.inputs().back()) << ");" << std::endl; + } + } + + // Write the outputs from tmps + for (auto& x : outputs) { + os << " " << namer.get_name(x) << "[index] = tmp_" << namer.get_name(x) + << ";" << std::endl; + } + + // Finish the kernel + os << "}" << std::endl; + + if (cnt > 31) { + std::ostringstream msg; + msg << "[compile] Too many inputs/outputs fused in the Metal Compile " + << "primitive which exhausted the available argument buffers for " + << "the kernel. Please file an issue with the function that results " + << "in this error. The name of the kernel is '" << kernel_name << "'"; + throw std::runtime_error(msg.str()); + } +} + void Compiled::eval_gpu( const std::vector& inputs, std::vector& outputs) { - // Just a fall-back to the original tape for now - std::unordered_map trace_to_real; - for (int i = 0; i < inputs.size(); ++i) { - trace_to_real.insert({inputs_[i].id(), inputs[i]}); - } - for (int i = 0; i < outputs.size(); ++i) { - trace_to_real.insert({outputs_[i].id(), outputs[i]}); + // Make the name for the kernel library + if (kernel_lib_.empty()) { + kernel_lib_ = build_lib_name(inputs_, outputs_, tape_, constant_ids_); } - for (auto& a : tape_) { - std::vector p_inputs; - for (auto& in : a.inputs()) { - p_inputs.push_back(trace_to_real.at(in.id())); - } - // If a is an output get it from the map, otherwise create it - // NB this is safe as long as no multi-output sub primitves are allowed - // in Compiled - std::vector p_outputs; - if (auto it = trace_to_real.find(a.id()); it != trace_to_real.end()) { - p_outputs.push_back(it->second); - } else { - p_outputs.push_back(array(a.shape(), a.dtype(), a.primitive_ptr(), {})); - trace_to_real.insert({a.id(), p_outputs[0]}); - } - a.primitive().eval_gpu(p_inputs, p_outputs); - } + // Get the kernel if someone else built it already auto& s = stream(); auto& d = metal::device(s.device); - auto command_buffer = d.get_command_buffer(s.index); - command_buffer->addCompletedHandler( - [trace_to_real](MTL::CommandBuffer*) mutable {}); + auto lib = d.get_library(kernel_lib_); + + // If not we have to build it ourselves + if (lib == nullptr) { + std::ostringstream kernel; + kernel << metal::get_kernel_preamble() << std::endl; + build_kernel( + kernel, + kernel_lib_ + "_contiguous", + inputs_, + outputs_, + tape_, + constant_ids_, + /* contiguous = */ true, + /* ndim = */ 0, + /* dynamic_dims = */ false); + for (int i = 1; i < 8; i++) { + build_kernel( + kernel, + kernel_lib_ + "_strided_" + std::to_string(i), + inputs_, + outputs_, + tape_, + constant_ids_, + /* contiguous = */ false, + /* ndim = */ i, + /* dynamic_dims = */ false); + } + build_kernel( + kernel, + kernel_lib_ + "_strided_dynamic", + inputs_, + outputs_, + tape_, + constant_ids_, + /* contiguous = */ false, + /* ndim = */ 0, + /* dynamic_dims = */ true); + + kernel_source_ = kernel.str(); + lib = d.get_library(kernel_lib_, kernel_source_); + } + + // Allocate space for the outputs + for (auto& out : outputs) { + out.set_data(allocator::malloc_or_wait(out.nbytes())); + } + + // Figure out which kernel we are using + auto& output_shape = outputs[0].shape(); + bool contiguous = true; + for (auto& x : inputs) { + if ((!x.flags().row_contiguous || x.shape() != output_shape) && + x.size() > 1) { + contiguous = false; + break; + } + } + + // Collapse contiguous dims to route to a faster kernel if possible. Also + // handle all broadcasting. + std::vector> initial_strides; + initial_strides.push_back(outputs[0].strides()); + std::vector shape; + std::vector> strides; + if (!contiguous) { + for (int i = 0; i < inputs.size(); i++) { + // Skip constants. + if (constant_ids_.find(inputs_[i].id()) != constant_ids_.end()) { + continue; + } + auto& x = inputs[i]; + + // Skip scalar inputs. + if (x.size() <= 1) { + continue; + } + + // Broadcast the inputs to the output shape. + std::vector xstrides; + int j = 0; + for (; j < output_shape.size() - x.ndim(); j++) { + if (output_shape[j] == 1) { + xstrides.push_back(outputs[0].strides()[j]); + } else { + xstrides.push_back(0); + } + } + for (int i = 0; i < x.ndim(); i++, j++) { + if (x.shape(i) == 1) { + if (output_shape[j] == 1) { + xstrides.push_back(outputs[0].strides()[j]); + } else { + xstrides.push_back(0); + } + } else { + xstrides.push_back(x.strides()[i]); + } + } + initial_strides.push_back(std::move(xstrides)); + } + std::tie(shape, strides) = + collapse_contiguous_dims(output_shape, initial_strides); + } + + // Get the kernel from the lib + int ndim = shape.size(); + bool dynamic = ndim >= 8; + auto kernel_name = kernel_lib_ + (contiguous ? "_contiguous" : "_strided_"); + if (!contiguous) { + if (dynamic) { + kernel_name += "dynamic"; + } else { + kernel_name += std::to_string(shape.size()); + } + } + auto kernel = d.get_kernel(kernel_name, lib); + auto compute_encoder = d.get_command_encoder(s.index); + compute_encoder->setComputePipelineState(kernel); + + // Put the inputs in + int cnt = 0; + int stride_idx = 1; // idx 0 is the output strides + for (int i = 0; i < inputs.size(); i++) { + if (constant_ids_.find(inputs_[i].id()) != constant_ids_.end()) { + continue; + } + auto& x = inputs[i]; + set_array_buffer(compute_encoder, x, cnt++); + if (!contiguous && x.size() > 1) { + compute_encoder->setBytes( + strides[stride_idx].data(), + strides[stride_idx].size() * sizeof(size_t), + cnt++); + stride_idx++; + } + } + + // Put the outputs in + for (auto& x : outputs) { + set_array_buffer(compute_encoder, x, cnt++); + } + + // Put the output shape and strides in + if (!contiguous) { + compute_encoder->setBytes( + strides[0].data(), strides[0].size() * sizeof(size_t), cnt++); + compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), cnt++); + } + + // Put the number of dims in if it is dynamic + if (dynamic) { + compute_encoder->setBytes(&ndim, sizeof(int), cnt++); + } + + // Launch the kernel + if (contiguous) { + size_t nthreads = outputs[0].size(); + MTL::Size grid_dims(nthreads, 1, 1); + MTL::Size group_dims( + std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1); + compute_encoder->dispatchThreads(grid_dims, group_dims); + } else { + size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1; + size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1; + size_t rest = outputs[0].size() / (dim0 * dim1); + NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); + if (thread_group_size != 1024) { + throw std::runtime_error("[Metal::binary] Must use 1024 sized block"); + } + auto group_dims = get_block_dims(dim0, dim1, rest); + MTL::Size grid_dims = MTL::Size(dim0, dim1, rest); + compute_encoder->dispatchThreads(grid_dims, group_dims); + } } } // namespace mlx::core diff --git a/mlx/backend/metal/compiled_preamble.h b/mlx/backend/metal/compiled_preamble.h new file mode 100644 index 000000000..9122d3d54 --- /dev/null +++ b/mlx/backend/metal/compiled_preamble.h @@ -0,0 +1,9 @@ +// Copyright © 2023-24 Apple Inc. + +#pragma once + +namespace mlx::core::metal { + +const char* get_kernel_preamble(); + +} diff --git a/mlx/backend/metal/device.cpp b/mlx/backend/metal/device.cpp index 7c61e68ae..e50441d48 100644 --- a/mlx/backend/metal/device.cpp +++ b/mlx/backend/metal/device.cpp @@ -414,6 +414,11 @@ MTL::ComputePipelineState* Device::get_kernel_( return kernel; } +MTL::Library* Device::get_library(const std::string& name) { + auto it = library_map_.find(name); + return (it != library_map_.end()) ? it->second : nullptr; +} + MTL::Library* Device::get_library( const std::string& name, const std::string& source, diff --git a/mlx/backend/metal/device.h b/mlx/backend/metal/device.h index 6acfe9332..8312084ce 100644 --- a/mlx/backend/metal/device.h +++ b/mlx/backend/metal/device.h @@ -62,6 +62,8 @@ class Device { const std::function& lib_path_func = get_colocated_mtllib_path); + MTL::Library* get_library(const std::string& name); + MTL::Library* get_library( const std::string& name, const std::string& source_string, diff --git a/mlx/backend/metal/kernels/binary.h b/mlx/backend/metal/kernels/binary.h new file mode 100644 index 000000000..8adb84c58 --- /dev/null +++ b/mlx/backend/metal/kernels/binary.h @@ -0,0 +1,221 @@ +// Copyright © 2023-2024 Apple Inc. + +#pragma once + +#include +#include + +#include "mlx/backend/metal/kernels/bf16.h" +#include "mlx/backend/metal/kernels/utils.h" + +struct Add { + template + T operator()(T x, T y) { + return x + y; + } +}; + +struct Divide { + template + T operator()(T x, T y) { + return x / y; + } +}; + +struct Remainder { + template + T operator()(T x, T y) { + return x % y; + } + template <> + float operator()(float x, float y) { + return fmod(x, y); + } + template <> + half operator()(half x, half y) { + return fmod(x, y); + } + template <> + bfloat16_t operator()(bfloat16_t x, bfloat16_t y) { + return fmod(x, y); + } +}; + +struct Equal { + template + bool operator()(T x, T y) { + return x == y; + } +}; + +struct NaNEqual { + template + bool operator()(T x, T y) { + return x == y || (metal::isnan(x) && metal::isnan(y)); + } + template <> + bool operator()(complex64_t x, complex64_t y) { + return x == y || + (metal::isnan(x.real) && metal::isnan(y.real) && metal::isnan(x.imag) && + metal::isnan(y.imag)) || + (x.real == y.real && metal::isnan(x.imag) && metal::isnan(y.imag)) || + (metal::isnan(x.real) && metal::isnan(y.real) && x.imag == y.imag); + } +}; + +struct Greater { + template + bool operator()(T x, T y) { + return x > y; + } +}; + +struct GreaterEqual { + template + bool operator()(T x, T y) { + return x >= y; + } +}; + +struct Less { + template + bool operator()(T x, T y) { + return x < y; + } +}; + +struct LessEqual { + template + bool operator()(T x, T y) { + return x <= y; + } +}; + +struct LogAddExp { + template + T operator()(T x, T y) { + if (metal::isnan(x) || metal::isnan(y)) { + return metal::numeric_limits::quiet_NaN(); + } + constexpr T inf = metal::numeric_limits::infinity(); + T maxval = metal::max(x, y); + T minval = metal::min(x, y); + return (minval == -inf || maxval == inf) + ? maxval + : (maxval + log1p(metal::exp(minval - maxval))); + }; +}; + +struct Maximum { + template + metal::enable_if_t, T> operator()(T x, T y) { + return metal::max(x, y); + } + + template + metal::enable_if_t, T> operator()(T x, T y) { + if (metal::isnan(x)) { + return x; + } + return x > y ? x : y; + } + + template <> + complex64_t operator()(complex64_t x, complex64_t y) { + if (metal::isnan(x.real) || metal::isnan(x.imag)) { + return x; + } + return x > y ? x : y; + } +}; + +struct Minimum { + template + metal::enable_if_t, T> operator()(T x, T y) { + return metal::min(x, y); + } + + template + metal::enable_if_t, T> operator()(T x, T y) { + if (metal::isnan(x)) { + return x; + } + return x < y ? x : y; + } + + template <> + complex64_t operator()(complex64_t x, complex64_t y) { + if (metal::isnan(x.real) || metal::isnan(x.imag)) { + return x; + } + return x < y ? x : y; + } +}; + +struct Multiply { + template + T operator()(T x, T y) { + return x * y; + } +}; + +struct NotEqual { + template + bool operator()(T x, T y) { + return x != y; + } + template <> + bool operator()(complex64_t x, complex64_t y) { + return x.real != y.real || x.imag != y.imag; + } +}; + +struct Power { + template + metal::enable_if_t, T> operator()(T base, T exp) { + return metal::pow(base, exp); + } + + template + metal::enable_if_t, T> operator()(T base, T exp) { + T res = 1; + while (exp) { + if (exp & 1) { + res *= base; + } + exp >>= 1; + base *= base; + } + return res; + } + + template <> + complex64_t operator()(complex64_t x, complex64_t y) { + auto x_theta = metal::atan(x.imag / x.real); + auto x_ln_r = 0.5 * metal::log(x.real * x.real + x.imag * x.imag); + auto mag = metal::exp(y.real * x_ln_r - y.imag * x_theta); + auto phase = y.imag * x_ln_r + y.real * x_theta; + return {mag * metal::cos(phase), mag * metal::sin(phase)}; + } +}; + +struct Subtract { + template + T operator()(T x, T y) { + return x - y; + } +}; + +struct LogicalAnd { + template + T operator()(T x, T y) { + return x && y; + }; +}; + +struct LogicalOr { + template + T operator()(T x, T y) { + return x || y; + }; +}; diff --git a/mlx/backend/metal/kernels/binary.metal b/mlx/backend/metal/kernels/binary.metal index 1b84c70a5..4d449ab69 100644 --- a/mlx/backend/metal/kernels/binary.metal +++ b/mlx/backend/metal/kernels/binary.metal @@ -1,176 +1,6 @@ -// Copyright © 2023 Apple Inc. +// Copyright © 2023-2024 Apple Inc. -#include -#include - -#include "mlx/backend/metal/kernels/utils.h" -#include "mlx/backend/metal/kernels/bf16.h" - -struct Add { - template T operator()(T x, T y) { return x + y; } -}; - -struct Divide { - template T operator()(T x, T y) { return x / y; } -}; - -struct Remainder { - template T operator()(T x, T y) { return x % y; } - template <> float operator()(float x, float y) { return fmod(x, y); } - template <> half operator()(half x, half y) { return fmod(x, y); } - template <> bfloat16_t operator()(bfloat16_t x, bfloat16_t y) { return fmod(x, y); } -}; - -struct Equal { - template bool operator()(T x, T y) { return x == y; } -}; - -struct NaNEqual { - template bool operator()(T x, T y) { - return x == y || (metal::isnan(x) && metal::isnan(y)); - } - template <> - bool operator()(complex64_t x, complex64_t y) { - return x == y || - (metal::isnan(x.real) && metal::isnan(y.real) - && metal::isnan(x.imag) && metal::isnan(y.imag)) || - (x.real == y.real && metal::isnan(x.imag) && metal::isnan(y.imag)) || - (metal::isnan(x.real) && metal::isnan(y.real) && x.imag == y.imag); - } -}; - -struct Greater { - template bool operator()(T x, T y) { return x > y; } -}; - -struct GreaterEqual { - template bool operator()(T x, T y) { return x >= y; } -}; - -struct Less { - template bool operator()(T x, T y) { return x < y; } -}; - -struct LessEqual { - template bool operator()(T x, T y) { return x <= y; } -}; - -struct LogAddExp { - template - T operator()(T x, T y) { - if (metal::isnan(x) || metal::isnan(y)) { - return metal::numeric_limits::quiet_NaN(); - } - constexpr T inf = metal::numeric_limits::infinity(); - T maxval = metal::max(x, y); - T minval = metal::min(x, y); - return (minval == -inf || maxval == inf) ? maxval : - (maxval + log1p(metal::exp(minval - maxval))); - }; -}; - -struct Maximum { - template - metal::enable_if_t, T> operator()(T x, T y) { - return metal::max(x, y); - } - - template - metal::enable_if_t, T> operator()(T x, T y) { - if (metal::isnan(x)) { - return x; - } - return x > y ? x : y; - } - - template <> - complex64_t operator()(complex64_t x, complex64_t y) { - if (metal::isnan(x.real) || metal::isnan(x.imag)) { - return x; - } - return x > y ? x : y; - } -}; - -struct Minimum { - template - metal::enable_if_t, T> operator()(T x, T y) { - return metal::min(x, y); - } - - template - metal::enable_if_t, T> operator()(T x, T y) { - if (metal::isnan(x)) { - return x; - } - return x < y ? x : y; - } - - template <> - complex64_t operator()(complex64_t x, complex64_t y) { - if (metal::isnan(x.real) || metal::isnan(x.imag)) { - return x; - } - return x < y ? x : y; - } -}; - -struct Multiply { - template T operator()(T x, T y) { return x * y; } -}; - -struct NotEqual { - template bool operator()(T x, T y) { return x != y; } - template <> - bool operator()(complex64_t x, complex64_t y) { - return x.real != y.real || x.imag != y.imag; - } -}; - -struct Power { - - template - metal::enable_if_t, T> operator()(T base, T exp) { - return metal::pow(base, exp); - } - - template - metal::enable_if_t, T> operator()(T base, T exp) { - T res = 1; - while (exp) { - if (exp & 1) { - res *= base; - } - exp >>= 1; - base *= base; - } - return res; - } - - template <> - complex64_t operator()(complex64_t x, complex64_t y) { - auto x_theta = metal::atan(x.imag / x.real); - auto x_ln_r = 0.5 * metal::log(x.real * x.real + x.imag * x.imag); - auto mag = metal::exp(y.real * x_ln_r - y.imag * x_theta); - auto phase = y.imag * x_ln_r + y.real * x_theta; - return {mag * metal::cos(phase), mag * metal::sin(phase)}; - } -}; - - -struct Subtract { - template T operator()(T x, T y) { return x - y; } -}; - -struct LogicalAnd { - template - T operator()(T x, T y) { return x && y; }; -}; - -struct LogicalOr { - template - T operator()(T x, T y) { return x || y; }; -}; +#include "mlx/backend/metal/kernels/binary.h" template [[kernel]] void binary_op_s2s( diff --git a/mlx/backend/metal/kernels/compiled_preamble.h b/mlx/backend/metal/kernels/compiled_preamble.h new file mode 100644 index 000000000..82a9e9c5c --- /dev/null +++ b/mlx/backend/metal/kernels/compiled_preamble.h @@ -0,0 +1,4 @@ +// Copyright © 2023-2024 Apple Inc. + +#include "mlx/backend/metal/kernels/binary.h" +#include "mlx/backend/metal/kernels/unary.h" diff --git a/mlx/backend/metal/kernels/unary.h b/mlx/backend/metal/kernels/unary.h new file mode 100644 index 000000000..6d086b775 --- /dev/null +++ b/mlx/backend/metal/kernels/unary.h @@ -0,0 +1,376 @@ +// Copyright © 2023-2024 Apple Inc. + +#pragma once + +#include +#include + +#include "mlx/backend/metal/kernels/bf16.h" +#include "mlx/backend/metal/kernels/erf.h" +#include "mlx/backend/metal/kernels/utils.h" + +struct Abs { + template + T operator()(T x) { + return metal::abs(x); + }; + template <> + uint8_t operator()(uint8_t x) { + return x; + }; + template <> + uint16_t operator()(uint16_t x) { + return x; + }; + template <> + uint32_t operator()(uint32_t x) { + return x; + }; + template <> + uint64_t operator()(uint64_t x) { + return x; + }; + template <> + bool operator()(bool x) { + return x; + }; + template <> + complex64_t operator()(complex64_t x) { + return {metal::precise::sqrt(x.real * x.real + x.imag * x.imag), 0}; + }; +}; + +struct ArcCos { + template + T operator()(T x) { + return metal::precise::acos(x); + }; +}; + +struct ArcCosh { + template + T operator()(T x) { + return metal::precise::acosh(x); + }; +}; + +struct ArcSin { + template + T operator()(T x) { + return metal::precise::asin(x); + }; +}; + +struct ArcSinh { + template + T operator()(T x) { + return metal::precise::asinh(x); + }; +}; + +struct ArcTan { + template + T operator()(T x) { + return metal::precise::atan(x); + }; +}; + +struct ArcTanh { + template + T operator()(T x) { + return metal::precise::atanh(x); + }; +}; + +struct Ceil { + template + T operator()(T x) { + return metal::ceil(x); + }; + template <> + int8_t operator()(int8_t x) { + return x; + }; + template <> + int16_t operator()(int16_t x) { + return x; + }; + template <> + int32_t operator()(int32_t x) { + return x; + }; + template <> + int64_t operator()(int64_t x) { + return x; + }; + template <> + uint8_t operator()(uint8_t x) { + return x; + }; + template <> + uint16_t operator()(uint16_t x) { + return x; + }; + template <> + uint32_t operator()(uint32_t x) { + return x; + }; + template <> + uint64_t operator()(uint64_t x) { + return x; + }; + template <> + bool operator()(bool x) { + return x; + }; +}; + +struct Cos { + template + T operator()(T x) { + return metal::precise::cos(x); + }; + + template <> + complex64_t operator()(complex64_t x) { + return { + metal::precise::cos(x.real) * metal::precise::cosh(x.imag), + -metal::precise::sin(x.real) * metal::precise::sinh(x.imag)}; + }; +}; + +struct Cosh { + template + T operator()(T x) { + return metal::precise::cosh(x); + }; + + template <> + complex64_t operator()(complex64_t x) { + return { + metal::precise::cosh(x.real) * metal::precise::cos(x.imag), + metal::precise::sinh(x.real) * metal::precise::sin(x.imag)}; + }; +}; + +struct Erf { + template + T operator()(T x) { + return static_cast(erf(static_cast(x))); + }; +}; + +struct ErfInv { + template + T operator()(T x) { + return static_cast(erfinv(static_cast(x))); + }; +}; + +struct Exp { + template + T operator()(T x) { + return metal::precise::exp(x); + }; + template <> + complex64_t operator()(complex64_t x) { + auto m = metal::precise::exp(x.real); + return {m * metal::precise::cos(x.imag), m * metal::precise::sin(x.imag)}; + } +}; + +struct Floor { + template + T operator()(T x) { + return metal::floor(x); + }; + template <> + int8_t operator()(int8_t x) { + return x; + }; + template <> + int16_t operator()(int16_t x) { + return x; + }; + template <> + int32_t operator()(int32_t x) { + return x; + }; + template <> + int64_t operator()(int64_t x) { + return x; + }; + template <> + uint8_t operator()(uint8_t x) { + return x; + }; + template <> + uint16_t operator()(uint16_t x) { + return x; + }; + template <> + uint32_t operator()(uint32_t x) { + return x; + }; + template <> + uint64_t operator()(uint64_t x) { + return x; + }; + template <> + bool operator()(bool x) { + return x; + }; +}; + +struct Log { + template + T operator()(T x) { + return metal::precise::log(x); + }; +}; + +struct Log2 { + template + T operator()(T x) { + return metal::precise::log2(x); + }; +}; + +struct Log10 { + template + T operator()(T x) { + return metal::precise::log10(x); + }; +}; + +struct Log1p { + template + T operator()(T x) { + return log1p(x); + }; +}; + +struct LogicalNot { + template + T operator()(T x) { + return !x; + }; +}; + +struct Negative { + template + T operator()(T x) { + return -x; + }; +}; + +struct Round { + template + T operator()(T x) { + return metal::rint(x); + }; + template <> + complex64_t operator()(complex64_t x) { + return {metal::rint(x.real), metal::rint(x.imag)}; + }; +}; + +struct Sigmoid { + template + T operator()(T x) { + auto y = 1 / (1 + metal::exp(-metal::abs(x))); + return (x < 0) ? 1 - y : y; + } +}; + +struct Sign { + template + T operator()(T x) { + return (x > T(0)) - (x < T(0)); + }; + template <> + uint32_t operator()(uint32_t x) { + return x != 0; + }; +}; + +struct Sin { + template + T operator()(T x) { + return metal::precise::sin(x); + }; + + template <> + complex64_t operator()(complex64_t x) { + return { + metal::precise::sin(x.real) * metal::precise::cosh(x.imag), + metal::precise::cos(x.real) * metal::precise::sinh(x.imag)}; + }; +}; + +struct Sinh { + template + T operator()(T x) { + return metal::precise::sinh(x); + }; + + template <> + complex64_t operator()(complex64_t x) { + return { + metal::precise::sinh(x.real) * metal::precise::cos(x.imag), + metal::precise::cosh(x.real) * metal::precise::sin(x.imag)}; + }; +}; + +struct Square { + template + T operator()(T x) { + return x * x; + }; +}; + +struct Sqrt { + template + T operator()(T x) { + return metal::precise::sqrt(x); + }; +}; + +struct Rsqrt { + template + T operator()(T x) { + return metal::precise::rsqrt(x); + }; +}; + +struct Tan { + template + T operator()(T x) { + return metal::precise::tan(x); + }; + + template <> + complex64_t operator()(complex64_t x) { + float tan_a = metal::precise::tan(x.real); + float tanh_b = metal::precise::tanh(x.imag); + float t1 = tan_a * tanh_b; + float denom = 1. + t1 * t1; + return {(tan_a - tanh_b * t1) / denom, (tanh_b + tan_a * t1) / denom}; + }; +}; + +struct Tanh { + template + T operator()(T x) { + return metal::precise::tanh(x); + }; + + template <> + complex64_t operator()(complex64_t x) { + float tanh_a = metal::precise::tanh(x.real); + float tan_b = metal::precise::tan(x.imag); + float t1 = tanh_a * tan_b; + float denom = 1. + t1 * t1; + return {(tanh_a + tan_b * t1) / denom, (tan_b - tanh_a * t1) / denom}; + }; +}; diff --git a/mlx/backend/metal/kernels/unary.metal b/mlx/backend/metal/kernels/unary.metal index 681d7707f..154db0520 100644 --- a/mlx/backend/metal/kernels/unary.metal +++ b/mlx/backend/metal/kernels/unary.metal @@ -1,223 +1,6 @@ -// Copyright © 2023 Apple Inc. +// Copyright © 2023-2024 Apple Inc. -#include -#include - -#include "mlx/backend/metal/kernels/utils.h" -#include "mlx/backend/metal/kernels/erf.h" -#include "mlx/backend/metal/kernels/bf16.h" - -struct Abs { - template T operator()(T x) { return metal::abs(x); }; - template <> uint8_t operator()(uint8_t x) { return x; }; - template <> uint16_t operator()(uint16_t x) { return x; }; - template <> uint32_t operator()(uint32_t x) { return x; }; - template <> uint64_t operator()(uint64_t x) { return x; }; - template <> bool operator()(bool x) { return x; }; - template <> complex64_t operator()(complex64_t x) { - return {metal::precise::sqrt(x.real * x.real + x.imag * x.imag), 0}; - }; -}; - -struct ArcCos { - template T operator()(T x) { return metal::precise::acos(x); }; -}; - -struct ArcCosh { - template T operator()(T x) { return metal::precise::acosh(x); }; -}; - -struct ArcSin { - template T operator()(T x) { return metal::precise::asin(x); }; -}; - -struct ArcSinh { - template T operator()(T x) { return metal::precise::asinh(x); }; -}; - -struct ArcTan { - template T operator()(T x) { return metal::precise::atan(x); }; -}; - -struct ArcTanh { - template T operator()(T x) { return metal::precise::atanh(x); }; -}; - -struct Ceil { - template T operator()(T x) { return metal::ceil(x); }; - template <> int8_t operator()(int8_t x) { return x; }; - template <> int16_t operator()(int16_t x) { return x; }; - template <> int32_t operator()(int32_t x) { return x; }; - template <> int64_t operator()(int64_t x) { return x; }; - template <> uint8_t operator()(uint8_t x) { return x; }; - template <> uint16_t operator()(uint16_t x) { return x; }; - template <> uint32_t operator()(uint32_t x) { return x; }; - template <> uint64_t operator()(uint64_t x) { return x; }; - template <> bool operator()(bool x) { return x; }; -}; - -struct Cos { - template T operator()(T x) { return metal::precise::cos(x); }; - - template <> - complex64_t operator()(complex64_t x) { - return { - metal::precise::cos(x.real) * metal::precise::cosh(x.imag), - -metal::precise::sin(x.real) * metal::precise::sinh(x.imag) - }; - }; -}; - -struct Cosh { - template T operator()(T x) { return metal::precise::cosh(x); }; - - template <> - complex64_t operator()(complex64_t x) { - return { - metal::precise::cosh(x.real) * metal::precise::cos(x.imag), - metal::precise::sinh(x.real) * metal::precise::sin(x.imag) - }; - }; -}; - -struct Erf { - template T operator()(T x) { return static_cast(erf(static_cast(x))); }; -}; - -struct ErfInv { - template T operator()(T x) { return static_cast(erfinv(static_cast(x))); }; -}; - -struct Exp { - template T operator()(T x) { return metal::precise::exp(x); }; - template <> complex64_t operator()(complex64_t x) { - auto m = metal::precise::exp(x.real); - return {m * metal::precise::cos(x.imag), m * metal::precise::sin(x.imag)}; - } -}; - -struct Floor { - template T operator()(T x) { return metal::floor(x); }; - template <> int8_t operator()(int8_t x) { return x; }; - template <> int16_t operator()(int16_t x) { return x; }; - template <> int32_t operator()(int32_t x) { return x; }; - template <> int64_t operator()(int64_t x) { return x; }; - template <> uint8_t operator()(uint8_t x) { return x; }; - template <> uint16_t operator()(uint16_t x) { return x; }; - template <> uint32_t operator()(uint32_t x) { return x; }; - template <> uint64_t operator()(uint64_t x) { return x; }; - template <> bool operator()(bool x) { return x; }; -}; - -struct Log { - template T operator()(T x) { return metal::precise::log(x); }; -}; - -struct Log2 { - template T operator()(T x) { return metal::precise::log2(x); }; -}; - -struct Log10 { - template T operator()(T x) { return metal::precise::log10(x); }; -}; - -struct Log1p { - template T operator()(T x) { return log1p(x); }; -}; - -struct LogicalNot { - template T operator()(T x) { return !x; }; -}; - -struct Negative { - template T operator()(T x) { return -x; }; -}; - -struct Round { - template T operator()(T x) { return metal::rint(x); }; - template <> complex64_t operator()(complex64_t x) { return {metal::rint(x.real), metal::rint(x.imag)}; }; -}; - -struct Sigmoid { - template - T operator()(T x) { - auto y = 1 / (1 + metal::exp(-metal::abs(x))); - return (x < 0) ? 1 - y : y; - } -}; - -struct Sign { - template T operator()(T x) { return (x > T(0)) - (x < T(0)); }; - template <> uint32_t operator()(uint32_t x) { return x != 0; }; -}; - -struct Sin { - template T operator()(T x) { return metal::precise::sin(x); }; - - template <> - complex64_t operator()(complex64_t x) { - return { - metal::precise::sin(x.real) * metal::precise::cosh(x.imag), - metal::precise::cos(x.real) * metal::precise::sinh(x.imag) - }; - }; -}; - -struct Sinh { - template T operator()(T x) { return metal::precise::sinh(x); }; - - template <> - complex64_t operator()(complex64_t x) { - return { - metal::precise::sinh(x.real) * metal::precise::cos(x.imag), - metal::precise::cosh(x.real) * metal::precise::sin(x.imag) - }; - }; -}; - -struct Square { - template T operator()(T x) { return x * x; }; -}; - -struct Sqrt { - template T operator()(T x) { return metal::precise::sqrt(x); }; -}; - -struct Rsqrt { - template T operator()(T x) { return metal::precise::rsqrt(x); }; -}; - -struct Tan { - template T operator()(T x) { return metal::precise::tan(x); }; - - template <> - complex64_t operator()(complex64_t x) { - float tan_a = metal::precise::tan(x.real); - float tanh_b = metal::precise::tanh(x.imag); - float t1 = tan_a * tanh_b; - float denom = 1. + t1 * t1; - return { - (tan_a - tanh_b * t1) / denom, - (tanh_b + tan_a * t1) / denom - }; - }; -}; - -struct Tanh { - template T operator()(T x) { return metal::precise::tanh(x); }; - - template <> - complex64_t operator()(complex64_t x) { - float tanh_a = metal::precise::tanh(x.real); - float tan_b = metal::precise::tan(x.imag); - float t1 = tanh_a * tan_b; - float denom = 1. + t1 * t1; - return { - (tanh_a + tan_b * t1) / denom, - (tan_b - tanh_a * t1) / denom - }; - }; -}; +#include "mlx/backend/metal/kernels/unary.h" template [[kernel]] void unary_op_v( diff --git a/mlx/backend/metal/kernels/utils.h b/mlx/backend/metal/kernels/utils.h index 634a9d6df..f9d507cf2 100644 --- a/mlx/backend/metal/kernels/utils.h +++ b/mlx/backend/metal/kernels/utils.h @@ -12,10 +12,10 @@ template struct Limits { - static const constant U max; - static const constant U min; - static const constant U finite_max; - static const constant U finite_min; + static const constant U max = metal::numeric_limits::max(); + static const constant U min = metal::numeric_limits::min(); + static const constant U finite_max = metal::numeric_limits::max(); + static const constant U finite_min = metal::numeric_limits::min(); }; #define instantiate_default_limit(type) \ @@ -273,4 +273,4 @@ inline int64_t simd_shuffle_down(int64_t data, uint16_t delta) { inline bool simd_shuffle_down(bool data, uint16_t delta) { return simd_shuffle_down(static_cast(data), delta); -} \ No newline at end of file +} diff --git a/mlx/backend/metal/make_compiled_preamble.sh b/mlx/backend/metal/make_compiled_preamble.sh new file mode 100644 index 000000000..1271f567d --- /dev/null +++ b/mlx/backend/metal/make_compiled_preamble.sh @@ -0,0 +1,28 @@ +#!/bin/bash +# +# This script generates a C++ function that provides the Metal unary and binary +# ops at runtime for use with kernel generation. +# +# Copyright © 2023-24 Apple Inc. + + +OUTPUT_FILE=$1 +CC=$2 +SRCDIR=$3 + +CONTENT=$($CC -I $SRCDIR -E $SRCDIR/mlx/backend/metal/kernels/compiled_preamble.h 2>/dev/null) + +cat << EOF > $OUTPUT_FILE +// Copyright © 2023-24 Apple Inc. + +namespace mlx::core::metal { + +const char* get_kernel_preamble() { + return R"preamble( +$CONTENT +)preamble"; + +} + +} // namespace mlx::core::metal +EOF diff --git a/mlx/backend/metal/utils.h b/mlx/backend/metal/utils.h index 378850802..f7c672c9f 100644 --- a/mlx/backend/metal/utils.h +++ b/mlx/backend/metal/utils.h @@ -117,16 +117,18 @@ MTL::Size get_block_dims(int dim0, int dim1, int dim2) { // When multiple arrays are passed they should all have the same shape. The // collapsed axes are also the same so one shape is returned. std::tuple, std::vector>> -collapse_contiguous_dims(const std::vector& xs) { +collapse_contiguous_dims( + const std::vector& shape, + const std::vector> strides) { // Make a vector that has axes separated with -1. Collapse all axes between // -1. std::vector to_collapse; - if (xs[0].ndim() > 0) { + if (shape.size() > 0) { to_collapse.push_back(0); - for (int i = 1; i < xs[0].ndim(); i++) { + for (int i = 1; i < shape.size(); i++) { bool contiguous = true; - for (auto& x : xs) { - if (x.strides()[i] * x.shape()[i] != x.strides()[i - 1]) { + for (const std::vector& st : strides) { + if (st[i] * shape[i] != st[i - 1]) { contiguous = false; } if (!contiguous) { @@ -142,21 +144,31 @@ collapse_contiguous_dims(const std::vector& xs) { } std::vector out_shape; - std::vector> out_strides(xs.size()); + std::vector> out_strides(strides.size()); for (int i = 0; i < to_collapse.size(); i++) { - int current_shape = xs[0].shape()[to_collapse[i]]; + int current_shape = shape[to_collapse[i]]; while (to_collapse[++i] != -1) { - current_shape *= xs[0].shape()[to_collapse[i]]; + current_shape *= shape[to_collapse[i]]; } out_shape.push_back(current_shape); - for (int j = 0; j < xs.size(); j++) { - out_strides[j].push_back(xs[j].strides()[to_collapse[i - 1]]); + for (int j = 0; j < strides.size(); j++) { + const std::vector& st = strides[j]; + out_strides[j].push_back(st[to_collapse[i - 1]]); } } return std::make_tuple(out_shape, out_strides); } +std::tuple, std::vector>> +collapse_contiguous_dims(const std::vector& xs) { + std::vector> strides; + for (auto& x : xs) { + strides.emplace_back(x.strides()); + } + return collapse_contiguous_dims(xs[0].shape(), strides); +} + template std::tuple, std::vector>> collapse_contiguous_dims(Arrays... xs) { diff --git a/mlx/compile.cpp b/mlx/compile.cpp index fa9e0a987..c8ee3b0da 100644 --- a/mlx/compile.cpp +++ b/mlx/compile.cpp @@ -13,7 +13,7 @@ namespace mlx::core { -constexpr int max_compile_depth = 6; +constexpr int max_compile_depth = 10; bool is_unary(const Primitive& p) { return ( diff --git a/mlx/graph_utils.cpp b/mlx/graph_utils.cpp index 7c1c17740..ba031c441 100644 --- a/mlx/graph_utils.cpp +++ b/mlx/graph_utils.cpp @@ -12,27 +12,24 @@ namespace mlx::core { -struct NodeNamer { - std::unordered_map names; - - std::string get_name(const array& x) { - auto it = names.find(x.id()); - if (it == names.end()) { - // Get the next name in the sequence - // [A, B, ..., Z, AA, AB, ...] - std::vector letters; - auto var_num = names.size() + 1; - while (var_num > 0) { - letters.push_back('A' + (var_num - 1) % 26); - var_num = (var_num - 1) / 26; - } - std::string name(letters.rbegin(), letters.rend()); - names.insert({x.id(), name}); - return name; +const std::string& NodeNamer::get_name(const array& x) { + auto it = names.find(x.id()); + if (it == names.end()) { + // Get the next name in the sequence + // [A, B, ..., Z, AA, AB, ...] + std::vector letters; + auto var_num = names.size() + 1; + while (var_num > 0) { + letters.push_back('A' + (var_num - 1) % 26); + var_num = (var_num - 1) / 26; } - return it->second; + std::string name(letters.rbegin(), letters.rend()); + names.insert({x.id(), name}); + + return get_name(x); } -}; + return it->second; +} void depth_first_traversal( std::function callback, diff --git a/mlx/graph_utils.h b/mlx/graph_utils.h index 3bd373bec..5e024704e 100644 --- a/mlx/graph_utils.h +++ b/mlx/graph_utils.h @@ -6,6 +6,12 @@ namespace mlx::core { +struct NodeNamer { + std::unordered_map names; + + const std::string& get_name(const array& x); +}; + void print_graph(std::ostream& os, const std::vector& outputs); template diff --git a/mlx/primitives.h b/mlx/primitives.h index 5bdee12cf..b06a35780 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -473,23 +473,30 @@ class Compiled : public Primitive { void eval_cpu(const std::vector& inputs, std::vector& outputs) override; - void eval_gpu(const std::vector& inputs, std::vector& outputs) override; DEFINE_VMAP() DEFINE_GRADS() - void print(std::ostream& os) override; - bool is_equivalent(const Primitive& other) const override; + std::string metal_lib_name() const { + return kernel_lib_; + } + std::string metal_lib_source() const { + return kernel_source_; + } + private: const std::vector inputs_; const std::vector outputs_; const std::vector tape_; const std::unordered_set constant_ids_; + std::string kernel_lib_; + std::string kernel_source_; + void eval(const std::vector& inputs, std::vector& out); }; @@ -709,9 +716,16 @@ class Equal : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Equal) DEFINE_DEFAULT_IS_EQUIVALENT() + void print(std::ostream& os) override { + if (equal_nan_) { + os << "NanEqual"; + } else { + os << "Equal"; + } + } + private: void eval(const std::vector& inputs, array& out); bool equal_nan_; @@ -945,9 +959,22 @@ class Log : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Log) DEFINE_DEFAULT_IS_EQUIVALENT() + void print(std::ostream& os) override { + switch (base_) { + case e: + os << "Log"; + break; + case two: + os << "Log2"; + break; + case ten: + os << "Log10"; + break; + } + } + private: Base base_; void eval(const std::vector& inputs, array& out); @@ -1594,9 +1621,16 @@ class Sqrt : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Sqrt) bool is_equivalent(const Primitive& other) const override; + void print(std::ostream& os) override { + if (recip_) { + os << "Rsqrt"; + } else { + os << "Sqrt"; + } + } + private: void eval(const std::vector& inputs, array& out); bool recip_; diff --git a/tests/compile_tests.cpp b/tests/compile_tests.cpp index 935a881a8..be460e3b6 100644 --- a/tests/compile_tests.cpp +++ b/tests/compile_tests.cpp @@ -623,3 +623,63 @@ TEST_CASE("test transform compiled function") { CHECK(!outs[0].inputs()[0].has_primitive()); CHECK(!outs[0].inputs()[1].has_primitive()); } + +TEST_CASE("test metal fusion kernel reuse") { + if (default_device() != Device::gpu) { + return; + } + + auto cfun = compile(gelu_1); + auto x = array({2.0f, -2.0f}); + auto y = cfun({x})[0]; + auto p = std::dynamic_pointer_cast(y.primitive_ptr()); + eval(y); + + std::string lib_name = p->metal_lib_name(); + std::string lib_source = p->metal_lib_source(); + CHECK(!lib_name.empty()); + CHECK(!lib_source.empty()); + + x = astype(reshape(arange(10), {2, 5}), float32); + auto z = cfun({x})[0]; + auto pz = std::dynamic_pointer_cast(z.primitive_ptr()); + eval(z); + + std::string lib_name_z = pz->metal_lib_name(); + std::string lib_source_z = pz->metal_lib_source(); + CHECK(!lib_name_z.empty()); + CHECK(lib_source_z.empty()); + + CHECK_EQ(lib_name, lib_name_z); +} + +auto add3(const std::vector& xs) { + return std::vector{xs[0] + xs[0] + xs[0]}; +} + +TEST_CASE("test metal fusion types") { + if (default_device() != Device::gpu) { + return; + } + + auto cfun = compile(add3); + auto x = array({2.0f, -2.0f}); + auto y = cfun({x})[0]; + auto p = std::dynamic_pointer_cast(y.primitive_ptr()); + eval(y); + + std::string lib_name = p->metal_lib_name(); + std::string lib_source = p->metal_lib_source(); + CHECK(!lib_name.empty()); + CHECK(!lib_source.empty()); + + x = array({2, -2}, int32); + auto z = cfun({x})[0]; + auto pz = std::dynamic_pointer_cast(z.primitive_ptr()); + eval(z); + + std::string lib_name_z = pz->metal_lib_name(); + std::string lib_source_z = pz->metal_lib_source(); + CHECK(!lib_name_z.empty()); + CHECK(!lib_source_z.empty()); +} From e5e816a5efa7f639469737a25ca947a40e8bf76a Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Wed, 7 Feb 2024 13:22:27 -0800 Subject: [PATCH 05/42] fix sequential with empty modules at end (#647) --- python/mlx/nn/layers/base.py | 2 +- python/tests/test_nn.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/mlx/nn/layers/base.py b/python/mlx/nn/layers/base.py index 3da1993ec..febbafa78 100644 --- a/python/mlx/nn/layers/base.py +++ b/python/mlx/nn/layers/base.py @@ -312,7 +312,7 @@ class Module(dict): elif isinstance(current_value, (dict, list)): apply(current_value, new_value) elif isinstance(parameters, list): - for i in range(len(dst)): + for i in range(len(parameters)): current_value = dst[i] new_value = parameters[i] if isinstance(current_value, mx.array): diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index 7749e159a..d7b84bbf6 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -71,7 +71,7 @@ class TestBase(mlx_tests.MLXTestCase): def test_save_safetensors_weights(self): def make_model(): - return nn.Sequential(nn.Linear(2, 2), nn.ReLU(), nn.Linear(2, 2)) + return nn.Sequential(nn.Linear(2, 2), nn.ReLU(), nn.Linear(2, 2), nn.ReLU()) m = make_model() tdir = tempfile.TemporaryDirectory() From 1b97b2958b267d78180bdff161975693c5007680 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Wed, 7 Feb 2024 17:29:22 -0800 Subject: [PATCH 06/42] Compile with capture (#629) * Simple kernel generation * Remove the generate kernel from graph_utils * fix multi-output with compile * fuse with stopgrad * v1 input, output capture in compile * cleanup tree update with visitor update * nit * remove todo * state for model, optional explicit init and more pure optimizer steps * move learning rate to state * add lr to opt state, some fixes in capture * fix optim * update tuple of containers as well * fix stream for compiled output * rng state for compile * nit * updates and comments --------- Co-authored-by: Angelos Katharopoulos --- docs/src/_templates/nn-module-template.rst | 19 -- docs/src/python/nn/module.rst | 1 + docs/src/python/optimizer.rst | 23 ++ docs/src/python/optimizers.rst | 6 +- mlx/compile.cpp | 7 +- python/mlx/nn/layers/base.py | 13 ++ python/mlx/optimizers.py | 258 ++++++++++++--------- python/src/random.cpp | 70 +++++- python/src/transforms.cpp | 144 ++++++++++-- python/tests/test_compile.py | 80 +++++++ python/tests/test_eval.py | 8 + python/tests/test_nn.py | 5 + python/tests/test_optimizers.py | 246 +++++++++++++++++++- 13 files changed, 723 insertions(+), 157 deletions(-) delete mode 100644 docs/src/_templates/nn-module-template.rst create mode 100644 docs/src/python/optimizer.rst diff --git a/docs/src/_templates/nn-module-template.rst b/docs/src/_templates/nn-module-template.rst deleted file mode 100644 index 49f018eb5..000000000 --- a/docs/src/_templates/nn-module-template.rst +++ /dev/null @@ -1,19 +0,0 @@ -{{ fullname | escape | underline}} - -.. currentmodule:: {{ module }} - -.. autoclass:: {{ objname }} - - {#{% block methods %} - - {% if methods %} - .. rubric:: {{ _('Methods') }} - - .. autosummary:: - {% for item in methods %} - {%- if item not in inherited_members and item != '__init__' %} - ~{{ name }}.{{ item }} - {%- endif %} - {%- endfor %} - {% endif %} - {% endblock %}#} diff --git a/docs/src/python/nn/module.rst b/docs/src/python/nn/module.rst index 042a88028..c3a4dfa62 100644 --- a/docs/src/python/nn/module.rst +++ b/docs/src/python/nn/module.rst @@ -11,6 +11,7 @@ Module :toctree: _autosummary Module.training + Module.state .. rubric:: Methods diff --git a/docs/src/python/optimizer.rst b/docs/src/python/optimizer.rst new file mode 100644 index 000000000..cf6034dee --- /dev/null +++ b/docs/src/python/optimizer.rst @@ -0,0 +1,23 @@ +Optimizer +========= + +.. currentmodule:: mlx.optimizers + +.. autoclass:: Optimizer + + + .. rubric:: Attributes + + .. autosummary:: + :toctree: _autosummary + + Optimizer.state + + .. rubric:: Methods + + .. autosummary:: + :toctree: _autosummary + + Optimizer.apply_gradients + Optimizer.init + Optimizer.update diff --git a/docs/src/python/optimizers.rst b/docs/src/python/optimizers.rst index fe8632a7e..4ef43d50f 100644 --- a/docs/src/python/optimizers.rst +++ b/docs/src/python/optimizers.rst @@ -29,14 +29,16 @@ model's parameters and the **optimizer state**. # Compute the new parameters but also the optimizer state. mx.eval(model.parameters(), optimizer.state) +.. toctree:: + + optimizer + .. currentmodule:: mlx.optimizers .. autosummary:: :toctree: _autosummary :template: optimizers-template.rst - OptimizerState - Optimizer SGD RMSprop Adagrad diff --git a/mlx/compile.cpp b/mlx/compile.cpp index c8ee3b0da..e69c442f2 100644 --- a/mlx/compile.cpp +++ b/mlx/compile.cpp @@ -191,10 +191,7 @@ struct CompilerCache { auto is_match = [](const std::vector& in1, const std::vector& in2) { if (in1.size() != in2.size()) { - std::ostringstream msg; - msg << "[compiler] Unexpected number of inputs to compiled function:" - << " expected " << in2.size() << " got " << in1.size() << "."; - throw std::invalid_argument(msg.str()); + return false; } for (int i = 0; i < in1.size(); ++i) { if (in1[i].shape() != in2[i].shape()) { @@ -603,7 +600,7 @@ void compile_fuse( shapes, types, std::make_shared( - outputs.back().primitive().stream(), + old_outputs.back().primitive().stream(), inputs, old_outputs, std::move(fused_tape), diff --git a/python/mlx/nn/layers/base.py b/python/mlx/nn/layers/base.py index febbafa78..de7097673 100644 --- a/python/mlx/nn/layers/base.py +++ b/python/mlx/nn/layers/base.py @@ -66,6 +66,19 @@ class Module(dict): """Boolean indicating if the model is in training mode.""" return self._training + @property + def state(self): + """The module's state dictionary + + The module's state dictionary contains any attribute set on the + module including parameters in :meth:`Module.parameters` + + Unlike :meth:`Module.parameters`, the :attr:`Module.state` property is + a reference to the module's state. Updates to it will be reflected in + the original module. + """ + return self + def _extra_repr(self): return "" diff --git a/python/mlx/optimizers.py b/python/mlx/optimizers.py index b659ec5cf..4a53d4681 100644 --- a/python/mlx/optimizers.py +++ b/python/mlx/optimizers.py @@ -7,39 +7,14 @@ import mlx.core as mx from mlx.utils import tree_map -class OptimizerState(dict): - """The optimizer state implements a recursively defined - :class:`collections.defaultdict`, namely a missing key in an optimizer - state is an :class:`OptimizerState`. - - .. note:: - :meth:`OptimizerState.get` in contrast to a normal dictionary also sets - the key to the ``default`` value if the ``key`` was not present in the - dictionary. - """ - - def __getitem__(self, key): - if key not in self: - self[key] = OptimizerState() - return super().__getitem__(key) - - def get(self, key, default): - """If ``key`` doesn't exist set its value to ``default`` and then return it.""" - if key not in self: - self[key] = default - return super().__getitem__(key) - - class Optimizer: """The base class for all optimizers. It allows us to implement an optimizer on a per-parameter basis and apply it to a parameter tree. - - Attributes: - state (OptimizerState): It holds the optimizer's state dictionary. """ def __init__(self): - self.state = OptimizerState() + self._initialized = False + self._state = {} def update(self, model: "mlx.nn.Module", gradients: dict): """Apply the gradients to the parameters of the model and update the @@ -52,7 +27,41 @@ class Optimizer: """ model.update(self.apply_gradients(gradients, model)) - def apply_gradients(self, gradients: dict, model: dict): + def init(self, parameters: dict): + """Initialize the optimizer's state + + This function can be used to initialize optimizers which have state + (like momentum in :class:`SGD`). Using this method is optional as the + optimizer will initialize itself if the state is not yet set. However, + there are some cases where explicit initialization is useful in order + to have access to the :attr:`Optimizer.state` before the first call to + :meth:`Optimizer.update`. + + Args: + model (dict): A Python tree of parameters. + + Example: + >>> optimizer = optim.SGD(learning_rate=1e-1, momentum=0.9) + >>> model = nn.Linear(2, 2) + >>> optimizer.init(model.trainable_parameters()) + >>> optimizer.state + {'learning_rate': array(0.1, dtype=float32), 'weight': {'v': array([[0, 0], + [0, 0]], dtype=float32)}, 'bias': {'v': array([0, 0], dtype=float32)}} + """ + self._state.update(tree_map(lambda x: {}, parameters)) + tree_map(self.init_single, parameters, self._state) + self._initialized = True + + def init_single(self, parameter: mx.array, state: dict): + """To be extended by the children classes to implement each optimizer's + state initialization. + + Args: + parameter (mx.array): A single parameter that will be optimized. + """ + raise NotImplementedError() + + def apply_gradients(self, gradients: dict, parameters: dict): """Apply the gradients to the parameters and return the updated parameters. Can be used to update a model via @@ -61,19 +70,41 @@ class Optimizer: Args: gradients (dict): A Python tree of gradients. - model (dict): A Python tree of parameters. It can be a superset of - the gradients. In that case the returned python tree - will be of the same structure as the gradients. + parameters (dict): A Python tree of parameters. It can be a + superset of the gradients. In that case the returned python + tree will be of the same structure as the gradients. """ - return tree_map(self.apply_single, gradients, model, self.state) + if not self._initialized: + self.init(gradients) + return tree_map(self.apply_single, gradients, parameters, self.state) - def apply_single( - self, gradient: mx.array, parameter: mx.array, state: OptimizerState - ): - """To be extended by the children classes to implement each optimizer's - update.""" + def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict): + """To be extended by derived classes to implement the optimizer's update. + + Args: + gradient (mx.array): The ``parameter`` gradient. + parameter (mx.array): The ``parameter`` to update. + state (dict): The optimizer's state. + """ raise NotImplementedError() + @property + def state(self): + """The optimizer's state dictionary.""" + return self._state + + @state.setter + def state(self, state: dict): + self._state = state + + @property + def learning_rate(self): + return self.state["learning_rate"] + + @learning_rate.setter + def learning_rate(self, learning_rate: mx.array): + self.state["learning_rate"] = mx.array(learning_rate) + class SGD(Optimizer): r"""The stochastic gradient descent optimizer. @@ -113,9 +144,11 @@ class SGD(Optimizer): self.dampening = dampening self.nesterov = nesterov - def apply_single( - self, gradient: mx.array, parameter: mx.array, state: OptimizerState - ): + def init_single(self, parameter: mx.array, state: dict): + """Initialize optimizer state""" + state["v"] = mx.zeros_like(parameter) + + def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict): """Performs the SGD parameter update and stores :math:`v` in the optimizer state.""" @@ -123,24 +156,21 @@ class SGD(Optimizer): gradient += self.weight_decay * parameter if self.momentum <= 0: - return parameter - self.learning_rate * gradient + return parameter - self.learning_rate.astype(gradient.dtype) * gradient + v = self.momentum * state.get("v") if self.dampening > 0: - v = ( - state.get("v", (self.dampening / self.momentum) * gradient) - * self.momentum - ) v += (1 - self.dampening) * gradient else: - v = state.get("v", mx.zeros_like(gradient)) * self.momentum v += gradient if self.nesterov: update = gradient + self.momentum * v else: update = v + state["v"] = v - return parameter - self.learning_rate * update + return parameter - self.learning_rate.astype(gradient.dtype) * update class RMSprop(Optimizer): @@ -177,15 +207,17 @@ class RMSprop(Optimizer): f"RMSprop epsilon should be >0, {self.eps} was provided instead" ) - def apply_single( - self, gradient: mx.array, parameter: mx.array, state: OptimizerState - ): + def init_single(self, parameter: mx.array, state: dict): + """Initialize optimizer state""" + state["v"] = mx.zeros_like(parameter) + + def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict): """Performs the RMSprop parameter update and stores :math:`v` in the optimizer state.""" - lr = self.learning_rate + lr = self.learning_rate.astype(gradient.dtype) alpha = self.alpha eps = self.eps - v = state.get("v", mx.zeros_like(gradient)) + v = state["v"] v = alpha * v + (1 - alpha) * mx.square(gradient) state["v"] = v @@ -222,16 +254,17 @@ class Adagrad(Optimizer): f"Adagrad epsilon should be >0, {self.eps} was provided instead" ) - def apply_single( - self, gradient: mx.array, parameter: mx.array, state: OptimizerState - ): + def init_single(self, parameter: mx.array, state: dict): + """Initialize optimizer state""" + state["v"] = mx.zeros_like(parameter) + + def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict): """Performs the Adagrad parameter update and stores :math:`v` in the optimizer state.""" - lr = self.learning_rate + lr = self.learning_rate.astype(gradient.dtype) eps = self.eps - v = state.get("v", mx.zeros_like(gradient)) - v = v + mx.square(gradient) + v = state["v"] + mx.square(gradient) state["v"] = v return parameter - lr * gradient / (mx.sqrt(v) + eps) @@ -274,17 +307,20 @@ class AdaDelta(Optimizer): f"AdaDelta epsilon should be >0, {self.eps} was provided instead" ) - def apply_single( - self, gradient: mx.array, parameter: mx.array, state: OptimizerState - ): + def init_single(self, parameter: mx.array, state: dict): + """Initialize optimizer state""" + state["v"] = mx.zeros_like(parameter) + state["u"] = mx.zeros_like(parameter) + + def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict): """Performs the AdaDelta parameter update and stores :math:`v` and :math:`u` in the optimizer state.""" - lr = self.learning_rate + lr = self.learning_rate.astype(gradient.dtype) rho = self.rho eps = self.eps - v = state.get("v", mx.zeros_like(gradient)) - u = state.get("u", mx.zeros_like(gradient)) + v = state["v"] + u = state["u"] v = rho * v + (1 - rho) * mx.square(gradient) d = mx.sqrt(u + eps) / mx.sqrt(v + eps) * gradient @@ -329,17 +365,20 @@ class Adam(Optimizer): self.betas = betas self.eps = eps - def apply_single( - self, gradient: mx.array, parameter: mx.array, state: OptimizerState - ): + def init_single(self, parameter: mx.array, state: dict): + """Initialize optimizer state""" + state["m"] = mx.zeros_like(parameter) + state["v"] = mx.zeros_like(parameter) + + def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict): """Performs the Adam parameter update and stores :math:`v` and :math:`m` in the optimizer state.""" - lr = self.learning_rate + lr = self.learning_rate.astype(gradient.dtype) b1, b2 = self.betas eps = self.eps - m = state.get("m", gradient) - v = state.get("v", mx.square(gradient)) + m = state["m"] + v = state["v"] m = b1 * m + (1 - b1) * gradient v = b2 * v + (1 - b2) * mx.square(gradient) state["m"] = m @@ -385,15 +424,14 @@ class AdamW(Adam): super().__init__(learning_rate=learning_rate, betas=betas, eps=eps) self.weight_decay = weight_decay - def apply_single( - self, gradient: mx.array, parameter: mx.array, state: OptimizerState - ): + def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict): """Performs the AdamW parameter update by modifying the parameters passed into Adam. """ + lr = self.learning_rate.astype(gradient.dtype) return super().apply_single( - gradient, parameter * (1 - self.learning_rate * self.weight_decay), state + gradient, parameter * (1 - lr * self.weight_decay), state ) @@ -430,17 +468,20 @@ class Adamax(Adam): f"Epsilon value should be >=0, {self.eps} was provided instead" ) - def apply_single( - self, gradient: mx.array, parameter: mx.array, state: OptimizerState - ): + def init_single(self, parameter: mx.array, state: dict): + """Initialize optimizer state""" + state["m"] = mx.zeros_like(parameter) + state["v"] = mx.zeros_like(parameter) + + def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict): """Performs the Adamax parameter update and stores :math:`v` and :math:`m` in the optimizer state.""" - lr = self.learning_rate + lr = self.learning_rate.astype(gradient.dtype) b1, b2 = self.betas eps = self.eps - m = state.get("m", mx.zeros_like(gradient)) - v = state.get("v", mx.zeros_like(gradient)) + m = state["m"] + v = state["v"] m = b1 * m + (1 - b1) * gradient v = mx.maximum(b2 * v, mx.abs(gradient)) @@ -489,16 +530,18 @@ class Lion(Optimizer): self.betas = betas self.weight_decay = weight_decay - def apply_single( - self, gradient: mx.array, parameter: mx.array, state: OptimizerState - ): + def init_single(self, parameter: mx.array, state: dict): + """Initialize optimizer state""" + state["m"] = mx.zeros_like(parameter) + + def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict): """Performs the Lion parameter update and stores :math:`m` in the optimizer state.""" - lr = self.learning_rate + lr = self.learning_rate.astype(gradient.dtype) b1, b2 = self.betas weight_decay = self.weight_decay - m = state.get("m", gradient) + m = state["m"] c = b1 * m + (1 - b1) * gradient state["m"] = b2 * m + (1 - b2) * gradient if weight_decay > 0: @@ -552,7 +595,8 @@ class Adafactor(Optimizer): warmup_init: bool = False, ): super().__init__() - self.learning_rate = learning_rate + if learning_rate is not None: + self.learning_rate = learning_rate self.eps = eps self.clip_threshold = clip_threshold self.decay_rate = decay_rate @@ -562,14 +606,29 @@ class Adafactor(Optimizer): self.relative_step = relative_step self.warmup_init = warmup_init + def init_single(self, parameter: mx.array, state: dict): + """Initialize optimizer state""" + state["step"] = 0 + if parameter.ndim >= 2: + shape = parameter.shape + dtype = parameter.dtype + state["exp_avg_sq_row"] = mx.zeros(shape[:-1], dtype=dtype) + state["exp_avg_sq_col"] = mx.zeros(shape[:-2] + shape[-1:], dtype=dtype) + else: + state["exp_avg_sq"] = mx.zeros_like(parameter) + + if self.beta_1 is not None: + state["exp_avg"] = mx.zeros_like(parameter) + def _compute_rms(self, inputs): return mx.sqrt(mx.mean(mx.square(inputs))) def _compute_learning_rate(self, step, parameter_rms): - relative_step_size = self.learning_rate if self.relative_step: min_step = 1e-6 * step if self.warmup_init else 1e-2 relative_step_size = min(min_step, 1 / math.sqrt(step)) + else: + relative_step_size = self.learning_rate.astype(parameter_rms) parameter_scale = 1.0 if self.scale_parameter: @@ -585,13 +644,11 @@ class Adafactor(Optimizer): mx.expand_dims(r_factor, axis=-1), mx.expand_dims(c_factor, axis=0) ) - def apply_single( - self, gradient: mx.array, parameter: mx.array, state: OptimizerState - ): + def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict): """Performs the Adafactor parameter and state update.""" - gradient_shape = gradient.shape - factored = len(gradient_shape) >= 2 - step = state.get("step", 0) + 1 + factored = gradient.ndim >= 2 + + step = state["step"] + 1 state["step"] = step use_first_moment = self.beta_1 is not None @@ -601,15 +658,8 @@ class Adafactor(Optimizer): update = mx.square(gradient) + self.eps[0] if factored: - exp_avg_sq_row = state.get( - "exp_avg_sq_row", mx.zeros(gradient_shape[:-1], dtype=gradient.dtype) - ) - exp_avg_sq_col = state.get( - "exp_avg_sq_col", - mx.zeros( - gradient_shape[:-2] + gradient_shape[-1:], dtype=gradient.dtype - ), - ) + exp_avg_sq_row = state["exp_avg_sq_row"] + exp_avg_sq_col = state["exp_avg_sq_col"] exp_avg_sq_row = (beta_2 * exp_avg_sq_row) + ( (1 - beta_2) * mx.mean(update, axis=-1) ) @@ -621,7 +671,7 @@ class Adafactor(Optimizer): update = self._approximate_exp_moving_avg(exp_avg_sq_row, exp_avg_sq_col) update = update * gradient else: - exp_avg_sq = state.get("exp_avg_sq", mx.zeros_like(gradient)) + exp_avg_sq = state["exp_avg_sq"] exp_avg_sq = (beta_2 * exp_avg_sq) + ((1 - beta_2) * update) state["exp_avg_sq"] = exp_avg_sq update = mx.rsqrt(exp_avg_sq) * gradient @@ -632,7 +682,7 @@ class Adafactor(Optimizer): update = learning_rate * update if use_first_moment: - exp_avg = state.get("exp_avg", mx.zeros_like(gradient)) + exp_avg = state["exp_avg"] exp_avg = (self.beta_1 * exp_avg) + ((1 - self.beta_1) * update) state["exp_avg"] = exp_avg update = exp_avg diff --git a/python/src/random.cpp b/python/src/random.cpp index e9140e7d9..bbcb7a2c8 100644 --- a/python/src/random.cpp +++ b/python/src/random.cpp @@ -2,6 +2,7 @@ #include #include +#include #include "python/src/utils.h" @@ -13,13 +14,55 @@ using namespace py::literals; using namespace mlx::core; using namespace mlx::core::random; +class PyKeySequence { + public: + explicit PyKeySequence(uint64_t seed) { + state_.append(key(seed)); + } + + void seed(uint64_t seed) { + state_[0] = key(seed); + } + + array next() { + auto out = split(py::cast(state_[0])); + state_[0] = out.first; + return out.second; + } + + py::list state() { + return state_; + } + + void release() { + py::gil_scoped_acquire gil; + state_.release().dec_ref(); + } + + private: + py::list state_; +}; + +PyKeySequence& default_key() { + auto get_current_time_seed = []() { + auto now = std::chrono::system_clock::now(); + return std::chrono::duration_cast( + now.time_since_epoch()) + .count(); + }; + static PyKeySequence ks(get_current_time_seed()); + return ks; +} + void init_random(py::module_& parent_module) { auto m = parent_module.def_submodule( "random", "mlx.core.random: functionality related to random number generation"); + + m.attr("state") = default_key().state(); m.def( "seed", - &seed, + [](uint64_t seed) { default_key().seed(seed); }, "seed"_a, R"pbdoc( Seed the global PRNG. @@ -62,8 +105,9 @@ void init_random(py::module_& parent_module) { const ScalarOrArray& high, const std::vector& shape, std::optional type, - const std::optional& key, + const std::optional& key_, StreamOrDevice s) { + auto key = key_ ? key_.value() : default_key().next(); return uniform( to_array(low), to_array(high), @@ -101,11 +145,11 @@ void init_random(py::module_& parent_module) { std::optional type, float loc, float scale, - const std::optional& key, + const std::optional& key_, StreamOrDevice s) { + auto key = key_ ? key_.value() : default_key().next(); return normal(shape, type.value_or(float32), loc, scale, key, s); }, - "shape"_a = std::vector{}, "dtype"_a = std::optional{float32}, "loc"_a = 0.0, @@ -131,8 +175,9 @@ void init_random(py::module_& parent_module) { const ScalarOrArray& high, const std::vector& shape, std::optional type, - const std::optional& key, + const std::optional& key_, StreamOrDevice s) { + auto key = key_ ? key_.value() : default_key().next(); return randint( to_array(low), to_array(high), shape, type.value_or(int32), key, s); }, @@ -163,8 +208,9 @@ void init_random(py::module_& parent_module) { "bernoulli", [](const ScalarOrArray& p_, const std::optional> shape, - const std::optional& key, + const std::optional& key_, StreamOrDevice s) { + auto key = key_ ? key_.value() : default_key().next(); auto p = to_array(p_); if (shape.has_value()) { return bernoulli(p, shape.value(), key, s); @@ -199,8 +245,9 @@ void init_random(py::module_& parent_module) { const ScalarOrArray& upper_, const std::optional> shape_, std::optional type, - const std::optional& key, + const std::optional& key_, StreamOrDevice s) { + auto key = key_ ? key_.value() : default_key().next(); auto lower = to_array(lower_); auto upper = to_array(upper_); auto t = type.value_or(float32); @@ -239,8 +286,9 @@ void init_random(py::module_& parent_module) { "gumbel", [](const std::vector& shape, std::optional type, - const std::optional& key, + const std::optional& key_, StreamOrDevice s) { + auto key = key_ ? key_.value() : default_key().next(); return gumbel(shape, type.value_or(float32), key, s); }, "shape"_a = std::vector{}, @@ -267,8 +315,9 @@ void init_random(py::module_& parent_module) { int axis, const std::optional> shape, const std::optional num_samples, - const std::optional& key, + const std::optional& key_, StreamOrDevice s) { + auto key = key_ ? key_.value() : default_key().next(); if (shape.has_value() && num_samples.has_value()) { throw std::invalid_argument( "[categorical] At most one of shape or num_samples can be specified."); @@ -309,4 +358,7 @@ void init_random(py::module_& parent_module) { Returns: array: The ``shape``-sized output array with type ``uint32``. )pbdoc"); + // Register static Python object cleanup before the interpreter exits + auto atexit = py::module_::import("atexit"); + atexit.attr("register")(py::cpp_function([]() { default_key().release(); })); } diff --git a/python/src/transforms.cpp b/python/src/transforms.cpp index 78f867876..77170414a 100644 --- a/python/src/transforms.cpp +++ b/python/src/transforms.cpp @@ -135,6 +135,64 @@ py::object tree_map( }); } +void tree_visit_update( + py::object tree, + std::function visitor) { + std::function recurse; + recurse = [&](py::handle subtree) { + if (py::isinstance(subtree)) { + auto l = py::cast(subtree); + for (int i = 0; i < l.size(); ++i) { + l[i] = recurse(l[i]); + } + return py::cast(l); + } else if (py::isinstance(subtree)) { + for (auto item : subtree) { + recurse(item); + } + return py::cast(subtree); + } else if (py::isinstance(subtree)) { + auto d = py::cast(subtree); + for (auto item : d) { + d[item.first] = recurse(item.second); + } + return py::cast(d); + } else if (py::isinstance(subtree)) { + return visitor(subtree); + } else { + return py::cast(subtree); + } + }; + recurse(tree); +} + +// Fill a pytree (recursive dict or list of dict or list) +// in place with the given arrays +// Non dict or list nodes are ignored +void tree_fill(py::object& tree, const std::vector& values) { + size_t index = 0; + tree_visit_update( + tree, [&](py::handle node) { return py::cast(values[index++]); }); +} + +// Replace all the arrays from the src values with the dst values in the tree +void tree_replace( + py::object& tree, + const std::vector& src, + const std::vector& dst) { + std::unordered_map src_to_dst; + for (int i = 0; i < src.size(); ++i) { + src_to_dst.insert({src[i].id(), dst[i]}); + } + tree_visit_update(tree, [&](py::handle node) { + auto arr = py::cast(node); + if (auto it = src_to_dst.find(arr.id()); it != src_to_dst.end()) { + return py::cast(it->second); + } + return py::cast(arr); + }); +} + std::vector tree_flatten(py::object tree, bool strict = true) { std::vector flat_tree; @@ -495,9 +553,15 @@ std::unordered_map& tree_cache() { struct PyCompiledFun { py::function fun; size_t fun_id; + py::object captured_inputs; + py::object captured_outputs; + size_t num_outputs{0}; - PyCompiledFun(const py::function& fun) - : fun(fun), fun_id(reinterpret_cast(fun.ptr())) {} + PyCompiledFun(const py::function& fun, py::object inputs, py::object outputs) + : fun(fun), + fun_id(reinterpret_cast(fun.ptr())), + captured_inputs(inputs), + captured_outputs(outputs) {} PyCompiledFun(const PyCompiledFun&) = delete; PyCompiledFun& operator=(const PyCompiledFun&) = delete; @@ -505,23 +569,61 @@ struct PyCompiledFun { PyCompiledFun(PyCompiledFun&& other) : fun(std::move(other.fun)), fun_id(reinterpret_cast(fun.ptr())) { other.fun_id = 0; + captured_inputs = std::move(other.captured_inputs); + captured_outputs = std::move(other.captured_outputs); + num_outputs = other.num_outputs; }; py::object operator()(const py::args& args) { auto compile_fun = [this, &args](const std::vector& a) { - // Call the python function and flatten the outputs - auto [outputs, py_outputs] = tree_flatten_with_structure( - std::move(this->fun(*tree_unflatten(args, a))), true); + // Put tracers into captured inputs + std::vector flat_in_captures; + std::vector trace_captures; + if (!py::isinstance(captured_inputs)) { + flat_in_captures = tree_flatten(captured_inputs, false); + trace_captures.insert( + trace_captures.end(), a.end() - flat_in_captures.size(), a.end()); + tree_fill(captured_inputs, trace_captures); + } - tree_cache().insert({this->fun_id, py_outputs}); + auto [outputs, py_outputs] = tree_flatten_with_structure( + std::move(fun(*tree_unflatten(args, a))), false); + + tree_cache().insert({fun_id, py_outputs}); + + num_outputs = outputs.size(); + if (!py::isinstance(captured_outputs)) { + auto flat_out_captures = tree_flatten(captured_outputs, false); + outputs.insert( + outputs.end(), + std::make_move_iterator(flat_out_captures.begin()), + std::make_move_iterator(flat_out_captures.end())); + } + + // Replace tracers with originals in captured inputs + if (!py::isinstance(captured_inputs)) { + tree_replace(captured_inputs, trace_captures, flat_in_captures); + } return outputs; }; - // Inputs must be array or tree of arrays - auto inputs = tree_flatten(args, true); + auto inputs = tree_flatten(args, false); + if (!py::isinstance(captured_inputs)) { + auto flat_in_captures = tree_flatten(captured_inputs, false); + inputs.insert( + inputs.end(), + std::make_move_iterator(flat_in_captures.begin()), + std::make_move_iterator(flat_in_captures.end())); + } // Compile and call auto outputs = detail::compile(compile_fun, fun_id)(inputs); + if (!py::isinstance(captured_outputs)) { + std::vector captures( + std::make_move_iterator(outputs.begin() + num_outputs), + std::make_move_iterator(outputs.end())); + tree_fill(captured_outputs, captures); + } // Put the outputs back in the container py::object py_outputs = tree_cache().at(fun_id); @@ -534,6 +636,8 @@ struct PyCompiledFun { tree_cache().erase(fun_id); detail::compile_erase(fun_id); fun.release().dec_ref(); + captured_inputs.release().dec_ref(); + captured_outputs.release().dec_ref(); } }; @@ -601,7 +705,7 @@ void init_transforms(py::module_& m) { m.def( "eval", [](const py::args& args) { - std::vector arrays = tree_flatten(args); + std::vector arrays = tree_flatten(args, false); { py::gil_scoped_release nogil; eval(arrays); @@ -615,8 +719,8 @@ void init_transforms(py::module_& m) { Args: *args (arrays or trees of arrays): Each argument can be a single array or a tree of arrays. If a tree is given the nodes can be a Python - :class:`list`, :class:`tuple` or :class:`dict` but the leafs must all be - an :class:`array`. + :class:`list`, :class:`tuple` or :class:`dict`. Leaves which are not + arrays are ignored. )pbdoc"); m.def( "jvp", @@ -859,10 +963,14 @@ void init_transforms(py::module_& m) { "file"_a); m.def( "compile", - [](const py::function& fun) { - return py::cpp_function(PyCompiledFun{fun}); + [](const py::function& fun, + const py::object& inputs, + const py::object& outputs) { + return py::cpp_function(PyCompiledFun{fun, inputs, outputs}); }, "fun"_a, + "inputs"_a = std::nullopt, + "outputs"_a = std::nullopt, R"pbdoc( compile(fun: function) -> function @@ -872,6 +980,16 @@ void init_transforms(py::module_& m) { fun (function): A function which takes a variable number of :class:`array` or trees of :class:`array` and returns a variable number of :class:`array` or trees of :class:`array`. + inputs (list or dict, optional): These inputs will be captured during + the function compilation along with the inputs to ``fun``. The ``inputs`` + can be a :obj:`list` or a :obj:`dict` containing arbitrarily nested + lists, dictionaries, or arrays. Leaf nodes that are not + :obj:`array` are ignored. Default: ``None`` + outputs (list or dict, optional): These outputs will be captured and + updated in a compiled function. The ``outputs`` can be a + :obj:`list` or a :obj:`dict` containing arbitrarily nested lists, + dictionaries, or arrays. Leaf nodes that are not :obj:`array` are ignored. + Default: ``None`` Returns: function: A compiled function which has the same input arguments diff --git a/python/tests/test_compile.py b/python/tests/test_compile.py index 56dff8b3d..2e0bb1d7f 100644 --- a/python/tests/test_compile.py +++ b/python/tests/test_compile.py @@ -2,6 +2,7 @@ import io import unittest +from functools import partial import mlx.core as mx import mlx_tests @@ -301,6 +302,85 @@ class TestCompile(mlx_tests.MLXTestCase): cdfdx = mx.grad(outer)(x) self.assertTrue(mx.allclose(dfdx, cdfdx)) + def test_compile_capture(self): + # Test update captured state outside compiled function + state = {"y": mx.array(2)} + + @partial(mx.compile, inputs=state) + def test_state(x): + x = x + state["y"] + return x + + test_state(mx.array(1)) + # Check the state is unchanged + self.assertEqual(state["y"], 2) + + # Check the udpated state is used + state["y"] = mx.array(3) + out = test_state(mx.array(1)) + self.assertEqual(out.item(), 4) + + # Capture list + state = [mx.array(2)] + + @partial(mx.compile, inputs=state) + def test_state(x): + x = x + state[0] + return x + + out = test_state(mx.array(1)) + self.assertEqual(out.item(), 3) + state[0] = mx.array(3) + out = test_state(mx.array(1)) + self.assertEqual(out.item(), 4) + + # Capture tuple of list + state = ([mx.array(2)],) + + @partial(mx.compile, inputs=state) + def test_state(x): + x = x + state[0][0] + return x + + out = test_state(mx.array(1)) + self.assertEqual(out.item(), 3) + state[0][0] = mx.array(3) + out = test_state(mx.array(1)) + self.assertEqual(out.item(), 4) + + # Test state updated inside compiled function + state = {} + + @partial(mx.compile, outputs=state) + def test_state(x): + state["y"] = x + 3 + return mx.abs(x) + + test_state(mx.array(-1)) + self.assertEqual(state["y"].item(), 2) + + # Test state changed inside compiled function + # triggers recompile + state = {} + + @partial(mx.compile, inputs=state, outputs=state) + def test_state(x): + y = state.get("y", mx.array(0)) + state["y"] = x + y + return x + 2 * y + + test_state(mx.array(1)) + self.assertEqual(state["y"].item(), 1) + test_state(mx.array(1)) + self.assertEqual(state["y"].item(), 2) + + def test_compile_rng(self): + @partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state) + def fun(): + return mx.random.uniform(shape=(10, 10)) + + self.assertFalse(mx.allclose(fun(), fun(), 1e-2, 1e-2)) + if __name__ == "__main__": unittest.main() diff --git a/python/tests/test_eval.py b/python/tests/test_eval.py index 6619afa67..dc986a19a 100644 --- a/python/tests/test_eval.py +++ b/python/tests/test_eval.py @@ -24,6 +24,14 @@ class TestEval(mlx_tests.MLXTestCase): y = dfun_dx(mx.array(1.0)) self.assertEqual(y.item(), 6.0) + def test_eval_mixed(self): + x = mx.array(1) + 1 + 1 + y = 0 + z = "hello" + state = [x, y, z] + mx.eval(state) + self.assertEqual(x.item(), 3) + if __name__ == "__main__": unittest.main() diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index d7b84bbf6..201665f7f 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -130,6 +130,11 @@ class TestBase(mlx_tests.MLXTestCase): ] ) + def test_module_state(self): + m = nn.Linear(10, 1) + m.state["hello"] = "world" + self.assertEqual(m.state["hello"], "world") + class TestLayers(mlx_tests.MLXTestCase): def test_identity(self): diff --git a/python/tests/test_optimizers.py b/python/tests/test_optimizers.py index 59046184f..f894a7510 100644 --- a/python/tests/test_optimizers.py +++ b/python/tests/test_optimizers.py @@ -2,47 +2,209 @@ import inspect import unittest +from functools import partial import mlx.core as mx +import mlx.nn as nn import mlx.optimizers as opt import mlx.utils import mlx_tests +from mlx.utils import tree_flatten, tree_map def get_all_optimizers(): classes = dict() for name, obj in inspect.getmembers(opt): if inspect.isclass(obj): - if obj.__name__ not in ["OptimizerState", "Optimizer"]: + if obj.__name__ not in ["Optimizer"]: classes[name] = obj return classes +def tree_equal(fn, *args): + return all(v for _, v in tree_flatten(tree_map(fn, *args))) + + optimizers_dict = get_all_optimizers() class TestOptimizers(mlx_tests.MLXTestCase): + def test_optimizer_state(self): + optim = opt.SGD(0.1) + optim.state["hello"] = "world" + self.assertEqual(optim.state["hello"], "world") + + optim.state = {0: 1} + self.assertEqual(optim.state, {0: 1}) + def test_optimizers(self): params = { "first": [mx.zeros((10,)), mx.zeros((1,))], "second": mx.zeros((1,)), } - grads = mlx.utils.tree_map(lambda x: mx.ones_like(x), params) + grads = tree_map(lambda x: mx.ones_like(x), params) for optim_class in optimizers_dict.values(): optim = optim_class(0.1) update = optim.apply_gradients(grads, params) mx.eval(update) - equal_shape = mlx.utils.tree_map( - lambda x, y: x.shape == y.shape, params, update - ) + equal_shape = tree_map(lambda x, y: x.shape == y.shape, params, update) all_equal = all(v for _, v in mlx.utils.tree_flatten(equal_shape)) self.assertTrue(all_equal) + def test_types_conserved(self): + params = {"w": mx.ones((5, 5), mx.float16)} + grads = tree_map(lambda x: mx.ones_like(x), params) + for optim_class in optimizers_dict.values(): + optim = optim_class(0.1) + update = optim.apply_gradients(grads, params) + self.assertEqual(update["w"].dtype, mx.float16) + + def test_sgd(self): + params = { + "first": [mx.zeros((10,)), mx.zeros((1,))], + "second": mx.zeros((1,)), + } + grads = tree_map(lambda x: mx.ones_like(x), params) + + # Explicit init + optim = opt.SGD(learning_rate=1e-2, momentum=0.9) + optim.init(params) + self.assertTrue( + tree_equal( + lambda p, s: mx.array_equal(s["v"], mx.zeros_like(p)), + params, + optim.state, + ) + ) + + # Implicit init + optim = opt.SGD(learning_rate=1e-2, momentum=0.9) + optim.apply_gradients(grads, params) + self.assertTrue( + tree_equal(lambda g, s: mx.array_equal(s["v"], g), grads, optim.state) + ) + + def test_rmsprop(self): + params = { + "first": [mx.zeros((10,)), mx.zeros((1,))], + "second": mx.zeros((1,)), + } + grads = tree_map(lambda x: mx.ones_like(x), params) + + # Explicit init + optim = opt.RMSprop(learning_rate=1e-2) + optim.init(params) + self.assertTrue( + tree_equal( + lambda p, s: mx.array_equal(s["v"], mx.zeros_like(p)), + params, + optim.state, + ) + ) + + # Implicit init + alpha = 0.99 + optim = opt.RMSprop(learning_rate=1e-2, alpha=alpha) + optim.apply_gradients(grads, params) + self.assertTrue( + tree_equal( + lambda g, s: mx.allclose(s["v"], (1 - alpha) * g), grads, optim.state + ) + ) + + def test_adagrad(self): + params = { + "first": [mx.zeros((10,)), mx.zeros((1,))], + "second": mx.zeros((1,)), + } + grads = tree_map(lambda x: mx.ones_like(x), params) + + # Explicit init + optim = opt.Adagrad(learning_rate=1e-2) + optim.init(params) + self.assertTrue( + tree_equal( + lambda p, s: mx.array_equal(s["v"], mx.zeros_like(p)), + params, + optim.state, + ) + ) + + def test_adadelta(self): + params = { + "first": [mx.zeros((10,)), mx.zeros((1,))], + "second": mx.zeros((1,)), + } + grads = tree_map(lambda x: mx.ones_like(x), params) + + # Explicit init + optim = opt.AdaDelta(learning_rate=1e-2) + optim.init(params) + self.assertTrue( + tree_equal( + lambda p, s: mx.array_equal(s["v"], mx.zeros_like(p)), + params, + optim.state, + ) + ) + self.assertTrue( + tree_equal( + lambda p, s: mx.array_equal(s["u"], mx.zeros_like(p)), + params, + optim.state, + ) + ) + + def test_adam(self): + params = { + "first": [mx.zeros((10,)), mx.zeros((1,))], + "second": mx.zeros((1,)), + } + grads = tree_map(lambda x: mx.ones_like(x), params) + + # Explicit init + for optimizer in [opt.Adam, opt.AdamW, opt.Adamax]: + optim = optimizer(learning_rate=1e-2) + optim.init(params) + self.assertTrue( + tree_equal( + lambda p, s: mx.array_equal(s["v"], mx.zeros_like(p)), + params, + optim.state, + ) + ) + self.assertTrue( + tree_equal( + lambda p, s: mx.array_equal(s["m"], mx.zeros_like(p)), + params, + optim.state, + ) + ) + + def test_lion(self): + params = { + "first": [mx.zeros((10,)), mx.zeros((1,))], + "second": mx.zeros((1,)), + } + grads = tree_map(lambda x: mx.ones_like(x), params) + + # Explicit init + optim = opt.Lion(learning_rate=1e-2) + optim.init(params) + self.assertTrue( + tree_equal( + lambda p, s: mx.array_equal(s["m"], mx.zeros_like(p)), + params, + optim.state, + ) + ) + def test_adafactor(self): x = mx.zeros((5, 5)) grad = mx.ones_like(x) optimizer = opt.Adafactor() + optimizer.init(x) for _ in range(2): xp = optimizer.apply_single(grad, x, optimizer.state) self.assertEqual(xp.dtype, x.dtype) @@ -51,12 +213,86 @@ class TestOptimizers(mlx_tests.MLXTestCase): x = mx.zeros((5, 5), mx.float16) grad = mx.ones_like(x) optimizer = opt.Adafactor() + optimizer.init(x) for _ in range(2): xp = optimizer.apply_single(grad, x, optimizer.state) self.assertEqual(xp.dtype, x.dtype) self.assertEqual(xp.shape, x.shape) self.assertEqual(optimizer.state["step"], 2) + def test_compiled_optimizer(self): + model = nn.Linear(10, 10) + x = mx.random.uniform(shape=(2, 10)) + optim = opt.SGD(learning_rate=1e-2, momentum=0.9) + + orig_params = model.parameters() + + def loss(model, x): + return model(x).sum() + + # Uncompiled version + def step(x): + _, grad = nn.value_and_grad(model, loss)(model, x) + optim.update(model, grad) + + step(x) + uncompiled_params = model.parameters() + + # Pure version + def loss(params, x): + model.update(params) + return model(x).sum() + + model.update(orig_params) + optim = opt.SGD(learning_rate=1e-2, momentum=0.9) + + @mx.compile + def step(params, opt_state, x): + grad = mx.grad(loss)(params, x) + optim.state = opt_state + params = optim.apply_gradients(grad, params) + return params, optim.state + + optim.init(model.parameters()) + pure_params, _ = step(model.parameters(), optim.state, x) + self.assertTrue(mx.allclose(pure_params["weight"], uncompiled_params["weight"])) + self.assertTrue(mx.allclose(pure_params["bias"], uncompiled_params["bias"])) + + # Impure version + def loss(model, x): + return model(x).sum() + + model.update(orig_params) + optim = opt.SGD(learning_rate=1e-2, momentum=0.9) + state = [model.state, optim.state] + + @partial(mx.compile, inputs=state, outputs=state) + def step(x): + _, grad = nn.value_and_grad(model, loss)(model, x) + optim.update(model, grad) + + step(x) + impure_params = model.parameters() + self.assertTrue( + mx.allclose(impure_params["weight"], uncompiled_params["weight"]) + ) + self.assertTrue(mx.allclose(impure_params["bias"], uncompiled_params["bias"])) + + def test_update_lr_compiled(self): + params = {"w": mx.ones((5, 5))} + grads = tree_map(lambda x: mx.ones_like(x), params) + optim = opt.SGD(-1.0) + + @partial(mx.compile, inputs=optim.state) + def update(grads): + return optim.apply_gradients(grads, params) + + result = update(grads) + self.assertTrue(mx.allclose(result["w"], mx.full((5, 5), 2.0))) + optim.learning_rate = -2.0 + result = update(grads) + self.assertTrue(mx.allclose(result["w"], mx.full((5, 5), 3.0))) + if __name__ == "__main__": unittest.main() From 7dccd42133d9617c35ba68e802e22fc1d07e512a Mon Sep 17 00:00:00 2001 From: LeonEricsson <70749762+LeonEricsson@users.noreply.github.com> Date: Thu, 8 Feb 2024 18:01:59 +0100 Subject: [PATCH 07/42] updated calls to use loc &scale (#643) --- python/mlx/nn/init.py | 6 +++--- python/mlx/nn/layers/embedding.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/python/mlx/nn/init.py b/python/mlx/nn/init.py index 5afc6170e..6596ba741 100644 --- a/python/mlx/nn/init.py +++ b/python/mlx/nn/init.py @@ -60,7 +60,7 @@ def normal( """ def initializer(a: mx.array) -> mx.array: - return std * mx.random.normal(shape=a.shape, dtype=dtype) + mean + return mx.random.normal(shape=a.shape, scale=std, loc=mean, dtype=dtype) return initializer @@ -184,7 +184,7 @@ def glorot_normal( def initializer(a: mx.array, gain: float = 1.0) -> mx.array: fan_in, fan_out = _calculate_fan_in_fan_out(a) std = gain * math.sqrt(2.0 / (fan_in + fan_out)) - return mx.random.normal(shape=a.shape, dtype=dtype) * std + return mx.random.normal(shape=a.shape, scale=std, dtype=dtype) return initializer @@ -285,7 +285,7 @@ def he_normal( raise ValueError(f"Invalid mode: {mode}. Valid modes are: fan_in, fan_out") std = gain / math.sqrt(fan) - return mx.random.normal(shape=a.shape, dtype=dtype) * std + return mx.random.normal(shape=a.shape, scale=std, dtype=dtype) return initializer diff --git a/python/mlx/nn/layers/embedding.py b/python/mlx/nn/layers/embedding.py index c62b1206f..18482eddc 100644 --- a/python/mlx/nn/layers/embedding.py +++ b/python/mlx/nn/layers/embedding.py @@ -21,7 +21,7 @@ class Embedding(Module): def __init__(self, num_embeddings: int, dims: int): super().__init__() scale = math.sqrt(1 / dims) - self.weight = mx.random.normal((num_embeddings, dims)) * scale + self.weight = mx.random.normal(shape=(num_embeddings, dims), scale=scale) def _extra_repr(self): return f"{self.weight.shape[0]}, {self.weight.shape[1]}" From 5c03efaf296b30b361e5d47afb0fe3af58a93961 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 8 Feb 2024 11:21:50 -0800 Subject: [PATCH 08/42] Compile docs (#653) * compile docs * docs nits + comments --- docs/src/index.rst | 1 + docs/src/python/transforms.rst | 3 + docs/src/usage/compile.rst | 430 +++++++++++++++++++++++++ docs/src/usage/function_transforms.rst | 17 +- python/src/transforms.cpp | 2 +- 5 files changed, 445 insertions(+), 8 deletions(-) create mode 100644 docs/src/usage/compile.rst diff --git a/docs/src/index.rst b/docs/src/index.rst index 4f4411758..50dfe9083 100644 --- a/docs/src/index.rst +++ b/docs/src/index.rst @@ -41,6 +41,7 @@ are the CPU and GPU. usage/indexing usage/saving_and_loading usage/function_transforms + usage/compile usage/numpy usage/using_streams diff --git a/docs/src/python/transforms.rst b/docs/src/python/transforms.rst index cc8d681d5..ad9ba579b 100644 --- a/docs/src/python/transforms.rst +++ b/docs/src/python/transforms.rst @@ -9,6 +9,9 @@ Transforms :toctree: _autosummary eval + compile + disable_compile + enable_compile grad value_and_grad jvp diff --git a/docs/src/usage/compile.rst b/docs/src/usage/compile.rst new file mode 100644 index 000000000..97d5503a3 --- /dev/null +++ b/docs/src/usage/compile.rst @@ -0,0 +1,430 @@ +.. _compile: + +Compilation +=========== + +.. currentmodule:: mlx.core + +MLX has a :func:`compile` function transformation which compiles computation +graphs. Function compilation results in smaller graphs by merging common work +and fusing certain operations. In many cases this can lead to big improvements +in run-time and memory use. + +Getting started with :func:`compile` is simple, but there are some edge cases +that are good to be aware of for more complex graphs and advanced usage. + +Basics of Compile +----------------- + +Let's start with a simple example: + +.. code-block:: python + + def fun(x, y): + return mx.exp(-x) + y + + x = mx.array(1.0) + y = mx.array(2.0) + + # Regular call, no compilation + # Prints: array(2.36788, dtype=float32) + print(fun(x, y)) + + # Compile the function + compiled_fun = mx.compile(fun) + + # Prints: array(2.36788, dtype=float32) + print(compiled_fun(x, y)) + +The output of both the regular function and the compiled function is the same +up to numerical precision. + +The first time you call a compiled function, MLX will build the compute +graph, optimize it, and generate and compile code. This can be relatively +slow. However, MLX will cache compiled functions, so calling a compiled +function multiple times will not initiate a new compilation. This means you +should typically compile functions that you plan to use more than once. + +.. code-block:: python + + def fun(x, y): + return mx.exp(-x) + y + + x = mx.array(1.0) + y = mx.array(2.0) + + compiled_fun = mx.compile(fun) + + # Compiled here + compiled_fun(x, y) + + # Not compiled again + compiled_fun(x, y) + + # Not compiled again + mx.compile(fun)(x, y) + +There are some important cases to be aware of that can cause a function to +be recompiled: + +* Changing the shape or number of dimensions +* Changing the type of any of the inputs +* Changing the number of inputs to the function + +In certain cases only some of the compilation stack will be rerun (for +example when changing the shapes) and in other cases the full compilation +stack will be rerun (for example when changing the types). In general you +should avoid compiling functions too frequently. + +Another idiom to watch out for is compiling functions which get created and +destroyed frequently. This can happen, for example, when compiling an anonymous +function in a loop: + +.. code-block:: python + + a = mx.array(1.0) + # Don't do this, compiles lambda at each iteration + for _ in range(5): + mx.compile(lambda x: mx.exp(mx.abs(x)))(a) + +Example Speedup +--------------- + +The :func:`mlx.nn.gelu` is a nonlinear activation function commonly used with +Transformer-based models. The implementation involves several unary and binary +element-wise operations: + +.. code-block:: python + + def gelu(x): + return x * (1 + mx.erf(x / math.sqrt(2))) / 2 + +If you use this function with small arrays, it will be overhead bound. If you +use it with large arrays it will be memory bandwidth bound. However, all of +the operations in the ``gelu`` are fusible into a single kernel with +:func:`compile`. This can speedup both cases considerably. + +Let's compare the runtime of the regular function versus the compiled +function. We'll use the following timing helper which does a warm up and +handles synchronization: + +.. code-block:: python + + import time + + def timeit(fun, x): + # warm up + for _ in range(10): + mx.eval(fun(x)) + + tic = time.perf_counter() + for _ in range(100): + mx.eval(fun(x)) + toc = time.perf_counter() + tpi = 1e3 * (toc - tic) / 100 + print(f"Time per iteration {tpi:.3f} (ms)") + + +Now make an array, and benchmark both functions: + +.. code-block:: python + + x = mx.random.uniform(shape=(32, 1000, 4096)) + timeit(nn.gelu, x) + timeit(mx.compile(nn.gelu), x) + +On an M1 Max the times are 15.5 and 3.1 milliseconds. The compiled ``gelu`` is +five times faster. + +.. note:: + + As of the latest MLX, CPU functions are not fully compiled. Compiling CPU + functions can still be helpful, but won't typically result in as large a + speedup as compiling operations that run on the GPU. + + +Debugging +--------- + +When a compiled function is first called, it is traced with placeholder +inputs. This means you can't evaluate arrays (for example to print their +contents) inside compiled functions. + +.. code-block:: python + + @mx.compile + def fun(x): + z = -x + print(z) # Crash + return mx.exp(z) + + fun(mx.array(5.0)) + +For debugging, inspecting arrays can be helpful. One way to do that is to +globally disable compilation using the :func:`disable_compile` function or +``MLX_DISABLE_COMPILE`` flag. For example the following is okay even though +``fun`` is compiled: + +.. code-block:: python + + @mx.compile + def fun(x): + z = -x + print(z) # Okay + return mx.exp(z) + + mx.disable_compile() + fun(mx.array(5.0)) + + +Pure Functions +-------------- + +Compiled functions are intended to be *pure*; that is they should not have side +effects. For example: + +.. code-block:: python + + state = [] + + @mx.compile + def fun(x, y): + z = x + y + state.append(z) + return mx.exp(z) + + fun(mx.array(1.0), mx.array(2.0)) + # Crash! + print(state) + +After the first call of ``fun``, the ``state`` list will hold a placeholder +array. The placeholder does not have any data; it is only used to build the +computation graph. Printing such an array results in a crash. + +You have two options to deal with this. The first option is to simply return +``state`` as an output: + +.. code-block:: python + + state = [] + + @mx.compile + def fun(x, y): + z = x + y + state.append(z) + return mx.exp(z), state + + _, state = fun(mx.array(1.0), mx.array(2.0)) + # Prints [array(3, dtype=float32)] + print(state) + +In some cases returning updated state can be pretty inconvenient. Hence, +:func:`compile` has a parameter to capture implicit outputs: + +.. code-block:: python + + from functools import partial + + state = [] + + # Tell compile to capture state as an output + @partial(mx.compile, outputs=state) + def fun(x, y): + z = x + y + state.append(z) + return mx.exp(z), state + + fun(mx.array(1.0), mx.array(2.0)) + # Prints [array(3, dtype=float32)] + print(state) + +This is particularly useful for compiling a function which includes an update +to a container of arrays, as is commonly done when training the parameters of a +:class:`mlx.nn.Module`. + +Compiled functions will also treat any inputs not in the parameter list as +constants. For example: + +.. code-block:: python + + state = [mx.array(1.0)] + + @mx.compile + def fun(x): + return x + state[0] + + # Prints array(2, dtype=float32) + print(fun(mx.array(1.0))) + + # Update state + state[0] = mx.array(5.0) + + # Still prints array(2, dtype=float32) + print(fun(mx.array(1.0))) + +In order to have the change of state reflected in the outputs of ``fun`` you +again have two options. The first option is to simply pass ``state`` as input +to the function. In some cases this can be pretty inconvenient. Hence, +:func:`compile` also has a parameter to capture implicit inputs: + +.. code-block:: python + + from functools import partial + state = [mx.array(1.0)] + + # Tell compile to capture state as an input + @partial(mx.compile, inputs=state) + def fun(x): + return x + state[0] + + # Prints array(2, dtype=float32) + print(fun(mx.array(1.0))) + + # Update state + state[0] = mx.array(5.0) + + # Prints array(6, dtype=float32) + print(fun(mx.array(1.0))) + + +Compiling Training Graphs +------------------------- + +This section will step through how to use :func:`compile` with a simple example +of a common setup: training a model with :obj:`mlx.nn.Module` using an +:obj:`mlx.optimizers.Optimizer` with state. We will show how to compile the +full forward, backward, and update with :func:`compile`. + +To start, here is the simple example without any compilation: + +.. code-block:: python + + import mlx.core as mx + import mlx.nn as nn + import mlx.optimizers as optim + + # 4 examples with 10 features each + x = mx.random.uniform(shape=(4, 10)) + + # 0, 1 targets + y = mx.array([0, 1, 0, 1]) + + # Simple linear model + model = nn.Linear(10, 1) + + # SGD with momentum + optimizer = optim.SGD(learning_rate=0.1, momentum=0.8) + + def loss_fn(model, x, y): + logits = model(x).squeeze() + return nn.losses.binary_cross_entropy(logits, y) + + loss_and_grad_fn = nn.value_and_grad(model, loss_fn) + + # Perform 10 steps of gradient descent + for it in range(10): + loss, grads = loss_and_grad_fn(model, x, y) + optimizer.update(model, grads) + mx.eval(model.parameters(), optimizer.state) + +To compile the update we can put it all in a function and compile it with the +appropriate input and output captures. Here's the same example but compiled: + +.. code-block:: python + + import mlx.core as mx + import mlx.nn as nn + import mlx.optimizers as optim + from functools import partial + + # 4 examples with 10 features each + x = mx.random.uniform(shape=(4, 10)) + + # 0, 1 targets + y = mx.array([0, 1, 0, 1]) + + # Simple linear model + model = nn.Linear(10, 1) + + # SGD with momentum + optimizer = optim.SGD(learning_rate=0.1, momentum=0.8) + + def loss_fn(model, x, y): + logits = model(x).squeeze() + return nn.losses.binary_cross_entropy(logits, y) + + # The state that will be captured as input and output + state = [model.state, optimizer.state] + + @partial(mx.compile, inputs=state, outputs=state) + def step(x, y): + loss_and_grad_fn = nn.value_and_grad(model, loss_fn) + loss, grads = loss_and_grad_fn(model, x, y) + optimizer.update(model, grads) + return loss + + # Perform 10 steps of gradient descent + for it in range(10): + loss = step(x, y) + # Evaluate the model and optimizer state + mx.eval(state) + print(loss) + + +.. note:: + + If you are using a module which performs random sampling such as + :func:`mlx.nn.Dropout`, make sure you also include ``mx.random.state`` in the + ``state`` captured by :func:`compile`, i.e. ``state = [model.state, + optimizer.state, mx.random.state]``. + + +.. note:: + + For more examples of compiling full training graphs checkout the `MLX + Examples `_ GitHub repo. + +Transformations with Compile +---------------------------- + +In MLX function transformations are composable. You can apply any function +transformation to the output of any other function transformation. For more on +this, see the documentation on :ref:`function transforms +`. + +Compiling transformed functions works just as expected: + +.. code-block:: python + + grad_fn = mx.grad(mx.exp) + + compiled_grad_fn = mx.compile(grad_fn) + + # Prints: array(2.71828, dtype=float32) + print(grad_fn(mx.array(1.0))) + + # Also prints: array(2.71828, dtype=float32) + print(compiled_grad_fn(mx.array(1.0))) + +.. note:: + + In order to compile as much as possible, a transformation of a compiled + function will not by default be compiled. To compile the transformed + function simply pass it through :func:`compile`. + +You can also compile functions which themselves call compiled functions. A +good practice is to compile the outer most function to give :func:`compile` +the most opportunity to optimize the computation graph: + +.. code-block:: python + + @mx.compile + def inner(x): + return mx.exp(-mx.abs(x)) + + def outer(x): + inner(inner(x)) + + # Compiling the outer function is good to do as it will likely + # be faster even though the inner functions are compiled + fun = mx.compile(outer) diff --git a/docs/src/usage/function_transforms.rst b/docs/src/usage/function_transforms.rst index 72a313f97..02c5dec48 100644 --- a/docs/src/usage/function_transforms.rst +++ b/docs/src/usage/function_transforms.rst @@ -5,9 +5,12 @@ Function Transforms .. currentmodule:: mlx.core -MLX uses composable function transformations for automatic differentiation and -vectorization. The key idea behind composable function transformations is that -every transformation returns a function which can be further transformed. +MLX uses composable function transformations for automatic differentiation, +vectorization, and compute graph optimizations. To see the complete list of +function transformations check-out the :ref:`API documentation `. + +The key idea behind composable function transformations is that every +transformation returns a function which can be further transformed. Here is a simple example: @@ -36,10 +39,10 @@ Using :func:`grad` on the output of :func:`grad` is always ok. You keep getting higher order derivatives. Any of the MLX function transformations can be composed in any order to any -depth. To see the complete list of function transformations check-out the -:ref:`API documentation `. See the following sections for more -information on :ref:`automatic differentiaion ` and -:ref:`automatic vectorization `. +depth. See the following sections for more information on :ref:`automatic +differentiaion ` and :ref:`automatic vectorization `. +For more information on :func:`compile` see the :ref:`compile documentation `. + Automatic Differentiation ------------------------- diff --git a/python/src/transforms.cpp b/python/src/transforms.cpp index 77170414a..f081fdedd 100644 --- a/python/src/transforms.cpp +++ b/python/src/transforms.cpp @@ -1008,7 +1008,7 @@ void init_transforms(py::module_& m) { "enable_compile", &enable_compile, R"pbdoc( - enable_compiler() -> None + enable_compile() -> None Globally enable compilation. This will override the environment variable ``MLX_DISABLE_COMPILE`` if set. From 221f8d3fc2b0cebfee7fb8aa157c67e3947012a0 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Thu, 8 Feb 2024 11:27:12 -0800 Subject: [PATCH 09/42] Bump the version to 0.2 (#656) --- CMakeLists.txt | 2 +- setup.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 2d6bef9d2..b008e10e6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -18,7 +18,7 @@ option(MLX_BUILD_METAL "Build metal backend" ON) option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF) if(NOT MLX_VERSION) - set(MLX_VERSION 0.1.0) + set(MLX_VERSION 0.2.0) endif() # --------------------- Processor tests ------------------------- diff --git a/setup.py b/setup.py index 44cccf62d..cce1f8537 100644 --- a/setup.py +++ b/setup.py @@ -152,7 +152,7 @@ if __name__ == "__main__": setup( name="mlx", - version=get_version("0.1.0"), + version=get_version("0.2.0"), author="MLX Contributors", author_email="mlx@group.apple.com", description="A framework for machine learning on Apple silicon.", From b57bd0488daf195b1490db0d5ebeb833b7940e29 Mon Sep 17 00:00:00 2001 From: Diogo Date: Thu, 8 Feb 2024 22:33:15 -0500 Subject: [PATCH 10/42] Metadata support for safetensors (#639) * metadata support for safetensors * aliases making it alittle more readable * addressing comments * python binding tests --- mlx/io.h | 29 +++++++++++++---------- mlx/io/gguf.cpp | 15 +++++------- mlx/io/safetensor.cpp | 29 +++++++++++++---------- python/src/load.cpp | 49 ++++++++++++++++++++++++++------------- python/src/load.h | 18 +++++++------- python/src/ops.cpp | 4 +++- python/tests/test_load.py | 9 +++++++ tests/load_tests.cpp | 24 ++++++++++++------- 8 files changed, 108 insertions(+), 69 deletions(-) diff --git a/mlx/io.h b/mlx/io.h index c58e1959e..59866ea27 100644 --- a/mlx/io.h +++ b/mlx/io.h @@ -10,6 +10,14 @@ #include "mlx/stream.h" namespace mlx::core { +using GGUFMetaData = + std::variant>; +using GGUFLoad = std::pair< + std::unordered_map, + std::unordered_map>; +using SafetensorsLoad = std::pair< + std::unordered_map, + std::unordered_map>; /** Save array to out stream in .npy format */ void save(std::shared_ptr out_stream, array a); @@ -24,32 +32,29 @@ array load(std::shared_ptr in_stream, StreamOrDevice s = {}); array load(const std::string& file, StreamOrDevice s = {}); /** Load array map from .safetensors file format */ -std::unordered_map load_safetensors( +SafetensorsLoad load_safetensors( std::shared_ptr in_stream, StreamOrDevice s = {}); -std::unordered_map load_safetensors( +SafetensorsLoad load_safetensors( const std::string& file, StreamOrDevice s = {}); void save_safetensors( std::shared_ptr in_stream, - std::unordered_map); + std::unordered_map, + std::unordered_map metadata = {}); void save_safetensors( const std::string& file, - std::unordered_map); - -using MetaData = - std::variant>; + std::unordered_map, + std::unordered_map metadata = {}); /** Load array map and metadata from .gguf file format */ -std::pair< - std::unordered_map, - std::unordered_map> -load_gguf(const std::string& file, StreamOrDevice s = {}); + +GGUFLoad load_gguf(const std::string& file, StreamOrDevice s = {}); void save_gguf( std::string file, std::unordered_map array_map, - std::unordered_map meta_data = {}); + std::unordered_map meta_data = {}); } // namespace mlx::core diff --git a/mlx/io/gguf.cpp b/mlx/io/gguf.cpp index f4047d1a0..9e7953d6e 100644 --- a/mlx/io/gguf.cpp +++ b/mlx/io/gguf.cpp @@ -82,7 +82,7 @@ void set_mx_value_from_gguf( gguf_ctx* ctx, uint32_t type, gguf_value* val, - MetaData& value) { + GGUFMetaData& value) { switch (type) { case GGUF_VALUE_TYPE_UINT8: value = array(val->uint8, uint8); @@ -191,12 +191,12 @@ void set_mx_value_from_gguf( } } -std::unordered_map load_metadata(gguf_ctx* ctx) { - std::unordered_map metadata; +std::unordered_map load_metadata(gguf_ctx* ctx) { + std::unordered_map metadata; gguf_key key; while (gguf_get_key(ctx, &key)) { std::string key_name = std::string(key.name, key.namelen); - auto& val = metadata.insert({key_name, MetaData{}}).first->second; + auto& val = metadata.insert({key_name, GGUFMetaData{}}).first->second; set_mx_value_from_gguf(ctx, key.type, key.val, val); } return metadata; @@ -230,10 +230,7 @@ std::unordered_map load_arrays(gguf_ctx* ctx) { return array_map; } -std::pair< - std::unordered_map, - std::unordered_map> -load_gguf(const std::string& file, StreamOrDevice s) { +GGUFLoad load_gguf(const std::string& file, StreamOrDevice s) { gguf_ctx* ctx = gguf_open(file.c_str()); if (!ctx) { throw std::runtime_error("[load_gguf] gguf_init failed"); @@ -280,7 +277,7 @@ void append_kv_array( void save_gguf( std::string file, std::unordered_map array_map, - std::unordered_map metadata /* = {} */) { + std::unordered_map metadata /* = {} */) { // Add .gguf to file name if it is not there if (file.length() < 5 || file.substr(file.length() - 5, 5) != ".gguf") { file += ".gguf"; diff --git a/mlx/io/safetensor.cpp b/mlx/io/safetensor.cpp index 7e7868d49..1dd59f444 100644 --- a/mlx/io/safetensor.cpp +++ b/mlx/io/safetensor.cpp @@ -93,7 +93,7 @@ Dtype dtype_from_safetensor_str(std::string str) { } /** Load array from reader in safetensor format */ -std::unordered_map load_safetensors( +SafetensorsLoad load_safetensors( std::shared_ptr in_stream, StreamOrDevice s) { //////////////////////////////////////////////////////// @@ -121,9 +121,12 @@ std::unordered_map load_safetensors( size_t offset = jsonHeaderLength + 8; // Load the arrays using metadata std::unordered_map res; + std::unordered_map metadata_map; for (const auto& item : metadata.items()) { if (item.key() == "__metadata__") { - // ignore metadata for now + for (const auto& meta_item : item.value().items()) { + metadata_map.insert({meta_item.key(), meta_item.value()}); + } continue; } std::string dtype = item.value().at("dtype"); @@ -138,19 +141,18 @@ std::unordered_map load_safetensors( std::vector{}); res.insert({item.key(), loaded_array}); } - return res; + return {res, metadata_map}; } -std::unordered_map load_safetensors( - const std::string& file, - StreamOrDevice s) { +SafetensorsLoad load_safetensors(const std::string& file, StreamOrDevice s) { return load_safetensors(std::make_shared(file), s); } /** Save array to out stream in .npy format */ void save_safetensors( std::shared_ptr out_stream, - std::unordered_map a) { + std::unordered_map a, + std::unordered_map metadata /* = {} */) { //////////////////////////////////////////////////////// // Check file if (!out_stream->good() || !out_stream->is_open()) { @@ -161,9 +163,11 @@ void save_safetensors( //////////////////////////////////////////////////////// // Check array map json parent; - parent["__metadata__"] = json::object({ - {"format", "mlx"}, - }); + json _metadata; + for (auto& [key, value] : metadata) { + _metadata[key] = value; + } + parent["__metadata__"] = _metadata; size_t offset = 0; for (auto& [key, arr] : a) { arr.eval(); @@ -204,7 +208,8 @@ void save_safetensors( void save_safetensors( const std::string& file_, - std::unordered_map a) { + std::unordered_map a, + std::unordered_map metadata /* = {} */) { // Open and check file std::string file = file_; @@ -214,7 +219,7 @@ void save_safetensors( file += ".safetensors"; // Serialize array - save_safetensors(std::make_shared(file), a); + save_safetensors(std::make_shared(file), a, metadata); } } // namespace mlx::core diff --git a/python/src/load.cpp b/python/src/load.cpp index 18e89c7fb..9b6a6861e 100644 --- a/python/src/load.cpp +++ b/python/src/load.cpp @@ -160,31 +160,29 @@ class PyFileReader : public io::Reader { py::object tell_func_; }; -std::unordered_map mlx_load_safetensor_helper( - py::object file, - StreamOrDevice s) { +std::pair< + std::unordered_map, + std::unordered_map> +mlx_load_safetensor_helper(py::object file, StreamOrDevice s) { if (py::isinstance(file)) { // Assume .safetensors file path string return load_safetensors(py::cast(file), s); } else if (is_istream_object(file)) { // If we don't own the stream and it was passed to us, eval immediately - auto arr = load_safetensors(std::make_shared(file), s); + auto res = load_safetensors(std::make_shared(file), s); { py::gil_scoped_release gil; - for (auto& [key, arr] : arr) { + for (auto& [key, arr] : std::get<0>(res)) { arr.eval(); } } - return arr; + return res; } throw std::invalid_argument( "[load_safetensors] Input must be a file-like object, or string"); } -std::pair< - std::unordered_map, - std::unordered_map> -mlx_load_gguf_helper(py::object file, StreamOrDevice s) { +GGUFLoad mlx_load_gguf_helper(py::object file, StreamOrDevice s) { if (py::isinstance(file)) { // Assume .gguf file path string return load_gguf(py::cast(file), s); } @@ -274,12 +272,16 @@ LoadOutputTypes mlx_load_helper( format.emplace(fname.substr(ext + 1)); } - if (return_metadata && format.value() != "gguf") { + if (return_metadata && (format.value() == "npy" || format.value() == "npz")) { throw std::invalid_argument( "[load] metadata not supported for format " + format.value()); } if (format.value() == "safetensors") { - return mlx_load_safetensor_helper(file, s); + auto [dict, metadata] = mlx_load_safetensor_helper(file, s); + if (return_metadata) { + return std::make_pair(dict, metadata); + } + return dict; } else if (format.value() == "npz") { return mlx_load_npz_helper(file, s); } else if (format.value() == "npy") { @@ -444,18 +446,33 @@ void mlx_savez_helper( return; } -void mlx_save_safetensor_helper(py::object file, py::dict d) { +void mlx_save_safetensor_helper( + py::object file, + py::dict d, + std::optional m) { + std::unordered_map metadata_map; + if (m) { + try { + metadata_map = + m.value().cast>(); + } catch (const py::cast_error& e) { + throw std::invalid_argument( + "[save_safetensors] Metadata must be a dictionary with string keys and values"); + } + } else { + metadata_map = std::unordered_map(); + } auto arrays_map = d.cast>(); if (py::isinstance(file)) { { py::gil_scoped_release nogil; - save_safetensors(py::cast(file), arrays_map); + save_safetensors(py::cast(file), arrays_map, metadata_map); } } else if (is_ostream_object(file)) { auto writer = std::make_shared(file); { py::gil_scoped_release nogil; - save_safetensors(writer, arrays_map); + save_safetensors(writer, arrays_map, metadata_map); } } else { throw std::invalid_argument( @@ -471,7 +488,7 @@ void mlx_save_gguf_helper( if (py::isinstance(file)) { if (m) { auto metadata_map = - m.value().cast>(); + m.value().cast>(); { py::gil_scoped_release nogil; save_gguf(py::cast(file), arrays_map, metadata_map); diff --git a/python/src/load.h b/python/src/load.h index dbe0f9cd6..21f0cff32 100644 --- a/python/src/load.h +++ b/python/src/load.h @@ -15,19 +15,17 @@ using namespace mlx::core; using LoadOutputTypes = std::variant< array, std::unordered_map, - std::pair< - std::unordered_map, - std::unordered_map>>; + SafetensorsLoad, + GGUFLoad>; -std::unordered_map mlx_load_safetensor_helper( +SafetensorsLoad mlx_load_safetensor_helper(py::object file, StreamOrDevice s); +void mlx_save_safetensor_helper( py::object file, - StreamOrDevice s); -void mlx_save_safetensor_helper(py::object file, py::dict d); + py::dict d, + std::optional m); + +GGUFLoad mlx_load_gguf_helper(py::object file, StreamOrDevice s); -std::pair< - std::unordered_map, - std::unordered_map> -mlx_load_gguf_helper(py::object file, StreamOrDevice s); void mlx_save_gguf_helper( py::object file, py::dict d, diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 02a401543..8e08e6ca9 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -3214,8 +3214,9 @@ void init_ops(py::module_& m) { &mlx_save_safetensor_helper, "file"_a, "arrays"_a, + "metadata"_a = none, R"pbdoc( - save_safetensors(file: str, arrays: Dict[str, array]) + save_safetensors(file: str, arrays: Dict[str, array], metadata: Optional[Dict[str, str]] = None) Save array(s) to a binary file in ``.safetensors`` format. @@ -3225,6 +3226,7 @@ void init_ops(py::module_& m) { Args: file (file, str): File in which the array is saved. arrays (dict(str, array)): The dictionary of names to arrays to be saved. + metadata (dict(str, str), optional): The dictionary of metadata to be saved. )pbdoc"); m.def( "save_gguf", diff --git a/python/tests/test_load.py b/python/tests/test_load.py index a37ba83a9..ab2645bcf 100644 --- a/python/tests/test_load.py +++ b/python/tests/test_load.py @@ -66,6 +66,15 @@ class TestLoad(mlx_tests.MLXTestCase): def test_save_and_load_safetensors(self): if not os.path.isdir(self.test_dir): os.mkdir(self.test_dir) + with self.assertRaises(Exception): + mx.save_safetensors("test", {"a": mx.ones((4, 4))}, {"testing": 0}) + + mx.save_safetensors( + "test", {"test": mx.ones((2, 2))}, {"testing": "test", "format": "mlx"} + ) + res = mx.load("test.safetensors", return_metadata=True) + self.assertEqual(len(res), 2) + self.assertEqual(res[1], {"testing": "test", "format": "mlx"}) for dt in self.dtypes + ["bfloat16"]: with self.subTest(dtype=dt): diff --git a/tests/load_tests.cpp b/tests/load_tests.cpp index 51d1659f3..3a7556b57 100644 --- a/tests/load_tests.cpp +++ b/tests/load_tests.cpp @@ -19,8 +19,14 @@ TEST_CASE("test save_safetensors") { auto map = std::unordered_map(); map.insert({"test", array({1.0, 2.0, 3.0, 4.0})}); map.insert({"test2", ones({2, 2})}); - save_safetensors(file_path, map); - auto dict = load_safetensors(file_path); + auto _metadata = std::unordered_map(); + _metadata.insert({"test", "test"}); + _metadata.insert({"test2", "test2"}); + save_safetensors(file_path, map, _metadata); + auto [dict, metadata] = load_safetensors(file_path); + + CHECK_EQ(metadata, _metadata); + CHECK_EQ(dict.size(), 2); CHECK_EQ(dict.count("test"), 1); CHECK_EQ(dict.count("test2"), 1); @@ -55,7 +61,7 @@ TEST_CASE("test gguf") { } // Test saving and loading string metadata - std::unordered_map original_metadata; + std::unordered_map original_metadata; original_metadata.insert({"test_str", "my string"}); save_gguf(file_path, original_weights, original_metadata); @@ -97,7 +103,7 @@ TEST_CASE("test gguf metadata") { // Scalar array { - std::unordered_map original_metadata; + std::unordered_map original_metadata; original_metadata.insert({"test_arr", array(1.0)}); save_gguf(file_path, original_weights, original_metadata); @@ -111,7 +117,7 @@ TEST_CASE("test gguf metadata") { // 1D Array { - std::unordered_map original_metadata; + std::unordered_map original_metadata; auto arr = array({1.0, 2.0}); original_metadata.insert({"test_arr", arr}); save_gguf(file_path, original_weights, original_metadata); @@ -138,21 +144,21 @@ TEST_CASE("test gguf metadata") { // > 1D array throws { - std::unordered_map original_metadata; + std::unordered_map original_metadata; original_metadata.insert({"test_arr", array({1.0}, {1, 1})}); CHECK_THROWS(save_gguf(file_path, original_weights, original_metadata)); } // empty array throws { - std::unordered_map original_metadata; + std::unordered_map original_metadata; original_metadata.insert({"test_arr", array({})}); CHECK_THROWS(save_gguf(file_path, original_weights, original_metadata)); } // vector of string { - std::unordered_map original_metadata; + std::unordered_map original_metadata; std::vector data = {"data1", "data2", "data1234"}; original_metadata.insert({"meta", data}); save_gguf(file_path, original_weights, original_metadata); @@ -169,7 +175,7 @@ TEST_CASE("test gguf metadata") { // vector of string, string, scalar, and array { - std::unordered_map original_metadata; + std::unordered_map original_metadata; std::vector data = {"data1", "data2", "data1234"}; original_metadata.insert({"meta1", data}); original_metadata.insert({"meta2", array(2.5)}); From b6704851853548dd6d0b57188f11bb9f067062c0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Abdussamet=20T=C3=BCrker?= <53705368+abdussamettrkr@users.noreply.github.com> Date: Sat, 10 Feb 2024 03:49:14 +0300 Subject: [PATCH 11/42] Remainder negative numerator bug fixed (#641) Co-authored-by: Angelos Katharopoulos --- mlx/backend/accelerate/primitives.cpp | 40 +--------------------- mlx/backend/common/binary.cpp | 24 +++++++++++-- mlx/backend/metal/kernels/binary.h | 30 ++++++++++------ mlx/backend/metal/kernels/binary_two.metal | 27 ++++++++++++--- mlx/backend/metal/kernels/complex.h | 6 ++++ mlx/types/complex.h | 10 ++++++ python/tests/test_ops.py | 14 ++++++++ 7 files changed, 95 insertions(+), 56 deletions(-) diff --git a/mlx/backend/accelerate/primitives.cpp b/mlx/backend/accelerate/primitives.cpp index 6cd851111..499cc0ce4 100644 --- a/mlx/backend/accelerate/primitives.cpp +++ b/mlx/backend/accelerate/primitives.cpp @@ -62,6 +62,7 @@ DEFAULT(Partition) DEFAULT_MULTI(QRF) DEFAULT(RandomBits) DEFAULT(Reshape) +DEFAULT(Remainder) DEFAULT(Round) DEFAULT(Scatter) DEFAULT(Sigmoid) @@ -292,45 +293,6 @@ void Divide::eval_cpu(const std::vector& inputs, array& out) { } } -// TODO: Avoid code duplication with the common backend. -struct RemainderFn { - template - std::enable_if_t, T> operator()( - T numerator, - T denominator) { - return std::fmod(numerator, denominator); - } - - template - std::enable_if_t, T> operator()( - T numerator, - T denominator) { - return numerator % denominator; - } -}; - -void Remainder::eval_cpu(const std::vector& inputs, array& out) { - assert(inputs.size() == 2); - auto& a = inputs[0]; - auto& b = inputs[1]; - - if (a.dtype() == float32) { - binary( - a, - b, - out, - RemainderFn{}, - UseDefaultBinaryOp(), - UseDefaultBinaryOp(), - [](const auto* a, const auto* b, auto* o, auto n) { - int num_el = n; - vvremainderf((float*)o, (const float*)a, (const float*)b, &num_el); - }); - } else { - binary(a, b, out, RemainderFn{}); - } -} - void Exp::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; diff --git a/mlx/backend/common/binary.cpp b/mlx/backend/common/binary.cpp index a51d22d0f..855e8467b 100644 --- a/mlx/backend/common/binary.cpp +++ b/mlx/backend/common/binary.cpp @@ -140,16 +140,34 @@ void Divide::eval(const std::vector& inputs, array& out) { struct RemainderFn { template - std::enable_if_t, T> operator()( + std::enable_if_t & !std::is_signed_v, T> operator()( T numerator, T denominator) { - return std::fmod(numerator, denominator); + return numerator % denominator; } template - std::enable_if_t, T> operator()( + std::enable_if_t & std::is_signed_v, T> operator()( T numerator, T denominator) { + auto r = numerator % denominator; + if (r != 0 && (r < 0 != denominator < 0)) + r += denominator; + return r; + } + + template + std::enable_if_t, T> operator()( + T numerator, + T denominator) { + auto r = std::fmod(numerator, denominator); + if (r != 0 && (r < 0 != denominator < 0)) { + r += denominator; + } + return r; + } + + complex64_t operator()(complex64_t numerator, complex64_t denominator) { return numerator % denominator; } }; diff --git a/mlx/backend/metal/kernels/binary.h b/mlx/backend/metal/kernels/binary.h index 8adb84c58..006f2ff0e 100644 --- a/mlx/backend/metal/kernels/binary.h +++ b/mlx/backend/metal/kernels/binary.h @@ -24,20 +24,30 @@ struct Divide { struct Remainder { template - T operator()(T x, T y) { + metal::enable_if_t & !metal::is_signed_v, T> + operator()(T x, T y) { return x % y; } - template <> - float operator()(float x, float y) { - return fmod(x, y); + template + metal::enable_if_t & metal::is_signed_v, T> + operator()(T x, T y) { + auto r = x % y; + if (r != 0 && (r < 0 != y < 0)) { + r += y; + } + return r; + } + template + metal::enable_if_t, T> operator()(T x, T y) { + T r = fmod(x, y); + if (r != 0 && (r < 0 != y < 0)) { + r += y; + } + return r; } template <> - half operator()(half x, half y) { - return fmod(x, y); - } - template <> - bfloat16_t operator()(bfloat16_t x, bfloat16_t y) { - return fmod(x, y); + complex64_t operator()(complex64_t x, complex64_t y) { + return x % y; } }; diff --git a/mlx/backend/metal/kernels/binary_two.metal b/mlx/backend/metal/kernels/binary_two.metal index 3e4dbc8f2..245ced024 100644 --- a/mlx/backend/metal/kernels/binary_two.metal +++ b/mlx/backend/metal/kernels/binary_two.metal @@ -14,10 +14,29 @@ struct FloorDivide { }; struct Remainder { - template T operator()(T x, T y) { return x % y; } - template <> float operator()(float x, float y) { return fmod(x, y); } - template <> half operator()(half x, half y) { return fmod(x, y); } - template <> bfloat16_t operator()(bfloat16_t x, bfloat16_t y) { return fmod(x, y); } + template + metal::enable_if_t & !metal::is_signed_v, T> operator()(T x, T y) { + return x % y; + } + template + metal::enable_if_t & metal::is_signed_v, T> operator()(T x, T y) { + auto r = x % y; + if (r != 0 && (r < 0 != y < 0)) { + r += y; + } + return r; + } + template + metal::enable_if_t, T> operator()(T x, T y) { + T r = fmod(x, y); + if (r != 0 && (r < 0 != y < 0)) { + r += y; + } + return r; + } + template <> complex64_t operator()(complex64_t x, complex64_t y) { + return x % y; + } }; template diff --git a/mlx/backend/metal/kernels/complex.h b/mlx/backend/metal/kernels/complex.h index ac966a293..9cb27c5a3 100644 --- a/mlx/backend/metal/kernels/complex.h +++ b/mlx/backend/metal/kernels/complex.h @@ -121,5 +121,11 @@ constexpr complex64_t operator/(complex64_t a, complex64_t b) { constexpr complex64_t operator%(complex64_t a, complex64_t b) { auto real = a.real - (b.real * static_cast(a.real / b.real)); auto imag = a.imag - (b.imag * static_cast(a.imag / b.imag)); + if (real != 0 && (real < 0 != b.real < 0)) { + real += b.real; + } + if (imag != 0 && (imag < 0 != b.imag < 0)) { + imag += b.imag; + } return {real, imag}; } diff --git a/mlx/types/complex.h b/mlx/types/complex.h index 55cbe447a..19ab1b542 100644 --- a/mlx/types/complex.h +++ b/mlx/types/complex.h @@ -35,6 +35,16 @@ inline bool operator>(const complex64_t& a, const complex64_t& b) { return (a.real() > b.real()) || (a.real() == b.real() && a.imag() > b.imag()); } +inline complex64_t operator%(complex64_t a, complex64_t b) { + auto real = a.real() - (b.real() * static_cast(a.real() / b.real())); + auto imag = a.imag() - (b.imag() * static_cast(a.imag() / b.imag())); + if (real != 0 && (real < 0 != b.real() < 0)) + real += b.real(); + if (imag != 0 && (imag < 0 != b.imag() < 0)) + imag += b.imag(); + return {real, imag}; +} + inline bool operator<=(const complex64_t& a, const complex64_t& b) { return operator>=(b, a); } diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 3d84f4b02..9ad6d5a53 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -275,6 +275,20 @@ class TestOps(mlx_tests.MLXTestCase): self.assertEqual(z.dtype, dt) self.assertEqual(z.item(), 1) + z = -1 % x + self.assertEqual(z.dtype, dt) + self.assertEqual(z.item(), 1) + + z = -1 % -x + self.assertEqual(z.dtype, dt) + self.assertEqual(z.item(), -1) + + x = mx.arange(10).astype(dt) - 5 + y = x % 5 + z = x % -5 + self.assertEqual(y.tolist(), [0, 1, 2, 3, 4, 0, 1, 2, 3, 4]) + self.assertEqual(z.tolist(), [0, -4, -3, -2, -1, 0, -4, -3, -2, -1]) + def test_comparisons(self): a = mx.array([0.0, 1.0, 5.0]) b = mx.array([-1.0, 2.0, 5.0]) From b96be943dcea5ce5c32d91fdcf262f5b60855639 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 9 Feb 2024 16:50:45 -0800 Subject: [PATCH 12/42] bug fix (#658) --- mlx/backend/accelerate/softmax.cpp | 7 ++++++- mlx/backend/common/softmax.cpp | 7 ++++++- mlx/backend/metal/softmax.cpp | 7 ++++++- python/tests/test_ops.py | 5 +++++ 4 files changed, 23 insertions(+), 3 deletions(-) diff --git a/mlx/backend/accelerate/softmax.cpp b/mlx/backend/accelerate/softmax.cpp index fcd8fbe50..9e7ddf632 100644 --- a/mlx/backend/accelerate/softmax.cpp +++ b/mlx/backend/accelerate/softmax.cpp @@ -274,7 +274,12 @@ void Softmax::eval_cpu(const std::vector& inputs, array& out) { // Make sure that the last dimension is contiguous auto check_input = [](array x) { - if (x.strides()[x.ndim() - 1] == 1) { + bool no_copy = x.strides()[x.ndim() - 1] == 1; + if (x.ndim() > 1) { + auto s = x.strides()[x.ndim() - 1]; + no_copy &= (s == 0 || s == x.shape().back()); + } + if (no_copy) { return x; } else { array x_copy(x.shape(), x.dtype(), nullptr, {}); diff --git a/mlx/backend/common/softmax.cpp b/mlx/backend/common/softmax.cpp index 90874c72d..564fd1f22 100644 --- a/mlx/backend/common/softmax.cpp +++ b/mlx/backend/common/softmax.cpp @@ -53,7 +53,12 @@ void Softmax::eval(const std::vector& inputs, array& out) { // Make sure that the last dimension is contiguous auto check_input = [](array x) { - if (x.strides().back() == 1) { + bool no_copy = x.strides()[x.ndim() - 1] == 1; + if (x.ndim() > 1) { + auto s = x.strides()[x.ndim() - 1]; + no_copy &= (s == 0 || s == x.shape().back()); + } + if (no_copy) { return x; } else { array x_copy(x.shape(), x.dtype(), nullptr, {}); diff --git a/mlx/backend/metal/softmax.cpp b/mlx/backend/metal/softmax.cpp index 33ec8014c..7edc91b55 100644 --- a/mlx/backend/metal/softmax.cpp +++ b/mlx/backend/metal/softmax.cpp @@ -22,7 +22,12 @@ void Softmax::eval_gpu(const std::vector& inputs, array& out) { // Make sure that the last dimension is contiguous std::vector copies; auto check_input = [&copies, &s](const array& x) { - if (x.strides()[x.ndim() - 1] == 1) { + bool no_copy = x.strides()[x.ndim() - 1] == 1; + if (x.ndim() > 1) { + auto s = x.strides()[x.ndim() - 1]; + no_copy &= (s == 0 || s == x.shape().back()); + } + if (no_copy) { return x; } else { array x_copy(x.shape(), x.dtype(), nullptr, {}); diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 9ad6d5a53..433890237 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -1386,6 +1386,11 @@ class TestOps(mlx_tests.MLXTestCase): self.assertTrue((a[:-1] < 1e-9).all()) self.assertEqual(a[-1], 1) + # Sliced inputs + y = mx.random.uniform(shape=(8, 4)) + out = mx.softmax(y[:, 0:2], axis=-1) + self.assertAlmostEqual(out.sum().item(), 8.0) + def test_concatenate(self): a_npy = np.random.randn(32, 32, 32) b_npy = np.random.randn(32, 32, 32) From 7f3f8d8f8d081fc79456565b07a16b7a2f2da520 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 9 Feb 2024 17:02:13 -0800 Subject: [PATCH 13/42] Fix the softmax fix (#661) --- mlx/backend/accelerate/softmax.cpp | 2 +- mlx/backend/common/softmax.cpp | 2 +- mlx/backend/metal/softmax.cpp | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/mlx/backend/accelerate/softmax.cpp b/mlx/backend/accelerate/softmax.cpp index 9e7ddf632..8b95e32d4 100644 --- a/mlx/backend/accelerate/softmax.cpp +++ b/mlx/backend/accelerate/softmax.cpp @@ -276,7 +276,7 @@ void Softmax::eval_cpu(const std::vector& inputs, array& out) { auto check_input = [](array x) { bool no_copy = x.strides()[x.ndim() - 1] == 1; if (x.ndim() > 1) { - auto s = x.strides()[x.ndim() - 1]; + auto s = x.strides()[x.ndim() - 2]; no_copy &= (s == 0 || s == x.shape().back()); } if (no_copy) { diff --git a/mlx/backend/common/softmax.cpp b/mlx/backend/common/softmax.cpp index 564fd1f22..87ce748c8 100644 --- a/mlx/backend/common/softmax.cpp +++ b/mlx/backend/common/softmax.cpp @@ -55,7 +55,7 @@ void Softmax::eval(const std::vector& inputs, array& out) { auto check_input = [](array x) { bool no_copy = x.strides()[x.ndim() - 1] == 1; if (x.ndim() > 1) { - auto s = x.strides()[x.ndim() - 1]; + auto s = x.strides()[x.ndim() - 2]; no_copy &= (s == 0 || s == x.shape().back()); } if (no_copy) { diff --git a/mlx/backend/metal/softmax.cpp b/mlx/backend/metal/softmax.cpp index 7edc91b55..be25bc032 100644 --- a/mlx/backend/metal/softmax.cpp +++ b/mlx/backend/metal/softmax.cpp @@ -24,7 +24,7 @@ void Softmax::eval_gpu(const std::vector& inputs, array& out) { auto check_input = [&copies, &s](const array& x) { bool no_copy = x.strides()[x.ndim() - 1] == 1; if (x.ndim() > 1) { - auto s = x.strides()[x.ndim() - 1]; + auto s = x.strides()[x.ndim() - 2]; no_copy &= (s == 0 || s == x.shape().back()); } if (no_copy) { From 11d2c8f7a18cf97e2aba0985b16e0e91fe833576 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Fri, 9 Feb 2024 18:17:04 -0800 Subject: [PATCH 14/42] Linux build for CI of other packages (#660) --- .circleci/config.yml | 60 ++++++++++++++++++++++++++++++++++++++++++++ CMakeLists.txt | 6 ++--- 2 files changed, 63 insertions(+), 3 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 26aa21ce8..c0cbb6e9a 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -7,6 +7,9 @@ parameters: weekly_build: type: boolean default: false + test_release: + type: boolean + default: false jobs: linux_build_and_test: @@ -170,12 +173,61 @@ jobs: - store_artifacts: path: dist/ + build_linux_test_release: + parameters: + python_version: + type: string + default: "3.9" + extra_env: + type: string + default: "DEV_RELEASE=1" + docker: + - image: ubuntu:20.04 + steps: + - checkout + - run: + name: Build wheel + command: | + PYTHON=python<< parameters.python_version >> + apt-get update + apt-get upgrade -y + DEBIAN_FRONTEND=noninteractive TZ=Etc/UTC apt-get -y install tzdata + apt-get install -y apt-utils + apt-get install -y software-properties-common + add-apt-repository -y ppa:deadsnakes/ppa + apt-get install -y $PYTHON $PYTHON-dev $PYTHON-full + apt-get install -y libblas-dev liblapack-dev liblapacke-dev + apt-get install -y build-essential git + $PYTHON -m venv env + source env/bin/activate + pip install --upgrade pip + pip install --upgrade cmake + pip install --upgrade pybind11[global] + pip install --upgrade setuptools + pip install pybind11-stubgen + pip install numpy + pip install auditwheel + pip install patchelf + pip install build + << parameters.extra_env >> \ + CMAKE_BUILD_PARALLEL_LEVEL="" \ + pip install . -v + python setup.py generate_stubs + << parameters.extra_env >> \ + CMAKE_BUILD_PARALLEL_LEVEL="" \ + python -m build --wheel + auditwheel show dist/* + auditwheel repair dist/* --plat manylinux_2_31_x86_64 + - store_artifacts: + path: wheelhouse/ + workflows: build_and_test: when: and: - not: << pipeline.parameters.nightly_build >> - not: << pipeline.parameters.weekly_build >> + - not: << pipeline.parameters.test_release >> jobs: - mac_build_and_test - linux_build_and_test @@ -207,3 +259,11 @@ workflows: python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"] xcode_version: ["14.3.1", "15.2.0"] build_env: ["DEV_RELEASE=1"] + linux_test_release: + when: << pipeline.parameters.test_release >> + jobs: + - build_linux_test_release: + matrix: + parameters: + python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"] + extra_env: ["PYPI_RELEASE=1"] diff --git a/CMakeLists.txt b/CMakeLists.txt index b008e10e6..e889353ce 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -123,8 +123,8 @@ else() /usr/include /usr/local/include $ENV{BLAS_HOME}/include) - message(STATUS "Blas lib" ${BLAS_LIBRARIES}) - message(STATUS "Blas include" ${BLAS_INCLUDE_DIRS}) + message(STATUS "Blas lib " ${BLAS_LIBRARIES}) + message(STATUS "Blas include " ${BLAS_INCLUDE_DIRS}) target_include_directories(mlx PRIVATE ${BLAS_INCLUDE_DIRS}) target_link_libraries(mlx ${BLAS_LIBRARIES}) find_package(LAPACK REQUIRED) @@ -134,7 +134,7 @@ else() find_path(LAPACK_INCLUDE_DIRS lapacke.h /usr/include /usr/local/include) - message(STATUS "Lapack lib" ${LAPACK_LIBRARIES}) + message(STATUS "Lapack lib " ${LAPACK_LIBRARIES}) message(STATUS "Lapack include " ${LAPACK_INCLUDE_DIRS}) target_include_directories(mlx PRIVATE ${LAPACK_INCLUDE_DIRS}) target_link_libraries(mlx ${LAPACK_LIBRARIES}) From 06072601cef7876a23d079d9f8c28df28f6e28e8 Mon Sep 17 00:00:00 2001 From: Vijay Krish Date: Sat, 10 Feb 2024 08:49:51 -0800 Subject: [PATCH 15/42] Scatter optimization : Eliminate 64b integer divide. (#662) Launch 2D grid to eliminate divide and mod in device code, since 64b integer division is very expensive. Github Issue #506 Co-authored-by: Vijay Krishnamoorthy --- benchmarks/python/gather_bench.py | 13 +----- benchmarks/python/scatter_bench.py | 56 ++++++++++++++++++++++++ benchmarks/python/time_utils.py | 14 +++++- mlx/backend/metal/indexing.cpp | 5 +-- mlx/backend/metal/kernels/indexing.metal | 10 ++--- 5 files changed, 77 insertions(+), 21 deletions(-) create mode 100644 benchmarks/python/scatter_bench.py diff --git a/benchmarks/python/gather_bench.py b/benchmarks/python/gather_bench.py index 8ce6dcb8f..e000841d2 100644 --- a/benchmarks/python/gather_bench.py +++ b/benchmarks/python/gather_bench.py @@ -5,18 +5,7 @@ from time import time import mlx.core as mx import torch - - -def measure_runtime(fn, **kwargs): - # Warmup - for _ in range(5): - fn(**kwargs) - - tic = time() - iters = 10 - for _ in range(iters): - fn(**kwargs) - return (time() - tic) * 1000 / iters +from time_utils import measure_runtime def benchmark_gather_mlx(x_shape, idx_shape): diff --git a/benchmarks/python/scatter_bench.py b/benchmarks/python/scatter_bench.py new file mode 100644 index 000000000..2d63d8bf1 --- /dev/null +++ b/benchmarks/python/scatter_bench.py @@ -0,0 +1,56 @@ +# Copyright © 2023-2024 Apple Inc. + +import argparse + +import mlx.core as mx +import torch +from time_utils import measure_runtime + + +def benchmark_scatter_mlx(dst_shape, x_shape, idx_shape): + def scatter(dst, x, idx): + dst[idx] = x + mx.eval(dst) + + idx = mx.random.randint(0, dst_shape[0] - 1, idx_shape) + x = mx.random.normal(x_shape).astype(mx.float32) + dst = mx.random.normal(dst_shape).astype(mx.float32) + + runtime = measure_runtime(scatter, dst=dst, x=x, idx=idx) + print(f"MLX: {runtime:.3f}ms") + + +def benchmark_scatter_torch(dst_shape, x_shape, idx_shape, device): + def gather(dst, x, idx, device): + dst[idx] = x + if device == torch.device("mps"): + torch.mps.synchronize() + + idx = torch.randint(0, dst_shape[0] - 1, idx_shape).to(device) + x = torch.randn(x_shape, dtype=torch.float32).to(device) + dst = torch.randn(dst_shape, dtype=torch.float32).to(device) + + runtime = measure_runtime(gather, dst=dst, x=x, idx=idx, device=device) + print(f"PyTorch: {runtime:.3f}ms") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("Gather benchmarks.") + parser.add_argument("--cpu", action="store_true", help="Use the CPU.") + args = parser.parse_args() + + if args.cpu: + mx.set_default_device(mx.cpu) + device = torch.device("cpu") + else: + device = torch.device("mps") + + dst_shapes = [(10, 64), (100_000, 64), (1_000_000, 64)] + idx_shapes = [(1_000_000,), (1_000_000,), (100_000,)] + x_shapes = [(1_000_000, 64), (1_000_000, 64), (100_000, 64)] + + for dst_shape, x_shape, idx_shape in zip(dst_shapes, x_shapes, idx_shapes): + print("=" * 20) + print(f"X {x_shape}, Indices {idx_shape}") + benchmark_scatter_mlx(dst_shape, x_shape, idx_shape) + benchmark_scatter_torch(dst_shape, x_shape, idx_shape, device=device) diff --git a/benchmarks/python/time_utils.py b/benchmarks/python/time_utils.py index 73266cb7a..32d22ed99 100644 --- a/benchmarks/python/time_utils.py +++ b/benchmarks/python/time_utils.py @@ -1,4 +1,4 @@ -# Copyright © 2023 Apple Inc. +# Copyright © 2023-2024 Apple Inc. import time @@ -20,3 +20,15 @@ def time_fn(fn, *args, **kwargs): msec = 1e3 * (toc - tic) / num_iters print(f"{msec:.5f} msec") + + +def measure_runtime(fn, **kwargs): + # Warmup + for _ in range(5): + fn(**kwargs) + + tic = time.time() + iters = 10 + for _ in range(iters): + fn(**kwargs) + return (time.time() - tic) * 1000 / iters diff --git a/mlx/backend/metal/indexing.cpp b/mlx/backend/metal/indexing.cpp index 28edb9126..cf2256846 100644 --- a/mlx/backend/metal/indexing.cpp +++ b/mlx/backend/metal/indexing.cpp @@ -212,9 +212,6 @@ void Scatter::eval_gpu(const std::vector& inputs, array& out) { thread_group_size = nthreads; } - MTL::Size grid_dims = MTL::Size(nthreads, 1, 1); - MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); - compute_encoder->setComputePipelineState(kernel); // Make the argument buffer to store the indices for the @@ -317,6 +314,8 @@ void Scatter::eval_gpu(const std::vector& inputs, array& out) { compute_encoder->setBytes(&out_ndim, sizeof(size_t), 9); compute_encoder->setBytes(axes_.data(), axes_.size() * sizeof(int), 10); + MTL::Size grid_dims = MTL::Size(upd_size, nthreads / upd_size, 1); + MTL::Size group_dims = get_block_dims(upd_size, nthreads / upd_size, 1); compute_encoder->dispatchThreads(grid_dims, group_dims); // Cleanup temporaries diff --git a/mlx/backend/metal/kernels/indexing.metal b/mlx/backend/metal/kernels/indexing.metal index 395bc7819..7b6e2399a 100644 --- a/mlx/backend/metal/kernels/indexing.metal +++ b/mlx/backend/metal/kernels/indexing.metal @@ -187,11 +187,11 @@ template const device size_t *out_strides [[buffer(8)]], const device size_t& out_ndim [[buffer(9)]], const device int* axes [[buffer(10)]], - uint gid [[thread_position_in_grid]]) { + uint2 gid [[thread_position_in_grid]]) { Op op; - auto ind_idx = gid / upd_size; - auto ind_offset = gid % upd_size; + auto ind_idx = gid.y; + auto ind_offset = gid.x; size_t out_idx = 0; for (int i = 0; i < NIDX; ++i) { @@ -208,7 +208,7 @@ template auto out_offset = elem_to_loc( ind_offset, upd_shape + indices.ndim, out_strides, out_ndim); - auto upd_idx = elem_to_loc(gid, upd_shape, upd_strides, upd_ndim); + auto upd_idx = elem_to_loc(gid.y * upd_size + gid.x, upd_shape, upd_strides, upd_ndim); op.atomic_update(out, updates[upd_idx], out_idx + out_offset); } @@ -226,7 +226,7 @@ template [[host_name("scatter" name "_" #nindex)]] \ const device size_t *out_strides [[buffer(8)]], \ const device size_t& out_ndim [[buffer(9)]], \ const device int* axes [[buffer(10)]], \ - uint gid [[thread_position_in_grid]]); + uint2 gid [[thread_position_in_grid]]); // Special case NINDEX=0 #define instantiate_scatter_nd0(name, type) \ From 0dbc4c754714085af01c3a9bfdcbb6db61b0eda1 Mon Sep 17 00:00:00 2001 From: Nripesh Niketan <86844847+NripeshN@users.noreply.github.com> Date: Sun, 11 Feb 2024 18:08:20 +0400 Subject: [PATCH 16/42] feat: Update pre-commit-config.yaml (#667) --- .pre-commit-config.yaml | 2 +- benchmarks/python/comparative/compare.py | 6 ++---- python/tests/test_load.py | 16 ++++++++++------ python/tests/test_ops.py | 4 +--- 4 files changed, 14 insertions(+), 14 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d0ebc6d48..279ab5c91 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -5,7 +5,7 @@ repos: - id: clang-format # Using this mirror lets us use mypyc-compiled black, which is about 2x faster - repo: https://github.com/psf/black-pre-commit-mirror - rev: 23.12.1 + rev: 24.1.1 hooks: - id: black - repo: https://github.com/pycqa/isort diff --git a/benchmarks/python/comparative/compare.py b/benchmarks/python/comparative/compare.py index a9d3df22d..5b71cf583 100644 --- a/benchmarks/python/comparative/compare.py +++ b/benchmarks/python/comparative/compare.py @@ -80,10 +80,8 @@ if __name__ == "__main__": _filter = make_predicate(args.filter, args.negative_filter) if args.mlx_dtypes: - compare_filtered = ( - lambda x: compare_mlx_dtypes( - x.split() + rest, args.mlx_dtypes[0], args.mlx_dtypes[1] - ) + compare_filtered = lambda x: ( + compare_mlx_dtypes(x.split() + rest, args.mlx_dtypes[0], args.mlx_dtypes[1]) if _filter(x) else None ) diff --git a/python/tests/test_load.py b/python/tests/test_load.py index ab2645bcf..fdf06041a 100644 --- a/python/tests/test_load.py +++ b/python/tests/test_load.py @@ -84,9 +84,11 @@ class TestLoad(mlx_tests.MLXTestCase): self.test_dir, f"mlx_{dt}_{i}_fs.safetensors" ) save_dict = { - "test": mx.random.normal(shape=shape, dtype=getattr(mx, dt)) - if dt in ["float32", "float16", "bfloat16"] - else mx.ones(shape, dtype=getattr(mx, dt)) + "test": ( + mx.random.normal(shape=shape, dtype=getattr(mx, dt)) + if dt in ["float32", "float16", "bfloat16"] + else mx.ones(shape, dtype=getattr(mx, dt)) + ) } with open(save_file_mlx, "wb") as f: @@ -113,9 +115,11 @@ class TestLoad(mlx_tests.MLXTestCase): self.test_dir, f"mlx_{dt}_{i}_fs.gguf" ) save_dict = { - "test": mx.random.normal(shape=shape, dtype=getattr(mx, dt)) - if dt in ["float32", "float16", "bfloat16"] - else mx.ones(shape, dtype=getattr(mx, dt)) + "test": ( + mx.random.normal(shape=shape, dtype=getattr(mx, dt)) + if dt in ["float32", "float16", "bfloat16"] + else mx.ones(shape, dtype=getattr(mx, dt)) + ) } mx.save_gguf(save_file_mlx, save_dict) diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 433890237..edb98032b 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -1333,9 +1333,7 @@ class TestOps(mlx_tests.MLXTestCase): for d in dims: anp = np.random.randint(-20, 20, (size**d,)).reshape([size] * d) for n_bsx in range(d): - bnp = np.random.randint(-20, 20, (size**n_bsx,)).reshape( - [size] * n_bsx - ) + bnp = np.random.randint(-20, 20, (size**n_bsx,)).reshape([size] * n_bsx) for _ in range(trial_mul * d): amlx = mx.array(anp) bmlx = mx.array(bnp) From d12573daa60c14a371637c9c33ca8def428c6445 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Sun, 11 Feb 2024 10:33:30 -0800 Subject: [PATCH 17/42] quote file name (#670) --- mlx/backend/metal/make_compiled_preamble.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlx/backend/metal/make_compiled_preamble.sh b/mlx/backend/metal/make_compiled_preamble.sh index 1271f567d..26b575de4 100644 --- a/mlx/backend/metal/make_compiled_preamble.sh +++ b/mlx/backend/metal/make_compiled_preamble.sh @@ -12,7 +12,7 @@ SRCDIR=$3 CONTENT=$($CC -I $SRCDIR -E $SRCDIR/mlx/backend/metal/kernels/compiled_preamble.h 2>/dev/null) -cat << EOF > $OUTPUT_FILE +cat << EOF > "$OUTPUT_FILE" // Copyright © 2023-24 Apple Inc. namespace mlx::core::metal { From 3756381358b0df088f62e46b2942efefcc94274d Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Sun, 11 Feb 2024 21:53:16 -0800 Subject: [PATCH 18/42] Faster bfloat quantized mat-vec and vec-mat (#663) --- mlx/backend/metal/kernels/quantized.metal | 46 ++++++++++++++--------- 1 file changed, 28 insertions(+), 18 deletions(-) diff --git a/mlx/backend/metal/kernels/quantized.metal b/mlx/backend/metal/kernels/quantized.metal index 5bf3142d4..0de84093d 100644 --- a/mlx/backend/metal/kernels/quantized.metal +++ b/mlx/backend/metal/kernels/quantized.metal @@ -15,6 +15,14 @@ using namespace metal; MLX_MTL_CONST int SIMD_SIZE = 32; +template struct AccT { + typedef T acc_t; +}; + +template <> struct AccT { + typedef float acc_t; +}; + template [[kernel]] void qmv( const device uint32_t* w [[buffer(0)]], @@ -37,15 +45,16 @@ template ::acc_t U; + threadgroup U scales_block[BM * groups_per_block]; + threadgroup U biases_block[BM * groups_per_block]; + threadgroup U x_block[colgroup]; thread uint32_t w_local; - thread T result = 0; - thread T scale = 1; - thread T bias = 0; - thread T x_thread[el_per_thread]; + thread U result = 0; + thread U scale = 1; + thread U bias = 0; + thread U x_thread[el_per_thread]; // Adjust positions const int in_vec_size_w = in_vec_size / el_per_thread; @@ -90,7 +99,7 @@ template (w_local & bitmask) + bias) * x_thread[k]; + result += (scale * static_cast(w_local & bitmask) + bias) * x_thread[k]; w_local >>= bits; } } @@ -100,7 +109,7 @@ template (result); } } @@ -129,15 +138,16 @@ template ::acc_t U; + threadgroup U scales_block[BM * groups_per_block]; + threadgroup U biases_block[BM * groups_per_block]; + threadgroup U x_block[BM]; thread uint32_t w_local; - thread T result[el_per_int] = {0}; - thread T scale = 1; - thread T bias = 0; - thread T x_local = 0; + thread U result[el_per_int] = {0}; + thread U scale = 1; + thread U bias = 0; + thread U x_local = 0; // Adjust positions const int out_vec_size_w = out_vec_size / el_per_int; @@ -186,7 +196,7 @@ template (w_local & bitmask) + bias) * x_local; + result[k] += (scale * static_cast(w_local & bitmask) + bias) * x_local; w_local >>= bits; } } @@ -201,7 +211,7 @@ template (result[k]); } } } From 74caa68d0255b28ea88b829a0f54a4442de366b0 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 12 Feb 2024 12:25:04 -0800 Subject: [PATCH 19/42] nit in readme (#675) --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index a9abd58ad..118cc828e 100644 --- a/README.md +++ b/README.md @@ -6,8 +6,8 @@ [![CircleCI](https://circleci.com/gh/ml-explore/mlx.svg?style=svg)](https://circleci.com/gh/ml-explore/mlx) -MLX is an array framework for machine learning on Apple silicon, brought to you -by Apple machine learning research. +MLX is an array framework for machine learning research on Apple silicon, +brought to you by Apple machine learning research. Some key features of MLX include: From 4cc70290f794104225657f27d027d0d1ac20fdcb Mon Sep 17 00:00:00 2001 From: Mike Drob Date: Mon, 12 Feb 2024 19:47:21 -0600 Subject: [PATCH 20/42] PR Builder Workflow (#659) --- .circleci/config.yml | 28 +++++++++++++++++++++++++--- 1 file changed, 25 insertions(+), 3 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index c0cbb6e9a..537f15969 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -225,6 +225,7 @@ workflows: build_and_test: when: and: + - equal: [ main, << pipeline.git.branch >> ] - not: << pipeline.parameters.nightly_build >> - not: << pipeline.parameters.weekly_build >> - not: << pipeline.parameters.test_release >> @@ -242,8 +243,23 @@ workflows: python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"] xcode_version: ["14.3.1", "15.2.0"] build_env: ["PYPI_RELEASE=1"] + prb: + when: + matches: + pattern: "^pull/\\d+(/head)?$" + value: << pipeline.git.branch >> + jobs: + - hold: + type: approval + - mac_build_and_test: + requires: [ hold ] + - linux_build_and_test: + requires: [ hold ] nightly_build: - when: << pipeline.parameters.nightly_build >> + when: + and: + - equal: [ main, << pipeline.git.branch >> ] + - << pipeline.parameters.nightly_build >> jobs: - build_release: matrix: @@ -251,7 +267,10 @@ workflows: python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"] xcode_version: ["14.3.1", "15.2.0"] weekly_build: - when: << pipeline.parameters.weekly_build >> + when: + and: + - equal: [ main, << pipeline.git.branch >> ] + - << pipeline.parameters.weekly_build >> jobs: - build_release: matrix: @@ -260,7 +279,10 @@ workflows: xcode_version: ["14.3.1", "15.2.0"] build_env: ["DEV_RELEASE=1"] linux_test_release: - when: << pipeline.parameters.test_release >> + when: + and: + - equal: [ main, << pipeline.git.branch >> ] + - << pipeline.parameters.test_release >> jobs: - build_linux_test_release: matrix: From 40c108766b146453fea2d1662382658334d069c6 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Mon, 12 Feb 2024 18:54:21 -0800 Subject: [PATCH 21/42] Quantized matmul fix (#677) * Fix qmv for small or unaligned matrices * Fix qmm --- mlx/backend/metal/kernels/quantized.metal | 24 ++++++--- mlx/backend/metal/quantized.cpp | 2 +- python/tests/test_quantized.py | 64 +++++++++++++++++++++++ 3 files changed, 81 insertions(+), 9 deletions(-) diff --git a/mlx/backend/metal/kernels/quantized.metal b/mlx/backend/metal/kernels/quantized.metal index 0de84093d..c2bfba9f9 100644 --- a/mlx/backend/metal/kernels/quantized.metal +++ b/mlx/backend/metal/kernels/quantized.metal @@ -39,11 +39,12 @@ template ::acc_t U; threadgroup U scales_block[BM * groups_per_block]; @@ -66,12 +67,19 @@ template = out_vec_size) { + return; + } + // Loop over in_vec in blocks of colgroup for (int i=0; i; using loader_x_t = mlx::steel::BlockLoader; - threadgroup T scales_block[BN * groups_per_block]; threadgroup T biases_block[BN * groups_per_block]; threadgroup T Xs[BM * BK]; @@ -313,7 +320,7 @@ template = K) { + if (num_k < BK) { for (int wo=0; wo& inputs, array& out) { int bo = std::min(32, O); int bd = 32; MTL::Size group_dims = MTL::Size(bd, bo, 1); - MTL::Size grid_dims = MTL::Size(1, O / bo, B); + MTL::Size grid_dims = MTL::Size(1, (O + bo - 1) / bo, B); set_array_buffer(compute_encoder, w, 0); set_array_buffer(compute_encoder, scales, 1); diff --git a/python/tests/test_quantized.py b/python/tests/test_quantized.py index b068aa6ee..fad2ba51c 100644 --- a/python/tests/test_quantized.py +++ b/python/tests/test_quantized.py @@ -165,6 +165,70 @@ class TestQuantized(mlx_tests.MLXTestCase): self.assertEqual(y_q.shape, y_hat.shape) self.assertLess((y_q - y_hat).abs().max(), 1e-3) + def test_non_multiples(self): + w = mx.random.normal(shape=(33, 256)) + w_q, scales, biases = mx.quantize(w) + w_hat = mx.dequantize(w_q, scales, biases) + + # Test qmv + x = mx.random.normal(shape=(1, 256)) + y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True) + y_hat = x @ w_hat.T + self.assertLess((y_q - y_hat).abs().max(), 1e-3) + + # Test qmm_t + x = mx.random.normal(shape=(10, 256)) + y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True) + y_hat = x @ w_hat.T + self.assertEqual(y_q.shape, y_hat.shape) + self.assertLess((y_q - y_hat).abs().max(), 1e-3) + + # Test qvm + x = mx.random.normal(shape=(1, 33)) + y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=False) + y_hat = x @ w_hat + self.assertEqual(y_q.shape, y_hat.shape) + self.assertLess((y_q - y_hat).abs().max(), 1e-3) + + # Test qmm + x = mx.random.normal(shape=(10, 33)) + y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=False) + y_hat = x @ w_hat + self.assertEqual(y_q.shape, y_hat.shape) + self.assertLess((y_q - y_hat).abs().max(), 1e-3) + + # Smaller than 8 + w = mx.random.normal(shape=(3, 256)) + w_q, scales, biases = mx.quantize(w) + w_hat = mx.dequantize(w_q, scales, biases) + + # Test qmv + x = mx.random.normal(shape=(1, 256)) + y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True) + y_hat = x @ w_hat.T + self.assertLess((y_q - y_hat).abs().max(), 1e-3) + + # Test qmm_t + x = mx.random.normal(shape=(10, 256)) + y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True) + y_hat = x @ w_hat.T + self.assertEqual(y_q.shape, y_hat.shape) + self.assertLess((y_q - y_hat).abs().max(), 1e-3) + + # Test qvm + x = mx.random.normal(shape=(1, 3)) + y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=False) + y_hat = x @ w_hat + self.assertEqual(y_q.shape, y_hat.shape) + self.assertLess((y_q - y_hat).abs().max(), 1e-3) + + # Test qmm + x = mx.random.normal(shape=(10, 3)) + y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=False) + y_hat = x @ w_hat + self.assertEqual(y_q.shape, y_hat.shape) + self.assertLess((y_q - y_hat).abs().max(), 1e-3) + if __name__ == "__main__": unittest.main() From e54cbb7ba64faf58a08c11383a25fed837429be4 Mon Sep 17 00:00:00 2001 From: Gabrijel Boduljak Date: Tue, 13 Feb 2024 07:08:13 +0100 Subject: [PATCH 22/42] Pooling layers (#357) Co-authored-by: Angelos Katharopoulos Co-authored-by: Awni Hannun --- ACKNOWLEDGMENTS.md | 2 +- docs/src/python/nn/layers.rst | 4 + python/mlx/nn/layers/__init__.py | 1 + python/mlx/nn/layers/pooling.py | 308 ++++++++++++++++++++++++++++ python/tests/test_nn.py | 341 +++++++++++++++++++++++++++++++ 5 files changed, 655 insertions(+), 1 deletion(-) create mode 100644 python/mlx/nn/layers/pooling.py diff --git a/ACKNOWLEDGMENTS.md b/ACKNOWLEDGMENTS.md index 18a8c5599..2a3c6c612 100644 --- a/ACKNOWLEDGMENTS.md +++ b/ACKNOWLEDGMENTS.md @@ -11,7 +11,7 @@ MLX was developed with contributions from the following individuals: - Juarez Bochi: Fixed bug in cross attention. - Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example. - Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile` and safetensor support -- Gabrijel Boduljak: Added `mlx.core.linalg`, implemented `norm` method and `InstanceNorm` layer. +- Gabrijel Boduljak: Added `mlx.core.linalg`, implemented `norm` method and `InstanceNorm` layer. Implemented ``MaxPool1d``, ``MaxPool2d``, ``AvgPool1d``, ``AvgPool2d``. diff --git a/docs/src/python/nn/layers.rst b/docs/src/python/nn/layers.rst index ef099ef2f..0f5fca9db 100644 --- a/docs/src/python/nn/layers.rst +++ b/docs/src/python/nn/layers.rst @@ -10,6 +10,8 @@ Layers :template: nn-module-template.rst ALiBi + AvgPool1d + AvgPool2d BatchNorm Conv1d Conv2d @@ -22,6 +24,8 @@ Layers InstanceNorm LayerNorm Linear + MaxPool1d + MaxPool2d Mish MultiHeadAttention PReLU diff --git a/python/mlx/nn/layers/__init__.py b/python/mlx/nn/layers/__init__.py index f5092418b..207cb01b2 100644 --- a/python/mlx/nn/layers/__init__.py +++ b/python/mlx/nn/layers/__init__.py @@ -58,6 +58,7 @@ from mlx.nn.layers.normalization import ( LayerNorm, RMSNorm, ) +from mlx.nn.layers.pooling import AvgPool1d, AvgPool2d, MaxPool1d, MaxPool2d from mlx.nn.layers.positional_encoding import ALiBi, RoPE, SinusoidalPositionalEncoding from mlx.nn.layers.quantized import QuantizedLinear from mlx.nn.layers.transformer import ( diff --git a/python/mlx/nn/layers/pooling.py b/python/mlx/nn/layers/pooling.py new file mode 100644 index 000000000..ffa05f5d2 --- /dev/null +++ b/python/mlx/nn/layers/pooling.py @@ -0,0 +1,308 @@ +# Copyright © 2023-2024 Apple Inc. + +import operator +from itertools import accumulate +from typing import Optional, Tuple, Union + +import mlx.core as mx +from mlx.nn.layers.base import Module + + +def _value_or_list(x, n, msg): + if isinstance(x, (list, tuple)): + if len(x) != n: + raise ValueError(msg) + return list(x) + + if not isinstance(x, int): + raise ValueError(msg) + + return [x] * n + + +def _sliding_windows(x, window_shape, window_strides): + if x.ndim < 3: + raise ValueError( + f"To extract sliding windows at least 1 spatial dimension " + f"(3 total) is needed but the input only has {x.ndim} dimensions." + ) + + spatial_dims = x.shape[1:-1] + if not (len(spatial_dims) == len(window_shape) == len(window_strides)): + raise ValueError( + f"To extract sliding windows the window shapes and strides must have " + f"the same number of spatial dimensions as the signal but the signal " + f"has {len(spatial_dims)} dims and the window shape has {len(window_shape)} " + f"and strides have {len(window_strides)}." + ) + + shape = x.shape + strides = list(reversed(list(accumulate(reversed(shape + (1,)), operator.mul))))[1:] + + # Compute the output shape + final_shape = [shape[0]] + final_shape += [ + (size - window) // stride + 1 + for size, window, stride in zip(spatial_dims, window_shape, window_strides) + ] + final_shape += window_shape + final_shape += [shape[-1]] + + # Compute the output strides + final_strides = strides[:1] + final_strides += [ + og_stride * stride for og_stride, stride in zip(strides[1:-1], window_strides) + ] + final_strides += strides[1:-1] + final_strides += strides[-1:] # should always be [1] + + return mx.as_strided(x, final_shape, final_strides) + + +class _Pool(Module): + def __init__(self, pooling_function, kernel_size, stride, padding, padding_value): + super().__init__() + + self._pooling_function = pooling_function + self._kernel_size = kernel_size + self._stride = stride + self._padding = padding + self._padding_value = padding_value + self._axes = tuple(range(-len(self._kernel_size) - 1, -1, 1)) + + def _extra_repr(self): + ks = tuple(self._kernel_size) + st = tuple(self._stride) + pd = tuple(p[0] for p in self._padding) + + return f"kernel_size={ks}, stride={st}, padding={pd}" + + def __call__(self, x): + if any(p[0] > 0 for p in self._padding): + x = mx.pad(x, [(0, 0)] + self._padding + [(0, 0)], self._padding_value) + x = _sliding_windows(x, self._kernel_size, self._stride) + return self._pooling_function(x, self._axes) + + +class _Pool1d(_Pool): + def __init__( + self, + pooling_function, + padding_value, + kernel_size: Union[int, Tuple[int]], + stride: Optional[Union[int, Tuple[int]]] = None, + padding: Union[int, Tuple[int]] = 0, + ): + class_name = type(self).__name__ + msg = "[{}] '{}' must be an integer or a tuple containing 1 integer" + kernel_size = _value_or_list( + kernel_size, 1, msg.format(class_name, "kernel_size") + ) + if stride is not None: + stride = _value_or_list(stride, 1, msg.format(class_name, "stride")) + else: + stride = kernel_size + padding = _value_or_list(padding, 1, msg.format(class_name, "padding")) + padding = [(p, p) for p in padding] + + super().__init__(pooling_function, kernel_size, stride, padding, padding_value) + + +class _Pool2d(_Pool): + def __init__( + self, + pooling_function, + padding_value, + kernel_size: Union[int, Tuple[int, int]], + stride: Optional[Union[int, Tuple[int, int]]] = None, + padding: Optional[Union[int, Tuple[int, int]]] = 0, + ): + class_name = type(self).__name__ + msg = "[{}] '{}' must be an integer or a tuple containing 2 integers" + kernel_size = _value_or_list( + kernel_size, 2, msg.format(class_name, "kernel_size") + ) + if stride is not None: + stride = _value_or_list(stride, 2, msg.format(class_name, "stride")) + else: + stride = kernel_size + padding = _value_or_list(padding, 2, msg.format(class_name, "padding")) + padding = [(p, p) for p in padding] + + super().__init__(pooling_function, kernel_size, stride, padding, padding_value) + + +class MaxPool1d(_Pool1d): + r"""Applies 1-dimensional max pooling. + + Assuming an input of shape :math:`(N, L, C)` and ``kernel_size`` is + :math:`k`, the output is a tensor of shape :math:`(N, L_{out}, C)`, given + by: + + .. math:: + \text{out}(N_i, t, C_j) = \max_{m=0, \ldots, k - 1} + \text{input}(N_i, \text{stride} \times t + m, C_j), + + where :math:`L_{out} = \left\lfloor \frac{L + 2 \times \text{padding} - + \text{kernel_size}}{\text{stride}}\right\rfloor + 1`. + + Args: + kernel_size (int or tuple(int)): The size of the pooling window kernel. + stride (int or tuple(int), optional): The stride of the pooling window. + Default: ``kernel_size``. + padding (int or tuple(int), optional): How much negative infinity + padding to apply to the input. The padding amount is applied to + both sides of the spatial axis. Default: ``0``. + + Examples: + >>> import mlx.core as mx + >>> import mlx.nn.layers as nn + >>> x = mx.random.normal(shape=(4, 16, 5)) + >>> pool = nn.MaxPool1d(kernel_size=2, stride=2) + >>> pool(x) + """ + + def __init__( + self, + kernel_size: Union[int, Tuple[int, int]], + stride: Optional[Union[int, Tuple[int, int]]] = None, + padding: Optional[Union[int, Tuple[int, int]]] = 0, + ): + super().__init__(mx.max, -float("inf"), kernel_size, stride, padding) + + +class AvgPool1d(_Pool1d): + r"""Applies 1-dimensional average pooling. + + Assuming an input of shape :math:`(N, L, C)` and ``kernel_size`` is + :math:`k`, the output is a tensor of shape :math:`(N, L_{out}, C)`, given + by: + + .. math:: + \text{out}(N_i, t, C_j) = \frac{1}{k} \sum_{m=0, \ldots, k - 1} + \text{input}(N_i, \text{stride} \times t + m, C_j), + + where :math:`L_{out} = \left\lfloor \frac{L + 2 \times \text{padding} - + \text{kernel_size}}{\text{stride}}\right\rfloor + 1`. + + Args: + kernel_size (int or tuple(int)): The size of the pooling window kernel. + stride (int or tuple(int), optional): The stride of the pooling window. + Default: ``kernel_size``. + padding (int or tuple(int), optional): How much zero padding to apply to + the input. The padding amount is applied to both sides of the spatial + axis. Default: ``0``. + + Examples: + >>> import mlx.core as mx + >>> import mlx.nn.layers as nn + >>> x = mx.random.normal(shape=(4, 16, 5)) + >>> pool = nn.AvgPool1d(kernel_size=2, stride=2) + >>> pool(x) + """ + + def __init__( + self, + kernel_size: Union[int, Tuple[int, int]], + stride: Optional[Union[int, Tuple[int, int]]] = None, + padding: Optional[Union[int, Tuple[int, int]]] = 0, + ): + super().__init__(mx.mean, 0, kernel_size, stride, padding) + + +class MaxPool2d(_Pool2d): + r"""Applies 2-dimensional max pooling. + + Assuming an input of shape :math:`(N, H, W, C)` and ``kernel_size`` is + :math:`(k_H, k_W)`, the output is a tensor of shape :math:`(N, H_{out}, + W_{out}, C)`, given by: + + .. math:: + \begin{aligned} + \text{out}(N_i, h, w, C_j) = & \max_{m=0, \ldots, k_H-1} \max_{n=0, \ldots, k_W-1} \\ + & \text{input}(N_i, \text{stride[0]} \times h + m, + \text{stride[1]} \times w + n, C_j), + \end{aligned} + + where :math:`H_{out} = \left\lfloor\frac{H + 2 * \text{padding[0]} - \text{kernel_size[0]}}{\text{stride[0]}}\right\rfloor + 1`, + :math:`W_{out} = \left\lfloor\frac{W + 2 * \text{padding[1]} - \text{kernel_size[1]}}{\text{stride[1]}}\right\rfloor + 1`. + + The parameters ``kernel_size``, ``stride``, ``padding``, can either be: + + - a single ``int`` -- in which case the same value is used for both the + height and width axis; + - a ``tuple`` of two ``int`` s -- in which case, the first ``int`` is + used for the height axis, the second ``int`` for the width axis. + + Args: + kernel_size (int or tuple(int, int)): The size of the pooling window. + stride (int or tuple(int, int), optional): The stride of the pooling + window. Default: ``kernel_size``. + padding (int or tuple(int, int), optional): How much negative infinity + padding to apply to the input. The padding is applied on both sides + of the height and width axis. Default: ``0``. + + Examples: + >>> import mlx.core as mx + >>> import mlx.nn.layers as nn + >>> x = mx.random.normal(shape=(8, 32, 32, 4)) + >>> pool = nn.MaxPool2d(kernel_size=2, stride=2) + >>> pool(x) + """ + + def __init__( + self, + kernel_size: Union[int, Tuple[int, int]], + stride: Optional[Union[int, Tuple[int, int]]] = None, + padding: Optional[Union[int, Tuple[int, int]]] = 0, + ): + super().__init__(mx.max, -float("inf"), kernel_size, stride, padding) + + +class AvgPool2d(_Pool2d): + r"""Applies 2-dimensional average pooling. + + Assuming an input of shape :math:`(N, H, W, C)` and ``kernel_size`` is + :math:`(k_H, k_W)`, the output is a tensor of shape :math:`(N, H_{out}, + W_{out}, C)`, given by: + + .. math:: + \begin{aligned} + \text{out}(N_i, h, w, C_j) = & \frac{1}{k_H k_W} \sum_{m=0, \ldots, k_H-1} \sum_{n=0, \ldots, k_W-1} \\ + & \text{input}(N_i, \text{stride[0]} \times h + m, + \text{stride[1]} \times w + n, C_j), + \end{aligned} + + where :math:`H_{out} = \left\lfloor\frac{H + 2 * \text{padding[0]} - \text{kernel_size[0]}}{\text{stride[0]}}\right\rfloor + 1`, + :math:`W_{out} = \left\lfloor\frac{W + 2 * \text{padding[1]} - \text{kernel_size[1]}}{\text{stride[1]}}\right\rfloor + 1`. + + The parameters ``kernel_size``, ``stride``, ``padding``, can either be: + + - a single ``int`` -- in which case the same value is used for both the + height and width axis; + - a ``tuple`` of two ``int`` s -- in which case, the first ``int`` is + used for the height axis, the second ``int`` for the width axis. + + Args: + kernel_size (int or tuple(int, int)): The size of the pooling window. + stride (int or tuple(int, int), optional): The stride of the pooling + window. Default: ``kernel_size``. + padding (int or tuple(int, int), optional): How much zero + padding to apply to the input. The padding is applied on both sides + of the height and width axis. Default: ``0``. + + Examples: + >>> import mlx.core as mx + >>> import mlx.nn.layers as nn + >>> x = mx.random.normal(shape=(8, 32, 32, 4)) + >>> pool = nn.MaxPool2d(kernel_size=2, stride=2) + >>> pool(x) + """ + + def __init__( + self, + kernel_size: Union[int, Tuple[int, int]], + stride: Optional[Union[int, Tuple[int, int]]] = None, + padding: Optional[Union[int, Tuple[int, int]]] = 0, + ): + super().__init__(mx.mean, 0, kernel_size, stride, padding) diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index 201665f7f..eaaf3bb9c 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -905,6 +905,347 @@ class TestLayers(mlx_tests.MLXTestCase): self.assertTrue(y.shape, x.shape) self.assertTrue(y.dtype, mx.float16) + def test_pooling(self): + # Test 1d pooling + x = mx.array( + [ + [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]], + [[12, 13, 14], [15, 16, 17], [18, 19, 20], [21, 22, 23]], + ] + ) + expected_max_pool_output_no_padding_stride_1 = [ + [[3.0, 4.0, 5.0], [6.0, 7.0, 8.0], [9.0, 10.0, 11.0]], + [[15.0, 16.0, 17.0], [18.0, 19.0, 20.0], [21.0, 22.0, 23.0]], + ] + expected_max_pool_output_no_padding_stride_2 = [ + [[3.0, 4.0, 5.0], [9.0, 10.0, 11.0]], + [[15.0, 16.0, 17.0], [21.0, 22.0, 23.0]], + ] + expected_max_pool_output_padding_1_stride_2 = [ + [[0.0, 1.0, 2.0], [6.0, 7.0, 8.0], [9.0, 10.0, 11.0]], + [[12.0, 13.0, 14.0], [18.0, 19.0, 20.0], [21.0, 22.0, 23.0]], + ] + expected_max_pool_output_padding_1_stride_2_kernel_3 = [ + [[3.0, 4.0, 5.0], [9.0, 10.0, 11.0]], + [[15.0, 16.0, 17.0], [21.0, 22.0, 23.0]], + ] + expected_avg_pool_output_no_padding_stride_1 = [ + [ + [1.5000, 2.5000, 3.5000], + [4.5000, 5.5000, 6.5000], + [7.5000, 8.5000, 9.5000], + ], + [ + [13.5000, 14.5000, 15.5000], + [16.5000, 17.5000, 18.5000], + [19.5000, 20.5000, 21.5000], + ], + ] + expected_avg_pool_output_no_padding_stride_2 = [ + [[1.5000, 2.5000, 3.5000], [7.5000, 8.5000, 9.5000]], + [[13.5000, 14.5000, 15.5000], [19.5000, 20.5000, 21.5000]], + ] + expected_avg_pool_output_padding_1_stride_2 = [ + [ + [0.0000, 0.5000, 1.0000], + [4.5000, 5.5000, 6.5000], + [4.5000, 5.0000, 5.5000], + ], + [ + [6.0000, 6.5000, 7.0000], + [16.5000, 17.5000, 18.5000], + [10.5000, 11.0000, 11.5000], + ], + ] + expected_avg_pool_output_padding_1_kernel_3 = [ + [[1, 1.66667, 2.33333], [6, 7, 8]], + [[9, 9.66667, 10.3333], [18, 19, 20]], + ] + self.assertTrue( + np.array_equal( + nn.MaxPool1d(kernel_size=2, stride=1, padding=0)(x), + expected_max_pool_output_no_padding_stride_1, + ) + ) + self.assertTrue( + np.array_equal( + nn.MaxPool1d(kernel_size=2, stride=2, padding=0)(x), + expected_max_pool_output_no_padding_stride_2, + ) + ) + self.assertTrue( + np.array_equal( + nn.MaxPool1d(kernel_size=2, stride=2, padding=1)(x), + expected_max_pool_output_padding_1_stride_2, + ) + ) + self.assertTrue( + np.array_equal( + nn.MaxPool1d(kernel_size=3, stride=2, padding=1)(x), + expected_max_pool_output_padding_1_stride_2_kernel_3, + ) + ) + self.assertTrue( + np.allclose( + nn.AvgPool1d(kernel_size=2, stride=1, padding=0)(x), + expected_avg_pool_output_no_padding_stride_1, + ) + ) + self.assertTrue( + np.allclose( + nn.AvgPool1d(kernel_size=2, stride=2, padding=0)(x), + expected_avg_pool_output_no_padding_stride_2, + ) + ) + self.assertTrue( + np.allclose( + nn.AvgPool1d(kernel_size=2, stride=2, padding=1)(x), + expected_avg_pool_output_padding_1_stride_2, + ) + ) + self.assertTrue( + np.allclose( + nn.AvgPool1d(kernel_size=3, stride=2, padding=1)(x), + expected_avg_pool_output_padding_1_kernel_3, + ) + ) + # Test 2d pooling + x = mx.array( + [ + [ + [[0, 16], [1, 17], [2, 18], [3, 19]], + [[4, 20], [5, 21], [6, 22], [7, 23]], + [[8, 24], [9, 25], [10, 26], [11, 27]], + [[12, 28], [13, 29], [14, 30], [15, 31]], + ] + ] + ) + expected_max_pool_output_no_padding_stride_1 = [ + [ + [[5, 21], [6, 22], [7, 23]], + [[9, 25], [10, 26], [11, 27]], + [[13, 29], [14, 30], [15, 31]], + ] + ] + expected_max_pool_output_no_padding_stride_2 = [ + [[[5, 21], [7, 23]], [[13, 29], [15, 31]]] + ] + expected_max_pool_output_padding_1 = [ + [ + [[0, 16], [2, 18], [3, 19]], + [[8, 24], [10, 26], [11, 27]], + [[12, 28], [14, 30], [15, 31]], + ] + ] + expected_mean_pool_output_no_padding_stride_1 = [ + [ + [[2.5000, 18.5000], [3.5000, 19.5000], [4.5000, 20.5000]], + [[6.5000, 22.5000], [7.5000, 23.5000], [8.5000, 24.5000]], + [[10.5000, 26.5000], [11.5000, 27.5000], [12.5000, 28.5000]], + ] + ] + expected_mean_pool_output_no_padding_stride_2 = [ + [ + [[2.5000, 18.5000], [4.5000, 20.5000]], + [[10.5000, 26.5000], [12.5000, 28.5000]], + ] + ] + expected_mean_pool_output_padding_1 = [ + [ + [[0.0000, 4.0000], [0.7500, 8.7500], [0.7500, 4.7500]], + [[3.0000, 11.0000], [7.5000, 23.5000], [4.5000, 12.5000]], + [[3.0000, 7.0000], [6.7500, 14.7500], [3.7500, 7.7500]], + ] + ] + self.assertTrue( + np.array_equal( + nn.MaxPool2d(kernel_size=2, stride=1, padding=0)(x), + expected_max_pool_output_no_padding_stride_1, + ) + ) + self.assertTrue( + np.array_equal( + nn.MaxPool2d(kernel_size=2, stride=2, padding=0)(x), + expected_max_pool_output_no_padding_stride_2, + ) + ) + self.assertTrue( + np.array_equal( + nn.MaxPool2d(kernel_size=2, stride=2, padding=1)(x), + expected_max_pool_output_padding_1, + ) + ) + # Average pooling + self.assertTrue( + np.allclose( + nn.AvgPool2d(kernel_size=2, stride=1, padding=0)(x), + expected_mean_pool_output_no_padding_stride_1, + ) + ) + self.assertTrue( + np.array_equal( + nn.AvgPool2d(kernel_size=2, stride=2, padding=0)(x), + expected_mean_pool_output_no_padding_stride_2, + ) + ) + self.assertTrue( + np.array_equal( + nn.AvgPool2d(kernel_size=2, stride=2, padding=1)(x), + expected_mean_pool_output_padding_1, + ) + ) + # Test multiple batches + x = mx.array( + [ + [ + [[0, 1], [2, 3], [4, 5], [6, 7]], + [[8, 9], [10, 11], [12, 13], [14, 15]], + [[16, 17], [18, 19], [20, 21], [22, 23]], + [[24, 25], [26, 27], [28, 29], [30, 31]], + ], + [ + [[32, 33], [34, 35], [36, 37], [38, 39]], + [[40, 41], [42, 43], [44, 45], [46, 47]], + [[48, 49], [50, 51], [52, 53], [54, 55]], + [[56, 57], [58, 59], [60, 61], [62, 63]], + ], + ] + ) + expected_max_pool_output = [ + [[[10.0, 11.0], [14.0, 15.0]], [[26.0, 27.0], [30.0, 31.0]]], + [[[42.0, 43.0], [46.0, 47.0]], [[58.0, 59.0], [62.0, 63.0]]], + ] + expected_avg_pool_output = [ + [[[2.22222, 2.66667], [5.33333, 6]], [[11.3333, 12], [20, 21]]], + [[[16.4444, 16.8889], [26.6667, 27.3333]], [[32.6667, 33.3333], [52, 53]]], + ] + self.assertTrue( + np.array_equal( + nn.MaxPool2d(kernel_size=3, stride=2, padding=1)(x), + expected_max_pool_output, + ) + ) + self.assertTrue( + np.allclose( + nn.AvgPool2d(kernel_size=3, stride=2, padding=1)(x), + expected_avg_pool_output, + ) + ) + # Test irregular kernel (2, 4), stride (3, 1) and padding (1, 2) + x = mx.array( + [ + [ + [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]], + [[12, 13, 14], [15, 16, 17], [18, 19, 20], [21, 22, 23]], + [[24, 25, 26], [27, 28, 29], [30, 31, 32], [33, 34, 35]], + [[36, 37, 38], [39, 40, 41], [42, 43, 44], [45, 46, 47]], + ], + [ + [[48, 49, 50], [51, 52, 53], [54, 55, 56], [57, 58, 59]], + [[60, 61, 62], [63, 64, 65], [66, 67, 68], [69, 70, 71]], + [[72, 73, 74], [75, 76, 77], [78, 79, 80], [81, 82, 83]], + [[84, 85, 86], [87, 88, 89], [90, 91, 92], [93, 94, 95]], + ], + ] + ) + expected_irregular_max_pool_output = [ + [ + [ + [3.0, 4.0, 5.0], + [6.0, 7.0, 8.0], + [9.0, 10.0, 11.0], + [9.0, 10.0, 11.0], + [9.0, 10.0, 11.0], + ], + [ + [39.0, 40.0, 41.0], + [42.0, 43.0, 44.0], + [45.0, 46.0, 47.0], + [45.0, 46.0, 47.0], + [45.0, 46.0, 47.0], + ], + ], + [ + [ + [51.0, 52.0, 53.0], + [54.0, 55.0, 56.0], + [57.0, 58.0, 59.0], + [57.0, 58.0, 59.0], + [57.0, 58.0, 59.0], + ], + [ + [87.0, 88.0, 89.0], + [90.0, 91.0, 92.0], + [93.0, 94.0, 95.0], + [93.0, 94.0, 95.0], + [93.0, 94.0, 95.0], + ], + ], + ] + expected_irregular_average_pool_output = [ + [ + [ + [0.3750, 0.6250, 0.8750], + [1.1250, 1.5000, 1.8750], + [2.2500, 2.7500, 3.2500], + [2.2500, 2.6250, 3.0000], + [1.8750, 2.1250, 2.3750], + ], + [ + [15.7500, 16.2500, 16.7500], + [24.7500, 25.5000, 26.2500], + [34.5000, 35.5000, 36.5000], + [27.0000, 27.7500, 28.5000], + [18.7500, 19.2500, 19.7500], + ], + ], + [ + [ + [12.3750, 12.6250, 12.8750], + [19.1250, 19.5000, 19.8750], + [26.2500, 26.7500, 27.2500], + [20.2500, 20.6250, 21.0000], + [13.8750, 14.1250, 14.3750], + ], + [ + [39.7500, 40.2500, 40.7500], + [60.7500, 61.5000, 62.2500], + [82.5000, 83.5000, 84.5000], + [63.0000, 63.7500, 64.5000], + [42.7500, 43.2500, 43.7500], + ], + ], + ] + self.assertTrue( + np.array_equal( + nn.MaxPool2d(kernel_size=(2, 4), stride=(3, 1), padding=(1, 2))(x), + expected_irregular_max_pool_output, + ) + ) + self.assertTrue( + np.allclose( + nn.AvgPool2d(kernel_size=(2, 4), stride=(3, 1), padding=(1, 2))(x), + expected_irregular_average_pool_output, + ) + ) + # Test repr + self.assertEqual( + str(nn.MaxPool1d(kernel_size=3, padding=2)), + "MaxPool1d(kernel_size=(3,), stride=(3,), padding=(2,))", + ) + self.assertEqual( + str(nn.AvgPool1d(kernel_size=2, stride=3)), + "AvgPool1d(kernel_size=(2,), stride=(3,), padding=(0,))", + ) + self.assertEqual( + str(nn.MaxPool2d(kernel_size=3, stride=2, padding=1)), + "MaxPool2d(kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))", + ) + self.assertEqual( + str(nn.AvgPool2d(kernel_size=(1, 2), stride=2, padding=(1, 2))), + "AvgPool2d(kernel_size=(1, 2), stride=(2, 2), padding=(1, 2))", + ) + if __name__ == "__main__": unittest.main() From be6e9d6a9f0c6e982584e0c2adbd3483cbdb2703 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hinrik=20Sn=C3=A6r=20Gu=C3=B0mundsson?= Date: Tue, 13 Feb 2024 11:39:02 -0500 Subject: [PATCH 23/42] Fixed wording in extensions.rst (#678) changed "learn how add" -> "learn how to add" --- docs/src/dev/extensions.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/src/dev/extensions.rst b/docs/src/dev/extensions.rst index 3563305bf..5f63c6337 100644 --- a/docs/src/dev/extensions.rst +++ b/docs/src/dev/extensions.rst @@ -35,7 +35,7 @@ However, you work with vector math libraries often and realize that the You would really like the part of your applications that does this operation on the CPU to be very fast - so you decide that you want it to rely on the ``axpby`` routine provided by the Accelerate_ framework. Continuing to impose -our assumptions on to you, let's also assume that you want to learn how add +our assumptions on to you, let's also assume that you want to learn how to add your own implementation for the gradients of your new operation while going over the ins-and-outs of the MLX framework. From 2fdc2462c3b7e261d69700075a50252cb2e8c83f Mon Sep 17 00:00:00 2001 From: Vijay Krish Date: Tue, 13 Feb 2024 17:47:41 -0800 Subject: [PATCH 24/42] Faster gather and scatter. (#682) Reduce unnecessary integer ops, especially since there kernels are integer bound. Increase number of iterations for benchmarks for better smoothing. Github Issue #506 Co-authored-by: Vijay Krishnamoorthy --- benchmarks/python/time_utils.py | 2 +- mlx/backend/metal/kernels/utils.h | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/benchmarks/python/time_utils.py b/benchmarks/python/time_utils.py index 32d22ed99..f10635ec9 100644 --- a/benchmarks/python/time_utils.py +++ b/benchmarks/python/time_utils.py @@ -28,7 +28,7 @@ def measure_runtime(fn, **kwargs): fn(**kwargs) tic = time.time() - iters = 10 + iters = 100 for _ in range(iters): fn(**kwargs) return (time.time() - tic) * 1000 / iters diff --git a/mlx/backend/metal/kernels/utils.h b/mlx/backend/metal/kernels/utils.h index f9d507cf2..8ef1127b6 100644 --- a/mlx/backend/metal/kernels/utils.h +++ b/mlx/backend/metal/kernels/utils.h @@ -71,7 +71,7 @@ inline size_t elem_to_loc( device const size_t* strides, int ndim) { size_t loc = 0; - for (int i = ndim - 1; i >= 0; --i) { + for (int i = ndim - 1; i >= 0 && elem > 0; --i) { loc += (elem % shape[i]) * strides[i]; elem /= shape[i]; } @@ -84,7 +84,7 @@ inline size_t elem_to_loc( constant const size_t* strides, int ndim) { size_t loc = 0; - for (int i = ndim - 1; i >= 0; --i) { + for (int i = ndim - 1; i >= 0 && elem > 0; --i) { loc += (elem % shape[i]) * strides[i]; elem /= shape[i]; } From 0c65517e910e3d0d1e07bc853561b4fd6f92d42c Mon Sep 17 00:00:00 2001 From: Noah Farr <69793313+noahfarr@users.noreply.github.com> Date: Wed, 14 Feb 2024 02:49:31 +0100 Subject: [PATCH 25/42] Return empty array when repeats is 0 in mx.repeat (#681) * Return empty array when repeats is 0 * Add test case for repeats = 0 --- mlx/ops.cpp | 2 +- python/tests/test_ops.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 01ee6d388..96107d515 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -774,7 +774,7 @@ array repeat(const array& arr, int repeats, int axis, StreamOrDevice s) { } if (repeats == 0) { - return array({}, arr.dtype()); + return array(std::initializer_list{}, arr.dtype()); } if (repeats == 1) { diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index edb98032b..5588ebd62 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -1699,6 +1699,8 @@ class TestOps(mlx_tests.MLXTestCase): def test_repeat(self): # Setup data for the tests data = mx.array([[[13, 3], [16, 6]], [[14, 4], [15, 5]], [[11, 1], [12, 2]]]) + # Test repeat 0 times + self.assertCmpNumpy([data, 0], mx.repeat, np.repeat) # Test repeat along axis 0 self.assertCmpNumpy([data, 2], mx.repeat, np.repeat, axis=0) # Test repeat along axis 1 From 1eb04aa23f92389b45e863f142d86210ae6ec96b Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Tue, 13 Feb 2024 23:34:17 -0800 Subject: [PATCH 26/42] Fix empty array construction in cpp (#684) --- mlx/array.cpp | 7 +++++++ mlx/array.h | 3 +++ mlx/ops.cpp | 2 +- tests/array_tests.cpp | 18 ++++++++++++++++++ 4 files changed, 29 insertions(+), 1 deletion(-) diff --git a/mlx/array.cpp b/mlx/array.cpp index 7f3dd854b..83c2fe6d7 100644 --- a/mlx/array.cpp +++ b/mlx/array.cpp @@ -82,6 +82,13 @@ array::array(std::initializer_list data) init(data.begin()); } +array::array(std::initializer_list data, Dtype dtype) + : array_desc_(std::make_shared( + std::vector{static_cast(data.size())}, + dtype)) { + init(data.begin()); +} + /* Build an array from a shared buffer */ array::array( allocator::Buffer data, diff --git a/mlx/array.h b/mlx/array.h index 5eefcf727..fe01cbfd7 100644 --- a/mlx/array.h +++ b/mlx/array.h @@ -41,6 +41,9 @@ class array { /* Special case so empty lists default to float32. */ array(std::initializer_list data); + /* Special case so array({}, type) is an empty array. */ + array(std::initializer_list data, Dtype dtype); + template array( std::initializer_list data, diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 96107d515..01ee6d388 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -774,7 +774,7 @@ array repeat(const array& arr, int repeats, int axis, StreamOrDevice s) { } if (repeats == 0) { - return array(std::initializer_list{}, arr.dtype()); + return array({}, arr.dtype()); } if (repeats == 1) { diff --git a/tests/array_tests.cpp b/tests/array_tests.cpp index 080d53daa..62341c5c7 100644 --- a/tests/array_tests.cpp +++ b/tests/array_tests.cpp @@ -591,3 +591,21 @@ TEST_CASE("test array shared buffer") { eval(a + b); } + +TEST_CASE("test make empty array") { + auto a = array({}); + CHECK_EQ(a.size(), 0); + CHECK_EQ(a.dtype(), float32); + + a = array({}, int32); + CHECK_EQ(a.size(), 0); + CHECK_EQ(a.dtype(), int32); + + a = array({}, float32); + CHECK_EQ(a.size(), 0); + CHECK_EQ(a.dtype(), float32); + + a = array({}, bool_); + CHECK_EQ(a.size(), 0); + CHECK_EQ(a.dtype(), bool_); +} From 1a48713d32268ba9ffaa8ff744c55e9fce9356a8 Mon Sep 17 00:00:00 2001 From: Jagrit Digani Date: Wed, 14 Feb 2024 13:42:13 -0800 Subject: [PATCH 27/42] Update gather and scatter to not use Argument Encoder (#683) * Replace argument encoder usage for gather and scatter * Use constant address space for shapes and strides * Split gather and scatter to improve compile times * Enable the GPU tests * Update the CI config * Fix scatter dispatch for scalar indices * Remove arg encoder utils --------- Co-authored-by: Angelos Katharopoulos --- .circleci/config.yml | 12 +- mlx/backend/metal/device.cpp | 9 - mlx/backend/metal/indexing.cpp | 211 ++++++----------- mlx/backend/metal/kernels/CMakeLists.txt | 4 +- mlx/backend/metal/kernels/gather.metal | 187 +++++++++++++++ mlx/backend/metal/kernels/indexing.h | 54 +++++ mlx/backend/metal/kernels/indexing.metal | 290 ----------------------- mlx/backend/metal/kernels/scatter.metal | 194 +++++++++++++++ mlx/backend/metal/utils.h | 14 -- 9 files changed, 514 insertions(+), 461 deletions(-) create mode 100644 mlx/backend/metal/kernels/gather.metal create mode 100644 mlx/backend/metal/kernels/indexing.h delete mode 100644 mlx/backend/metal/kernels/indexing.metal create mode 100644 mlx/backend/metal/kernels/scatter.metal diff --git a/.circleci/config.yml b/.circleci/config.yml index 537f15969..5f26778c4 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -94,8 +94,7 @@ jobs: command: | source env/bin/activate LOW_MEMORY=1 DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu - # TODO: Reenable when Circle CI can run gpu jobs - # DEVICE=gpu python3.9 -m xmlrunner discover -v python/tests -o test-results/gpu + LOW_MEMORY=1 DEVICE=gpu python3.9 -m xmlrunner discover -v python/tests -o test-results/gpu # TODO: Reenable when extension api becomes stable # - run: # name: Build example extension @@ -110,8 +109,9 @@ jobs: mkdir -p build && cd build && cmake .. && make -j - run: name: Run CPP tests - #command: METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 ./build/tests/tests - command: DEVICE=cpu ./build/tests/tests + command: | + DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 ./build/tests/tests + DEVICE=cpu ./build/tests/tests build_release: parameters: @@ -225,7 +225,9 @@ workflows: build_and_test: when: and: - - equal: [ main, << pipeline.git.branch >> ] + - matches: + pattern: "^(?!pull/)[-\\w]+$" + value: << pipeline.git.branch >> - not: << pipeline.parameters.nightly_build >> - not: << pipeline.parameters.weekly_build >> - not: << pipeline.parameters.test_release >> diff --git a/mlx/backend/metal/device.cpp b/mlx/backend/metal/device.cpp index e50441d48..1fb2bd46f 100644 --- a/mlx/backend/metal/device.cpp +++ b/mlx/backend/metal/device.cpp @@ -215,15 +215,6 @@ MTL::ComputeCommandEncoder* Device::get_command_encoder(int index) { return eit->second; } -MTL::ArgumentEncoder* Device::argument_encoder( - const std::vector& arg_descs) const { - // NB array here is already autoreleased but the returned argument - // encoder is owned by the caller and must be released/autoreleased - NS::Array* arg_desc_arr = NS::Array::array( - reinterpret_cast(arg_descs.data()), arg_descs.size()); - return device_->newArgumentEncoder(arg_desc_arr); -} - void Device::register_library( const std::string& lib_name, const std::string& lib_path) { diff --git a/mlx/backend/metal/indexing.cpp b/mlx/backend/metal/indexing.cpp index cf2256846..6908f8684 100644 --- a/mlx/backend/metal/indexing.cpp +++ b/mlx/backend/metal/indexing.cpp @@ -51,6 +51,7 @@ void Gather::eval_gpu(const std::vector& inputs, array& out) { auto compute_encoder = d.get_command_encoder(s.index); auto kernel = d.get_kernel(kname.str()); + compute_encoder->setComputePipelineState(kernel); size_t slice_size = 1; for (auto s : slice_sizes_) { @@ -63,91 +64,50 @@ void Gather::eval_gpu(const std::vector& inputs, array& out) { auto group_dims = get_block_dims(dim0, dim1, 1); MTL::Size grid_dims = MTL::Size(dim0, dim1, 1); - compute_encoder->setComputePipelineState(kernel); + // Collect all idx shapes and strides into one place + std::vector idx_shapes; + std::vector idx_strides; - // Make the argument buffer to store the indices for the - // `Indices` struct in kernels/indexing.metal - std::vector arg_descs(4); - arg_descs[0] = MTL::ArgumentDescriptor::argumentDescriptor(); - arg_descs[0]->setIndex(0); - arg_descs[0]->setDataType(MTL::DataType::DataTypePointer); - arg_descs[0]->setArrayLength(nidx); - - // Shapes - arg_descs[1] = MTL::ArgumentDescriptor::argumentDescriptor(); - arg_descs[1]->setDataType(MTL::DataType::DataTypePointer); - arg_descs[1]->setIndex(nidx + 1); - - // Strides - arg_descs[2] = MTL::ArgumentDescriptor::argumentDescriptor(); - arg_descs[2]->setDataType(MTL::DataType::DataTypePointer); - arg_descs[2]->setIndex(nidx + 2); - - // Indices ndim - arg_descs[3] = MTL::ArgumentDescriptor::argumentDescriptor(); - arg_descs[3]->setDataType(MTL::DataType::DataTypeInt); - arg_descs[3]->setIndex(nidx + 3); - - // Get the argument encoder - auto arg_enc = d.argument_encoder(arg_descs); - - // Allocate and fill buffers for shapes and strides - auto idx_shapes_buf = allocator::malloc_or_wait(sizeof(int) * idx_ndim); - auto idx_strides_buf = allocator::malloc_or_wait(sizeof(size_t) * idx_ndim); for (int i = 0; i < nidx; ++i) { - std::copy( + idx_shapes.insert( + idx_shapes.end(), inputs[i + 1].shape().begin(), - inputs[i + 1].shape().end(), - static_cast(idx_shapes_buf.raw_ptr()) + i * idx_ndim); - std::copy( + inputs[i + 1].shape().end()); + + idx_strides.insert( + idx_strides.end(), inputs[i + 1].strides().begin(), - inputs[i + 1].strides().end(), - static_cast(idx_strides_buf.raw_ptr()) + i * idx_ndim); + inputs[i + 1].strides().end()); } - // Allocate the argument buffer - auto arg_buf = allocator::malloc_or_wait(arg_enc->encodedLength()); - - // Register data with the encoder - arg_enc->setArgumentBuffer(static_cast(arg_buf.ptr()), 0); - for (int i = 0; i < nidx; ++i) { - set_array_buffer(compute_encoder, arg_enc, inputs[i + 1], i); - } - if (idx_ndim > 0) { - arg_enc->setBuffer( - static_cast(idx_shapes_buf.ptr()), 0, nidx + 1); - compute_encoder->useResource( - static_cast(idx_shapes_buf.ptr()), - MTL::ResourceUsageRead); - arg_enc->setBuffer( - static_cast(idx_strides_buf.ptr()), 0, nidx + 2); - compute_encoder->useResource( - static_cast(idx_strides_buf.ptr()), - MTL::ResourceUsageRead); - } - *static_cast(arg_enc->constantData(nidx + 3)) = idx_ndim; - // Set all the buffers set_array_buffer(compute_encoder, src, 0); - compute_encoder->setBuffer(static_cast(arg_buf.ptr()), 0, 1); - set_array_buffer(compute_encoder, out, 2); + set_array_buffer(compute_encoder, out, 1); - compute_encoder->setBytes(src.shape().data(), ndim * sizeof(int), 3); - compute_encoder->setBytes(src.strides().data(), ndim * sizeof(size_t), 4); - compute_encoder->setBytes(&ndim, sizeof(size_t), 5); - compute_encoder->setBytes(slice_sizes_.data(), ndim * sizeof(int), 6); - compute_encoder->setBytes(axes_.data(), nidx * sizeof(int), 7); + // Set source info + compute_encoder->setBytes(src.shape().data(), ndim * sizeof(int), 2); + compute_encoder->setBytes(src.strides().data(), ndim * sizeof(size_t), 3); + compute_encoder->setBytes(&ndim, sizeof(size_t), 4); + compute_encoder->setBytes(slice_sizes_.data(), ndim * sizeof(int), 5); + compute_encoder->setBytes(axes_.data(), nidx * sizeof(int), 6); + // Set index info + // + // We don't need to check for empty idx_shapes because gather has a + // idx_ndim == 0 specialization + compute_encoder->setBytes( + idx_shapes.data(), idx_shapes.size() * sizeof(int), 7); + compute_encoder->setBytes( + idx_strides.data(), idx_strides.size() * sizeof(size_t), 8); + compute_encoder->setBytes(&idx_ndim, sizeof(int), 9); + + // Set index buffers + for (int i = 1; i < nidx + 1; ++i) { + set_array_buffer(compute_encoder, inputs[i], 20 + i); + } + + // Launch grid compute_encoder->dispatchThreads(grid_dims, group_dims); - - // Cleanup temporaries - arg_enc->release(); - d.get_command_buffer(s.index)->addCompletedHandler( - [arg_buf, idx_shapes_buf, idx_strides_buf](MTL::CommandBuffer*) { - allocator::free(arg_buf); - allocator::free(idx_shapes_buf); - allocator::free(idx_strides_buf); - }); } void Scatter::eval_gpu(const std::vector& inputs, array& out) { @@ -214,77 +174,33 @@ void Scatter::eval_gpu(const std::vector& inputs, array& out) { compute_encoder->setComputePipelineState(kernel); - // Make the argument buffer to store the indices for the - // `Indices` struct in kernels/indexing.metal - std::vector arg_descs(4); - arg_descs[0] = MTL::ArgumentDescriptor::argumentDescriptor(); - arg_descs[0]->setIndex(0); - arg_descs[0]->setDataType(MTL::DataType::DataTypePointer); - arg_descs[0]->setArrayLength(nidx); - - // Shapes - arg_descs[1] = MTL::ArgumentDescriptor::argumentDescriptor(); - arg_descs[1]->setDataType(MTL::DataType::DataTypePointer); - arg_descs[1]->setIndex(nidx + 1); - - // Strides - arg_descs[2] = MTL::ArgumentDescriptor::argumentDescriptor(); - arg_descs[2]->setDataType(MTL::DataType::DataTypePointer); - arg_descs[2]->setIndex(nidx + 2); - - // Indices ndim - arg_descs[3] = MTL::ArgumentDescriptor::argumentDescriptor(); - arg_descs[3]->setDataType(MTL::DataType::DataTypeInt); - arg_descs[3]->setIndex(nidx + 3); - - // Get the argument encoder - auto arg_enc = d.argument_encoder(arg_descs); - - // Allocate and fill buffers for shapes and strides + // Collect all idx shapes and strides into one place int idx_ndim = nidx ? inputs[1].ndim() : 0; - auto idx_shapes_buf = allocator::malloc_or_wait(sizeof(int) * idx_ndim); - auto idx_strides_buf = allocator::malloc_or_wait(sizeof(size_t) * idx_ndim); + std::vector idx_shapes; + std::vector idx_strides; + for (int i = 0; i < nidx; ++i) { - std::copy( + idx_shapes.insert( + idx_shapes.end(), inputs[i + 1].shape().begin(), - inputs[i + 1].shape().end(), - static_cast(idx_shapes_buf.raw_ptr()) + i * idx_ndim); - std::copy( + inputs[i + 1].shape().end()); + + idx_strides.insert( + idx_strides.end(), inputs[i + 1].strides().begin(), - inputs[i + 1].strides().end(), - static_cast(idx_strides_buf.raw_ptr()) + i * idx_ndim); + inputs[i + 1].strides().end()); } - // Allocate the argument buffer - auto arg_buf = allocator::malloc_or_wait(arg_enc->encodedLength()); + // Set all the buffers + set_array_buffer(compute_encoder, upd, 1); + set_array_buffer(compute_encoder, out, 2); - // Register data with the encoder - arg_enc->setArgumentBuffer(static_cast(arg_buf.ptr()), 0); - for (int i = 0; i < nidx; ++i) { - set_array_buffer(compute_encoder, arg_enc, inputs[i + 1], i); - } - if (idx_ndim > 0) { - arg_enc->setBuffer( - static_cast(idx_shapes_buf.ptr()), 0, nidx + 1); - compute_encoder->useResource( - static_cast(idx_shapes_buf.ptr()), - MTL::ResourceUsageRead); - arg_enc->setBuffer( - static_cast(idx_strides_buf.ptr()), 0, nidx + 2); - compute_encoder->useResource( - static_cast(idx_strides_buf.ptr()), - MTL::ResourceUsageRead); - } - *static_cast(arg_enc->constantData(nidx + 3)) = idx_ndim; - - compute_encoder->setBuffer(static_cast(arg_buf.ptr()), 0, 0); + // Set update info size_t upd_ndim = upd.ndim(); size_t upd_size = 1; for (int i = idx_ndim; i < upd.ndim(); ++i) { upd_size *= upd.shape(i); } - set_array_buffer(compute_encoder, upd, 1); - set_array_buffer(compute_encoder, out, 2); if (upd_ndim == 0) { // Need placeholders so Metal doesn't compalain int shape_ = 0; @@ -299,6 +215,7 @@ void Scatter::eval_gpu(const std::vector& inputs, array& out) { compute_encoder->setBytes(&upd_ndim, sizeof(size_t), 5); compute_encoder->setBytes(&upd_size, sizeof(size_t), 6); + // Set output info size_t out_ndim = out.ndim(); if (out_ndim == 0) { // Need placeholders so Metal doesn't compalain @@ -314,18 +231,28 @@ void Scatter::eval_gpu(const std::vector& inputs, array& out) { compute_encoder->setBytes(&out_ndim, sizeof(size_t), 9); compute_encoder->setBytes(axes_.data(), axes_.size() * sizeof(int), 10); + // Set index info + if (idx_ndim == 0) { + // Add a 0 in idx_shapes and strides to avoid the missing buffer binding + // error in the metal API. + idx_shapes.push_back(0); + idx_strides.push_back(0); + } + compute_encoder->setBytes( + idx_shapes.data(), idx_shapes.size() * sizeof(int), 11); + compute_encoder->setBytes( + idx_strides.data(), idx_strides.size() * sizeof(size_t), 12); + compute_encoder->setBytes(&idx_ndim, sizeof(int), 13); + + // Set index buffers + for (int i = 1; i < nidx + 1; ++i) { + set_array_buffer(compute_encoder, inputs[i], 20 + i); + } + + // Launch grid MTL::Size grid_dims = MTL::Size(upd_size, nthreads / upd_size, 1); MTL::Size group_dims = get_block_dims(upd_size, nthreads / upd_size, 1); compute_encoder->dispatchThreads(grid_dims, group_dims); - - // Cleanup temporaries - arg_enc->release(); - d.get_command_buffer(s.index)->addCompletedHandler( - [arg_buf, idx_shapes_buf, idx_strides_buf](MTL::CommandBuffer*) { - allocator::free(arg_buf); - allocator::free(idx_shapes_buf); - allocator::free(idx_strides_buf); - }); } } // namespace mlx::core diff --git a/mlx/backend/metal/kernels/CMakeLists.txt b/mlx/backend/metal/kernels/CMakeLists.txt index 2d271abb4..12e09deaa 100644 --- a/mlx/backend/metal/kernels/CMakeLists.txt +++ b/mlx/backend/metal/kernels/CMakeLists.txt @@ -6,6 +6,7 @@ set( ${CMAKE_CURRENT_SOURCE_DIR}/complex.h ${CMAKE_CURRENT_SOURCE_DIR}/defines.h ${CMAKE_CURRENT_SOURCE_DIR}/erf.h + ${CMAKE_CURRENT_SOURCE_DIR}/indexing.h ${CMAKE_CURRENT_SOURCE_DIR}/reduce.h ${CMAKE_CURRENT_SOURCE_DIR}/utils.h ) @@ -26,7 +27,8 @@ set( "softmax" "sort" "unary" - "indexing" + "gather" + "scatter" ) function(build_kernel_base TARGET SRCFILE DEPS) diff --git a/mlx/backend/metal/kernels/gather.metal b/mlx/backend/metal/kernels/gather.metal new file mode 100644 index 000000000..793b2af62 --- /dev/null +++ b/mlx/backend/metal/kernels/gather.metal @@ -0,0 +1,187 @@ +// Copyright © 2023-2024 Apple Inc. + +#include + +#include "mlx/backend/metal/kernels/bf16.h" +#include "mlx/backend/metal/kernels/indexing.h" +#include "mlx/backend/metal/kernels/utils.h" + +using namespace metal; + +///////////////////////////////////////////////////////////////////// +// Gather kernel +///////////////////////////////////////////////////////////////////// + +template +METAL_FUNC void gather_impl( + const device T *src [[buffer(0)]], + device T *out [[buffer(1)]], + const constant int *src_shape [[buffer(2)]], + const constant size_t *src_strides [[buffer(3)]], + const constant size_t& src_ndim [[buffer(4)]], + const constant int *slice_sizes [[buffer(5)]], + const constant int *axes [[buffer(6)]], + const thread Indices& indices, + uint2 index [[thread_position_in_grid]], + uint2 grid_dim [[threads_per_grid]]) { + + auto ind_idx = index.x; + auto ind_offset = index.y; + + size_t src_idx = 0; + for (int i = 0; i < NIDX; ++i) { + size_t idx_loc; + if (IDX_NDIM == 0) { + idx_loc = 0; + } else if (IDX_NDIM == 1) { + idx_loc = ind_idx * indices.strides[indices.ndim * i]; + } else { + idx_loc = elem_to_loc( + ind_idx, + &indices.shapes[indices.ndim * i], + &indices.strides[indices.ndim * i], + indices.ndim); + } + auto ax = axes[i]; + auto idx_val = offset_neg_idx( + indices.buffers[i][idx_loc], src_shape[ax]); + src_idx += idx_val * src_strides[ax]; + } + + auto src_offset = elem_to_loc( + ind_offset, slice_sizes, src_strides, src_ndim); + + size_t out_idx = index.y + static_cast(grid_dim.y) * index.x; + out[out_idx] = src[src_offset + src_idx]; + +} + +#define make_gather_impl(IDX_ARG, IDX_ARR) \ +template \ +[[kernel]] void gather( \ + const device T *src [[buffer(0)]], \ + device T *out [[buffer(1)]], \ + const constant int *src_shape [[buffer(2)]], \ + const constant size_t *src_strides [[buffer(3)]], \ + const constant size_t& src_ndim [[buffer(4)]], \ + const constant int *slice_sizes [[buffer(5)]], \ + const constant int *axes [[buffer(6)]], \ + const constant int *idx_shapes [[buffer(7)]], \ + const constant size_t *idx_strides [[buffer(8)]], \ + const constant int& idx_ndim [[buffer(9)]], \ + IDX_ARG(IdxT) \ + uint2 index [[thread_position_in_grid]], \ + uint2 grid_dim [[threads_per_grid]]) { \ + \ + Indices idxs{ \ + {{IDX_ARR()}}, \ + idx_shapes, \ + idx_strides, \ + idx_ndim}; \ + \ + return gather_impl( \ + src, \ + out, \ + src_shape, \ + src_strides, \ + src_ndim, \ + slice_sizes, \ + axes, \ + idxs, \ + index, \ + grid_dim); \ +} + +#define make_gather(n) make_gather_impl(IDX_ARG_ ##n, IDX_ARR_ ##n) + +make_gather(0) +make_gather(1) +make_gather(2) +make_gather(3) +make_gather(4) +make_gather(5) +make_gather(6) +make_gather(7) +make_gather(8) +make_gather(9) +make_gather(10) + +///////////////////////////////////////////////////////////////////// +// Gather instantiations +///////////////////////////////////////////////////////////////////// + +#define instantiate_gather6(name, src_t, idx_t, nidx, IDX_ARG, nd, nd_name) \ +template [[host_name("gather" name "_" #nidx "" #nd_name)]] \ +[[kernel]] void gather( \ + const device src_t *src [[buffer(0)]], \ + device src_t *out [[buffer(1)]], \ + const constant int *src_shape [[buffer(2)]], \ + const constant size_t *src_strides [[buffer(3)]], \ + const constant size_t& src_ndim [[buffer(4)]], \ + const constant int *slice_sizes [[buffer(5)]], \ + const constant int *axes [[buffer(6)]], \ + const constant int *idx_shapes [[buffer(7)]], \ + const constant size_t *idx_strides [[buffer(8)]], \ + const constant int& idx_ndim [[buffer(9)]], \ + IDX_ARG(idx_t) \ + uint2 index [[thread_position_in_grid]], \ + uint2 grid_dim [[threads_per_grid]]); + +#define instantiate_gather5(name, src_t, idx_t, nidx, nd, nd_name) \ + instantiate_gather6(name, src_t, idx_t, nidx, IDX_ARG_ ##nidx, nd, nd_name) + +#define instantiate_gather4(name, src_t, idx_t, nidx) \ + instantiate_gather5(name, src_t, idx_t, nidx, 0, _0) \ + instantiate_gather5(name, src_t, idx_t, nidx, 1, _1) \ + instantiate_gather5(name, src_t, idx_t, nidx, 2, ) + + +// Special for case NIDX=0 +instantiate_gather4("bool_", bool, bool, 0) +instantiate_gather4("uint8", uint8_t, bool, 0) +instantiate_gather4("uint16", uint16_t, bool, 0) +instantiate_gather4("uint32", uint32_t, bool, 0) +instantiate_gather4("uint64", uint64_t, bool, 0) +instantiate_gather4("int8", int8_t, bool, 0) +instantiate_gather4("int16", int16_t, bool, 0) +instantiate_gather4("int32", int32_t, bool, 0) +instantiate_gather4("int64", int64_t, bool, 0) +instantiate_gather4("float16", half, bool, 0) +instantiate_gather4("float32", float, bool, 0) +instantiate_gather4("bfloat16", bfloat16_t, bool, 0) + +#define instantiate_gather3(name, src_type, ind_type) \ + instantiate_gather4(name, src_type, ind_type, 1) \ + instantiate_gather4(name, src_type, ind_type, 2) \ + instantiate_gather4(name, src_type, ind_type, 3) \ + instantiate_gather4(name, src_type, ind_type, 4) \ + instantiate_gather4(name, src_type, ind_type, 5) \ + instantiate_gather4(name, src_type, ind_type, 6) \ + instantiate_gather4(name, src_type, ind_type, 7) \ + instantiate_gather4(name, src_type, ind_type, 8) \ + instantiate_gather4(name, src_type, ind_type, 9) \ + instantiate_gather4(name, src_type, ind_type, 10) + +#define instantiate_gather(name, src_type) \ + instantiate_gather3(#name "bool_", src_type, bool) \ + instantiate_gather3(#name "uint8", src_type, uint8_t) \ + instantiate_gather3(#name "uint16", src_type, uint16_t) \ + instantiate_gather3(#name "uint32", src_type, uint32_t) \ + instantiate_gather3(#name "uint64", src_type, uint64_t) \ + instantiate_gather3(#name "int8", src_type, int8_t) \ + instantiate_gather3(#name "int16", src_type, int16_t) \ + instantiate_gather3(#name "int32", src_type, int32_t) \ + instantiate_gather3(#name "int64", src_type, int64_t) + +instantiate_gather(bool_, bool) +instantiate_gather(uint8, uint8_t) +instantiate_gather(uint16, uint16_t) +instantiate_gather(uint32, uint32_t) +instantiate_gather(uint64, uint64_t) +instantiate_gather(int8, int8_t) +instantiate_gather(int16, int16_t) +instantiate_gather(int32, int32_t) +instantiate_gather(int64, int64_t) +instantiate_gather(float16, half) +instantiate_gather(float32, float) +instantiate_gather(bfloat16, bfloat16_t) \ No newline at end of file diff --git a/mlx/backend/metal/kernels/indexing.h b/mlx/backend/metal/kernels/indexing.h new file mode 100644 index 000000000..c2b37f3ff --- /dev/null +++ b/mlx/backend/metal/kernels/indexing.h @@ -0,0 +1,54 @@ +// Copyright © 2023-2024 Apple Inc. + +#include + +using namespace metal; + +///////////////////////////////////////////////////////////////////// +// Indexing utils +///////////////////////////////////////////////////////////////////// + +template +struct Indices { + const array buffers; + const constant int* shapes; + const constant size_t* strides; + const int ndim; +}; + +template +METAL_FUNC size_t offset_neg_idx(IdxT idx, size_t size) { + if (is_unsigned_v) { + return idx; + } else { + return (idx < 0) ? idx + size : idx; + } +} + +#define IDX_ARG_N(idx_t, n) const device idx_t *idx##n [[buffer(n)]], + +#define IDX_ARG_0(idx_t) +#define IDX_ARG_1(idx_t) IDX_ARG_0(idx_t) IDX_ARG_N(idx_t, 21) +#define IDX_ARG_2(idx_t) IDX_ARG_1(idx_t) IDX_ARG_N(idx_t, 22) +#define IDX_ARG_3(idx_t) IDX_ARG_2(idx_t) IDX_ARG_N(idx_t, 23) +#define IDX_ARG_4(idx_t) IDX_ARG_3(idx_t) IDX_ARG_N(idx_t, 24) +#define IDX_ARG_5(idx_t) IDX_ARG_4(idx_t) IDX_ARG_N(idx_t, 25) +#define IDX_ARG_6(idx_t) IDX_ARG_5(idx_t) IDX_ARG_N(idx_t, 26) +#define IDX_ARG_7(idx_t) IDX_ARG_6(idx_t) IDX_ARG_N(idx_t, 27) +#define IDX_ARG_8(idx_t) IDX_ARG_7(idx_t) IDX_ARG_N(idx_t, 28) +#define IDX_ARG_9(idx_t) IDX_ARG_8(idx_t) IDX_ARG_N(idx_t, 29) +#define IDX_ARG_10(idx_t) IDX_ARG_9(idx_t) IDX_ARG_N(idx_t, 30) + +#define IDX_ARR_N(n) idx##n, + +#define IDX_ARR_0() +#define IDX_ARR_1() IDX_ARR_0() IDX_ARR_N(21) +#define IDX_ARR_2() IDX_ARR_1() IDX_ARR_N(22) +#define IDX_ARR_3() IDX_ARR_2() IDX_ARR_N(23) +#define IDX_ARR_4() IDX_ARR_3() IDX_ARR_N(24) +#define IDX_ARR_5() IDX_ARR_4() IDX_ARR_N(25) +#define IDX_ARR_6() IDX_ARR_5() IDX_ARR_N(26) +#define IDX_ARR_7() IDX_ARR_6() IDX_ARR_N(27) +#define IDX_ARR_8() IDX_ARR_7() IDX_ARR_N(28) +#define IDX_ARR_9() IDX_ARR_8() IDX_ARR_N(29) +#define IDX_ARR_10() IDX_ARR_9() IDX_ARR_N(30) \ No newline at end of file diff --git a/mlx/backend/metal/kernels/indexing.metal b/mlx/backend/metal/kernels/indexing.metal deleted file mode 100644 index 7b6e2399a..000000000 --- a/mlx/backend/metal/kernels/indexing.metal +++ /dev/null @@ -1,290 +0,0 @@ -// Copyright © 2023-2024 Apple Inc. - -#include -#include - -#include "mlx/backend/metal/kernels/bf16.h" -#include "mlx/backend/metal/kernels/reduce.h" -#include "mlx/backend/metal/kernels/utils.h" - -using namespace metal; - -///////////////////////////////////////////////////////////////////// -// Gather kernel -///////////////////////////////////////////////////////////////////// - -template -struct Indices { - const array buffers [[id(0)]]; - device int* shapes [[id(NIDX + 1)]]; - device size_t* strides [[id(NIDX + 2)]]; - const int ndim [[id(NIDX + 3)]]; -}; - -template -inline size_t offset_neg_idx(IdxT idx, size_t size) { - return (idx < 0) ? idx + size : idx; -} - -template <> -inline size_t offset_neg_idx(bool idx, size_t) { - return idx; -} - -template <> -inline size_t offset_neg_idx(uint32_t idx, size_t) { - return idx; -} - -// IDX_NDIM is the number of dimensions of the indices arrays. Compile-time -// special case for 0 and 1. Anything >= 2 uses the general case -template -[[kernel]] void gather( - const device T *src [[buffer(0)]], - const constant Indices& indices [[buffer(1)]], - device T *out [[buffer(2)]], - const constant int *src_shape [[buffer(3)]], - const constant size_t *src_strides [[buffer(4)]], - const constant size_t& src_ndim [[buffer(5)]], - const constant int *slice_sizes [[buffer(6)]], - const constant int *axes [[buffer(7)]], - uint2 index [[thread_position_in_grid]], - uint2 grid_dim [[threads_per_grid]]) { - - auto ind_idx = index.x; - auto ind_offset = index.y; - - size_t src_idx = 0; - for (int i = 0; i < NIDX; ++i) { - size_t idx_loc; - if (IDX_NDIM == 0) { - idx_loc = 0; - } else if (IDX_NDIM == 1) { - idx_loc = ind_idx * indices.strides[indices.ndim * i]; - } else { - idx_loc = elem_to_loc( - ind_idx, - &indices.shapes[indices.ndim * i], - &indices.strides[indices.ndim * i], - indices.ndim); - } - auto ax = axes[i]; - auto idx_val = offset_neg_idx( - indices.buffers[i][idx_loc], src_shape[ax]); - src_idx += idx_val * src_strides[ax]; - } - - auto src_offset = elem_to_loc( - ind_offset, slice_sizes, src_strides, src_ndim); - - size_t out_idx = index.y + static_cast(grid_dim.y) * index.x; - out[out_idx] = src[src_offset + src_idx]; -} - -#define instantiate_gather4(name, src_type, ind_type, nindex) \ -template [[host_name("gather" name "_" #nindex "_0")]] \ -[[kernel]] void gather( \ - const device src_type *src [[buffer(0)]], \ - const constant Indices& indices [[buffer(1)]], \ - device src_type *out [[buffer(2)]], \ - const constant int *src_shape [[buffer(3)]], \ - const constant size_t *src_strides [[buffer(4)]], \ - const constant size_t& src_ndim [[buffer(5)]], \ - const constant int *slice_sizes [[buffer(6)]], \ - const constant int* axes [[buffer(7)]], \ - uint2 index [[thread_position_in_grid]], \ - uint2 grid_dim [[threads_per_grid]]); \ -template [[host_name("gather" name "_" #nindex "_1")]] \ -[[kernel]] void gather( \ - const device src_type *src [[buffer(0)]], \ - const constant Indices& indices [[buffer(1)]], \ - device src_type *out [[buffer(2)]], \ - const constant int *src_shape [[buffer(3)]], \ - const constant size_t *src_strides [[buffer(4)]], \ - const constant size_t& src_ndim [[buffer(5)]], \ - const constant int *slice_sizes [[buffer(6)]], \ - const constant int* axes [[buffer(7)]], \ - uint2 index [[thread_position_in_grid]], \ - uint2 grid_dim [[threads_per_grid]]); \ -template [[host_name("gather" name "_" #nindex)]] \ -[[kernel]] void gather( \ - const device src_type *src [[buffer(0)]], \ - const constant Indices& indices [[buffer(1)]], \ - device src_type *out [[buffer(2)]], \ - const constant int *src_shape [[buffer(3)]], \ - const constant size_t *src_strides [[buffer(4)]], \ - const constant size_t& src_ndim [[buffer(5)]], \ - const constant int *slice_sizes [[buffer(6)]], \ - const constant int* axes [[buffer(7)]], \ - uint2 index [[thread_position_in_grid]], \ - uint2 grid_dim [[threads_per_grid]]); - - -// Special for case NIDX=0 -instantiate_gather4("bool_", bool, bool, 0) -instantiate_gather4("uint8", uint8_t, bool, 0) -instantiate_gather4("uint16", uint16_t, bool, 0) -instantiate_gather4("uint32", uint32_t, bool, 0) -instantiate_gather4("uint64", uint64_t, bool, 0) -instantiate_gather4("int8", int8_t, bool, 0) -instantiate_gather4("int16", int16_t, bool, 0) -instantiate_gather4("int32", int32_t, bool, 0) -instantiate_gather4("int64", int64_t, bool, 0) -instantiate_gather4("float16", half, bool, 0) -instantiate_gather4("float32", float, bool, 0) -instantiate_gather4("bfloat16", bfloat16_t, bool, 0) - -#define instantiate_gather3(name, src_type, ind_type) \ - instantiate_gather4(name, src_type, ind_type, 1) \ - instantiate_gather4(name, src_type, ind_type, 2) \ - instantiate_gather4(name, src_type, ind_type, 3) \ - instantiate_gather4(name, src_type, ind_type, 4) \ - instantiate_gather4(name, src_type, ind_type, 5) \ - instantiate_gather4(name, src_type, ind_type, 6) \ - instantiate_gather4(name, src_type, ind_type, 7) \ - instantiate_gather4(name, src_type, ind_type, 8) \ - instantiate_gather4(name, src_type, ind_type, 9) \ - instantiate_gather4(name, src_type, ind_type, 10) - -#define instantiate_gather(name, src_type) \ - instantiate_gather3(#name "bool_", src_type, bool) \ - instantiate_gather3(#name "uint8", src_type, uint8_t) \ - instantiate_gather3(#name "uint16", src_type, uint16_t) \ - instantiate_gather3(#name "uint32", src_type, uint32_t) \ - instantiate_gather3(#name "uint64", src_type, uint64_t) \ - instantiate_gather3(#name "int8", src_type, int8_t) \ - instantiate_gather3(#name "int16", src_type, int16_t) \ - instantiate_gather3(#name "int32", src_type, int32_t) \ - instantiate_gather3(#name "int64", src_type, int64_t) - -instantiate_gather(bool_, bool) -instantiate_gather(uint8, uint8_t) -instantiate_gather(uint16, uint16_t) -instantiate_gather(uint32, uint32_t) -instantiate_gather(uint64, uint64_t) -instantiate_gather(int8, int8_t) -instantiate_gather(int16, int16_t) -instantiate_gather(int32, int32_t) -instantiate_gather(int64, int64_t) -instantiate_gather(float16, half) -instantiate_gather(float32, float) -instantiate_gather(bfloat16, bfloat16_t) - -///////////////////////////////////////////////////////////////////// -// Scatter kernel -///////////////////////////////////////////////////////////////////// - -template -[[kernel]] void scatter( - const device Indices& indices [[buffer(0)]], - const device T *updates [[buffer(1)]], - device mlx_atomic *out [[buffer(2)]], - const device int *upd_shape [[buffer(3)]], - const device size_t *upd_strides [[buffer(4)]], - const device size_t& upd_ndim [[buffer(5)]], - const device size_t& upd_size [[buffer(6)]], - const device int *out_shape [[buffer(7)]], - const device size_t *out_strides [[buffer(8)]], - const device size_t& out_ndim [[buffer(9)]], - const device int* axes [[buffer(10)]], - uint2 gid [[thread_position_in_grid]]) { - - Op op; - auto ind_idx = gid.y; - auto ind_offset = gid.x; - - size_t out_idx = 0; - for (int i = 0; i < NIDX; ++i) { - auto idx_loc = elem_to_loc( - ind_idx, - &indices.shapes[indices.ndim * i], - &indices.strides[indices.ndim * i], - indices.ndim); - auto ax = axes[i]; - auto idx_val = offset_neg_idx( - indices.buffers[i][idx_loc], out_shape[ax]); - out_idx += idx_val * out_strides[ax]; - } - - auto out_offset = elem_to_loc( - ind_offset, upd_shape + indices.ndim, out_strides, out_ndim); - auto upd_idx = elem_to_loc(gid.y * upd_size + gid.x, upd_shape, upd_strides, upd_ndim); - op.atomic_update(out, updates[upd_idx], out_idx + out_offset); -} - -#define instantiate_scatter4(name, type, ind_type, op_type, nindex) \ -template [[host_name("scatter" name "_" #nindex)]] \ -[[kernel]] void scatter( \ - const device Indices& indices [[buffer(0)]], \ - const device type *updates [[buffer(1)]], \ - device mlx_atomic *out [[buffer(2)]], \ - const device int *upd_shape [[buffer(3)]], \ - const device size_t *upd_strides [[buffer(4)]], \ - const device size_t& upd_ndim [[buffer(5)]], \ - const device size_t& upd_size [[buffer(6)]], \ - const device int *out_shape [[buffer(7)]], \ - const device size_t *out_strides [[buffer(8)]], \ - const device size_t& out_ndim [[buffer(9)]], \ - const device int* axes [[buffer(10)]], \ - uint2 gid [[thread_position_in_grid]]); - -// Special case NINDEX=0 -#define instantiate_scatter_nd0(name, type) \ - instantiate_scatter4(#name "none", type, bool, None, 0) \ - instantiate_scatter4(#name "_sum", type, bool, Sum, 0) \ - instantiate_scatter4(#name "_prod", type, bool, Prod, 0) \ - instantiate_scatter4(#name "_max", type, bool, Max, 0) \ - instantiate_scatter4(#name "_min", type, bool, Min, 0) - -#define instantiate_scatter3(name, type, ind_type, op_type) \ - instantiate_scatter4(name, type, ind_type, op_type, 1) \ - instantiate_scatter4(name, type, ind_type, op_type, 2) \ - instantiate_scatter4(name, type, ind_type, op_type, 3) \ - instantiate_scatter4(name, type, ind_type, op_type, 4) \ - instantiate_scatter4(name, type, ind_type, op_type, 5) \ - instantiate_scatter4(name, type, ind_type, op_type, 6) \ - instantiate_scatter4(name, type, ind_type, op_type, 7) \ - instantiate_scatter4(name, type, ind_type, op_type, 8) \ - instantiate_scatter4(name, type, ind_type, op_type, 9) \ - instantiate_scatter4(name, type, ind_type, op_type, 10) - -#define instantiate_scatter2(name, type, ind_type) \ - instantiate_scatter3(name "_none", type, ind_type, None) \ - instantiate_scatter3(name "_sum", type, ind_type, Sum) \ - instantiate_scatter3(name "_prod", type, ind_type, Prod) \ - instantiate_scatter3(name "_max", type, ind_type, Max) \ - instantiate_scatter3(name "_min", type, ind_type, Min) - -#define instantiate_scatter(name, type) \ - instantiate_scatter2(#name "bool_", type, bool) \ - instantiate_scatter2(#name "uint8", type, uint8_t) \ - instantiate_scatter2(#name "uint16", type, uint16_t) \ - instantiate_scatter2(#name "uint32", type, uint32_t) \ - instantiate_scatter2(#name "uint64", type, uint64_t) \ - instantiate_scatter2(#name "int8", type, int8_t) \ - instantiate_scatter2(#name "int16", type, int16_t) \ - instantiate_scatter2(#name "int32", type, int32_t) \ - instantiate_scatter2(#name "int64", type, int64_t) - -// TODO uint64 and int64 unsupported -instantiate_scatter_nd0(bool_, bool) -instantiate_scatter_nd0(uint8, uint8_t) -instantiate_scatter_nd0(uint16, uint16_t) -instantiate_scatter_nd0(uint32, uint32_t) -instantiate_scatter_nd0(int8, int8_t) -instantiate_scatter_nd0(int16, int16_t) -instantiate_scatter_nd0(int32, int32_t) -instantiate_scatter_nd0(float16, half) -instantiate_scatter_nd0(float32, float) -instantiate_scatter_nd0(bfloat16, bfloat16_t) - -instantiate_scatter(bool_, bool) -instantiate_scatter(uint8, uint8_t) -instantiate_scatter(uint16, uint16_t) -instantiate_scatter(uint32, uint32_t) -instantiate_scatter(int8, int8_t) -instantiate_scatter(int16, int16_t) -instantiate_scatter(int32, int32_t) -instantiate_scatter(float16, half) -instantiate_scatter(float32, float) -instantiate_scatter(bfloat16, bfloat16_t) diff --git a/mlx/backend/metal/kernels/scatter.metal b/mlx/backend/metal/kernels/scatter.metal new file mode 100644 index 000000000..7a94be7da --- /dev/null +++ b/mlx/backend/metal/kernels/scatter.metal @@ -0,0 +1,194 @@ +// Copyright © 2023-2024 Apple Inc. + +#include + +#include "mlx/backend/metal/kernels/bf16.h" +#include "mlx/backend/metal/kernels/indexing.h" +#include "mlx/backend/metal/kernels/reduce.h" +#include "mlx/backend/metal/kernels/utils.h" + +using namespace metal; + +///////////////////////////////////////////////////////////////////// +// Scatter kernel +///////////////////////////////////////////////////////////////////// + + +template +METAL_FUNC void scatter_impl( + const device T *updates [[buffer(1)]], + device mlx_atomic *out [[buffer(2)]], + const constant int *upd_shape [[buffer(3)]], + const constant size_t *upd_strides [[buffer(4)]], + const constant size_t& upd_ndim [[buffer(5)]], + const constant size_t& upd_size [[buffer(6)]], + const constant int *out_shape [[buffer(7)]], + const constant size_t *out_strides [[buffer(8)]], + const constant size_t& out_ndim [[buffer(9)]], + const constant int* axes [[buffer(10)]], + const thread Indices& indices, + uint2 gid [[thread_position_in_grid]]) { + + Op op; + auto ind_idx = gid.y; + auto ind_offset = gid.x; + + size_t out_idx = 0; + for (int i = 0; i < NIDX; ++i) { + auto idx_loc = elem_to_loc( + ind_idx, + &indices.shapes[indices.ndim * i], + &indices.strides[indices.ndim * i], + indices.ndim); + auto ax = axes[i]; + auto idx_val = offset_neg_idx( + indices.buffers[i][idx_loc], out_shape[ax]); + out_idx += idx_val * out_strides[ax]; + } + + auto out_offset = elem_to_loc( + ind_offset, upd_shape + indices.ndim, out_strides, out_ndim); + auto upd_idx = elem_to_loc(gid.y * upd_size + gid.x, upd_shape, upd_strides, upd_ndim); + op.atomic_update(out, updates[upd_idx], out_idx + out_offset); +} + +#define make_scatter_impl(IDX_ARG, IDX_ARR) \ +template \ +[[kernel]] void scatter( \ + const device T *updates [[buffer(1)]], \ + device mlx_atomic *out [[buffer(2)]], \ + const constant int *upd_shape [[buffer(3)]], \ + const constant size_t *upd_strides [[buffer(4)]], \ + const constant size_t& upd_ndim [[buffer(5)]], \ + const constant size_t& upd_size [[buffer(6)]], \ + const constant int *out_shape [[buffer(7)]], \ + const constant size_t *out_strides [[buffer(8)]], \ + const constant size_t& out_ndim [[buffer(9)]], \ + const constant int* axes [[buffer(10)]], \ + const constant int *idx_shapes [[buffer(11)]], \ + const constant size_t *idx_strides [[buffer(12)]], \ + const constant int& idx_ndim [[buffer(13)]], \ + IDX_ARG(IdxT) \ + uint2 gid [[thread_position_in_grid]]) { \ + \ + Indices idxs{ \ + {{IDX_ARR()}}, \ + idx_shapes, \ + idx_strides, \ + idx_ndim}; \ + \ + return scatter_impl( \ + updates, \ + out, \ + upd_shape, \ + upd_strides, \ + upd_ndim, \ + upd_size, \ + out_shape, \ + out_strides, \ + out_ndim, \ + axes, \ + idxs, \ + gid); \ +} + +#define make_scatter(n) make_scatter_impl(IDX_ARG_ ##n, IDX_ARR_ ##n) + +make_scatter(0) +make_scatter(1) +make_scatter(2) +make_scatter(3) +make_scatter(4) +make_scatter(5) +make_scatter(6) +make_scatter(7) +make_scatter(8) +make_scatter(9) +make_scatter(10) + +///////////////////////////////////////////////////////////////////// +// Scatter instantiations +///////////////////////////////////////////////////////////////////// + +#define instantiate_scatter5(name, src_t, idx_t, op_t, nidx, IDX_ARG) \ +template [[host_name("scatter" name "_" #nidx)]] \ +[[kernel]] void scatter( \ + const device src_t *updates [[buffer(1)]], \ + device mlx_atomic *out [[buffer(2)]], \ + const constant int *upd_shape [[buffer(3)]], \ + const constant size_t *upd_strides [[buffer(4)]], \ + const constant size_t& upd_ndim [[buffer(5)]], \ + const constant size_t& upd_size [[buffer(6)]], \ + const constant int *out_shape [[buffer(7)]], \ + const constant size_t *out_strides [[buffer(8)]], \ + const constant size_t& out_ndim [[buffer(9)]], \ + const constant int* axes [[buffer(10)]], \ + const constant int *idx_shapes [[buffer(11)]], \ + const constant size_t *idx_strides [[buffer(12)]], \ + const constant int& idx_ndim [[buffer(13)]], \ + IDX_ARG(idx_t) \ + uint2 gid [[thread_position_in_grid]]); + +#define instantiate_scatter4(name, src_t, idx_t, op_t, nidx) \ + instantiate_scatter5(name, src_t, idx_t, op_t, nidx, IDX_ARG_ ##nidx) + +// Special case NINDEX=0 +#define instantiate_scatter_nd0(name, type) \ + instantiate_scatter4(#name "none", type, bool, None, 0) \ + instantiate_scatter4(#name "_sum", type, bool, Sum, 0) \ + instantiate_scatter4(#name "_prod", type, bool, Prod, 0) \ + instantiate_scatter4(#name "_max", type, bool, Max, 0) \ + instantiate_scatter4(#name "_min", type, bool, Min, 0) + +#define instantiate_scatter3(name, type, ind_type, op_type) \ + instantiate_scatter4(name, type, ind_type, op_type, 1) \ + instantiate_scatter4(name, type, ind_type, op_type, 2) \ + instantiate_scatter4(name, type, ind_type, op_type, 3) \ + instantiate_scatter4(name, type, ind_type, op_type, 4) \ + instantiate_scatter4(name, type, ind_type, op_type, 5) \ + instantiate_scatter4(name, type, ind_type, op_type, 6) \ + instantiate_scatter4(name, type, ind_type, op_type, 7) \ + instantiate_scatter4(name, type, ind_type, op_type, 8) \ + instantiate_scatter4(name, type, ind_type, op_type, 9) \ + instantiate_scatter4(name, type, ind_type, op_type, 10) + +#define instantiate_scatter2(name, type, ind_type) \ + instantiate_scatter3(name "_none", type, ind_type, None) \ + instantiate_scatter3(name "_sum", type, ind_type, Sum) \ + instantiate_scatter3(name "_prod", type, ind_type, Prod) \ + instantiate_scatter3(name "_max", type, ind_type, Max) \ + instantiate_scatter3(name "_min", type, ind_type, Min) + +#define instantiate_scatter(name, type) \ + instantiate_scatter2(#name "bool_", type, bool) \ + instantiate_scatter2(#name "uint8", type, uint8_t) \ + instantiate_scatter2(#name "uint16", type, uint16_t) \ + instantiate_scatter2(#name "uint32", type, uint32_t) \ + instantiate_scatter2(#name "uint64", type, uint64_t) \ + instantiate_scatter2(#name "int8", type, int8_t) \ + instantiate_scatter2(#name "int16", type, int16_t) \ + instantiate_scatter2(#name "int32", type, int32_t) \ + instantiate_scatter2(#name "int64", type, int64_t) + +// TODO uint64 and int64 unsupported +instantiate_scatter_nd0(bool_, bool) +instantiate_scatter_nd0(uint8, uint8_t) +instantiate_scatter_nd0(uint16, uint16_t) +instantiate_scatter_nd0(uint32, uint32_t) +instantiate_scatter_nd0(int8, int8_t) +instantiate_scatter_nd0(int16, int16_t) +instantiate_scatter_nd0(int32, int32_t) +instantiate_scatter_nd0(float16, half) +instantiate_scatter_nd0(float32, float) +instantiate_scatter_nd0(bfloat16, bfloat16_t) + +instantiate_scatter(bool_, bool) +instantiate_scatter(uint8, uint8_t) +instantiate_scatter(uint16, uint16_t) +instantiate_scatter(uint32, uint32_t) +instantiate_scatter(int8, int8_t) +instantiate_scatter(int16, int16_t) +instantiate_scatter(int32, int32_t) +instantiate_scatter(float16, half) +instantiate_scatter(float32, float) +instantiate_scatter(bfloat16, bfloat16_t) diff --git a/mlx/backend/metal/utils.h b/mlx/backend/metal/utils.h index f7c672c9f..363632a30 100644 --- a/mlx/backend/metal/utils.h +++ b/mlx/backend/metal/utils.h @@ -9,20 +9,6 @@ namespace mlx::core { namespace { -void set_array_buffer( - MTL::ComputeCommandEncoder* compute_encoder, - MTL::ArgumentEncoder* enc, - const array& a, - int idx) { - auto a_buf = static_cast(a.buffer().ptr()); - auto offset = a.data() - - static_cast(const_cast(a_buf)->contents()); - enc->setBuffer(a_buf, offset, idx); - // MTL::Resource usage through argument buffer needs to be explicitly - // flagged to enable hazard tracking - compute_encoder->useResource(a_buf, MTL::ResourceUsageRead); -} - void set_array_buffer( MTL::ComputeCommandEncoder* enc, const array& a, From ccf16459955fd8bb382df94a7f0b2cd379c5431b Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Wed, 14 Feb 2024 14:04:25 -0800 Subject: [PATCH 28/42] Custom primitive + RoPE fat op (#676) * extensions start * rope custom op * fix build * docs + rope benchmark * fix test * Add a Metal kernel for RoPE * Fix position of traditional * transform tests * Move rope computation to float and fix tests * Fix the test and a typo * change to fast * fix no metal build --------- Co-authored-by: Angelos Katharopoulos --- benchmarks/python/rope_bench.py | 35 +++++ mlx/CMakeLists.txt | 3 +- mlx/backend/common/CMakeLists.txt | 1 + mlx/backend/common/rope.cpp | 14 ++ mlx/backend/metal/CMakeLists.txt | 1 + mlx/backend/metal/kernels/CMakeLists.txt | 1 + mlx/backend/metal/kernels/rope.metal | 68 +++++++++ mlx/backend/metal/rope.cpp | 55 +++++++ mlx/backend/no_metal/primitives.cpp | 5 + mlx/fast.cpp | 128 ++++++++++++++++ mlx/fast.h | 82 ++++++++++ mlx/mlx.h | 1 + python/mlx/nn/layers/positional_encoding.py | 78 ++-------- python/src/CMakeLists.txt | 1 + python/src/fast.cpp | 59 ++++++++ python/src/mlx.cpp | 2 + python/src/random.cpp | 2 +- python/tests/test_fast.py | 158 ++++++++++++++++++++ 18 files changed, 624 insertions(+), 70 deletions(-) create mode 100644 benchmarks/python/rope_bench.py create mode 100644 mlx/backend/common/rope.cpp create mode 100644 mlx/backend/metal/kernels/rope.metal create mode 100644 mlx/backend/metal/rope.cpp create mode 100644 mlx/fast.cpp create mode 100644 mlx/fast.h create mode 100644 python/src/fast.cpp create mode 100644 python/tests/test_fast.py diff --git a/benchmarks/python/rope_bench.py b/benchmarks/python/rope_bench.py new file mode 100644 index 000000000..62f01648e --- /dev/null +++ b/benchmarks/python/rope_bench.py @@ -0,0 +1,35 @@ +# Copyright © 2023-2024 Apple Inc. + +import mlx.core as mx +import mlx.nn as nn +from time_utils import time_fn + + +def time_rope(): + rope = nn.RoPE(4096) + + # vec + x = mx.random.uniform(shape=(1, 4096)).astype(mx.float16) + mx.eval(x) + + def rope_vec(x): + for _ in range(32): + x = rope(x) + return x + + time_fn(rope_vec, x) + + # matrix + x = mx.random.uniform(shape=(1024, 4096)).astype(mx.float16) + mx.eval(x) + + def rope_mat(x): + for _ in range(32): + x = rope(x) + return x + + time_fn(rope_mat, x) + + +if __name__ == "__main__": + time_rope() diff --git a/mlx/CMakeLists.txt b/mlx/CMakeLists.txt index cb65e518b..d2f021af5 100644 --- a/mlx/CMakeLists.txt +++ b/mlx/CMakeLists.txt @@ -3,9 +3,10 @@ target_sources( PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp ${CMAKE_CURRENT_SOURCE_DIR}/array.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp ${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp ${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp ${CMAKE_CURRENT_SOURCE_DIR}/graph_utils.cpp diff --git a/mlx/backend/common/CMakeLists.txt b/mlx/backend/common/CMakeLists.txt index 0263dff9b..b25001f2c 100644 --- a/mlx/backend/common/CMakeLists.txt +++ b/mlx/backend/common/CMakeLists.txt @@ -11,6 +11,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp ${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp ${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/rope.cpp ${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp ${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp ${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp diff --git a/mlx/backend/common/rope.cpp b/mlx/backend/common/rope.cpp new file mode 100644 index 000000000..c0c2bba8e --- /dev/null +++ b/mlx/backend/common/rope.cpp @@ -0,0 +1,14 @@ +// Copyright © 2023-2024 Apple Inc. + +#include "mlx/fast.h" +#include "mlx/primitives.h" + +namespace mlx::core::fast { + +void RoPE::eval_cpu( + const std::vector& inputs, + std::vector& outputs) { + throw std::runtime_error("NYI"); +} + +} // namespace mlx::core::fast diff --git a/mlx/backend/metal/CMakeLists.txt b/mlx/backend/metal/CMakeLists.txt index 93a25434f..063c283fe 100644 --- a/mlx/backend/metal/CMakeLists.txt +++ b/mlx/backend/metal/CMakeLists.txt @@ -32,6 +32,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp ${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/rope.cpp ${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp ${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp ${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp diff --git a/mlx/backend/metal/kernels/CMakeLists.txt b/mlx/backend/metal/kernels/CMakeLists.txt index 12e09deaa..afd2fbc8a 100644 --- a/mlx/backend/metal/kernels/CMakeLists.txt +++ b/mlx/backend/metal/kernels/CMakeLists.txt @@ -23,6 +23,7 @@ set( "quantized" "random" "reduce" + "rope" "scan" "softmax" "sort" diff --git a/mlx/backend/metal/kernels/rope.metal b/mlx/backend/metal/kernels/rope.metal new file mode 100644 index 000000000..484697b6d --- /dev/null +++ b/mlx/backend/metal/kernels/rope.metal @@ -0,0 +1,68 @@ +// Copyright © 2023-2024 Apple Inc. + +#include + +#include "mlx/backend/metal/kernels/bf16.h" +#include "mlx/backend/metal/kernels/utils.h" + +template +[[kernel]] void rope( + const device T *in [[buffer(0)]], + device T * out [[buffer(1)]], + constant const size_t strides[3], + constant const int& offset, + constant const float& base, + constant const float& scale, + uint3 pos [[thread_position_in_grid]], + uint3 grid [[threads_per_grid]]) { + // Compute the input and output indices + uint in_index_1, in_index_2; + uint out_index_1, out_index_2; + if (traditional) { + out_index_1 = 2 * (pos.x + grid.x * (pos.y + grid.y * pos.z)); + out_index_2 = out_index_1 + 1; + in_index_1 = 2 * pos.x * strides[2] + pos.y * strides[1] + pos.z * strides[0]; + in_index_2 = in_index_1 + strides[2]; + } else { + out_index_1 = pos.x + 2*(grid.x * (pos.y + grid.y * pos.z)); + out_index_2 = out_index_1 + grid.x; + in_index_1 = pos.x * strides[2] + pos.y * strides[1] + pos.z * strides[0]; + in_index_2 = in_index_1 + grid.x * strides[2]; + } + + // Figure out L and d. + float L = scale * static_cast(pos.y + offset); + float d = static_cast(pos.x) / static_cast(grid.x); + + // Compute costheta, sintheta + float theta = L * metal::exp2(-d * base); + float costheta = metal::fast::cos(theta); + float sintheta = metal::fast::sin(theta); + + // Read and write the output + float x1 = static_cast(in[in_index_1]); + float x2 = static_cast(in[in_index_2]); + float rx1 = x1 * costheta - x2 * sintheta; + float rx2 = x1 * sintheta + x2 * costheta; + out[out_index_1] = static_cast(rx1); + out[out_index_2] = static_cast(rx2); +} + +#define instantiate_rope(name, type, traditional) \ + template [[host_name("rope_" #name)]] \ + [[kernel]] void rope( \ + const device type* in [[buffer(0)]], \ + device type* out [[buffer(1)]], \ + constant const size_t strides[3], \ + constant const int& offset, \ + constant const float& base, \ + constant const float& scale, \ + uint3 pos [[thread_position_in_grid]], \ + uint3 grid [[threads_per_grid]]); + +instantiate_rope(traditional_float16, half, true) +instantiate_rope(traditional_bfloat16, bfloat16_t, true) +instantiate_rope(traditional_float32, float, true) +instantiate_rope(float16, half, false) +instantiate_rope(bfloat16, bfloat16_t, false) +instantiate_rope(float32, float, false) diff --git a/mlx/backend/metal/rope.cpp b/mlx/backend/metal/rope.cpp new file mode 100644 index 000000000..29295f3ac --- /dev/null +++ b/mlx/backend/metal/rope.cpp @@ -0,0 +1,55 @@ +// Copyright © 2023-2024 Apple Inc. + +#include "mlx/backend/metal/utils.h" +#include "mlx/fast.h" +#include "mlx/primitives.h" + +namespace mlx::core::fast { + +void RoPE::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + assert(inputs.size() == 1); + assert(outputs.size() == 1); + auto& in = inputs[0]; + auto& out = outputs[0]; + + if (in.ndim() != 3) { + throw std::runtime_error( + "[RoPE] Only 3 dimensions are supported (batch x sequence x dims)"); + } + if (dims_ != in.shape(-1)) { + throw std::runtime_error("[RoPE] Partial RoPE application not supported"); + } + if (in.flags().row_contiguous && in.is_donatable()) { + out.move_shared_buffer(in); + } else { + out.set_data(allocator::malloc_or_wait(out.nbytes())); + } + + auto& s = out.primitive().stream(); + auto& d = metal::device(s.device); + std::ostringstream kname; + kname << "rope_" << (traditional_ ? "traditional_" : "") << type_to_name(in); + auto kernel = d.get_kernel(kname.str()); + auto compute_encoder = d.get_command_encoder(s.index); + + bool donated = in.data_shared_ptr() == nullptr; + float base = std::log2(base_); + compute_encoder->setComputePipelineState(kernel); + set_array_buffer(compute_encoder, donated ? out : in, 0); + set_array_buffer(compute_encoder, out, 1); + compute_encoder->setBytes(in.strides().data(), 3 * sizeof(size_t), 2); + compute_encoder->setBytes(&offset_, sizeof(int), 3); + compute_encoder->setBytes(&base, sizeof(float), 4); + compute_encoder->setBytes(&scale_, sizeof(float), 5); + + int dim0 = in.shape(2) / 2; + int dim1 = in.shape(1); + int dim2 = in.shape(0); + auto group_dims = get_block_dims(dim0, dim1, dim2); + auto grid_dims = MTL::Size(dim0, dim1, dim2); + compute_encoder->dispatchThreads(grid_dims, group_dims); +} + +} // namespace mlx::core::fast diff --git a/mlx/backend/no_metal/primitives.cpp b/mlx/backend/no_metal/primitives.cpp index dd4edc2ed..bd4026e2c 100644 --- a/mlx/backend/no_metal/primitives.cpp +++ b/mlx/backend/no_metal/primitives.cpp @@ -1,6 +1,7 @@ // Copyright © 2023-2024 Apple Inc. #include "mlx/primitives.h" +#include "mlx/fast.h" #define NO_GPU_MULTI(func) \ void func::eval_gpu( \ @@ -95,4 +96,8 @@ NO_GPU(Tan) NO_GPU(Tanh) NO_GPU(Transpose) +namespace fast { +NO_GPU_MULTI(RoPE) +} // namespace fast + } // namespace mlx::core diff --git a/mlx/fast.cpp b/mlx/fast.cpp new file mode 100644 index 000000000..96d4f03ce --- /dev/null +++ b/mlx/fast.cpp @@ -0,0 +1,128 @@ +// Copyright © 2023-2024 Apple Inc. + +#include "mlx/fast.h" +#include "mlx/transforms.h" + +namespace mlx::core::fast { + +std::vector Custom::vjp( + const std::vector& primals, + const std::vector& cotangents, + const std::vector& argnums, + const std::vector& outputs) { + auto [_, vjps] = mlx::core::vjp(fallback_, primals, cotangents); + std::vector vjp_outs; + for (int i = 0, j = 0; i < vjps.size(); ++i) { + if (i < argnums.size() && i == argnums[j]) { + vjp_outs.push_back(vjps[i]); + j++; + } + } + return vjp_outs; +} + +std::vector Custom::jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) { + auto [_, jvps] = mlx::core::jvp(fallback_, primals, tangents); + std::vector jvp_outs; + for (int i = 0, j = 0; i < jvps.size(); ++i) { + if (i < argnums.size() && i == argnums[j]) { + jvp_outs.push_back(jvps[i]); + j++; + } + } + return jvp_outs; +} + +std::pair, std::vector> Custom::vmap( + const std::vector& inputs, + const std::vector& axes) { + auto outputs = mlx::core::vmap(fallback_, axes)(inputs); + auto out_axes = std::vector(outputs.size(), 0); + return {outputs, out_axes}; +} + +array rope( + const array& x, + int dims, + bool traditional, + float base, + float scale, + int offset, + StreamOrDevice s /* = {} */) { + if (x.ndim() != 3) { + std::ostringstream msg; + msg << "[rope] Input must have 3 dimensions but got input with " << x.ndim() + << " dimensions."; + throw std::invalid_argument(msg.str()); + } + if (traditional && x.shape(-1) != dims) { + throw std::invalid_argument( + "[rope] Does not support partial traditional application."); + } + + auto fallback = [dims, traditional, base, scale, offset, s]( + const std::vector& inputs) { + auto& x = inputs[0]; + auto t = x.dtype(); + auto N = x.shape(1) + offset; + // Compute sines and cosines + auto half_dims = dims / 2; + auto positions = multiply(arange(offset, N, t, s), array(scale, t), s); + auto freqs = negative(arange(0, half_dims, t, s), s); + freqs = exp(multiply(freqs, array(std::log(base) / half_dims, t), s), s); + auto theta = + multiply(expand_dims(positions, 1, s), expand_dims(freqs, 0, s), s); + auto coss = cos(theta, s); + auto sins = sin(theta, s); + + if (traditional) { + auto x1 = slice(x, {0, 0, 0}, x.shape(), {1, 1, 2}, s); + auto x2 = slice(x, {0, 0, 1}, x.shape(), {1, 1, 2}, s); + std::vector outs; + outs.push_back(subtract(multiply(x1, coss, s), multiply(x2, sins, s), s)); + outs.push_back(add(multiply(x1, sins, s), multiply(x2, coss, s), s)); + for (auto& o : outs) { + o = expand_dims(o, 3, s); + } + return std::vector{reshape(concatenate(outs, 3, s), x.shape(), s)}; + } else { + auto out_s = x.shape(); + out_s.back() = half_dims; + auto x1 = slice(x, {0, 0, 0}, out_s, s); + out_s.back() = dims; + auto x2 = slice(x, {0, 0, half_dims}, out_s, s); + + std::vector outs; + outs.push_back(subtract(multiply(x1, coss, s), multiply(x2, sins, s), s)); + outs.push_back(add(multiply(x1, sins, s), multiply(x2, coss, s), s)); + if (dims < x.shape(-1)) { + outs.push_back(slice(x, {0, 0, dims}, x.shape(), s)); + } + return std::vector{concatenate(outs, 2, s)}; + } + }; + // TODO change to condition for using custom prim + auto stream = to_stream(s); + if (stream.device == Device::gpu && x.shape(-1) == dims) { + return array( + x.shape(), + x.dtype(), + std::make_unique( + stream, fallback, dims, traditional, base, scale, offset), + {x}); + } + return fallback({x})[0]; +} + +bool RoPE::is_equivalent(const Primitive& other) const { + const RoPE& a_other = static_cast(other); + return ( + dims_ == a_other.dims_ && base_ == a_other.base_ && + scale_ == a_other.scale_ && traditional_ == a_other.traditional_ && + offset_ == a_other.offset_); +} + +} // namespace mlx::core::fast diff --git a/mlx/fast.h b/mlx/fast.h new file mode 100644 index 000000000..5deac0cdb --- /dev/null +++ b/mlx/fast.h @@ -0,0 +1,82 @@ +// Copyright © 2023-2024 Apple Inc. + +#pragma once + +#include "mlx/ops.h" +#include "mlx/primitives.h" + +namespace mlx::core::fast { + +// Custom primitive accepts a fallback function which it uses for +// transformations. Transformations are virtual so that derived classes may to +// override the default behavior +class Custom : public Primitive { + public: + explicit Custom( + Stream stream, + std::function(std::vector)> fallback) + : Primitive(stream), fallback_(fallback){}; + + virtual std::pair, std::vector> vmap( + const std::vector& inputs, + const std::vector& axes) override; + + virtual std::vector jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) override; + + virtual std::vector vjp( + const std::vector& primals, + const std::vector& cotangents, + const std::vector& argnums, + const std::vector& outputs) override; + + private: + std::function(std::vector)> fallback_; +}; + +array rope( + const array& x, + int dims, + bool traditional, + float base, + float scale, + int offset, + StreamOrDevice s /* = {} */); + +class RoPE : public Custom { + public: + RoPE( + Stream stream, + std::function(std::vector)> fallback, + int dims, + bool traditional, + float base, + float scale, + int offset) + : Custom(stream, fallback), + dims_(dims), + traditional_(traditional), + base_(base), + scale_(scale), + offset_(offset){}; + + void eval_cpu(const std::vector& inputs, std::vector& outputs) + override; + void eval_gpu(const std::vector& inputs, std::vector& outputs) + override; + + DEFINE_PRINT(RoPE) + bool is_equivalent(const Primitive& other) const override; + + private: + std::function(std::vector)> fallback_; + int dims_; + bool traditional_; + float base_; + float scale_; + int offset_; +}; + +} // namespace mlx::core::fast diff --git a/mlx/mlx.h b/mlx/mlx.h index 7b33faba7..1963a4c50 100644 --- a/mlx/mlx.h +++ b/mlx/mlx.h @@ -6,6 +6,7 @@ #include "mlx/backend/metal/metal.h" #include "mlx/compile.h" #include "mlx/device.h" +#include "mlx/fast.h" #include "mlx/fft.h" #include "mlx/io.h" #include "mlx/linalg.h" diff --git a/python/mlx/nn/layers/positional_encoding.py b/python/mlx/nn/layers/positional_encoding.py index a8024f0a4..f0bb92863 100644 --- a/python/mlx/nn/layers/positional_encoding.py +++ b/python/mlx/nn/layers/positional_encoding.py @@ -1,4 +1,4 @@ -# Copyright © 2023 Apple Inc. +# Copyright © 2023-2024 Apple Inc. import math from typing import Optional @@ -20,20 +20,13 @@ class RoPE(Module): Args: dims (int): The feature dimensions to be rotated. If the input feature is larger than dims then the rest is left unchanged. - traditional (bool, optional): If set to True choose the traditional + traditional (bool, optional): If set to ``True`` choose the traditional implementation which is slightly less efficient. Default: ``False``. base (float, optional): The base used to compute angular frequency for each dimension in the positional encodings. Default: ``10000``. scale (float, optional): The scale used to scale the positions. Default: ``1.0``. - - Attributes: - _cos_sin_theta_key (tuple): Cached key for the precomputed cosine and sine values. - _cos_sin_theta_value (tuple): Cached cosine and sine values. """ - _cos_sin_theta_key = None - _cos_sin_theta_value = None - def __init__( self, dims: int, @@ -50,69 +43,18 @@ class RoPE(Module): def _extra_repr(self): return f"{self.dims}, traditional={self.traditional}" - def _compute_rope(self, costheta, sintheta, x): - x1 = x[..., : self.dims // 2] - x2 = x[..., self.dims // 2 : self.dims] - rx1 = x1 * costheta - x2 * sintheta - rx2 = x1 * sintheta + x2 * costheta - - if self.dims < x.shape[-1]: - rx = mx.concatenate([rx1, rx2, x[..., self.dims :]], axis=-1) - else: - rx = mx.concatenate([rx1, rx2], axis=-1) - - return rx - - def _compute_traditional_rope(self, costheta, sintheta, x): - x1 = x[..., ::2] - x2 = x[..., 1::2] - rx1 = x1 * costheta - x2 * sintheta - rx2 = x1 * sintheta + x2 * costheta - - if self.dims < x.shape[-1]: - raise NotImplementedError( - "RoPE doesn't implement partial traditional application" - ) - - rx = mx.concatenate([rx1[..., None], rx2[..., None]], axis=-1) - - return rx - def __call__(self, x, offset: int = 0): shape = x.shape x = mx.reshape(x, (-1, shape[-2], shape[-1])) - N = x.shape[1] + offset - costheta, sintheta = RoPE.create_cos_sin_theta( - N, self.dims, offset=offset, base=self.base, scale=self.scale, dtype=x.dtype + x = mx.fast.rope( + x, + self.dims, + traditional=self.traditional, + base=self.base, + scale=self.scale, + offset=offset, ) - - rope = ( - self._compute_traditional_rope if self.traditional else self._compute_rope - ) - rx = rope(costheta, sintheta, x) - - return mx.reshape(rx, shape) - - @classmethod - def create_cos_sin_theta( - cls, - N: int, - D: int, - offset: int = 0, - base: float = 10000, - scale: float = 1.0, - dtype=mx.float32, - ): - if (N, D, offset, base, scale, dtype) != cls._cos_sin_theta_key: - half_D = D // 2 - positions = mx.arange(offset, N, dtype=dtype) * scale - freqs = mx.exp( - -mx.arange(0.0, half_D, dtype=dtype) * (math.log(base) / half_D) - ) - theta = mx.reshape(positions, (-1, 1)) * mx.reshape(freqs, (1, -1)) - cls._cos_sin_theta_key = (N, D, offset, base, scale, dtype) - cls._cos_sin_theta_value = (mx.cos(theta), mx.sin(theta)) - return cls._cos_sin_theta_value + return mx.reshape(x, shape) class SinusoidalPositionalEncoding(Module): diff --git a/python/src/CMakeLists.txt b/python/src/CMakeLists.txt index 1ba037fdc..7dd862033 100644 --- a/python/src/CMakeLists.txt +++ b/python/src/CMakeLists.txt @@ -3,6 +3,7 @@ pybind11_add_module( ${CMAKE_CURRENT_SOURCE_DIR}/mlx.cpp ${CMAKE_CURRENT_SOURCE_DIR}/array.cpp ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/load.cpp diff --git a/python/src/fast.cpp b/python/src/fast.cpp new file mode 100644 index 000000000..115ea37ec --- /dev/null +++ b/python/src/fast.cpp @@ -0,0 +1,59 @@ +// Copyright © 2023-2024 Apple Inc. + +#include +#include + +#include "mlx/fast.h" +#include "mlx/ops.h" +#include "python/src/utils.h" + +namespace py = pybind11; +using namespace py::literals; +using namespace mlx::core; + +void init_extensions(py::module_& parent_module) { + py::options options; + options.disable_function_signatures(); + + auto m = + parent_module.def_submodule("fast", "mlx.core.fast: fast operations"); + + m.def( + "rope", + [](const array& a, + int dims, + bool traditional, + float base, + float scale, + int offset, + const StreamOrDevice& s /* = {} */) { + return fast::rope(a, dims, traditional, base, scale, offset, s); + }, + "a"_a, + "dims"_a, + py::kw_only(), + "traditional"_a, + "base"_a, + "scale"_a, + "offset"_a, + "stream"_a = none, + R"pbdoc( + rope(a: array, dims: int, *, traditinoal: bool, base: float, scale: float, offset: int, stream: Union[None, Stream, Device] = None) -> array + + Apply rotary positional encoding to the input. + + Args: + a (array): Input array. + dims (int): The feature dimensions to be rotated. If the input feature + is larger than dims then the rest is left unchanged. + traditional (bool): If set to ``True`` choose the traditional + implementation which rotates consecutive dimensions. + base (float): The base used to compute angular frequency for + each dimension in the positional encodings. + scale (float): The scale used to scale the positions. + offset (int): The position offset to start at. + + Returns: + array: The output array. + )pbdoc"); +} diff --git a/python/src/mlx.cpp b/python/src/mlx.cpp index 81626e565..ee0f469f9 100644 --- a/python/src/mlx.cpp +++ b/python/src/mlx.cpp @@ -17,6 +17,7 @@ void init_random(py::module_&); void init_fft(py::module_&); void init_linalg(py::module_&); void init_constants(py::module_&); +void init_extensions(py::module_&); PYBIND11_MODULE(core, m) { m.doc() = "mlx: A framework for machine learning on Apple silicon."; @@ -33,5 +34,6 @@ PYBIND11_MODULE(core, m) { init_fft(m); init_linalg(m); init_constants(m); + init_extensions(m); m.attr("__version__") = TOSTRING(_VERSION_); } diff --git a/python/src/random.cpp b/python/src/random.cpp index bbcb7a2c8..442d81fee 100644 --- a/python/src/random.cpp +++ b/python/src/random.cpp @@ -133,7 +133,7 @@ void init_random(py::module_& parent_module) { low (scalar or array, optional): Lower bound of the distribution. Default is ``0``. high (scalar or array, optional): Upper bound of the distribution. Default is ``1``. shape (list(int), optional): Shape of the output. Default is ``()``. - key (array, optional): A PRNG key. Default: None. + key (array, optional): A PRNG key. Default: ``None``. dtype (Dtype, optional): Type of the output. Default is ``float32``. Returns: diff --git a/python/tests/test_fast.py b/python/tests/test_fast.py new file mode 100644 index 000000000..1cb4ddcca --- /dev/null +++ b/python/tests/test_fast.py @@ -0,0 +1,158 @@ +# Copyright © 2023-2024 Apple Inc. + +import math +import unittest + +import mlx.core as mx +import mlx_tests + + +def rope_orig(x, dims, traditional, base, scale, offset): + N = x.shape[1] + offset + dtype = x.dtype + half_D = dims // 2 + positions = mx.arange(offset, N, dtype=dtype) * scale + freqs = mx.exp(-mx.arange(0.0, half_D, dtype=dtype) * (math.log(base) / half_D)) + theta = mx.reshape(positions, (-1, 1)) * mx.reshape(freqs, (1, -1)) + costheta, sintheta = mx.cos(theta), mx.sin(theta) + if traditional: + x1 = x[..., ::2] + x2 = x[..., 1::2] + rx1 = x1 * costheta - x2 * sintheta + rx2 = x1 * sintheta + x2 * costheta + rx = mx.concatenate([rx1[..., None], rx2[..., None]], axis=-1) + return mx.reshape(rx, x.shape) + else: + x1 = x[..., : dims // 2] + x2 = x[..., dims // 2 : dims] + rx1 = x1 * costheta - x2 * sintheta + rx2 = x1 * sintheta + x2 * costheta + if dims < x.shape[-1]: + rx = mx.concatenate([rx1, rx2, x[..., dims:]], axis=-1) + else: + rx = mx.concatenate([rx1, rx2], axis=-1) + return rx + + +class TestFast(mlx_tests.MLXTestCase): + def test_rope(self): + T = 4 + + # Defaults: dims, dtype, base, scale, offset, traditional + defaults = (8, mx.float32, 10000.0, 1.0, 0, False) + + # Per dtype absolute tolerance + tolerances = {mx.float32: 1e-6, mx.float16: 1e-3, mx.bfloat16: 1e-2} + + # Test cases: + dtypes = [mx.float32, mx.float16, mx.bfloat16] + bases = [10000.0, 1000000.0] + scales = [1.0, 2.0] + offsets = [0, 3] + traditional = [True, False] + + for traditional in [True, False]: + dims, dtype, _, scale, offset, _ = defaults + for base in bases: + x = mx.random.uniform(shape=(2, T, dims)).astype(dtype) + rx = rope_orig(x, dims, traditional, base, scale, offset) + rx_fast = mx.fast.rope( + x, + dims, + traditional=traditional, + base=base, + scale=scale, + offset=offset, + ) + self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype]) + + dims, _, base, scale, offset, _ = defaults + for dtype in dtypes: + x = mx.random.uniform(shape=(2, T, dims)).astype(dtype) + ry = rope_orig( + x.astype(mx.float32), dims, traditional, base, scale, offset + ) + rx = rope_orig(x, dims, traditional, base, scale, offset) + rx_fast = mx.fast.rope( + x, + dims, + traditional=traditional, + base=base, + scale=scale, + offset=offset, + ) + if dtype != mx.float32: + self.assertLessEqual( + mx.abs(ry - rx_fast).max(), mx.abs(ry - rx).max() + ) + self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype]) + + dims, dtype, base, scale, _, _ = defaults + for offset in offsets: + x = mx.random.uniform(shape=(2, T, dims)).astype(dtype) + rx = rope_orig(x, dims, traditional, base, scale, offset) + rx_fast = mx.fast.rope( + x, + dims, + traditional=traditional, + base=base, + scale=scale, + offset=offset, + ) + self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype]) + + dims, dtype, base, _, offset, _ = defaults + for scale in scales: + x = mx.random.uniform(shape=(2, T, dims)).astype(dtype) + rx = rope_orig(x, dims, traditional, base, scale, offset) + rx_fast = mx.fast.rope( + x, + dims, + traditional=traditional, + base=base, + scale=scale, + offset=offset, + ) + self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype]) + + def test_fast_transforms(self): + x = mx.random.uniform(shape=(2, 2, 8)) + + defaults = (8, False, 10000.0, 1.0, 0) + dims, traditional, base, scale, offset = defaults + + # VJP + _, vjp_out = mx.vjp(lambda x: rope_orig(x, *defaults), (x,), (mx.ones_like(x),)) + _, vjp_fast_out = mx.vjp( + lambda x: mx.fast.rope( + x, dims, traditional=traditional, base=base, scale=scale, offset=offset + ), + (x,), + (mx.ones_like(x),), + ) + self.assertTrue(mx.allclose(vjp_out[0], vjp_fast_out[0])) + + # JVP + _, jvp_out = mx.jvp(lambda x: rope_orig(x, *defaults), (x,), (mx.ones_like(x),)) + _, jvp_fast_out = mx.jvp( + lambda x: mx.fast.rope( + x, dims, traditional=traditional, base=base, scale=scale, offset=offset + ), + (x,), + (mx.ones_like(x),), + ) + self.assertTrue(mx.allclose(jvp_out[0], jvp_fast_out[0])) + + # VMAP + x = mx.random.uniform(shape=(2, 2, 2, 8)) + vmap_out = mx.vmap(lambda x: rope_orig(x, *defaults))(x) + vmap_fast_out = mx.vmap( + lambda x: mx.fast.rope( + x, dims, traditional=traditional, base=base, scale=scale, offset=offset + ) + )(x) + self.assertTrue(mx.allclose(vmap_out, vmap_fast_out)) + + +if __name__ == "__main__": + unittest.main() From 35431a4ac8fbe286e42225f6bb5d521dbb8d7334 Mon Sep 17 00:00:00 2001 From: Diogo Date: Wed, 14 Feb 2024 17:14:58 -0500 Subject: [PATCH 29/42] Adds device context manager (#679) --- ACKNOWLEDGMENTS.md | 2 +- docs/src/conf.py | 1 + docs/src/python/devices_and_streams.rst | 3 +- mlx/ops.cpp | 10 --- mlx/ops.h | 6 +- mlx/utils.cpp | 10 +++ mlx/utils.h | 26 ++++++ python/mlx/utils.py | 1 - python/src/CMakeLists.txt | 1 + python/src/device.cpp | 14 +++- python/src/mlx.cpp | 3 + python/src/stream.cpp | 33 +++++++- python/src/utils.cpp | 81 ++++++++++++++++++ python/tests/test_device.py | 11 +++ python/tests/test_fft.py | 105 ++++++++++++------------ 15 files changed, 230 insertions(+), 77 deletions(-) create mode 100644 python/src/utils.cpp diff --git a/ACKNOWLEDGMENTS.md b/ACKNOWLEDGMENTS.md index 2a3c6c612..36aedc77a 100644 --- a/ACKNOWLEDGMENTS.md +++ b/ACKNOWLEDGMENTS.md @@ -10,7 +10,7 @@ MLX was developed with contributions from the following individuals: - Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. Added `LogicalAnd` and `LogicalOR` ops. - Juarez Bochi: Fixed bug in cross attention. - Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example. -- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile` and safetensor support +- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile`, `StreamContext`, `stream` and safetensor support - Gabrijel Boduljak: Added `mlx.core.linalg`, implemented `norm` method and `InstanceNorm` layer. Implemented ``MaxPool1d``, ``MaxPool2d``, ``AvgPool1d``, ``AvgPool2d``. diff --git a/docs/src/conf.py b/docs/src/conf.py index bec2c976c..0654cf53c 100644 --- a/docs/src/conf.py +++ b/docs/src/conf.py @@ -26,6 +26,7 @@ extensions = [ python_use_unqualified_type_names = True autosummary_generate = True +autosummary_filename_map = {"mlx.core.Stream": "stream_class"} intersphinx_mapping = { "https://docs.python.org/3": None, diff --git a/docs/src/python/devices_and_streams.rst b/docs/src/python/devices_and_streams.rst index bb9dfae2f..e16ab9875 100644 --- a/docs/src/python/devices_and_streams.rst +++ b/docs/src/python/devices_and_streams.rst @@ -9,9 +9,10 @@ Devices and Streams :toctree: _autosummary Device + Stream default_device set_default_device - Stream default_stream new_stream set_default_stream + stream diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 01ee6d388..549d26512 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -59,16 +59,6 @@ Dtype at_least_float(const Dtype& d) { } // namespace -Stream to_stream(StreamOrDevice s) { - if (std::holds_alternative(s)) { - return default_stream(default_device()); - } else if (std::holds_alternative(s)) { - return default_stream(std::get(s)); - } else { - return std::get(s); - } -} - array arange( double start, double stop, diff --git a/mlx/ops.h b/mlx/ops.h index a4b1dd1ef..f7036b8c6 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -3,18 +3,14 @@ #pragma once #include -#include #include "mlx/array.h" #include "mlx/device.h" #include "mlx/stream.h" +#include "mlx/utils.h" namespace mlx::core { -using StreamOrDevice = std::variant; - -Stream to_stream(StreamOrDevice s); - /** Creation operations */ /** diff --git a/mlx/utils.cpp b/mlx/utils.cpp index eece43717..c6365beb9 100644 --- a/mlx/utils.cpp +++ b/mlx/utils.cpp @@ -7,6 +7,16 @@ namespace mlx::core { +Stream to_stream(StreamOrDevice s) { + if (std::holds_alternative(s)) { + return default_stream(default_device()); + } else if (std::holds_alternative(s)) { + return default_stream(std::get(s)); + } else { + return std::get(s); + } +} + void PrintFormatter::print(std::ostream& os, bool val) { if (capitalize_bool) { os << (val ? "True" : "False"); diff --git a/mlx/utils.h b/mlx/utils.h index f28970369..88f47e3e1 100644 --- a/mlx/utils.h +++ b/mlx/utils.h @@ -2,6 +2,8 @@ #pragma once +#include + #include "array.h" #include "device.h" #include "dtype.h" @@ -9,6 +11,30 @@ namespace mlx::core { +using StreamOrDevice = std::variant; +Stream to_stream(StreamOrDevice s); + +struct StreamContext { + public: + StreamContext(StreamOrDevice s) : _stream(default_stream(default_device())) { + if (std::holds_alternative(s)) { + throw std::runtime_error( + "[StreamContext] Invalid argument, please specify a stream or device."); + } + auto _s = to_stream(s); + set_default_device(_s.device); + set_default_stream(_s); + } + + ~StreamContext() { + set_default_device(_stream.device); + set_default_stream(_stream); + } + + private: + Stream _stream; +}; + struct PrintFormatter { inline void print(std::ostream& os, bool val); inline void print(std::ostream& os, int16_t val); diff --git a/python/mlx/utils.py b/python/mlx/utils.py index 137a8aae4..802b03831 100644 --- a/python/mlx/utils.py +++ b/python/mlx/utils.py @@ -1,5 +1,4 @@ # Copyright © 2023 Apple Inc. - from collections import defaultdict diff --git a/python/src/CMakeLists.txt b/python/src/CMakeLists.txt index 7dd862033..4df503a4a 100644 --- a/python/src/CMakeLists.txt +++ b/python/src/CMakeLists.txt @@ -14,6 +14,7 @@ pybind11_add_module( ${CMAKE_CURRENT_SOURCE_DIR}/random.cpp ${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp ${CMAKE_CURRENT_SOURCE_DIR}/constants.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp ) if (NOT MLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY) diff --git a/python/src/device.cpp b/python/src/device.cpp index 8c36f0f85..c88144520 100644 --- a/python/src/device.cpp +++ b/python/src/device.cpp @@ -12,7 +12,8 @@ using namespace py::literals; using namespace mlx::core; void init_device(py::module_& m) { - auto device_class = py::class_(m, "Device"); + auto device_class = py::class_( + m, "Device", R"pbdoc(A device to run operations on.)pbdoc"); py::enum_(m, "DeviceType") .value("cpu", Device::DeviceType::cpu) .value("gpu", Device::DeviceType::gpu) @@ -39,6 +40,13 @@ void init_device(py::module_& m) { py::implicitly_convertible(); - m.def("default_device", &default_device); - m.def("set_default_device", &set_default_device, "device"_a); + m.def( + "default_device", + &default_device, + R"pbdoc(Get the default device.)pbdoc"); + m.def( + "set_default_device", + &set_default_device, + "device"_a, + R"pbdoc(Set the default device.)pbdoc"); } diff --git a/python/src/mlx.cpp b/python/src/mlx.cpp index ee0f469f9..5fb9e74e2 100644 --- a/python/src/mlx.cpp +++ b/python/src/mlx.cpp @@ -18,6 +18,7 @@ void init_fft(py::module_&); void init_linalg(py::module_&); void init_constants(py::module_&); void init_extensions(py::module_&); +void init_utils(py::module_&); PYBIND11_MODULE(core, m) { m.doc() = "mlx: A framework for machine learning on Apple silicon."; @@ -35,5 +36,7 @@ PYBIND11_MODULE(core, m) { init_linalg(m); init_constants(m); init_extensions(m); + init_utils(m); + m.attr("__version__") = TOSTRING(_VERSION_); } diff --git a/python/src/stream.cpp b/python/src/stream.cpp index 7b1b2f55d..768795fc1 100644 --- a/python/src/stream.cpp +++ b/python/src/stream.cpp @@ -12,7 +12,12 @@ using namespace py::literals; using namespace mlx::core; void init_stream(py::module_& m) { - py::class_(m, "Stream") + py::class_( + m, + "Stream", + R"pbdoc( + A stream for running operations on a given device. + )pbdoc") .def(py::init(), "index"_a, "device"_a) .def_readonly("device", &Stream::device) .def( @@ -28,7 +33,27 @@ void init_stream(py::module_& m) { py::implicitly_convertible(); - m.def("default_stream", &default_stream, "device"_a); - m.def("set_default_stream", &set_default_stream, "stream"_a); - m.def("new_stream", &new_stream, "device"_a); + m.def( + "default_stream", + &default_stream, + "device"_a, + R"pbdoc(Get the device's default stream.)pbdoc"); + m.def( + "set_default_stream", + &set_default_stream, + "stream"_a, + R"pbdoc( + Set the default stream. + + This will make the given stream the default for the + streams device. It will not change the default device. + + Args: + stream (stream): Stream to make the default. + )pbdoc"); + m.def( + "new_stream", + &new_stream, + "device"_a, + R"pbdoc(Make a new stream on the given device.)pbdoc"); } diff --git a/python/src/utils.cpp b/python/src/utils.cpp new file mode 100644 index 000000000..c07016709 --- /dev/null +++ b/python/src/utils.cpp @@ -0,0 +1,81 @@ + +#include "mlx/utils.h" +#include +#include +#include + +namespace py = pybind11; +using namespace py::literals; +using namespace mlx::core; + +// Slightly different from the original, with python context on init we are not +// in the context yet. Only create the inner context on enter then delete on +// exit. +class PyStreamContext { + public: + PyStreamContext(StreamOrDevice s) : _inner(nullptr) { + if (std::holds_alternative(s)) { + throw std::runtime_error( + "[StreamContext] Invalid argument, please specify a stream or device."); + } + _s = s; + } + + void enter() { + _inner = new StreamContext(_s); + } + + void exit() { + if (_inner != nullptr) { + delete _inner; + _inner = nullptr; + } + } + + private: + StreamOrDevice _s; + StreamContext* _inner; +}; + +void init_utils(py::module_& m) { + py::class_(m, "StreamContext", R"pbdoc( + A context manager for setting the current device and stream. + + See :func:`stream` for usage. + + Args: + s: The stream or device to set as the default. + )pbdoc") + .def(py::init(), "s"_a) + .def("__enter__", [](PyStreamContext& scm) { scm.enter(); }) + .def( + "__exit__", + [](PyStreamContext& scm, + const std::optional& exc_type, + const std::optional& exc_value, + const std::optional& traceback) { scm.exit(); }); + m.def( + "stream", + [](StreamOrDevice s) { return PyStreamContext(s); }, + "s"_a, + R"pbdoc( + Create a context manager to set the default device and stream. + + Args: + s: The :obj:`Stream` or :obj:`Device` to set as the default. + + Returns: + A context manager that sets the default device and stream. + + Example: + + .. code-block::python + + import mlx.core as mx + + # Create a context manager for the default device and stream. + with mx.stream(mx.cpu): + # Operations here will use mx.cpu by default. + pass + )pbdoc"); +} diff --git a/python/tests/test_device.py b/python/tests/test_device.py index 8aac105bc..53826cad7 100644 --- a/python/tests/test_device.py +++ b/python/tests/test_device.py @@ -38,6 +38,17 @@ class TestDevice(mlx_tests.MLXTestCase): # Restore device mx.set_default_device(device) + @unittest.skipIf(not mx.metal.is_available(), "Metal is not available") + def test_device_context(self): + default = mx.default_device() + diff = mx.cpu if default == mx.gpu else mx.gpu + self.assertNotEqual(default, diff) + with mx.stream(diff): + a = mx.add(mx.zeros((2, 2)), mx.ones((2, 2))) + mx.eval(a) + self.assertEqual(mx.default_device(), diff) + self.assertEqual(mx.default_device(), default) + def test_op_on_device(self): x = mx.array(1.0) y = mx.array(1.0) diff --git a/python/tests/test_fft.py b/python/tests/test_fft.py index 4be12e21f..14473afa1 100644 --- a/python/tests/test_fft.py +++ b/python/tests/test_fft.py @@ -19,72 +19,73 @@ class TestFFT(mlx_tests.MLXTestCase): self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6)) def test_fft(self): - default = mx.default_device() - mx.set_default_device(mx.cpu) - def check_mx_np(op_mx, op_np, a_np, **kwargs): out_np = op_np(a_np, **kwargs) a_mx = mx.array(a_np) out_mx = op_mx(a_mx, **kwargs) self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6)) - r = np.random.rand(100).astype(np.float32) - i = np.random.rand(100).astype(np.float32) - a_np = r + 1j * i - check_mx_np(mx.fft.fft, np.fft.fft, a_np) + with mx.stream(mx.cpu): + r = np.random.rand(100).astype(np.float32) + i = np.random.rand(100).astype(np.float32) + a_np = r + 1j * i + check_mx_np(mx.fft.fft, np.fft.fft, a_np) - # Check with slicing and padding - r = np.random.rand(100).astype(np.float32) - i = np.random.rand(100).astype(np.float32) - a_np = r + 1j * i - check_mx_np(mx.fft.fft, np.fft.fft, a_np, n=80) - check_mx_np(mx.fft.fft, np.fft.fft, a_np, n=120) + # Check with slicing and padding + r = np.random.rand(100).astype(np.float32) + i = np.random.rand(100).astype(np.float32) + a_np = r + 1j * i + check_mx_np(mx.fft.fft, np.fft.fft, a_np, n=80) + check_mx_np(mx.fft.fft, np.fft.fft, a_np, n=120) - # Check different axes - r = np.random.rand(100, 100).astype(np.float32) - i = np.random.rand(100, 100).astype(np.float32) - a_np = r + 1j * i - check_mx_np(mx.fft.fft, np.fft.fft, a_np, axis=0) - check_mx_np(mx.fft.fft, np.fft.fft, a_np, axis=1) + # Check different axes + r = np.random.rand(100, 100).astype(np.float32) + i = np.random.rand(100, 100).astype(np.float32) + a_np = r + 1j * i + check_mx_np(mx.fft.fft, np.fft.fft, a_np, axis=0) + check_mx_np(mx.fft.fft, np.fft.fft, a_np, axis=1) - # Check real fft - a_np = np.random.rand(100).astype(np.float32) - check_mx_np(mx.fft.rfft, np.fft.rfft, a_np) - check_mx_np(mx.fft.rfft, np.fft.rfft, a_np, n=80) - check_mx_np(mx.fft.rfft, np.fft.rfft, a_np, n=120) + # Check real fft + a_np = np.random.rand(100).astype(np.float32) + check_mx_np(mx.fft.rfft, np.fft.rfft, a_np) + check_mx_np(mx.fft.rfft, np.fft.rfft, a_np, n=80) + check_mx_np(mx.fft.rfft, np.fft.rfft, a_np, n=120) - # Check real inverse - r = np.random.rand(100, 100).astype(np.float32) - i = np.random.rand(100, 100).astype(np.float32) - a_np = r + 1j * i - check_mx_np(mx.fft.ifft, np.fft.ifft, a_np) - check_mx_np(mx.fft.ifft, np.fft.ifft, a_np, n=80) - check_mx_np(mx.fft.ifft, np.fft.ifft, a_np, n=120) - check_mx_np(mx.fft.irfft, np.fft.irfft, a_np) - check_mx_np(mx.fft.irfft, np.fft.irfft, a_np, n=80) - check_mx_np(mx.fft.irfft, np.fft.irfft, a_np, n=120) - - mx.set_default_device(default) + # Check real inverse + r = np.random.rand(100, 100).astype(np.float32) + i = np.random.rand(100, 100).astype(np.float32) + a_np = r + 1j * i + check_mx_np(mx.fft.ifft, np.fft.ifft, a_np) + check_mx_np(mx.fft.ifft, np.fft.ifft, a_np, n=80) + check_mx_np(mx.fft.ifft, np.fft.ifft, a_np, n=120) + check_mx_np(mx.fft.irfft, np.fft.irfft, a_np) + check_mx_np(mx.fft.irfft, np.fft.irfft, a_np, n=80) + check_mx_np(mx.fft.irfft, np.fft.irfft, a_np, n=120) def test_fftn(self): - default = mx.default_device() - mx.set_default_device(mx.cpu) + with mx.stream(mx.cpu): + r = np.random.randn(8, 8, 8).astype(np.float32) + i = np.random.randn(8, 8, 8).astype(np.float32) + a = r + 1j * i - r = np.random.randn(8, 8, 8).astype(np.float32) - i = np.random.randn(8, 8, 8).astype(np.float32) - a = r + 1j * i + axes = [None, (1, 2), (2, 1), (0, 2)] + shapes = [None, (10, 5), (5, 10)] + ops = [ + "fft2", + "ifft2", + "rfft2", + "irfft2", + "fftn", + "ifftn", + "rfftn", + "irfftn", + ] - axes = [None, (1, 2), (2, 1), (0, 2)] - shapes = [None, (10, 5), (5, 10)] - ops = ["fft2", "ifft2", "rfft2", "irfft2", "fftn", "ifftn", "rfftn", "irfftn"] - - for op, ax, s in itertools.product(ops, axes, shapes): - x = a - if op in ["rfft2", "rfftn"]: - x = r - self.check_mx_np(op, x, axes=ax, s=s) - - mx.set_default_device(default) + for op, ax, s in itertools.product(ops, axes, shapes): + x = a + if op in ["rfft2", "rfftn"]: + x = r + self.check_mx_np(op, x, axes=ax, s=s) if __name__ == "__main__": From 85143fecdd87e7b2e390aa0c39123195ef485cf6 Mon Sep 17 00:00:00 2001 From: toji Date: Thu, 15 Feb 2024 20:55:38 +0530 Subject: [PATCH 30/42] improved error msg for invalid axis(`mx.split`) (#685) * improved error msg for invalid axis(`mx.split`) * Apply suggestions from code review Co-authored-by: Awni Hannun * fixed formatting issue --------- Co-authored-by: Awni Hannun --- mlx/ops.cpp | 7 +++++++ python/tests/test_ops.py | 3 +++ 2 files changed, 10 insertions(+) diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 549d26512..32af8a078 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -622,6 +622,13 @@ std::vector split( std::vector split(const array& a, int num_splits, int axis, StreamOrDevice s /* = {} */) { + auto ax = axis < 0 ? axis + a.ndim() : axis; + if (ax < 0 || ax >= a.ndim()) { + std::ostringstream msg; + msg << "Invalid axis " << axis << " passed to split" + << " for array with shape " << a.shape() << "."; + throw std::invalid_argument(msg.str()); + } auto q_and_r = std::ldiv(a.shape(axis), num_splits); if (q_and_r.rem) { std::ostringstream msg; diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 5588ebd62..66e683303 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -1027,6 +1027,9 @@ class TestOps(mlx_tests.MLXTestCase): self.assertEqual(y.tolist(), [[3, 4]]) self.assertEqual(z.tolist(), [[5, 6]]) + with self.assertRaises(ValueError): + mx.split(a, 3, axis=2) + a = mx.arange(8) x, y, z = mx.split(a, [1, 5]) self.assertEqual(x.tolist(), [0]) From 818cda16bcc4a68a1e971874c42f909df30145e3 Mon Sep 17 00:00:00 2001 From: Srimukh Sripada Date: Thu, 15 Feb 2024 20:26:20 +0100 Subject: [PATCH 31/42] Support LR schedulers (#334) * Add a few LR schedulers * Move parents's constructor call to the top * Fix docstring * refactor optimizers into two files * add docs * nit * Fix Callable type annotation for python 3.8 --------- Co-authored-by: Awni Hannun Co-authored-by: Angelos Katharopoulos --- docs/.gitignore | 1 + docs/src/python/optimizers.rst | 20 +---- .../python/optimizers/common_optimizers.rst | 20 +++++ .../src/python/{ => optimizers}/optimizer.rst | 0 docs/src/python/optimizers/schedulers.rst | 13 +++ python/mlx/optimizers/__init__.py | 4 + python/mlx/{ => optimizers}/optimizers.py | 71 ++++++++++----- python/mlx/optimizers/schedulers.py | 86 +++++++++++++++++++ python/src/array.cpp | 6 ++ python/tests/test_optimizers.py | 61 +++++++++++-- 10 files changed, 235 insertions(+), 47 deletions(-) create mode 100644 docs/src/python/optimizers/common_optimizers.rst rename docs/src/python/{ => optimizers}/optimizer.rst (100%) create mode 100644 docs/src/python/optimizers/schedulers.rst create mode 100644 python/mlx/optimizers/__init__.py rename python/mlx/{ => optimizers}/optimizers.py (92%) create mode 100644 python/mlx/optimizers/schedulers.py diff --git a/docs/.gitignore b/docs/.gitignore index 5c2693cb6..fa80a135e 100644 --- a/docs/.gitignore +++ b/docs/.gitignore @@ -1,2 +1,3 @@ src/python/_autosummary*/ src/python/nn/_autosummary*/ +src/python/optimizers/_autosummary*/ diff --git a/docs/src/python/optimizers.rst b/docs/src/python/optimizers.rst index 4ef43d50f..f437ddc15 100644 --- a/docs/src/python/optimizers.rst +++ b/docs/src/python/optimizers.rst @@ -31,20 +31,6 @@ model's parameters and the **optimizer state**. .. toctree:: - optimizer - -.. currentmodule:: mlx.optimizers - -.. autosummary:: - :toctree: _autosummary - :template: optimizers-template.rst - - SGD - RMSprop - Adagrad - Adafactor - AdaDelta - Adam - AdamW - Adamax - Lion + optimizers/optimizer + optimizers/common_optimizers + optimizers/schedulers diff --git a/docs/src/python/optimizers/common_optimizers.rst b/docs/src/python/optimizers/common_optimizers.rst new file mode 100644 index 000000000..41b3fba03 --- /dev/null +++ b/docs/src/python/optimizers/common_optimizers.rst @@ -0,0 +1,20 @@ +.. _common_optimizers: + +Common Optimizers +================= + +.. currentmodule:: mlx.optimizers + +.. autosummary:: + :toctree: _autosummary + :template: optimizers-template.rst + + SGD + RMSprop + Adagrad + Adafactor + AdaDelta + Adam + AdamW + Adamax + Lion diff --git a/docs/src/python/optimizer.rst b/docs/src/python/optimizers/optimizer.rst similarity index 100% rename from docs/src/python/optimizer.rst rename to docs/src/python/optimizers/optimizer.rst diff --git a/docs/src/python/optimizers/schedulers.rst b/docs/src/python/optimizers/schedulers.rst new file mode 100644 index 000000000..a83883ddb --- /dev/null +++ b/docs/src/python/optimizers/schedulers.rst @@ -0,0 +1,13 @@ +.. _schedulers: + +Schedulers +========== + +.. currentmodule:: mlx.optimizers + +.. autosummary:: + :toctree: _autosummary + + step_decay + exponential_decay + cosine_decay diff --git a/python/mlx/optimizers/__init__.py b/python/mlx/optimizers/__init__.py new file mode 100644 index 000000000..6e8e0ccd4 --- /dev/null +++ b/python/mlx/optimizers/__init__.py @@ -0,0 +1,4 @@ +# Copyright © 2023-2024 Apple Inc. + +from mlx.optimizers.optimizers import * +from mlx.optimizers.schedulers import * diff --git a/python/mlx/optimizers.py b/python/mlx/optimizers/optimizers.py similarity index 92% rename from python/mlx/optimizers.py rename to python/mlx/optimizers/optimizers.py index 4a53d4681..16928625f 100644 --- a/python/mlx/optimizers.py +++ b/python/mlx/optimizers/optimizers.py @@ -1,7 +1,7 @@ -# Copyright © 2023 Apple Inc. +# Copyright © 2023-2024 Apple Inc. import math -from typing import List, Optional, Tuple +from typing import Callable, List, Optional, Tuple, Union import mlx.core as mx from mlx.utils import tree_map @@ -12,9 +12,10 @@ class Optimizer: optimizer on a per-parameter basis and apply it to a parameter tree. """ - def __init__(self): + def __init__(self, schedulers=None): self._initialized = False - self._state = {} + self._state = {"step": mx.array(0, mx.uint64)} + self._schedulers = {k: v for k, v in (schedulers or {}).items()} def update(self, model: "mlx.nn.Module", gradients: dict): """Apply the gradients to the parameters of the model and update the @@ -44,9 +45,8 @@ class Optimizer: >>> optimizer = optim.SGD(learning_rate=1e-1, momentum=0.9) >>> model = nn.Linear(2, 2) >>> optimizer.init(model.trainable_parameters()) - >>> optimizer.state - {'learning_rate': array(0.1, dtype=float32), 'weight': {'v': array([[0, 0], - [0, 0]], dtype=float32)}, 'bias': {'v': array([0, 0], dtype=float32)}} + >>> optimizer.state.keys() + dict_keys(['step', 'learning_rate', 'weight', 'bias']) """ self._state.update(tree_map(lambda x: {}, parameters)) tree_map(self.init_single, parameters, self._state) @@ -76,6 +76,15 @@ class Optimizer: """ if not self._initialized: self.init(gradients) + + # Update any scheduled variables + for param, scheduler in self._schedulers.items(): + self.state[param] = scheduler(self.step) + + # Increment the step + self.state["step"] = self.step + 1 + + # Apply the update return tree_map(self.apply_single, gradients, parameters, self.state) def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict): @@ -97,14 +106,31 @@ class Optimizer: def state(self, state: dict): self._state = state + @property + def step(self): + return self.state["step"] + @property def learning_rate(self): return self.state["learning_rate"] @learning_rate.setter - def learning_rate(self, learning_rate: mx.array): + def learning_rate(self, learning_rate: Union[float, mx.array]): self.state["learning_rate"] = mx.array(learning_rate) + def _maybe_schedule( + self, name: str, param: Union[float, Callable[[mx.array], mx.array]] + ): + """ + To be used by derived classes to optionally put a parameter on a schedule. + """ + if isinstance(param, Callable): + self._schedulers[name] = param + param = param(self.step) + else: + param = mx.array(param) + self.state[name] = param + class SGD(Optimizer): r"""The stochastic gradient descent optimizer. @@ -117,7 +143,7 @@ class SGD(Optimizer): w_{t+1} &= w_t - \lambda v_{t+1} Args: - learning_rate (float): The learning rate :math:`\lambda`. + learning_rate (float or callable): The learning rate :math:`\lambda`. momentum (float, optional): The momentum strength :math:`\mu`. Default: ``0`` weight_decay (float, optional): The weight decay (L2 penalty). Default: ``0`` dampening (float, optional): Dampening for momentum :math:`\tau`. Default: ``0`` @@ -126,7 +152,7 @@ class SGD(Optimizer): def __init__( self, - learning_rate: float, + learning_rate: Union[float, Callable[[mx.array], mx.array]], momentum: float = 0.0, weight_decay: float = 0.0, dampening: float = 0.0, @@ -138,7 +164,7 @@ class SGD(Optimizer): ) super().__init__() - self.learning_rate = learning_rate + self._maybe_schedule("learning_rate", learning_rate) self.momentum = momentum self.weight_decay = weight_decay self.dampening = dampening @@ -194,7 +220,7 @@ class RMSprop(Optimizer): def __init__(self, learning_rate: float, alpha: float = 0.99, eps: float = 1e-8): super().__init__() - self.learning_rate = learning_rate + self._maybe_schedule("learning_rate", learning_rate) self.alpha = alpha self.eps = eps @@ -246,7 +272,7 @@ class Adagrad(Optimizer): def __init__(self, learning_rate: float, eps: float = 1e-8): super().__init__() - self.learning_rate = learning_rate + self._maybe_schedule("learning_rate", learning_rate) self.eps = eps if self.eps < 0.0: @@ -295,7 +321,7 @@ class AdaDelta(Optimizer): def __init__(self, learning_rate: float, rho: float = 0.9, eps: float = 1e-6): super().__init__() - self.learning_rate = learning_rate + self._maybe_schedule("learning_rate", learning_rate) self.rho = rho self.eps = eps if self.rho < 0.0: @@ -361,7 +387,7 @@ class Adam(Optimizer): ): super().__init__() - self.learning_rate = learning_rate + self._maybe_schedule("learning_rate", learning_rate) self.betas = betas self.eps = eps @@ -526,7 +552,7 @@ class Lion(Optimizer): ): super().__init__() - self.learning_rate = learning_rate + self._maybe_schedule("learning_rate", learning_rate) self.betas = betas self.weight_decay = weight_decay @@ -596,7 +622,7 @@ class Adafactor(Optimizer): ): super().__init__() if learning_rate is not None: - self.learning_rate = learning_rate + self._maybe_schedule("learning_rate", learning_rate) self.eps = eps self.clip_threshold = clip_threshold self.decay_rate = decay_rate @@ -608,7 +634,6 @@ class Adafactor(Optimizer): def init_single(self, parameter: mx.array, state: dict): """Initialize optimizer state""" - state["step"] = 0 if parameter.ndim >= 2: shape = parameter.shape dtype = parameter.dtype @@ -626,10 +651,11 @@ class Adafactor(Optimizer): def _compute_learning_rate(self, step, parameter_rms): if self.relative_step: min_step = 1e-6 * step if self.warmup_init else 1e-2 - relative_step_size = min(min_step, 1 / math.sqrt(step)) + relative_step_size = mx.minimum(min_step, mx.rsqrt(step)) else: - relative_step_size = self.learning_rate.astype(parameter_rms) + relative_step_size = self.learning_rate + relative_step_size = relative_step_size.astype(parameter_rms.dtype) parameter_scale = 1.0 if self.scale_parameter: parameter_scale = mx.maximum(self.eps[1], parameter_rms) @@ -648,13 +674,12 @@ class Adafactor(Optimizer): """Performs the Adafactor parameter and state update.""" factored = gradient.ndim >= 2 - step = state["step"] + 1 - state["step"] = step + step = self.step use_first_moment = self.beta_1 is not None parameter_rms = self._compute_rms(parameter) learning_rate = self._compute_learning_rate(step, parameter_rms) - beta_2 = 1.0 - math.pow(step, self.decay_rate) + beta_2 = 1.0 - (step**self.decay_rate).astype(parameter_rms.dtype) update = mx.square(gradient) + self.eps[0] if factored: diff --git a/python/mlx/optimizers/schedulers.py b/python/mlx/optimizers/schedulers.py new file mode 100644 index 000000000..da058c03a --- /dev/null +++ b/python/mlx/optimizers/schedulers.py @@ -0,0 +1,86 @@ +# Copyright © 2023-2024 Apple Inc. + +import math + +import mlx.core as mx + + +def exponential_decay(init: float, decay_rate: float): + r"""Make an exponential decay scheduler. + + Args: + init (float): Initial value. + decay_rate (float): Multiplicative factor to decay by. + + Example: + >>> lr_schedule = optim.exponential_decay(1e-1, 0.9) + >>> optimizer = optim.SGD(learning_rate=lr_schedule) + >>> optimizer.learning_rate + array(0.1, dtype=float32) + >>> + >>> for _ in range(5): optimizer.update({}, {}) + ... + >>> optimizer.learning_rate + array(0.06561, dtype=float32) + """ + + def schedule(step): + return init * decay_rate**step + + return schedule + + +def step_decay(init: float, decay_rate: float, step_size: int): + r"""Make a step decay scheduler. + + Args: + init (float): Initial value. + decay_rate (float): Multiplicative factor to decay by. + step_size (int): Decay every ``step_size`` steps. + + Example: + + >>> lr_schedule = optim.step_decay(1e-1, 0.9, 10) + >>> optimizer = optim.SGD(learning_rate=lr_schedule) + >>> optimizer.learning_rate + array(0.1, dtype=float32) + >>> + >>> for _ in range(21): optimizer.update({}, {}) + ... + >>> optimizer.learning_rate + array(0.081, dtype=float32) + """ + + def schedule(step): + return init * (decay_rate ** (step // step_size)) + + return schedule + + +def cosine_decay(init: float, decay_steps: int): + r"""Make a cosine decay scheduler. + + Args: + init (float): Initial value. + decay_steps (int): Number of steps to decay over. The decayed + value is constant for steps beyond ``decay_steps``. + + Example: + + >>> lr_schedule = optim.cosine_decay(1e-1, 1000) + >>> optimizer = optim.SGD(learning_rate=lr_schedule) + >>> optimizer.learning_rate + array(0.1, dtype=float32) + >>> + >>> for _ in range(5): optimizer.update({}, {}) + ... + >>> optimizer.learning_rate + array(0.0999961, dtype=float32) + """ + + def scheduler(step): + s = mx.minimum(step, decay_steps) + decay = 0.5 * (1.0 + mx.cos((math.pi / decay_steps) * s)) + return init * decay + + return scheduler diff --git a/python/src/array.cpp b/python/src/array.cpp index 4395d50e6..57b867dbc 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -971,6 +971,12 @@ void init_array(py::module_& m) { return power(a, to_array(v, a.dtype())); }, "other"_a) + .def( + "__rpow__", + [](const array& a, const ScalarOrArray v) { + return power(to_array(v, a.dtype()), a); + }, + "other"_a) .def( "__ipow__", [](array& a, const ScalarOrArray v) { diff --git a/python/tests/test_optimizers.py b/python/tests/test_optimizers.py index f894a7510..f978943de 100644 --- a/python/tests/test_optimizers.py +++ b/python/tests/test_optimizers.py @@ -1,6 +1,7 @@ # Copyright © 2023 Apple Inc. import inspect +import math import unittest from functools import partial @@ -15,9 +16,12 @@ from mlx.utils import tree_flatten, tree_map def get_all_optimizers(): classes = dict() for name, obj in inspect.getmembers(opt): - if inspect.isclass(obj): - if obj.__name__ not in ["Optimizer"]: - classes[name] = obj + if ( + inspect.isclass(obj) + and issubclass(obj, opt.Optimizer) + and obj != opt.Optimizer + ): + classes[name] = obj return classes @@ -204,18 +208,16 @@ class TestOptimizers(mlx_tests.MLXTestCase): x = mx.zeros((5, 5)) grad = mx.ones_like(x) optimizer = opt.Adafactor() - optimizer.init(x) for _ in range(2): - xp = optimizer.apply_single(grad, x, optimizer.state) + xp = optimizer.apply_gradients(grad, x) self.assertEqual(xp.dtype, x.dtype) self.assertEqual(xp.shape, x.shape) x = mx.zeros((5, 5), mx.float16) grad = mx.ones_like(x) optimizer = opt.Adafactor() - optimizer.init(x) for _ in range(2): - xp = optimizer.apply_single(grad, x, optimizer.state) + xp = optimizer.apply_gradients(grad, x) self.assertEqual(xp.dtype, x.dtype) self.assertEqual(xp.shape, x.shape) self.assertEqual(optimizer.state["step"], 2) @@ -294,5 +296,50 @@ class TestOptimizers(mlx_tests.MLXTestCase): self.assertTrue(mx.allclose(result["w"], mx.full((5, 5), 3.0))) +class TestSchedulers(unittest.TestCase): + def test_decay_lr(self): + for optim_class in optimizers_dict.values(): + lr_schedule = opt.step_decay(1e-1, 0.9, 1000) + optimizer = optim_class(learning_rate=lr_schedule) + + params = {"w": mx.ones((5, 5))} + grads = tree_map(lambda x: mx.ones_like(x), params) + + for it in range(10): + expected_lr = 0.1 * (0.9**it) + self.assertAlmostEqual(optimizer.learning_rate, expected_lr, delta=1e-7) + return optimizer.apply_gradients(grads, params) + + def test_step_decay(self): + lr_schedule = opt.step_decay(1e-1, 0.9, 1000) + lr = lr_schedule(2500) + expected_lr = 0.1 * (0.9**2) + self.assertAlmostEqual(lr, expected_lr, delta=1e-7) + + def test_exponential_decay(self): + lr_schedule = opt.exponential_decay(1e-1, 0.99) + lr = lr_schedule(10) + expected_lr = 0.1 * (0.99**10) + self.assertAlmostEqual(lr, expected_lr, delta=1e-7) + + def test_cosine_decay(self): + lr_schedule = opt.cosine_decay(0.1, 10) + lr = lr_schedule(4) + expected_lr = 0.1 * 0.5 * (1.0 + math.cos(math.pi * 4 / 10)) + self.assertAlmostEqual(lr, expected_lr, delta=1e-7) + + def test_compile_with_schedule(self): + lr_schedule = opt.exponential_decay(1e-1, 0.9) + optimizer = opt.SGD(learning_rate=lr_schedule) + + @partial(mx.compile, inputs=optimizer.state, outputs=optimizer.state) + def update(): + optimizer.update({}, {}) + + for step in range(5): + update() + self.assertAlmostEqual(lr_schedule(step), optimizer.learning_rate.item()) + + if __name__ == "__main__": unittest.main() From 165abf0e4c22e6ee1c3d501737e44752f355ff85 Mon Sep 17 00:00:00 2001 From: Mike Drob Date: Thu, 15 Feb 2024 19:30:35 -0600 Subject: [PATCH 32/42] Auto-run PRs from contributors (#692) --- .circleci/config.yml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/.circleci/config.yml b/.circleci/config.yml index 5f26778c4..250f35faa 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -1,5 +1,8 @@ version: 2.1 +orbs: + apple: ml-explore/pr-approval@0.1.0 + parameters: nightly_build: type: boolean @@ -253,6 +256,8 @@ workflows: jobs: - hold: type: approval + - apple/authenticate: + context: pr-approval - mac_build_and_test: requires: [ hold ] - linux_build_and_test: From a000d2288c206c947c5a52199a8d65bfb00802ec Mon Sep 17 00:00:00 2001 From: Nripesh Niketan <86844847+NripeshN@users.noreply.github.com> Date: Fri, 16 Feb 2024 18:01:59 +0400 Subject: [PATCH 33/42] feat: update black pre-commit hook to 24.2.0 (#696) --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 279ab5c91..dd5ebec30 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -5,7 +5,7 @@ repos: - id: clang-format # Using this mirror lets us use mypyc-compiled black, which is about 2x faster - repo: https://github.com/psf/black-pre-commit-mirror - rev: 24.1.1 + rev: 24.2.0 hooks: - id: black - repo: https://github.com/pycqa/isort From bf7cd29970b603ed3e177cb2f4f6dc246a653555 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 16 Feb 2024 08:44:08 -0800 Subject: [PATCH 34/42] version bump (#698) --- CMakeLists.txt | 2 +- setup.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index e889353ce..8b1a4cf52 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -18,7 +18,7 @@ option(MLX_BUILD_METAL "Build metal backend" ON) option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF) if(NOT MLX_VERSION) - set(MLX_VERSION 0.2.0) + set(MLX_VERSION 0.3.0) endif() # --------------------- Processor tests ------------------------- diff --git a/setup.py b/setup.py index cce1f8537..961655419 100644 --- a/setup.py +++ b/setup.py @@ -152,7 +152,7 @@ if __name__ == "__main__": setup( name="mlx", - version=get_version("0.2.0"), + version=get_version("0.3.0"), author="MLX Contributors", author_email="mlx@group.apple.com", description="A framework for machine learning on Apple silicon.", From c3965fc5ee665083714b391f249ec9029895f216 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 16 Feb 2024 19:16:39 -0800 Subject: [PATCH 35/42] Separate fast ops and primitives (#699) --- mlx/backend/common/rope.cpp | 3 +- mlx/backend/metal/rope.cpp | 3 +- mlx/backend/no_metal/primitives.cpp | 2 +- mlx/fast.cpp | 2 + mlx/fast.h | 66 +--------------------------- mlx/fast_primitives.h | 68 +++++++++++++++++++++++++++++ 6 files changed, 74 insertions(+), 70 deletions(-) create mode 100644 mlx/fast_primitives.h diff --git a/mlx/backend/common/rope.cpp b/mlx/backend/common/rope.cpp index c0c2bba8e..15b5de7e5 100644 --- a/mlx/backend/common/rope.cpp +++ b/mlx/backend/common/rope.cpp @@ -1,7 +1,6 @@ // Copyright © 2023-2024 Apple Inc. -#include "mlx/fast.h" -#include "mlx/primitives.h" +#include "mlx/fast_primitives.h" namespace mlx::core::fast { diff --git a/mlx/backend/metal/rope.cpp b/mlx/backend/metal/rope.cpp index 29295f3ac..fdea57985 100644 --- a/mlx/backend/metal/rope.cpp +++ b/mlx/backend/metal/rope.cpp @@ -1,8 +1,7 @@ // Copyright © 2023-2024 Apple Inc. #include "mlx/backend/metal/utils.h" -#include "mlx/fast.h" -#include "mlx/primitives.h" +#include "mlx/fast_primitives.h" namespace mlx::core::fast { diff --git a/mlx/backend/no_metal/primitives.cpp b/mlx/backend/no_metal/primitives.cpp index bd4026e2c..8e66f56b3 100644 --- a/mlx/backend/no_metal/primitives.cpp +++ b/mlx/backend/no_metal/primitives.cpp @@ -1,7 +1,7 @@ // Copyright © 2023-2024 Apple Inc. #include "mlx/primitives.h" -#include "mlx/fast.h" +#include "mlx/fast_primitives.h" #define NO_GPU_MULTI(func) \ void func::eval_gpu( \ diff --git a/mlx/fast.cpp b/mlx/fast.cpp index 96d4f03ce..ee28138f1 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -1,6 +1,8 @@ // Copyright © 2023-2024 Apple Inc. #include "mlx/fast.h" +#include "mlx/fast_primitives.h" +#include "mlx/ops.h" #include "mlx/transforms.h" namespace mlx::core::fast { diff --git a/mlx/fast.h b/mlx/fast.h index 5deac0cdb..48ac90a5a 100644 --- a/mlx/fast.h +++ b/mlx/fast.h @@ -2,40 +2,10 @@ #pragma once -#include "mlx/ops.h" -#include "mlx/primitives.h" +#include "mlx/utils.h" namespace mlx::core::fast { -// Custom primitive accepts a fallback function which it uses for -// transformations. Transformations are virtual so that derived classes may to -// override the default behavior -class Custom : public Primitive { - public: - explicit Custom( - Stream stream, - std::function(std::vector)> fallback) - : Primitive(stream), fallback_(fallback){}; - - virtual std::pair, std::vector> vmap( - const std::vector& inputs, - const std::vector& axes) override; - - virtual std::vector jvp( - const std::vector& primals, - const std::vector& tangents, - const std::vector& argnums) override; - - virtual std::vector vjp( - const std::vector& primals, - const std::vector& cotangents, - const std::vector& argnums, - const std::vector& outputs) override; - - private: - std::function(std::vector)> fallback_; -}; - array rope( const array& x, int dims, @@ -45,38 +15,4 @@ array rope( int offset, StreamOrDevice s /* = {} */); -class RoPE : public Custom { - public: - RoPE( - Stream stream, - std::function(std::vector)> fallback, - int dims, - bool traditional, - float base, - float scale, - int offset) - : Custom(stream, fallback), - dims_(dims), - traditional_(traditional), - base_(base), - scale_(scale), - offset_(offset){}; - - void eval_cpu(const std::vector& inputs, std::vector& outputs) - override; - void eval_gpu(const std::vector& inputs, std::vector& outputs) - override; - - DEFINE_PRINT(RoPE) - bool is_equivalent(const Primitive& other) const override; - - private: - std::function(std::vector)> fallback_; - int dims_; - bool traditional_; - float base_; - float scale_; - int offset_; -}; - } // namespace mlx::core::fast diff --git a/mlx/fast_primitives.h b/mlx/fast_primitives.h new file mode 100644 index 000000000..acb5d0046 --- /dev/null +++ b/mlx/fast_primitives.h @@ -0,0 +1,68 @@ +#include "mlx/primitives.h" + +namespace mlx::core::fast { + +// Custom primitive accepts a fallback function which it uses for +// transformations. Transformations are virtual so that derived classes may to +// override the default behavior +class Custom : public Primitive { + public: + explicit Custom( + Stream stream, + std::function(std::vector)> fallback) + : Primitive(stream), fallback_(fallback){}; + + virtual std::pair, std::vector> vmap( + const std::vector& inputs, + const std::vector& axes) override; + + virtual std::vector jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) override; + + virtual std::vector vjp( + const std::vector& primals, + const std::vector& cotangents, + const std::vector& argnums, + const std::vector& outputs) override; + + private: + std::function(std::vector)> fallback_; +}; + +class RoPE : public Custom { + public: + RoPE( + Stream stream, + std::function(std::vector)> fallback, + int dims, + bool traditional, + float base, + float scale, + int offset) + : Custom(stream, fallback), + dims_(dims), + traditional_(traditional), + base_(base), + scale_(scale), + offset_(offset){}; + + void eval_cpu(const std::vector& inputs, std::vector& outputs) + override; + void eval_gpu(const std::vector& inputs, std::vector& outputs) + override; + + DEFINE_PRINT(RoPE) + bool is_equivalent(const Primitive& other) const override; + + private: + std::function(std::vector)> fallback_; + int dims_; + bool traditional_; + float base_; + float scale_; + int offset_; +}; + +} // namespace mlx::core::fast From dc937b8ed30160d5868a39ad00d4cca4ee730289 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Sat, 17 Feb 2024 06:54:32 -0800 Subject: [PATCH 36/42] CPU compile (#691) * build and load shared object for cpu compile * nits * cpu compile tests pass * cpu compile tests pass * fix preamble for g++ * donation * fix gpu buffer donation * reuse prebuilt libraries * faster contiguity conditoins * fix test * rid compiler warning * fast erf * Fix float16 for compile and add more types to cpu compile * Remove a forgotten comment * use cached libs * nits --------- Co-authored-by: Angelos Katharopoulos --- mlx/backend/accelerate/primitives.cpp | 1 - mlx/backend/common/compiled.cpp | 531 +++++++- mlx/backend/common/compiled.h | 52 + mlx/backend/common/compiled_preamble.h | 1121 +++++++++++++++++ mlx/backend/common/default_primitives.cpp | 1 - mlx/backend/metal/compiled.cpp | 148 +-- mlx/backend/metal/kernels/compiled_preamble.h | 2 + mlx/compile.cpp | 3 + mlx/fast_primitives.h | 4 +- mlx/primitives.h | 2 - mlx/types/complex.h | 35 +- mlx/utils.h | 2 +- tests/compile_tests.cpp | 6 +- 13 files changed, 1716 insertions(+), 192 deletions(-) create mode 100644 mlx/backend/common/compiled.h create mode 100644 mlx/backend/common/compiled_preamble.h diff --git a/mlx/backend/accelerate/primitives.cpp b/mlx/backend/accelerate/primitives.cpp index 499cc0ce4..4cccd35ae 100644 --- a/mlx/backend/accelerate/primitives.cpp +++ b/mlx/backend/accelerate/primitives.cpp @@ -33,7 +33,6 @@ DEFAULT(ArgSort) DEFAULT(AsStrided) DEFAULT(Broadcast) DEFAULT(Ceil) -DEFAULT_MULTI(Compiled) DEFAULT(Concatenate) DEFAULT(Copy) DEFAULT_MULTI(CustomVJP) diff --git a/mlx/backend/common/compiled.cpp b/mlx/backend/common/compiled.cpp index 149556530..8bf1f43e4 100644 --- a/mlx/backend/common/compiled.cpp +++ b/mlx/backend/common/compiled.cpp @@ -1,59 +1,506 @@ // Copyright © 2023-2024 Apple Inc. -#include +#include +#include +#include +#include "mlx/backend/common/compiled.h" +#include "mlx/backend/common/compiled_preamble.h" +#include "mlx/backend/common/utils.h" +#include "mlx/graph_utils.h" #include "mlx/primitives.h" +#include "mlx/utils.h" namespace mlx::core { -// Build the real tape -std::pair, std::vector> trace_to_real( - const std::vector& trace_tape, - const std::vector& trace_inputs, - const std::vector& trace_outputs, - const std::vector& inputs) { - std::unordered_map trace_to_real; - for (int i = 0; i < inputs.size(); ++i) { - trace_to_real.insert({trace_inputs[i].id(), inputs[i]}); - } - std::queue tape; - for (auto& a : trace_tape) { - // Find real inputs - std::vector real_inputs; - for (auto& in : a.inputs()) { - real_inputs.push_back(trace_to_real.at(in.id())); - } - tape.push( - array(a.shape(), a.dtype(), a.primitive_ptr(), std::move(real_inputs))); - trace_to_real.insert({a.id(), tape.back()}); - } - - std::vector outputs; - for (auto& o : trace_outputs) { - outputs.push_back(trace_to_real.at(o.id())); - } - return {tape, outputs}; +std::string get_temp_file(const std::string& name) { + return std::filesystem::temp_directory_path().append(name); } -void Compiled::eval( +std::string build_lib_name( + const std::vector& inputs, + const std::vector& outputs, + const std::vector& tape, + const std::unordered_set& constant_ids) { + std::ostringstream os; + std::ostringstream constant_hasher; + + // The primitives describing the tape. For unary and binary primitives this + // must be enough to describe the full computation. + for (auto& a : tape) { + a.primitive().print(os); + } + os << "_"; + + for (auto& x : inputs) { + if (constant_ids.find(x.id()) != constant_ids.end()) { + os << "C"; + print_constant(constant_hasher, x); + } else { + os << ((x.size() == 1) ? "S" : "V"); + } + } + os << "_"; + for (auto& x : inputs) { + if (constant_ids.find(x.id()) != constant_ids.end()) { + continue; + } + os << kindof(x.dtype()) << x.itemsize(); + } + os << "_" << std::hash{}(constant_hasher.str()); + + return os.str(); +} + +void print_constant(std::ostream& os, const array& x) { + switch (x.dtype()) { + case float32: + return print_float_constant(os, x); + case float16: + return print_float_constant(os, x); + case bfloat16: + return print_float_constant(os, x); + case complex64: + return print_complex_constant(os, x); + case int8: + return print_int_constant(os, x); + case int16: + return print_int_constant(os, x); + case int32: + return print_int_constant(os, x); + case int64: + return print_int_constant(os, x); + case uint8: + return print_int_constant(os, x); + case uint16: + return print_int_constant(os, x); + case uint32: + return print_int_constant(os, x); + case uint64: + return print_int_constant(os, x); + case bool_: + os << std::boolalpha << x.item(); + return; + default: + throw std::runtime_error("Unsupported constant type"); + } +} + +std::string get_type_string(Dtype d) { + switch (d) { + case float32: + return "float"; + case float16: + return "float16_t"; + case bfloat16: + return "bfloat16_t"; + case complex64: + return "complex64_t"; + case bool_: + return "bool"; + case int8: + return "int8_t"; + case int16: + return "int16_t"; + case int32: + return "int32_t"; + case int64: + return "int64_t"; + case uint8: + return "uint8_t"; + case uint16: + return "uint16_t"; + case uint32: + return "uint32_t"; + case uint64: + return "uint64_t"; + default: { + std::ostringstream msg; + msg << "Unsupported compilation type " << d; + throw std::runtime_error(msg.str()); + } + } +} + +inline bool is_scalar(const array& x) { + return x.size() == 1; +}; + +// Return a pointer to a compiled function +void* compile( + const std::string& kernel_name, + const std::string& source_code = "") { + struct DLib { + DLib(const std::string& libname) { + lib = dlopen(libname.c_str(), RTLD_NOW); + if (!lib) { + std::ostringstream msg; + msg << "Could not load C++ shared library " << dlerror(); + throw std::runtime_error(msg.str()); + } + } + + ~DLib() { + dlclose(lib); + } + void* lib; + }; + // Statics to cache compiled libraries and functions + static std::list libs; + static std::unordered_map kernels; + if (auto it = kernels.find(kernel_name); it != kernels.end()) { + return it->second; + } + if (source_code.empty()) { + return nullptr; + } + + std::ostringstream shared_lib_name; + shared_lib_name << "lib" << kernel_name << ".so"; + auto shared_lib_path = get_temp_file(shared_lib_name.str()); + bool lib_exists = false; + { + std::ifstream f(shared_lib_path.c_str()); + lib_exists = f.good(); + } + + if (!lib_exists) { + // Open source file and write source code to it + std::ostringstream source_file_name; + source_file_name << kernel_name << ".cpp"; + auto source_file_path = get_temp_file(source_file_name.str()); + + std::ofstream source_file(source_file_path); + source_file << source_code; + source_file.close(); + + std::ostringstream build_command; + build_command << "g++ -std=c++17 -O2 -Wall -fPIC -shared " + << source_file_path << " -o " << shared_lib_path; + std::string build_command_str = build_command.str(); + system(build_command_str.c_str()); + } + + // load library + libs.emplace_back(shared_lib_path); + + // Load function + void* fun = dlsym(libs.back().lib, kernel_name.c_str()); + if (!fun) { + std::ostringstream msg; + msg << "[Compile::eval_cpu] Failed to load compiled function " + << kernel_name << std::endl + << dlerror(); + throw std::runtime_error(msg.str()); + } + kernels.insert({kernel_name, fun}); + return fun; +} + +inline void build_kernel( + std::ostream& os, + const std::string& kernel_name, + const std::vector& inputs, + const std::vector& outputs, + const std::vector& tape, + const std::unordered_set& constant_ids, + bool contiguous, + int ndim) { + // All outputs should have the exact same shape and will be row contiguous + auto output_shape = outputs[0].shape(); + auto output_strides = outputs[0].strides(); + + // Constants are scalars that are captured by value and cannot change + auto is_constant = [&constant_ids](const array& x) { + return constant_ids.find(x.id()) != constant_ids.end(); + }; + + NodeNamer namer; + + // Start the kernel + os << "void " << kernel_name << "(void** args) {" << std::endl; + + // Add the input arguments + int cnt = 0; + for (auto& x : inputs) { + auto& xname = namer.get_name(x); + + // Skip constants from the input list + if (is_constant(x)) { + continue; + } + + auto tstr = get_type_string(x.dtype()); + os << " " << tstr << "* " << xname << " = (" << tstr << "*)args[" << cnt++ + << "];" << std::endl; + // Scalars and contiguous need no strides + if (!is_scalar(x) && !contiguous) { + os << " const size_t* " << xname << "_strides = (size_t*)args[" << cnt++ + << "];" << std::endl; + } + } + + // Add the output arguments + for (auto& x : outputs) { + auto tstr = get_type_string(x.dtype()); + os << " " << tstr << "* " << namer.get_name(x) << " = (" << tstr + << "*)args[" << cnt++ << "];" << std::endl; + } + // Add output strides and shape to extract the indices. + if (!contiguous) { + os << " const int* shape = (int*)args[" << cnt++ << "];" << std::endl; + } else { + os << " const size_t size = (size_t)args[" << cnt++ << "];" << std::endl; + } + + if (contiguous) { + os << " for (size_t i = 0; i < size; ++i) {" << std::endl; + } else { + for (int d = 0; d < ndim; ++d) { + os << " for (int i" << d << " = 0; i" << d << " < shape[" << d + << "]; ++i" << d << ") {" << std::endl; + } + } + + // Read the inputs in tmps + for (auto& x : inputs) { + auto& xname = namer.get_name(x); + + if (is_constant(x)) { + os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "; + print_constant(os, x); + os << ";" << std::endl; + } else if (is_scalar(x)) { + os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = " + << xname << "[0];" << std::endl; + } else if (contiguous) { + os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = " + << xname << "[i];" << std::endl; + } else { + os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = *" + << xname << ";" << std::endl; + } + } + + // Actually write the computation + for (auto& x : tape) { + os << " " << get_type_string(x.dtype()) << " tmp_" << namer.get_name(x) + << " = "; + if (is_static_cast(x.primitive())) { + os << "static_cast<" << get_type_string(x.dtype()) << ">(tmp_" + << namer.get_name(x.inputs()[0]) << ");" << std::endl; + } else { + x.primitive().print(os); + os << "()("; + for (int i = 0; i < x.inputs().size() - 1; i++) { + os << "tmp_" << namer.get_name(x.inputs()[i]) << ", "; + } + os << "tmp_" << namer.get_name(x.inputs().back()) << ");" << std::endl; + } + } + + // Write the outputs from tmps + for (auto& x : outputs) { + if (contiguous) { + os << " " << namer.get_name(x) << "[i] = tmp_" << namer.get_name(x) + << ";" << std::endl; + } else { + os << " *" << namer.get_name(x) << "++ = tmp_" << namer.get_name(x) + << ";" << std::endl; + } + } + + // Close loops + if (contiguous) { + os << " }" << std::endl; + } else { + for (int d = ndim - 1; d >= 0; --d) { + // Update pointers + for (auto& x : inputs) { + if (is_constant(x) || is_scalar(x)) { + continue; + } + auto& xname = namer.get_name(x); + os << " " << xname << " += " << xname << "_strides[" << d << "];" + << std::endl; + if (d < ndim - 1) { + os << " " << xname << " -= " << xname << "_strides[" << d + 1 << "]" + << " * shape[" << d + 1 << "];" << std::endl; + } + } + os << " }" << std::endl; + } + } + + // Finish the kernel + os << "}" << std::endl; +} + +void Compiled::eval_cpu( const std::vector& inputs, std::vector& outputs) { - // Make the a real tape from the tracers - auto [tape, real_outputs] = trace_to_real(tape_, inputs_, outputs_, inputs); - - // Run the tape - while (!tape.empty()) { - auto a = std::move(tape.front()); - tape.pop(); - auto outputs = a.outputs(); - a.primitive().eval_cpu(a.inputs(), outputs); - a.detach(); + if (kernel_lib_.empty()) { + kernel_lib_ = build_lib_name(inputs_, outputs_, tape_, constant_ids_); } - // Copy results into outputs - for (int o = 0; o < real_outputs.size(); ++o) { - outputs[o].copy_shared_buffer(real_outputs[o]); + // Figure out which kernel we are using + auto& shape = outputs[0].shape(); + bool contiguous = true; + { + bool all_contig = true; + bool all_row_contig = true; + bool all_col_contig = true; + int non_scalar_inputs = 0; + for (auto& x : inputs) { + if (x.size() == 1) { + continue; + } + non_scalar_inputs++; + bool shape_eq = x.shape() == shape; + all_contig &= (x.flags().contiguous && shape_eq); + all_row_contig &= (x.flags().row_contiguous && shape_eq); + all_col_contig &= (x.flags().col_contiguous && shape_eq); + } + if (non_scalar_inputs > 1 && !all_row_contig && !all_col_contig) { + contiguous = false; + } else if (non_scalar_inputs == 1 && !all_contig) { + contiguous = false; + } } + + // Handle all broadcasting and collect function input arguments + std::vector args; + std::vector> strides; + for (int i = 0; i < inputs.size(); i++) { + // Skip constants. + if (constant_ids_.find(inputs_[i].id()) != constant_ids_.end()) { + continue; + } + auto& x = inputs[i]; + args.push_back((void*)x.data()); + + if (contiguous || x.size() <= 1) { + continue; + } + + // Broadcast the input to the output shape. + std::vector xstrides; + int j = 0; + for (; j < shape.size() - x.ndim(); j++) { + if (shape[j] == 1) { + xstrides.push_back(outputs[0].strides()[j]); + } else { + xstrides.push_back(0); + } + } + for (int i = 0; i < x.ndim(); i++, j++) { + if (x.shape(i) == 1) { + if (shape[j] == 1) { + xstrides.push_back(outputs[0].strides()[j]); + } else { + xstrides.push_back(0); + } + } else { + xstrides.push_back(x.strides()[i]); + } + } + strides.push_back(std::move(xstrides)); + args.push_back(strides.back().data()); + } + + // Get the kernel name from the lib + int ndim = shape.size(); + bool dynamic = ndim >= 8; + auto kernel_name = kernel_lib_ + (contiguous ? "_contiguous" : "_strided_"); + if (!contiguous) { + kernel_name += std::to_string(shape.size()); + } + + // Get the function + auto fn_ptr = compile(kernel_name); + + // If it doesn't exist, compile it + if (fn_ptr == nullptr) { + std::ostringstream kernel; + kernel << preamble << std::endl; + kernel << "extern \"C\" {" << std::endl; + build_kernel( + kernel, + kernel_name, + inputs_, + outputs_, + tape_, + constant_ids_, + contiguous, + ndim); + // Close extern "C" + kernel << "}" << std::endl; + + // Compile and get function pointer + fn_ptr = compile(kernel_name, kernel.str()); + } + + // Allocate space for the outputs possibly with input donation + if (contiguous) { + int o = 0; + std::vector strides; + size_t data_size; + array::Flags flags; + for (int i = 0; i < inputs.size() && o < outputs.size(); ++i) { + auto& in = inputs[i]; + // Conditions for donation + // - Contiguous + // - Donatable + // - Correct size + // - Not a constant + if (in.flags().contiguous && in.size() > 1 && in.is_donatable() && + constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) { + outputs[o++].copy_shared_buffer(in); + } + // Get representative input flags to properly set non-donated outputs + if (strides.empty() && in.size() == outputs[0].size()) { + strides = in.strides(); + flags = in.flags(); + data_size = in.data_size(); + } + } + for (; o < outputs.size(); ++o) { + outputs[o].set_data( + allocator::malloc_or_wait(data_size * outputs[o].itemsize()), + data_size, + strides, + flags); + } + } else { + int o = 0; + for (int i = 0; i < inputs.size() && o < outputs.size(); ++i) { + auto& in = inputs[i]; + // Conditions for donation + // - Row contiguous + // - Donatable + // - Correct size + // - Not a constant + if (in.flags().row_contiguous && in.nbytes() == outputs[o].nbytes() && + in.is_donatable() && + constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) { + outputs[o++].copy_shared_buffer(in); + } + } + for (; o < outputs.size(); ++o) { + outputs[o].set_data(allocator::malloc_or_wait(outputs[o].nbytes())); + } + } + + for (auto& x : outputs) { + args.push_back(x.data()); + } + if (!contiguous) { + args.push_back((void*)outputs[0].shape().data()); + } else { + args.push_back((void*)outputs[0].data_size()); + } + auto fun = (void (*)(void**))fn_ptr; + fun(args.data()); } } // namespace mlx::core diff --git a/mlx/backend/common/compiled.h b/mlx/backend/common/compiled.h new file mode 100644 index 000000000..adbd5399c --- /dev/null +++ b/mlx/backend/common/compiled.h @@ -0,0 +1,52 @@ +// Copyright © 2023-2024 Apple Inc. +#pragma once + +#include +#include +#include + +#include "mlx/array.h" +#include "mlx/primitives.h" + +namespace mlx::core { + +inline bool is_static_cast(const Primitive& p) { + return ( + typeid(p) == typeid(Broadcast) || typeid(p) == typeid(Copy) || + typeid(p) == typeid(StopGradient) || typeid(p) == typeid(AsType)); +} + +std::string build_lib_name( + const std::vector& inputs, + const std::vector& outputs, + const std::vector& tape, + const std::unordered_set& constant_ids); + +std::string get_type_string(Dtype d); + +template +void print_float_constant(std::ostream& os, const array& x) { + auto old_precision = os.precision(); + os << std::setprecision(std::numeric_limits::digits10 + 1) + << x.item() << std::setprecision(old_precision); +} + +template +void print_int_constant(std::ostream& os, const array& x) { + os << x.item(); +} + +template +void print_complex_constant(std::ostream& os, const array& x) { + auto old_precision = os.precision(); + T constant = x.item(); + + os << get_type_string(x.dtype()) << "(" + << std::setprecision(std::numeric_limits::digits10 + 1) + << constant.real() << ", " << constant.imag() << ")" + << std::setprecision(old_precision); +} + +void print_constant(std::ostream& os, const array& x); + +} // namespace mlx::core diff --git a/mlx/backend/common/compiled_preamble.h b/mlx/backend/common/compiled_preamble.h new file mode 100644 index 000000000..8ccaa8bd7 --- /dev/null +++ b/mlx/backend/common/compiled_preamble.h @@ -0,0 +1,1121 @@ +// Copyright © 2023-2024 Apple Inc. + +const std::string preamble = R"( +#include +#include +#include + +#ifdef __ARM_FEATURE_FP16_SCALAR_ARITHMETIC + +#include +typedef __fp16 float16_t; + +#else + +#define ADD_HALF_BINOPS +#include +#include +#include +#include + +#define __MLX_HALF_NAN__ 0x7D00 + + +namespace { +union float_bits_fp16 { + float f; + uint32_t u; +}; +} // namespace + +struct _MLX_Float16 { + uint16_t bits_; + + // Default constructor + _MLX_Float16() = default; + + // Default copy constructor + _MLX_Float16(_MLX_Float16 const&) = default; + + // Appease std::vector for being special + _MLX_Float16& operator=(std::vector::reference x) { + bits_ = x; + return *this; + } + + _MLX_Float16& operator=(const float& x) { + return (*this = _MLX_Float16(x)); + } + + // From float32 + _MLX_Float16(const float& x) : bits_(0) { + // Conversion following + // https://github.com/Maratyszcza/FP16/blob/master/include/fp16/fp16.h + + // Union + float_bits_fp16 in; + + // Take fp32 bits + in.f = x; + + // Find and take sign bit + uint32_t x_sign_32 = in.u & uint32_t(0x80000000); + uint16_t x_sign_16 = (x_sign_32 >> 16); + + if (std::isnan(x)) { + bits_ = x_sign_16 | uint16_t(__MLX_HALF_NAN__); + } else { + // Union + float_bits_fp16 inf_scale, zero_scale, magic_bits; + + // Find exponent bits and take the max supported by half + uint32_t x_expo_32 = in.u & uint32_t(0x7f800000); + uint32_t max_expo_32 = uint32_t(0x38800000); + x_expo_32 = x_expo_32 < max_expo_32 ? max_expo_32 : x_expo_32; + x_expo_32 += uint32_t(15) << 23; + + // Handle scaling to inf as needed + inf_scale.u = uint32_t(0x77800000); + zero_scale.u = uint32_t(0x08800000); + + // Combine with magic and let addition do rounding + magic_bits.u = x_expo_32; + magic_bits.f += (std::abs(x) * inf_scale.f) * zero_scale.f; + + // Take the lower 5 bits of the exponent + uint32_t x_expo_16 = ((magic_bits.u >> 13) & uint32_t(0x7c00)); + + // Collect the lower 12 bits which have the mantissa + uint32_t x_mant_16 = magic_bits.u & uint32_t(0x0fff); + + // Combine sign, exp and mantissa + bits_ = (x_sign_16 | uint16_t(x_expo_16 + x_mant_16)); + } + } + + // To float32 + operator float() const { + // Conversion following + // https://github.com/Maratyszcza/FP16/blob/master/include/fp16/fp16.h + + // Union + float_bits_fp16 out; + + uint32_t x_sign_32 = (bits_ << 16) & uint32_t(0x80000000); + uint32_t base = (bits_ << 16); + uint32_t two_base = base + base; + + uint32_t denorm_max = 1u << 27; + if (two_base < denorm_max) { + out.u = uint32_t(126) << 23; // magic mask + out.u |= (two_base >> 17); // Bits from fp16 + out.f -= 0.5f; // magic bias + } else { + out.u = uint32_t(0xE0) << 23; // exponent offset + out.u += (two_base >> 4); // Bits from fp16 + float out_unscaled = out.f; // Store value + out.u = uint32_t(0x7800000); // exponent scale + out.f *= out_unscaled; + } + + // Add sign + out.u |= x_sign_32; + + return out.f; + } +}; + +#define half_binop_base(__op__, __operator__, otype, atype, btype, ctype) \ + inline otype __operator__(atype lhs, btype rhs) { \ + return static_cast(lhs) __op__ static_cast(rhs); \ + } + +#define half_binop_helper(__op__, __operator__, otype, itype, ctype) \ + inline otype __operator__(_MLX_Float16 lhs, itype rhs) { \ + return static_cast(lhs) __op__ static_cast(rhs); \ + } \ + inline otype __operator__(itype lhs, _MLX_Float16 rhs) { \ + return static_cast(lhs) __op__ static_cast(rhs); \ + } + +// Operators +#define half_binop(__op__, __operator__) \ + half_binop_base( \ + __op__, __operator__, _MLX_Float16, _MLX_Float16, _MLX_Float16, float); \ + half_binop_helper(__op__, __operator__, float, float, float); \ + half_binop_helper(__op__, __operator__, double, double, double); \ + half_binop_helper(__op__, __operator__, _MLX_Float16, bool, float); \ + half_binop_helper(__op__, __operator__, _MLX_Float16, int32_t, float); \ + half_binop_helper(__op__, __operator__, _MLX_Float16, uint32_t, float); \ + half_binop_helper(__op__, __operator__, _MLX_Float16, int64_t, float); \ + half_binop_helper(__op__, __operator__, _MLX_Float16, uint64_t, float); + +half_binop(+, operator+); +half_binop(-, operator-); +half_binop(*, operator*); +half_binop(/, operator/); + +#undef half_binop + +// Comparison ops +#define half_compop(__op__, __operator__) \ + half_binop_base( \ + __op__, __operator__, bool, _MLX_Float16, _MLX_Float16, float); \ + half_binop_helper(__op__, __operator__, bool, float, float); \ + half_binop_helper(__op__, __operator__, bool, double, double); \ + half_binop_helper(__op__, __operator__, bool, int32_t, float); \ + half_binop_helper(__op__, __operator__, bool, uint32_t, float); \ + half_binop_helper(__op__, __operator__, bool, int64_t, float); \ + half_binop_helper(__op__, __operator__, bool, uint64_t, float); + +half_compop(>, operator>); +half_compop(<, operator<); +half_compop(>=, operator>=); +half_compop(<=, operator<=); +half_compop(==, operator==); +half_compop(!=, operator!=); + +#undef half_compop + +// Negative +inline _MLX_Float16 operator-(_MLX_Float16 lhs) { + return -static_cast(lhs); +} + +// Inplace ops +#define half_inplace_op(__op__, __operator__) \ + inline _MLX_Float16& __operator__(_MLX_Float16& lhs, const float& rhs) { \ + lhs = lhs __op__ rhs; \ + return lhs; \ + } \ + inline float& __operator__(float& lhs, _MLX_Float16 rhs) { \ + lhs = lhs __op__ rhs; \ + return lhs; \ + } + +half_inplace_op(+, operator+=); +half_inplace_op(-, operator-=); +half_inplace_op(*, operator*=); +half_inplace_op(/, operator/=); + +#undef half_inplace_op + +// Bitwise ops + +#define half_bitop(__op__, __operator__) \ + inline _MLX_Float16 __operator__(_MLX_Float16 lhs, _MLX_Float16 rhs) { \ + _MLX_Float16 out; \ + out.bits_ = lhs.bits_ __op__ rhs.bits_; \ + return out; \ + } \ + inline _MLX_Float16 __operator__(_MLX_Float16 lhs, uint16_t rhs) { \ + _MLX_Float16 out; \ + out.bits_ = lhs.bits_ __op__ rhs; \ + return out; \ + } \ + inline _MLX_Float16 __operator__(uint16_t lhs, _MLX_Float16 rhs) { \ + _MLX_Float16 out; \ + out.bits_ = lhs __op__ rhs.bits_; \ + return out; \ + } + +half_bitop(|, operator|); +half_bitop(&, operator&); +half_bitop(^, operator^); + +#undef half_bitop + +#define half_inplace_bitop(__op__, __operator__) \ + inline _MLX_Float16& __operator__(_MLX_Float16& lhs, _MLX_Float16 rhs) { \ + lhs.bits_ = lhs.bits_ __op__ rhs.bits_; \ + return lhs; \ + } \ + inline _MLX_Float16& __operator__(_MLX_Float16& lhs, uint16_t rhs) { \ + lhs.bits_ = lhs.bits_ __op__ rhs; \ + return lhs; \ + } + +half_inplace_bitop(|, operator|=); +half_inplace_bitop(&, operator&=); +half_inplace_bitop(^, operator^=); + +#undef half_inplace_bitop + +typedef struct _MLX_Float16 float16_t; + +#endif // __ARM_FEATURE_FP16_SCALAR_ARITHMETIC +#ifdef __ARM_FEATURE_BF16 + +#include +typedef __bf16 bfloat16_t; + +#else + +#define ADD_HALF_BINOPS +#include +#include +#include +#include + +#define __MLX_BFLOAT_NAN__ 0x7FC0 + + +namespace { +union float_bits_bf16 { + float f; + uint32_t u; +}; +} // namespace + +struct _MLX_BFloat16 { + uint16_t bits_; + + // Default constructor + _MLX_BFloat16() = default; + + // Default copy constructor + _MLX_BFloat16(_MLX_BFloat16 const&) = default; + + // Appease std::vector for being special + _MLX_BFloat16& operator=(std::vector::reference x) { + bits_ = x; + return *this; + } + + _MLX_BFloat16& operator=(const float& x) { + return (*this = _MLX_BFloat16(x)); + } + + // From float32 + _MLX_BFloat16(const float& x) { + if (std::isnan(x)) { + bits_ = __MLX_BFLOAT_NAN__; + } else { + // Union + float_bits_bf16 in; + + // Take bits + in.f = x; + + // Round to nearest even + in.u += (in.u >> 16 & 1) + uint32_t(0x7FFF); + + // Take upper 16 bits + bits_ = in.u >> 16; + } + } + + // To float32 + operator float() const { + // Union + float_bits_bf16 out; + + // Upper 16 bits are the data and lower 16 bits are 0s + out.u = ((uint32_t)bits_) << 16; + + return out.f; + } +}; + +#define bfloat_binop_base(__op__, __operator__, otype, atype, btype, ctype) \ + inline otype __operator__(atype lhs, btype rhs) { \ + return static_cast(lhs) __op__ static_cast(rhs); \ + } + +#define bfloat_binop_helper(__op__, __operator__, otype, itype, ctype) \ + inline otype __operator__(_MLX_BFloat16 lhs, itype rhs) { \ + return static_cast(lhs) __op__ static_cast(rhs); \ + } \ + inline otype __operator__(itype lhs, _MLX_BFloat16 rhs) { \ + return static_cast(lhs) __op__ static_cast(rhs); \ + } + +// Operators +#define bfloat_binop(_op_, _operator_) \ + bfloat_binop_base( \ + _op_, _operator_, _MLX_BFloat16, _MLX_BFloat16, _MLX_BFloat16, float); \ + bfloat_binop_helper(_op_, _operator_, float, float, float); \ + bfloat_binop_helper(_op_, _operator_, double, double, double); \ + bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, bool, float); \ + bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int32_t, float); \ + bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint32_t, float); \ + bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int64_t, float); \ + bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint64_t, float); + +bfloat_binop(+, operator+); +bfloat_binop(-, operator-); +bfloat_binop(*, operator*); +bfloat_binop(/, operator/); + +#undef bfloat_binop + +// Comparison ops +#define bfloat_compop(__op__, __operator__) \ + bfloat_binop_base( \ + __op__, __operator__, bool, _MLX_BFloat16, _MLX_BFloat16, float); \ + bfloat_binop_helper(__op__, __operator__, bool, float, float); \ + bfloat_binop_helper(__op__, __operator__, bool, double, double); \ + bfloat_binop_helper(__op__, __operator__, bool, int32_t, float); \ + bfloat_binop_helper(__op__, __operator__, bool, uint32_t, float); \ + bfloat_binop_helper(__op__, __operator__, bool, int64_t, float); \ + bfloat_binop_helper(__op__, __operator__, bool, uint64_t, float); + +bfloat_compop(>, operator>); +bfloat_compop(<, operator<); +bfloat_compop(>=, operator>=); +bfloat_compop(<=, operator<=); +bfloat_compop(==, operator==); +bfloat_compop(!=, operator!=); + +#undef bfloat_compop + +// Negative +inline _MLX_BFloat16 operator-(_MLX_BFloat16 lhs) { + return -static_cast(lhs); +} + +// Inplace ops +#define bfloat_inplace_op(__op__, __operator__) \ + inline _MLX_BFloat16& __operator__(_MLX_BFloat16& lhs, const float& rhs) { \ + lhs = lhs __op__ rhs; \ + return lhs; \ + } \ + inline float& __operator__(float& lhs, _MLX_BFloat16 rhs) { \ + lhs = lhs __op__ rhs; \ + return lhs; \ + } + +bfloat_inplace_op(+, operator+=); +bfloat_inplace_op(-, operator-=); +bfloat_inplace_op(*, operator*=); +bfloat_inplace_op(/, operator/=); + +#undef bfloat_inplace_op + +// Bitwise ops + +#define bfloat_bitop(__op__, __operator__) \ + inline _MLX_BFloat16 __operator__(_MLX_BFloat16 lhs, _MLX_BFloat16 rhs) { \ + _MLX_BFloat16 out; \ + out.bits_ = lhs.bits_ __op__ rhs.bits_; \ + return out; \ + } \ + inline _MLX_BFloat16 __operator__(_MLX_BFloat16 lhs, uint16_t rhs) { \ + _MLX_BFloat16 out; \ + out.bits_ = lhs.bits_ __op__ rhs; \ + return out; \ + } \ + inline _MLX_BFloat16 __operator__(uint16_t lhs, _MLX_BFloat16 rhs) { \ + _MLX_BFloat16 out; \ + out.bits_ = lhs __op__ rhs.bits_; \ + return out; \ + } + +bfloat_bitop(|, operator|); +bfloat_bitop(&, operator&); +bfloat_bitop(^, operator^); + +#undef bfloat_bitop + +#define bfloat_inplace_bitop(__op__, __operator__) \ + inline _MLX_BFloat16& __operator__(_MLX_BFloat16& lhs, _MLX_BFloat16 rhs) { \ + lhs.bits_ = lhs.bits_ __op__ rhs.bits_; \ + return lhs; \ + } \ + inline _MLX_BFloat16& __operator__(_MLX_BFloat16& lhs, uint16_t rhs) { \ + lhs.bits_ = lhs.bits_ __op__ rhs; \ + return lhs; \ + } + +bfloat_inplace_bitop(|, operator|=); +bfloat_inplace_bitop(&, operator&=); +bfloat_inplace_bitop(^, operator^=); + +#undef bfloat_inplace_bitop + +typedef struct _MLX_BFloat16 bfloat16_t; + +#endif // __ARM_FEATURE_BF16 + +#ifdef ADD_HALF_BINOPS + +// clang-format off +#define fp16_bf16_binop_helper(__op__, __operator__) \ + inline float __operator__(float16_t lhs, bfloat16_t rhs) { \ + return static_cast(lhs) __op__ static_cast(rhs); \ + } \ + inline float __operator__(bfloat16_t lhs, float16_t rhs) { \ + return static_cast(lhs) __op__ static_cast(rhs); \ + } + +fp16_bf16_binop_helper(+, operator+) +fp16_bf16_binop_helper(-, operator-) +fp16_bf16_binop_helper(*, operator*) +fp16_bf16_binop_helper(/, operator/) +// clang-format on + +#endif + + +struct complex64_t; + +template +static constexpr bool can_convert_to_complex64 = + !std::is_same_v && std::is_convertible_v; + +struct complex64_t : public std::complex { + complex64_t(float v, float u) : std::complex(v, u){}; + complex64_t(std::complex v) : std::complex(v){}; + + template < + typename T, + typename = typename std::enable_if>::type> + complex64_t(T x) : std::complex(x){}; + + operator float() const { + return real(); + }; +}; + +inline bool operator>=(const complex64_t& a, const complex64_t& b) { + return (a.real() > b.real()) || + (a.real() == b.real() && a.imag() >= b.imag()); +} + +inline bool operator>(const complex64_t& a, const complex64_t& b) { + return (a.real() > b.real()) || (a.real() == b.real() && a.imag() > b.imag()); +} + +inline complex64_t operator%(complex64_t a, complex64_t b) { + auto real = a.real() - (b.real() * static_cast(a.real() / b.real())); + auto imag = a.imag() - (b.imag() * static_cast(a.imag() / b.imag())); + if (real != 0 && (real < 0 != b.real() < 0)) + real += b.real(); + if (imag != 0 && (imag < 0 != b.imag() < 0)) + imag += b.imag(); + return {real, imag}; +} + +inline bool operator<=(const complex64_t& a, const complex64_t& b) { + return operator>=(b, a); +} + +inline bool operator<(const complex64_t& a, const complex64_t& b) { + return operator>(b, a); +} + +inline complex64_t operator-(const complex64_t& v) { + return -static_cast>(v); +} + +// clang-format off +#define complex_binop_helper(_op_, _operator_, itype) \ + inline complex64_t _operator_(itype x, const complex64_t& y) { \ + return static_cast(x) _op_ y; \ + } \ + inline complex64_t _operator_(const complex64_t& x, itype y) { \ + return x _op_ static_cast(y); \ + } + +#define complex_binop(_op_, _operator_) \ + inline complex64_t _operator_(const std::complex& x, const complex64_t& y) { \ + return x _op_ static_cast>(y); \ + } \ + inline complex64_t _operator_(const complex64_t& x, const std::complex& y) { \ + return static_cast>(x) _op_ y; \ + } \ + inline complex64_t _operator_(const complex64_t& x, const complex64_t& y) { \ + return static_cast>(x) \ + _op_ static_cast>(y); \ + } \ + complex_binop_helper(_op_, _operator_, bool) \ + complex_binop_helper(_op_, _operator_, uint32_t) \ + complex_binop_helper(_op_, _operator_, uint64_t) \ + complex_binop_helper(_op_, _operator_, int32_t) \ + complex_binop_helper(_op_, _operator_, int64_t) \ + complex_binop_helper(_op_, _operator_, float16_t) \ + complex_binop_helper(_op_, _operator_, bfloat16_t) \ + complex_binop_helper(_op_, _operator_, float) +// clang-format on + +complex_binop(+, operator+) + +typedef union { + int i; + float f; +} IntOrFloat; + +inline float fast_exp(float x) { + x *= 1.442695; // multiply with log_2(e) + float ipart, fpart; + IntOrFloat epart; + x = std::max(-80.f, std::min(x, 80.f)); + ipart = std::floor(x + 0.5); + fpart = x - ipart; + + x = 1.535336188319500e-4f; + x = x * fpart + 1.339887440266574e-3f; + x = x * fpart + 9.618437357674640e-3f; + x = x * fpart + 5.550332471162809e-2f; + x = x * fpart + 2.402264791363012e-1f; + x = x * fpart + 6.931472028550421e-1f; + x = x * fpart + 1.000000000000000f; + + // generate 2**ipart in the floating point representation using integer + // bitshifting + epart.i = (int(ipart) + 127) << 23; + + return epart.f * x; +} + +float fast_erf(float a) { + float r, s, t, u; + t = std::abs(a); + s = a * a; + if (t > 0.927734375f) { + // maximum error 0.99527 ulp + r = std::fma( + -1.72853470e-5f, t, 3.83197126e-4f); // -0x1.220000p-16,0x1.91cfb2p-12 + u = std::fma( + -3.88396438e-3f, t, 2.42546219e-2f); // -0x1.fd1438p-9, 0x1.8d6342p-6 + r = std::fma(r, s, u); + r = std::fma(r, t, -1.06777877e-1f); // -0x1.b55cb8p-4 + r = std::fma(r, t, -6.34846687e-1f); // -0x1.450aa0p-1 + r = std::fma(r, t, -1.28717512e-1f); // -0x1.079d0cp-3 + r = std::fma(r, t, -t); + // TODO, replace with expm1 when implemented + r = 1.0f - std::exp(r); + r = std::copysign(r, a); + } else { + // maximum error 0.98929 ulp + r = -5.96761703e-4f; // -0x1.38e000p-11 + r = std::fma(r, s, 4.99119423e-3f); // 0x1.471a58p-8 + r = std::fma(r, s, -2.67681349e-2f); // -0x1.b691b2p-6 + r = std::fma(r, s, 1.12819925e-1f); // 0x1.ce1c44p-4 + r = std::fma(r, s, -3.76125336e-1f); // -0x1.812700p-2 + r = std::fma(r, s, 1.28379166e-1f); // 0x1.06eba8p-3 + r = std::fma(r, a, a); + } + return r; +} + +float fast_erfinv(float a) { + auto t = std::fma(a, 0.0f - a, 1.0f); + t = std::log(t); + float p; + if (std::abs(t) > 6.125f) { // maximum ulp error = 2.35793 + p = 3.03697567e-10f; // 0x1.4deb44p-32 + p = std::fma(p, t, 2.93243101e-8f); // 0x1.f7c9aep-26 + p = std::fma(p, t, 1.22150334e-6f); // 0x1.47e512p-20 + p = std::fma(p, t, 2.84108955e-5f); // 0x1.dca7dep-16 + p = std::fma(p, t, 3.93552968e-4f); // 0x1.9cab92p-12 + p = std::fma(p, t, 3.02698812e-3f); // 0x1.8cc0dep-9 + p = std::fma(p, t, 4.83185798e-3f); // 0x1.3ca920p-8 + p = std::fma(p, t, -2.64646143e-1f); // -0x1.0eff66p-2 + p = std::fma(p, t, 8.40016484e-1f); // 0x1.ae16a4p-1 + } else { // maximum ulp error = 2.35002 + p = 5.43877832e-9f; // 0x1.75c000p-28 + p = std::fma(p, t, 1.43285448e-7f); // 0x1.33b402p-23 + p = std::fma(p, t, 1.22774793e-6f); // 0x1.499232p-20 + p = std::fma(p, t, 1.12963626e-7f); // 0x1.e52cd2p-24 + p = std::fma(p, t, -5.61530760e-5f); // -0x1.d70bd0p-15 + p = std::fma(p, t, -1.47697632e-4f); // -0x1.35be90p-13 + p = std::fma(p, t, 2.31468678e-3f); // 0x1.2f6400p-9 + p = std::fma(p, t, 1.15392581e-2f); // 0x1.7a1e50p-7 + p = std::fma(p, t, -2.32015476e-1f); // -0x1.db2aeep-3 + p = std::fma(p, t, 8.86226892e-1f); // 0x1.c5bf88p-1 + } + return a * p; +} + +struct Abs { + template + T operator()(T x) { + return std::abs(x); + }; + uint8_t operator()(uint8_t x) { + return x; + }; + uint16_t operator()(uint16_t x) { + return x; + }; + uint32_t operator()(uint32_t x) { + return x; + }; + uint64_t operator()(uint64_t x) { + return x; + }; + bool operator()(bool x) { + return x; + }; +}; + +struct ArcCos { + template + T operator()(T x) { + return std::acos(x); + }; +}; + +struct ArcCosh { + template + T operator()(T x) { + return std::acosh(x); + }; +}; + +struct ArcSin { + template + T operator()(T x) { + return std::asin(x); + }; +}; + +struct ArcSinh { + template + T operator()(T x) { + return std::asinh(x); + }; +}; + +struct ArcTan { + template + T operator()(T x) { + return std::atan(x); + }; +}; + +struct ArcTanh { + template + T operator()(T x) { + return std::atanh(x); + }; +}; + +struct Ceil { + template + T operator()(T x) { + return std::ceil(x); + }; + int8_t operator()(int8_t x) { + return x; + }; + int16_t operator()(int16_t x) { + return x; + }; + int32_t operator()(int32_t x) { + return x; + }; + int64_t operator()(int64_t x) { + return x; + }; + uint8_t operator()(uint8_t x) { + return x; + }; + uint16_t operator()(uint16_t x) { + return x; + }; + uint32_t operator()(uint32_t x) { + return x; + }; + uint64_t operator()(uint64_t x) { + return x; + }; + bool operator()(bool x) { + return x; + }; +}; + +struct Cos { + template + T operator()(T x) { + return std::cos(x); + }; +}; + +struct Cosh { + template + T operator()(T x) { + return std::cosh(x); + }; +}; + +struct Erf { + template + T operator()(T x) { + return static_cast(fast_erf(static_cast(x))); + }; +}; + +struct ErfInv { + template + T operator()(T x) { + return static_cast(fast_erfinv(static_cast(x))); + }; +}; + +struct Exp { + template + T operator()(T x) { + return fast_exp(x); + }; +}; + +struct Floor { + template + T operator()(T x) { + return std::floor(x); + }; + int8_t operator()(int8_t x) { + return x; + }; + int16_t operator()(int16_t x) { + return x; + }; + int32_t operator()(int32_t x) { + return x; + }; + int64_t operator()(int64_t x) { + return x; + }; + uint8_t operator()(uint8_t x) { + return x; + }; + uint16_t operator()(uint16_t x) { + return x; + }; + uint32_t operator()(uint32_t x) { + return x; + }; + uint64_t operator()(uint64_t x) { + return x; + }; + bool operator()(bool x) { + return x; + }; +}; + +struct Log { + template + T operator()(T x) { + return std::log(x); + }; +}; + +struct Log2 { + template + T operator()(T x) { + return std::log2(x); + }; +}; + +struct Log10 { + template + T operator()(T x) { + return std::log10(x); + }; +}; + +struct Log1p { + template + T operator()(T x) { + return log1p(x); + }; +}; + +struct LogicalNot { + template + T operator()(T x) { + return !x; + }; +}; + +struct Negative { + template + T operator()(T x) { + return -x; + }; +}; + +struct Round { + template + T operator()(T x) { + return std::rint(x); + } + + std::complex operator()(std::complex x) { + return {std::rint(x.real()), std::rint(x.imag())}; + } +}; + +struct Sigmoid { + template + T operator()(T x) { + auto one = static_cast(1.0); + return one / (one + fast_exp(-x)); + } +}; + +struct Sign { + template + T operator()(T x) { + return (x > T(0)) - (x < T(0)); + } + uint8_t operator()(uint8_t x) { + return x != 0; + } + uint16_t operator()(uint16_t x) { + return x != 0; + } + uint32_t operator()(uint32_t x) { + return x != 0; + } + uint64_t operator()(uint64_t x) { + return x != 0; + } +}; + +struct Sin { + template + T operator()(T x) { + return std::sin(x); + }; +}; + +struct Sinh { + template + T operator()(T x) { + return std::sinh(x); + }; +}; + +struct Square { + template + T operator()(T x) { + return x * x; + }; +}; + +struct Sqrt { + template + T operator()(T x) { + return std::sqrt(x); + }; +}; + +struct Rsqrt { + template + T operator()(T x) { + return static_cast(1.0) / std::sqrt(x); + }; +}; + +struct Tan { + template + T operator()(T x) { + return std::tan(x); + }; +}; + +struct Tanh { + template + T operator()(T x) { + return std::tanh(x); + }; +}; + +struct Add { + template + T operator()(T x, T y) { + return x + y; + } +}; + +struct Divide { + template + T operator()(T x, T y) { + return x / y; + } +}; + +struct Remainder { + template + std::enable_if_t & !std::is_signed_v, T> operator()( + T numerator, + T denominator) { + return numerator % denominator; + } + + template + std::enable_if_t & std::is_signed_v, T> operator()( + T numerator, + T denominator) { + auto r = numerator % denominator; + if (r != 0 && (r < 0 != denominator < 0)) + r += denominator; + return r; + } + + template + std::enable_if_t, T> operator()( + T numerator, + T denominator) { + auto r = std::fmod(numerator, denominator); + if (r != 0 && (r < 0 != denominator < 0)) { + r += denominator; + } + return r; + } + + std::complex operator()( + std::complex a, std::complex b) { + auto real = a.real() - (b.real() * static_cast(a.real() / b.real())); + auto imag = a.imag() - (b.imag() * static_cast(a.imag() / b.imag())); + if (real != 0 && ((real < 0) != (b.real() < 0))) + real += b.real(); + if (imag != 0 && ((imag < 0) != (b.imag() < 0))) + imag += b.imag(); + return {real, imag}; + } +}; + +struct Equal { + template + bool operator()(T x, T y) { + return x == y; + } +}; + +struct NaNEqual { + template + bool operator()(T x, T y) { + return x == y || (std::isnan(x) && std::isnan(y)); + } +}; + +struct Greater { + template + bool operator()(T x, T y) { + return x > y; + } +}; + +struct GreaterEqual { + template + bool operator()(T x, T y) { + return x >= y; + } +}; + +struct Less { + template + bool operator()(T x, T y) { + return x < y; + } +}; + +struct LessEqual { + template + bool operator()(T x, T y) { + return x <= y; + } +}; + +struct LogAddExp { + template + T operator()(T x, T y) { + constexpr float inf = std::numeric_limits::infinity(); + auto maxval = (x > y) ? x : y; + auto minval = (x > y) ? y : x; + return (minval == -inf || maxval == inf) + ? maxval + : static_cast( + maxval + std::log1p(fast_exp(minval - maxval))); + }; +}; + +struct Maximum { + template + std::enable_if_t, T> operator()(T x, T y) { + return (x > y) ? x : y; + } + + template + std::enable_if_t, T> operator()(T x, T y) { + if (std::isnan(x)) { + return x; + } + return (x > y) ? x : y; + } +}; + +struct Minimum { + template + std::enable_if_t, T> operator()(T x, T y) { + return x < y ? x : y; + } + + template + std::enable_if_t, T> operator()(T x, T y) { + if (std::isnan(x)) { + return x; + } + return x < y ? x : y; + } +}; + +struct Multiply { + template + T operator()(T x, T y) { + return x * y; + } +}; + +struct NotEqual { + template + bool operator()(T x, T y) { + return x != y; + } +}; + +struct Power { + template + std::enable_if_t, T> operator()(T base, T exp) { + return std::pow(base, exp); + } + + template + std::enable_if_t, T> operator()(T base, T exp) { + T res = 1; + while (exp) { + if (exp & 1) { + res *= base; + } + exp >>= 1; + base *= base; + } + return res; + } +}; + +struct Subtract { + template + T operator()(T x, T y) { + return x - y; + } +}; + +struct LogicalAnd { + template + T operator()(T x, T y) { + return x && y; + }; +}; + +struct LogicalOr { + template + T operator()(T x, T y) { + return x || y; + }; +}; +)"; diff --git a/mlx/backend/common/default_primitives.cpp b/mlx/backend/common/default_primitives.cpp index 6befc8eb9..c65028d95 100644 --- a/mlx/backend/common/default_primitives.cpp +++ b/mlx/backend/common/default_primitives.cpp @@ -43,7 +43,6 @@ DEFAULT(AsStrided) DEFAULT(Broadcast) DEFAULT_MULTI(DivMod) DEFAULT(Ceil) -DEFAULT_MULTI(Compiled) DEFAULT(Concatenate) DEFAULT(Convolution) DEFAULT(Copy) diff --git a/mlx/backend/metal/compiled.cpp b/mlx/backend/metal/compiled.cpp index 1f27a2493..681d635ba 100644 --- a/mlx/backend/metal/compiled.cpp +++ b/mlx/backend/metal/compiled.cpp @@ -2,6 +2,7 @@ #include +#include "mlx/backend/common/compiled.h" #include "mlx/backend/metal/compiled_preamble.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/utils.h" @@ -11,125 +12,6 @@ namespace mlx::core { -inline bool is_static_cast(const Primitive& p) { - return ( - typeid(p) == typeid(Broadcast) || typeid(p) == typeid(Copy) || - typeid(p) == typeid(StopGradient) || typeid(p) == typeid(AsType)); -} - -inline auto get_type_string(Dtype d) { - switch (d) { - case float32: - return "float"; - case float16: - return "half"; - case bfloat16: - return "bfloat16_t"; - case bool_: - return "bool"; - case int8: - return "int8_t"; - case int16: - return "int16_t"; - case int32: - return "int32_t"; - case int64: - return "int64_t"; - case uint8: - return "uint8_t"; - case uint16: - return "uint16_t"; - case uint32: - return "uint32_t"; - case uint64: - return "uint64_t"; - default: { - std::ostringstream msg; - msg << "Unsupported compilation type " << d; - throw std::runtime_error(msg.str()); - } - } -} - -template -void print_float_constant(std::ostream& os, const array& x) { - auto old_precision = os.precision(); - os << std::setprecision(std::numeric_limits::digits10 + 1) - << x.item() << std::setprecision(old_precision); -} - -template -void print_int_constant(std::ostream& os, const array& x) { - os << x.item(); -} - -void print_constant(std::ostream& os, const array& x) { - switch (x.dtype()) { - case float32: - return print_float_constant(os, x); - case float16: - return print_float_constant(os, x); - case bfloat16: - return print_float_constant(os, x); - case int8: - return print_int_constant(os, x); - case int16: - return print_int_constant(os, x); - case int32: - return print_int_constant(os, x); - case int64: - return print_int_constant(os, x); - case uint8: - return print_int_constant(os, x); - case uint16: - return print_int_constant(os, x); - case uint32: - return print_int_constant(os, x); - case uint64: - return print_int_constant(os, x); - case bool_: - os << std::boolalpha << x.item(); - return; - default: - throw std::runtime_error("Unsupported constant type"); - } -} - -inline std::string build_lib_name( - const std::vector& inputs, - const std::vector& outputs, - const std::vector& tape, - const std::unordered_set& constant_ids) { - std::ostringstream os; - std::ostringstream constant_hasher; - - // The primitives describing the tape. For unary and binary primitives this - // must be enough to describe the full computation. - for (auto& a : tape) { - a.primitive().print(os); - } - os << "_"; - - for (auto& x : inputs) { - if (constant_ids.find(x.id()) != constant_ids.end()) { - os << "C"; - print_constant(constant_hasher, x); - } else { - os << ((x.size() == 1) ? "S" : "V"); - } - } - os << "_"; - for (auto& x : inputs) { - if (constant_ids.find(x.id()) != constant_ids.end()) { - continue; - } - os << kindof(x.dtype()) << x.itemsize(); - } - os << "_" << std::hash{}(constant_hasher.str()); - - return os.str(); -} - inline void build_kernel( std::ostream& os, const std::string& kernel_name, @@ -286,7 +168,7 @@ inline void build_kernel( if (cnt > 31) { std::ostringstream msg; - msg << "[compile] Too many inputs/outputs fused in the Metal Compile " + msg << "[compile] Too many inputs/outputs fused in the Metal Compiled " << "primitive which exhausted the available argument buffers for " << "the kernel. Please file an issue with the function that results " << "in this error. The name of the kernel is '" << kernel_name << "'"; @@ -348,11 +230,6 @@ void Compiled::eval_gpu( lib = d.get_library(kernel_lib_, kernel_source_); } - // Allocate space for the outputs - for (auto& out : outputs) { - out.set_data(allocator::malloc_or_wait(out.nbytes())); - } - // Figure out which kernel we are using auto& output_shape = outputs[0].shape(); bool contiguous = true; @@ -443,6 +320,27 @@ void Compiled::eval_gpu( } } + // Allocate space for the outputs possibly with input donation + { + int o = 0; + for (int i = 0; i < inputs.size() && o < outputs.size(); ++i) { + auto& in = inputs[i]; + // Conditions for donation + // - Row contiguous + // - Donatable + // - Correct size + // - Not a constant + if (in.flags().row_contiguous && in.nbytes() == outputs[o].nbytes() && + in.is_donatable() && + constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) { + outputs[o++].move_shared_buffer(in); + } + } + for (; o < outputs.size(); ++o) { + outputs[o].set_data(allocator::malloc_or_wait(outputs[o].nbytes())); + } + } + // Put the outputs in for (auto& x : outputs) { set_array_buffer(compute_encoder, x, cnt++); diff --git a/mlx/backend/metal/kernels/compiled_preamble.h b/mlx/backend/metal/kernels/compiled_preamble.h index 82a9e9c5c..d5bf33696 100644 --- a/mlx/backend/metal/kernels/compiled_preamble.h +++ b/mlx/backend/metal/kernels/compiled_preamble.h @@ -2,3 +2,5 @@ #include "mlx/backend/metal/kernels/binary.h" #include "mlx/backend/metal/kernels/unary.h" + +typedef half float16_t; diff --git a/mlx/compile.cpp b/mlx/compile.cpp index e69c442f2..a648d191f 100644 --- a/mlx/compile.cpp +++ b/mlx/compile.cpp @@ -319,6 +319,9 @@ void compile_simplify( case 1: v = *a.data(); break; + case 2: + v = *a.data(); + break; case 4: v = *a.data(); break; diff --git a/mlx/fast_primitives.h b/mlx/fast_primitives.h index acb5d0046..2b854960b 100644 --- a/mlx/fast_primitives.h +++ b/mlx/fast_primitives.h @@ -3,8 +3,8 @@ namespace mlx::core::fast { // Custom primitive accepts a fallback function which it uses for -// transformations. Transformations are virtual so that derived classes may to -// override the default behavior +// transformations. Transformations are virtual so that derived classes may +// override the default behavior. class Custom : public Primitive { public: explicit Custom( diff --git a/mlx/primitives.h b/mlx/primitives.h index b06a35780..9d0a9181c 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -496,8 +496,6 @@ class Compiled : public Primitive { std::string kernel_lib_; std::string kernel_source_; - - void eval(const std::vector& inputs, std::vector& out); }; class Concatenate : public UnaryPrimitive { diff --git a/mlx/types/complex.h b/mlx/types/complex.h index 19ab1b542..46f4310f9 100644 --- a/mlx/types/complex.h +++ b/mlx/types/complex.h @@ -60,25 +60,30 @@ inline complex64_t operator-(const complex64_t& v) { // clang-format off #define complex_binop_helper(_op_, _operator_, itype) \ inline complex64_t _operator_(itype x, const complex64_t& y) { \ - return x _op_ static_cast>(y); \ + return static_cast(x) _op_ y; \ } \ inline complex64_t _operator_(const complex64_t& x, itype y) { \ - return static_cast>(x) _op_ y; \ + return x _op_ static_cast(y); \ } -#define complex_binop(_op_, _operator_) \ - inline complex64_t _operator_(const complex64_t& x, const complex64_t& y) { \ - return static_cast>(x) \ - _op_ static_cast>(y); \ - } \ - complex_binop_helper(_op_, _operator_, bool) \ - complex_binop_helper(_op_, _operator_, uint32_t) \ - complex_binop_helper(_op_, _operator_, uint64_t) \ - complex_binop_helper(_op_, _operator_, int32_t) \ - complex_binop_helper(_op_, _operator_, int64_t) \ - complex_binop_helper(_op_, _operator_, float16_t) \ - complex_binop_helper(_op_, _operator_, bfloat16_t) \ - complex_binop_helper(_op_, _operator_, const std::complex&) \ +#define complex_binop(_op_, _operator_) \ + inline complex64_t _operator_(const std::complex& x, const complex64_t& y) { \ + return x _op_ static_cast>(y); \ + } \ + inline complex64_t _operator_(const complex64_t& x, const std::complex& y) { \ + return static_cast>(x) _op_ y; \ + } \ + inline complex64_t _operator_(const complex64_t& x, const complex64_t& y) { \ + return static_cast>(x) \ + _op_ static_cast>(y); \ + } \ + complex_binop_helper(_op_, _operator_, bool) \ + complex_binop_helper(_op_, _operator_, uint32_t) \ + complex_binop_helper(_op_, _operator_, uint64_t) \ + complex_binop_helper(_op_, _operator_, int32_t) \ + complex_binop_helper(_op_, _operator_, int64_t) \ + complex_binop_helper(_op_, _operator_, float16_t) \ + complex_binop_helper(_op_, _operator_, bfloat16_t) \ complex_binop_helper(_op_, _operator_, float) // clang-format on diff --git a/mlx/utils.h b/mlx/utils.h index 88f47e3e1..ebcca3a1e 100644 --- a/mlx/utils.h +++ b/mlx/utils.h @@ -77,7 +77,7 @@ std::ostream& operator<<(std::ostream& os, array a); std::ostream& operator<<(std::ostream& os, const std::vector& v); std::ostream& operator<<(std::ostream& os, const std::vector& v); inline std::ostream& operator<<(std::ostream& os, const complex64_t& v) { - return os << v.real() << (v.imag() > 0 ? "+" : "") << v.imag() << "j"; + return os << v.real() << (v.imag() >= 0 ? "+" : "") << v.imag() << "j"; } inline std::ostream& operator<<(std::ostream& os, const float16_t& v) { return os << static_cast(v); diff --git a/tests/compile_tests.cpp b/tests/compile_tests.cpp index be460e3b6..8ad67a1ed 100644 --- a/tests/compile_tests.cpp +++ b/tests/compile_tests.cpp @@ -44,8 +44,8 @@ TEST_CASE("test compile with grad") { auto y = array(1.0f); auto grads_expected = grad_fun({x, y}); auto grads_compile = compile(grad_fun)({x, y}); - CHECK_EQ(grads_compile[0].item(), grads_expected[0].item()); - CHECK_EQ(grads_compile[1].item(), grads_expected[1].item()); + CHECK(allclose(grads_compile[0], grads_expected[0]).item()); + CHECK(allclose(grads_compile[1], grads_expected[1]).item()); } TEST_CASE("test compile inputs with primitive") { @@ -272,7 +272,7 @@ TEST_CASE("test compile unary fused") { CHECK_EQ(out.inputs()[0].id(), x.id()); auto expected_out = unary_fused_1({array(2.0)})[0]; - CHECK_EQ(out.item(), expected_out.item()); + CHECK(allclose(out, expected_out).item()); } { From 0925af43b041f01c3309f1b3c52bf040898b3189 Mon Sep 17 00:00:00 2001 From: Jack Mousseau Date: Sun, 18 Feb 2024 12:50:10 -0800 Subject: [PATCH 37/42] Remove unused variables (#706) --- mlx/backend/accelerate/quantized.cpp | 2 -- mlx/backend/common/compiled.cpp | 1 - mlx/backend/metal/conv.cpp | 1 - mlx/backend/metal/indexing.cpp | 4 ---- mlx/backend/metal/primitives.cpp | 1 - mlx/io/gguf_quants.cpp | 1 - mlx/primitives.cpp | 1 - 7 files changed, 11 deletions(-) diff --git a/mlx/backend/accelerate/quantized.cpp b/mlx/backend/accelerate/quantized.cpp index cc31c88a3..e9fec1303 100644 --- a/mlx/backend/accelerate/quantized.cpp +++ b/mlx/backend/accelerate/quantized.cpp @@ -24,8 +24,6 @@ void _qmm_t_4_64( constexpr int bitmask = (1 << bits) - 1; constexpr int pack_factor = 32 / bits; constexpr int packs_in_group = group_size / pack_factor; - const int Kg = K / group_size; - const int Kw = K / pack_factor; for (int m = 0; m < M; m++) { const uint32_t* w_local = w; diff --git a/mlx/backend/common/compiled.cpp b/mlx/backend/common/compiled.cpp index 8bf1f43e4..914b85ae3 100644 --- a/mlx/backend/common/compiled.cpp +++ b/mlx/backend/common/compiled.cpp @@ -410,7 +410,6 @@ void Compiled::eval_cpu( // Get the kernel name from the lib int ndim = shape.size(); - bool dynamic = ndim >= 8; auto kernel_name = kernel_lib_ + (contiguous ? "_contiguous" : "_strided_"); if (!contiguous) { kernel_name += std::to_string(shape.size()); diff --git a/mlx/backend/metal/conv.cpp b/mlx/backend/metal/conv.cpp index 9fbbffb33..4ade8da17 100644 --- a/mlx/backend/metal/conv.cpp +++ b/mlx/backend/metal/conv.cpp @@ -182,7 +182,6 @@ void implicit_gemm_conv_2D_gpu( int implicit_M = conv_params.N * conv_params.oS[0] * conv_params.oS[1]; int implicit_N = conv_params.O; - int implicit_K = conv_params.wS[0] * conv_params.wS[1] * conv_params.C; size_t grid_dim_x = (implicit_N + bn - 1) / bn; size_t grid_dim_y = (implicit_M + bm - 1) / bm; diff --git a/mlx/backend/metal/indexing.cpp b/mlx/backend/metal/indexing.cpp index 6908f8684..4eeb8858e 100644 --- a/mlx/backend/metal/indexing.cpp +++ b/mlx/backend/metal/indexing.cpp @@ -167,10 +167,6 @@ void Scatter::eval_gpu(const std::vector& inputs, array& out) { auto& upd = inputs.back(); size_t nthreads = upd.size(); - NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); - if (thread_group_size > nthreads) { - thread_group_size = nthreads; - } compute_encoder->setComputePipelineState(kernel); diff --git a/mlx/backend/metal/primitives.cpp b/mlx/backend/metal/primitives.cpp index 757888a66..38ec5993c 100644 --- a/mlx/backend/metal/primitives.cpp +++ b/mlx/backend/metal/primitives.cpp @@ -691,7 +691,6 @@ void RandomBits::eval_gpu(const std::vector& inputs, array& out) { // organize into grid nkeys x elem_per_key MTL::Size grid_dims = MTL::Size(num_keys, half_size + odd, 1); NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); - auto nthreads = std::min(num_keys * (half_size + odd), thread_group_size); MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); auto compute_encoder = d.get_command_encoder(s.index); compute_encoder->setComputePipelineState(kernel); diff --git a/mlx/io/gguf_quants.cpp b/mlx/io/gguf_quants.cpp index 636648bc7..b9fe1e3bf 100644 --- a/mlx/io/gguf_quants.cpp +++ b/mlx/io/gguf_quants.cpp @@ -114,7 +114,6 @@ void gguf_load_quantized( << "has incompatible last dim shape: " << shape[shape.size() - 1]; throw std::runtime_error(msg.str()); } - const uint64_t num_blocks = tensor.num_weights / weights_per_block; std::vector weights_shape = shape; weights_shape.back() /= (weights_per_byte * 4); diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index a7e1d205d..b78f8d405 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -628,7 +628,6 @@ std::vector Convolution::vjp( auto& wt = primals[1]; auto cotan = cotangents[0]; - int N = in.shape(0); int O = wt.shape(0); // Resolve Padded input shapes and strides From 1a4f4c5ea66d6be1f5568bba03170c1cd71f78d6 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 19 Feb 2024 06:12:53 -0800 Subject: [PATCH 38/42] Refactor CPU compile preamble (#708) * refactor cpu preamble * fix include order * fix some issues' * fixes for linux * try to fix includes * add back warning suppression * more linux fixes --- mlx/backend/accelerate/primitives.cpp | 5 +- mlx/backend/common/CMakeLists.txt | 31 + mlx/backend/common/binary.cpp | 131 +- mlx/backend/common/compiled.cpp | 10 +- mlx/backend/common/compiled_preamble.h | 1122 +----------------- mlx/backend/common/erf.h | 11 - mlx/backend/common/make_compiled_preamble.sh | 34 + mlx/backend/common/ops.h | 591 +++++++++ mlx/backend/common/primitives.cpp | 89 +- mlx/backend/common/unary.h | 53 - mlx/types/complex.h | 4 +- tests/ops_tests.cpp | 6 +- 12 files changed, 732 insertions(+), 1355 deletions(-) delete mode 100644 mlx/backend/common/erf.h create mode 100644 mlx/backend/common/make_compiled_preamble.sh create mode 100644 mlx/backend/common/ops.h diff --git a/mlx/backend/accelerate/primitives.cpp b/mlx/backend/accelerate/primitives.cpp index 4cccd35ae..e147b5888 100644 --- a/mlx/backend/accelerate/primitives.cpp +++ b/mlx/backend/accelerate/primitives.cpp @@ -81,11 +81,8 @@ void Abs::eval_cpu(const std::vector& inputs, array& out) { } else if (in.dtype() == int32 && in.flags().contiguous) { set_unary_output_data(in, out); vDSP_vabsi(in.data(), 1, out.data(), 1, in.data_size()); - } else if (is_unsigned(in.dtype())) { - // No-op for unsigned types - out.copy_shared_buffer(in); } else { - unary(in, out, AbsOp()); + eval(inputs, out); } } diff --git a/mlx/backend/common/CMakeLists.txt b/mlx/backend/common/CMakeLists.txt index b25001f2c..38a9819e5 100644 --- a/mlx/backend/common/CMakeLists.txt +++ b/mlx/backend/common/CMakeLists.txt @@ -1,3 +1,33 @@ + +if (${CMAKE_SYSTEM_NAME} MATCHES "Darwin") + set(CLANG TRUE) +endif() + +add_custom_command( + OUTPUT compiled_preamble.cpp + COMMAND /bin/bash + ${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.sh + ${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp + ${CMAKE_CXX_COMPILER} + ${CMAKE_SOURCE_DIR} + ${CLANG} + + DEPENDS make_compiled_preamble.sh + compiled_preamble.h + ${CMAKE_SOURCE_DIR}/mlx/types/half_types.h + ${CMAKE_SOURCE_DIR}/mlx/types/fp16.h + ${CMAKE_SOURCE_DIR}/mlx/types/bf16.h + ${CMAKE_SOURCE_DIR}/mlx/types/complex.h + ops.h +) + +add_custom_target( + cpu_compiled_preamble + DEPENDS compiled_preamble.cpp +) + +add_dependencies(mlx cpu_compiled_preamble) + target_sources( mlx PRIVATE @@ -19,4 +49,5 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/load.cpp ${CMAKE_CURRENT_SOURCE_DIR}/qrf.cpp + ${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp ) diff --git a/mlx/backend/common/binary.cpp b/mlx/backend/common/binary.cpp index 855e8467b..ec7097797 100644 --- a/mlx/backend/common/binary.cpp +++ b/mlx/backend/common/binary.cpp @@ -7,6 +7,7 @@ #include "mlx/allocator.h" #include "mlx/backend/common/binary.h" #include "mlx/backend/common/binary_two.h" +#include "mlx/backend/common/ops.h" #include "mlx/primitives.h" #include "mlx/utils.h" @@ -73,7 +74,7 @@ void Add::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 2); auto& a = inputs[0]; auto& b = inputs[1]; - binary(a, b, out, [](auto x, auto y) { return x + y; }); + binary(a, b, out, detail::Add()); } void DivMod::eval( @@ -135,106 +136,56 @@ void Divide::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 2); auto& a = inputs[0]; auto& b = inputs[1]; - binary(a, b, out, [](auto x, auto y) { return x / y; }); + binary(a, b, out, detail::Divide()); } -struct RemainderFn { - template - std::enable_if_t & !std::is_signed_v, T> operator()( - T numerator, - T denominator) { - return numerator % denominator; - } - - template - std::enable_if_t & std::is_signed_v, T> operator()( - T numerator, - T denominator) { - auto r = numerator % denominator; - if (r != 0 && (r < 0 != denominator < 0)) - r += denominator; - return r; - } - - template - std::enable_if_t, T> operator()( - T numerator, - T denominator) { - auto r = std::fmod(numerator, denominator); - if (r != 0 && (r < 0 != denominator < 0)) { - r += denominator; - } - return r; - } - - complex64_t operator()(complex64_t numerator, complex64_t denominator) { - return numerator % denominator; - } -}; - void Remainder::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 2); auto& a = inputs[0]; auto& b = inputs[1]; - binary(a, b, out, RemainderFn{}); + binary(a, b, out, detail::Remainder()); } void Equal::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 2); if (equal_nan_) { - comparison_op(inputs[0], inputs[1], out, [](auto x, auto y) { - return x == y || (std::isnan(x) && std::isnan(y)); - }); + comparison_op(inputs[0], inputs[1], out, detail::NaNEqual()); } else { - comparison_op( - inputs[0], inputs[1], out, [](auto x, auto y) { return x == y; }); + comparison_op(inputs[0], inputs[1], out, detail::Equal()); } } void Greater::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 2); - comparison_op( - inputs[0], inputs[1], out, [](auto x, auto y) { return x > y; }); + comparison_op(inputs[0], inputs[1], out, detail::Greater()); } void GreaterEqual::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 2); - comparison_op( - inputs[0], inputs[1], out, [](auto x, auto y) { return x >= y; }); + comparison_op(inputs[0], inputs[1], out, detail::GreaterEqual()); } void Less::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 2); - comparison_op( - inputs[0], inputs[1], out, [](auto x, auto y) { return x < y; }); + comparison_op(inputs[0], inputs[1], out, detail::Less()); } void LessEqual::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 2); - comparison_op( - inputs[0], inputs[1], out, [](auto x, auto y) { return x <= y; }); + comparison_op(inputs[0], inputs[1], out, detail::LessEqual()); } void LogAddExp::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 2); auto& a = inputs[0]; auto& b = inputs[1]; - auto op = [](auto x, auto y) { - constexpr float inf = std::numeric_limits::infinity(); - auto maxval = (x > y) ? x : y; - auto minval = (x > y) ? y : x; - return (minval == -inf || maxval == inf) - ? maxval - : static_cast( - maxval + std::log1p(std::exp(minval - maxval))); - }; if (is_floating_point(out.dtype())) { if (out.dtype() == float32) { - binary_op(a, b, out, op); + binary_op(a, b, out, detail::LogAddExp()); } else if (out.dtype() == float16) { - binary_op(a, b, out, op); + binary_op(a, b, out, detail::LogAddExp()); } else if (out.dtype() == bfloat16) { - binary_op(a, b, out, op); + binary_op(a, b, out, detail::LogAddExp()); } else { std::ostringstream err; err << "[logaddexp] Does not support " << out.dtype(); @@ -251,84 +202,40 @@ void Maximum::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 2); auto& a = inputs[0]; auto& b = inputs[1]; - - if (is_floating_point(out.dtype())) { - binary(a, b, out, [](auto x, auto y) { - if (std::isnan(x)) { - return x; - } - return (x > y) ? x : y; - }); - } else { - binary(a, b, out, [](auto x, auto y) { return (x > y) ? x : y; }); - } + binary(a, b, out, detail::Maximum()); } void Minimum::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 2); auto& a = inputs[0]; auto& b = inputs[1]; - if (is_floating_point(out.dtype())) { - binary(a, b, out, [](auto x, auto y) { - if (std::isnan(x)) { - return x; - } - return (x < y) ? x : y; - }); - } else { - binary(a, b, out, [](auto x, auto y) { return (x < y) ? x : y; }); - } + binary(a, b, out, detail::Minimum()); } void Multiply::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 2); auto& a = inputs[0]; auto& b = inputs[1]; - binary(a, b, out, [](auto x, auto y) { return x * y; }); + binary(a, b, out, detail::Multiply()); } void NotEqual::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 2); - comparison_op( - inputs[0], inputs[1], out, [](auto x, auto y) { return x != y; }); + comparison_op(inputs[0], inputs[1], out, detail::NotEqual()); } -struct PowerFn { - template - std::enable_if_t, T> operator()(T base, T exp) { - return std::pow(base, exp); - } - - template - std::enable_if_t, T> operator()(T base, T exp) { - if (exp < 0) { - throw std::invalid_argument( - "Integers cannot be raise to negative powers"); - } - T res = 1; - while (exp) { - if (exp & 1) { - res *= base; - } - exp >>= 1; - base *= base; - } - return res; - } -}; - void Power::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 2); auto& a = inputs[0]; auto& b = inputs[1]; - binary(a, b, out, PowerFn{}); + binary(a, b, out, detail::Power()); } void Subtract::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 2); auto& a = inputs[0]; auto& b = inputs[1]; - binary(a, b, out, [](auto x, auto y) { return x - y; }); + binary(a, b, out, detail::Subtract()); } } // namespace mlx::core diff --git a/mlx/backend/common/compiled.cpp b/mlx/backend/common/compiled.cpp index 914b85ae3..52bcac4fa 100644 --- a/mlx/backend/common/compiled.cpp +++ b/mlx/backend/common/compiled.cpp @@ -178,7 +178,13 @@ void* compile( build_command << "g++ -std=c++17 -O2 -Wall -fPIC -shared " << source_file_path << " -o " << shared_lib_path; std::string build_command_str = build_command.str(); - system(build_command_str.c_str()); + auto return_code = system(build_command_str.c_str()); + if (return_code) { + std::ostringstream msg; + msg << "[Compile::eval_cpu] Failed to compile function " << kernel_name + << " with error code " << return_code << "." << std::endl; + throw std::runtime_error(msg.str()); + } } // load library @@ -421,7 +427,7 @@ void Compiled::eval_cpu( // If it doesn't exist, compile it if (fn_ptr == nullptr) { std::ostringstream kernel; - kernel << preamble << std::endl; + kernel << get_kernel_preamble() << std::endl; kernel << "extern \"C\" {" << std::endl; build_kernel( kernel, diff --git a/mlx/backend/common/compiled_preamble.h b/mlx/backend/common/compiled_preamble.h index 8ccaa8bd7..84b77d29d 100644 --- a/mlx/backend/common/compiled_preamble.h +++ b/mlx/backend/common/compiled_preamble.h @@ -1,1121 +1,11 @@ -// Copyright © 2023-2024 Apple Inc. +// Copyright © 2023-24 Apple Inc. -const std::string preamble = R"( -#include -#include -#include - -#ifdef __ARM_FEATURE_FP16_SCALAR_ARITHMETIC - -#include -typedef __fp16 float16_t; - -#else - -#define ADD_HALF_BINOPS -#include -#include -#include -#include - -#define __MLX_HALF_NAN__ 0x7D00 - - -namespace { -union float_bits_fp16 { - float f; - uint32_t u; -}; -} // namespace - -struct _MLX_Float16 { - uint16_t bits_; - - // Default constructor - _MLX_Float16() = default; - - // Default copy constructor - _MLX_Float16(_MLX_Float16 const&) = default; - - // Appease std::vector for being special - _MLX_Float16& operator=(std::vector::reference x) { - bits_ = x; - return *this; - } - - _MLX_Float16& operator=(const float& x) { - return (*this = _MLX_Float16(x)); - } - - // From float32 - _MLX_Float16(const float& x) : bits_(0) { - // Conversion following - // https://github.com/Maratyszcza/FP16/blob/master/include/fp16/fp16.h - - // Union - float_bits_fp16 in; - - // Take fp32 bits - in.f = x; - - // Find and take sign bit - uint32_t x_sign_32 = in.u & uint32_t(0x80000000); - uint16_t x_sign_16 = (x_sign_32 >> 16); - - if (std::isnan(x)) { - bits_ = x_sign_16 | uint16_t(__MLX_HALF_NAN__); - } else { - // Union - float_bits_fp16 inf_scale, zero_scale, magic_bits; - - // Find exponent bits and take the max supported by half - uint32_t x_expo_32 = in.u & uint32_t(0x7f800000); - uint32_t max_expo_32 = uint32_t(0x38800000); - x_expo_32 = x_expo_32 < max_expo_32 ? max_expo_32 : x_expo_32; - x_expo_32 += uint32_t(15) << 23; - - // Handle scaling to inf as needed - inf_scale.u = uint32_t(0x77800000); - zero_scale.u = uint32_t(0x08800000); - - // Combine with magic and let addition do rounding - magic_bits.u = x_expo_32; - magic_bits.f += (std::abs(x) * inf_scale.f) * zero_scale.f; - - // Take the lower 5 bits of the exponent - uint32_t x_expo_16 = ((magic_bits.u >> 13) & uint32_t(0x7c00)); - - // Collect the lower 12 bits which have the mantissa - uint32_t x_mant_16 = magic_bits.u & uint32_t(0x0fff); - - // Combine sign, exp and mantissa - bits_ = (x_sign_16 | uint16_t(x_expo_16 + x_mant_16)); - } - } - - // To float32 - operator float() const { - // Conversion following - // https://github.com/Maratyszcza/FP16/blob/master/include/fp16/fp16.h - - // Union - float_bits_fp16 out; - - uint32_t x_sign_32 = (bits_ << 16) & uint32_t(0x80000000); - uint32_t base = (bits_ << 16); - uint32_t two_base = base + base; - - uint32_t denorm_max = 1u << 27; - if (two_base < denorm_max) { - out.u = uint32_t(126) << 23; // magic mask - out.u |= (two_base >> 17); // Bits from fp16 - out.f -= 0.5f; // magic bias - } else { - out.u = uint32_t(0xE0) << 23; // exponent offset - out.u += (two_base >> 4); // Bits from fp16 - float out_unscaled = out.f; // Store value - out.u = uint32_t(0x7800000); // exponent scale - out.f *= out_unscaled; - } - - // Add sign - out.u |= x_sign_32; - - return out.f; - } -}; - -#define half_binop_base(__op__, __operator__, otype, atype, btype, ctype) \ - inline otype __operator__(atype lhs, btype rhs) { \ - return static_cast(lhs) __op__ static_cast(rhs); \ - } - -#define half_binop_helper(__op__, __operator__, otype, itype, ctype) \ - inline otype __operator__(_MLX_Float16 lhs, itype rhs) { \ - return static_cast(lhs) __op__ static_cast(rhs); \ - } \ - inline otype __operator__(itype lhs, _MLX_Float16 rhs) { \ - return static_cast(lhs) __op__ static_cast(rhs); \ - } - -// Operators -#define half_binop(__op__, __operator__) \ - half_binop_base( \ - __op__, __operator__, _MLX_Float16, _MLX_Float16, _MLX_Float16, float); \ - half_binop_helper(__op__, __operator__, float, float, float); \ - half_binop_helper(__op__, __operator__, double, double, double); \ - half_binop_helper(__op__, __operator__, _MLX_Float16, bool, float); \ - half_binop_helper(__op__, __operator__, _MLX_Float16, int32_t, float); \ - half_binop_helper(__op__, __operator__, _MLX_Float16, uint32_t, float); \ - half_binop_helper(__op__, __operator__, _MLX_Float16, int64_t, float); \ - half_binop_helper(__op__, __operator__, _MLX_Float16, uint64_t, float); - -half_binop(+, operator+); -half_binop(-, operator-); -half_binop(*, operator*); -half_binop(/, operator/); - -#undef half_binop - -// Comparison ops -#define half_compop(__op__, __operator__) \ - half_binop_base( \ - __op__, __operator__, bool, _MLX_Float16, _MLX_Float16, float); \ - half_binop_helper(__op__, __operator__, bool, float, float); \ - half_binop_helper(__op__, __operator__, bool, double, double); \ - half_binop_helper(__op__, __operator__, bool, int32_t, float); \ - half_binop_helper(__op__, __operator__, bool, uint32_t, float); \ - half_binop_helper(__op__, __operator__, bool, int64_t, float); \ - half_binop_helper(__op__, __operator__, bool, uint64_t, float); - -half_compop(>, operator>); -half_compop(<, operator<); -half_compop(>=, operator>=); -half_compop(<=, operator<=); -half_compop(==, operator==); -half_compop(!=, operator!=); - -#undef half_compop - -// Negative -inline _MLX_Float16 operator-(_MLX_Float16 lhs) { - return -static_cast(lhs); -} - -// Inplace ops -#define half_inplace_op(__op__, __operator__) \ - inline _MLX_Float16& __operator__(_MLX_Float16& lhs, const float& rhs) { \ - lhs = lhs __op__ rhs; \ - return lhs; \ - } \ - inline float& __operator__(float& lhs, _MLX_Float16 rhs) { \ - lhs = lhs __op__ rhs; \ - return lhs; \ - } - -half_inplace_op(+, operator+=); -half_inplace_op(-, operator-=); -half_inplace_op(*, operator*=); -half_inplace_op(/, operator/=); - -#undef half_inplace_op - -// Bitwise ops - -#define half_bitop(__op__, __operator__) \ - inline _MLX_Float16 __operator__(_MLX_Float16 lhs, _MLX_Float16 rhs) { \ - _MLX_Float16 out; \ - out.bits_ = lhs.bits_ __op__ rhs.bits_; \ - return out; \ - } \ - inline _MLX_Float16 __operator__(_MLX_Float16 lhs, uint16_t rhs) { \ - _MLX_Float16 out; \ - out.bits_ = lhs.bits_ __op__ rhs; \ - return out; \ - } \ - inline _MLX_Float16 __operator__(uint16_t lhs, _MLX_Float16 rhs) { \ - _MLX_Float16 out; \ - out.bits_ = lhs __op__ rhs.bits_; \ - return out; \ - } - -half_bitop(|, operator|); -half_bitop(&, operator&); -half_bitop(^, operator^); - -#undef half_bitop - -#define half_inplace_bitop(__op__, __operator__) \ - inline _MLX_Float16& __operator__(_MLX_Float16& lhs, _MLX_Float16 rhs) { \ - lhs.bits_ = lhs.bits_ __op__ rhs.bits_; \ - return lhs; \ - } \ - inline _MLX_Float16& __operator__(_MLX_Float16& lhs, uint16_t rhs) { \ - lhs.bits_ = lhs.bits_ __op__ rhs; \ - return lhs; \ - } - -half_inplace_bitop(|, operator|=); -half_inplace_bitop(&, operator&=); -half_inplace_bitop(^, operator^=); - -#undef half_inplace_bitop - -typedef struct _MLX_Float16 float16_t; - -#endif // __ARM_FEATURE_FP16_SCALAR_ARITHMETIC -#ifdef __ARM_FEATURE_BF16 - -#include -typedef __bf16 bfloat16_t; - -#else - -#define ADD_HALF_BINOPS -#include -#include -#include -#include - -#define __MLX_BFLOAT_NAN__ 0x7FC0 - - -namespace { -union float_bits_bf16 { - float f; - uint32_t u; -}; -} // namespace - -struct _MLX_BFloat16 { - uint16_t bits_; - - // Default constructor - _MLX_BFloat16() = default; - - // Default copy constructor - _MLX_BFloat16(_MLX_BFloat16 const&) = default; - - // Appease std::vector for being special - _MLX_BFloat16& operator=(std::vector::reference x) { - bits_ = x; - return *this; - } - - _MLX_BFloat16& operator=(const float& x) { - return (*this = _MLX_BFloat16(x)); - } - - // From float32 - _MLX_BFloat16(const float& x) { - if (std::isnan(x)) { - bits_ = __MLX_BFLOAT_NAN__; - } else { - // Union - float_bits_bf16 in; - - // Take bits - in.f = x; - - // Round to nearest even - in.u += (in.u >> 16 & 1) + uint32_t(0x7FFF); - - // Take upper 16 bits - bits_ = in.u >> 16; - } - } - - // To float32 - operator float() const { - // Union - float_bits_bf16 out; - - // Upper 16 bits are the data and lower 16 bits are 0s - out.u = ((uint32_t)bits_) << 16; - - return out.f; - } -}; - -#define bfloat_binop_base(__op__, __operator__, otype, atype, btype, ctype) \ - inline otype __operator__(atype lhs, btype rhs) { \ - return static_cast(lhs) __op__ static_cast(rhs); \ - } - -#define bfloat_binop_helper(__op__, __operator__, otype, itype, ctype) \ - inline otype __operator__(_MLX_BFloat16 lhs, itype rhs) { \ - return static_cast(lhs) __op__ static_cast(rhs); \ - } \ - inline otype __operator__(itype lhs, _MLX_BFloat16 rhs) { \ - return static_cast(lhs) __op__ static_cast(rhs); \ - } - -// Operators -#define bfloat_binop(_op_, _operator_) \ - bfloat_binop_base( \ - _op_, _operator_, _MLX_BFloat16, _MLX_BFloat16, _MLX_BFloat16, float); \ - bfloat_binop_helper(_op_, _operator_, float, float, float); \ - bfloat_binop_helper(_op_, _operator_, double, double, double); \ - bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, bool, float); \ - bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int32_t, float); \ - bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint32_t, float); \ - bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int64_t, float); \ - bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint64_t, float); - -bfloat_binop(+, operator+); -bfloat_binop(-, operator-); -bfloat_binop(*, operator*); -bfloat_binop(/, operator/); - -#undef bfloat_binop - -// Comparison ops -#define bfloat_compop(__op__, __operator__) \ - bfloat_binop_base( \ - __op__, __operator__, bool, _MLX_BFloat16, _MLX_BFloat16, float); \ - bfloat_binop_helper(__op__, __operator__, bool, float, float); \ - bfloat_binop_helper(__op__, __operator__, bool, double, double); \ - bfloat_binop_helper(__op__, __operator__, bool, int32_t, float); \ - bfloat_binop_helper(__op__, __operator__, bool, uint32_t, float); \ - bfloat_binop_helper(__op__, __operator__, bool, int64_t, float); \ - bfloat_binop_helper(__op__, __operator__, bool, uint64_t, float); - -bfloat_compop(>, operator>); -bfloat_compop(<, operator<); -bfloat_compop(>=, operator>=); -bfloat_compop(<=, operator<=); -bfloat_compop(==, operator==); -bfloat_compop(!=, operator!=); - -#undef bfloat_compop - -// Negative -inline _MLX_BFloat16 operator-(_MLX_BFloat16 lhs) { - return -static_cast(lhs); -} - -// Inplace ops -#define bfloat_inplace_op(__op__, __operator__) \ - inline _MLX_BFloat16& __operator__(_MLX_BFloat16& lhs, const float& rhs) { \ - lhs = lhs __op__ rhs; \ - return lhs; \ - } \ - inline float& __operator__(float& lhs, _MLX_BFloat16 rhs) { \ - lhs = lhs __op__ rhs; \ - return lhs; \ - } - -bfloat_inplace_op(+, operator+=); -bfloat_inplace_op(-, operator-=); -bfloat_inplace_op(*, operator*=); -bfloat_inplace_op(/, operator/=); - -#undef bfloat_inplace_op - -// Bitwise ops - -#define bfloat_bitop(__op__, __operator__) \ - inline _MLX_BFloat16 __operator__(_MLX_BFloat16 lhs, _MLX_BFloat16 rhs) { \ - _MLX_BFloat16 out; \ - out.bits_ = lhs.bits_ __op__ rhs.bits_; \ - return out; \ - } \ - inline _MLX_BFloat16 __operator__(_MLX_BFloat16 lhs, uint16_t rhs) { \ - _MLX_BFloat16 out; \ - out.bits_ = lhs.bits_ __op__ rhs; \ - return out; \ - } \ - inline _MLX_BFloat16 __operator__(uint16_t lhs, _MLX_BFloat16 rhs) { \ - _MLX_BFloat16 out; \ - out.bits_ = lhs __op__ rhs.bits_; \ - return out; \ - } - -bfloat_bitop(|, operator|); -bfloat_bitop(&, operator&); -bfloat_bitop(^, operator^); - -#undef bfloat_bitop - -#define bfloat_inplace_bitop(__op__, __operator__) \ - inline _MLX_BFloat16& __operator__(_MLX_BFloat16& lhs, _MLX_BFloat16 rhs) { \ - lhs.bits_ = lhs.bits_ __op__ rhs.bits_; \ - return lhs; \ - } \ - inline _MLX_BFloat16& __operator__(_MLX_BFloat16& lhs, uint16_t rhs) { \ - lhs.bits_ = lhs.bits_ __op__ rhs; \ - return lhs; \ - } - -bfloat_inplace_bitop(|, operator|=); -bfloat_inplace_bitop(&, operator&=); -bfloat_inplace_bitop(^, operator^=); - -#undef bfloat_inplace_bitop - -typedef struct _MLX_BFloat16 bfloat16_t; - -#endif // __ARM_FEATURE_BF16 - -#ifdef ADD_HALF_BINOPS +#pragma once // clang-format off -#define fp16_bf16_binop_helper(__op__, __operator__) \ - inline float __operator__(float16_t lhs, bfloat16_t rhs) { \ - return static_cast(lhs) __op__ static_cast(rhs); \ - } \ - inline float __operator__(bfloat16_t lhs, float16_t rhs) { \ - return static_cast(lhs) __op__ static_cast(rhs); \ - } - -fp16_bf16_binop_helper(+, operator+) -fp16_bf16_binop_helper(-, operator-) -fp16_bf16_binop_helper(*, operator*) -fp16_bf16_binop_helper(/, operator/) +#include "mlx/types/half_types.h" +#include "mlx/types/complex.h" +#include "mlx/backend/common/ops.h" // clang-format on -#endif - - -struct complex64_t; - -template -static constexpr bool can_convert_to_complex64 = - !std::is_same_v && std::is_convertible_v; - -struct complex64_t : public std::complex { - complex64_t(float v, float u) : std::complex(v, u){}; - complex64_t(std::complex v) : std::complex(v){}; - - template < - typename T, - typename = typename std::enable_if>::type> - complex64_t(T x) : std::complex(x){}; - - operator float() const { - return real(); - }; -}; - -inline bool operator>=(const complex64_t& a, const complex64_t& b) { - return (a.real() > b.real()) || - (a.real() == b.real() && a.imag() >= b.imag()); -} - -inline bool operator>(const complex64_t& a, const complex64_t& b) { - return (a.real() > b.real()) || (a.real() == b.real() && a.imag() > b.imag()); -} - -inline complex64_t operator%(complex64_t a, complex64_t b) { - auto real = a.real() - (b.real() * static_cast(a.real() / b.real())); - auto imag = a.imag() - (b.imag() * static_cast(a.imag() / b.imag())); - if (real != 0 && (real < 0 != b.real() < 0)) - real += b.real(); - if (imag != 0 && (imag < 0 != b.imag() < 0)) - imag += b.imag(); - return {real, imag}; -} - -inline bool operator<=(const complex64_t& a, const complex64_t& b) { - return operator>=(b, a); -} - -inline bool operator<(const complex64_t& a, const complex64_t& b) { - return operator>(b, a); -} - -inline complex64_t operator-(const complex64_t& v) { - return -static_cast>(v); -} - -// clang-format off -#define complex_binop_helper(_op_, _operator_, itype) \ - inline complex64_t _operator_(itype x, const complex64_t& y) { \ - return static_cast(x) _op_ y; \ - } \ - inline complex64_t _operator_(const complex64_t& x, itype y) { \ - return x _op_ static_cast(y); \ - } - -#define complex_binop(_op_, _operator_) \ - inline complex64_t _operator_(const std::complex& x, const complex64_t& y) { \ - return x _op_ static_cast>(y); \ - } \ - inline complex64_t _operator_(const complex64_t& x, const std::complex& y) { \ - return static_cast>(x) _op_ y; \ - } \ - inline complex64_t _operator_(const complex64_t& x, const complex64_t& y) { \ - return static_cast>(x) \ - _op_ static_cast>(y); \ - } \ - complex_binop_helper(_op_, _operator_, bool) \ - complex_binop_helper(_op_, _operator_, uint32_t) \ - complex_binop_helper(_op_, _operator_, uint64_t) \ - complex_binop_helper(_op_, _operator_, int32_t) \ - complex_binop_helper(_op_, _operator_, int64_t) \ - complex_binop_helper(_op_, _operator_, float16_t) \ - complex_binop_helper(_op_, _operator_, bfloat16_t) \ - complex_binop_helper(_op_, _operator_, float) -// clang-format on - -complex_binop(+, operator+) - -typedef union { - int i; - float f; -} IntOrFloat; - -inline float fast_exp(float x) { - x *= 1.442695; // multiply with log_2(e) - float ipart, fpart; - IntOrFloat epart; - x = std::max(-80.f, std::min(x, 80.f)); - ipart = std::floor(x + 0.5); - fpart = x - ipart; - - x = 1.535336188319500e-4f; - x = x * fpart + 1.339887440266574e-3f; - x = x * fpart + 9.618437357674640e-3f; - x = x * fpart + 5.550332471162809e-2f; - x = x * fpart + 2.402264791363012e-1f; - x = x * fpart + 6.931472028550421e-1f; - x = x * fpart + 1.000000000000000f; - - // generate 2**ipart in the floating point representation using integer - // bitshifting - epart.i = (int(ipart) + 127) << 23; - - return epart.f * x; -} - -float fast_erf(float a) { - float r, s, t, u; - t = std::abs(a); - s = a * a; - if (t > 0.927734375f) { - // maximum error 0.99527 ulp - r = std::fma( - -1.72853470e-5f, t, 3.83197126e-4f); // -0x1.220000p-16,0x1.91cfb2p-12 - u = std::fma( - -3.88396438e-3f, t, 2.42546219e-2f); // -0x1.fd1438p-9, 0x1.8d6342p-6 - r = std::fma(r, s, u); - r = std::fma(r, t, -1.06777877e-1f); // -0x1.b55cb8p-4 - r = std::fma(r, t, -6.34846687e-1f); // -0x1.450aa0p-1 - r = std::fma(r, t, -1.28717512e-1f); // -0x1.079d0cp-3 - r = std::fma(r, t, -t); - // TODO, replace with expm1 when implemented - r = 1.0f - std::exp(r); - r = std::copysign(r, a); - } else { - // maximum error 0.98929 ulp - r = -5.96761703e-4f; // -0x1.38e000p-11 - r = std::fma(r, s, 4.99119423e-3f); // 0x1.471a58p-8 - r = std::fma(r, s, -2.67681349e-2f); // -0x1.b691b2p-6 - r = std::fma(r, s, 1.12819925e-1f); // 0x1.ce1c44p-4 - r = std::fma(r, s, -3.76125336e-1f); // -0x1.812700p-2 - r = std::fma(r, s, 1.28379166e-1f); // 0x1.06eba8p-3 - r = std::fma(r, a, a); - } - return r; -} - -float fast_erfinv(float a) { - auto t = std::fma(a, 0.0f - a, 1.0f); - t = std::log(t); - float p; - if (std::abs(t) > 6.125f) { // maximum ulp error = 2.35793 - p = 3.03697567e-10f; // 0x1.4deb44p-32 - p = std::fma(p, t, 2.93243101e-8f); // 0x1.f7c9aep-26 - p = std::fma(p, t, 1.22150334e-6f); // 0x1.47e512p-20 - p = std::fma(p, t, 2.84108955e-5f); // 0x1.dca7dep-16 - p = std::fma(p, t, 3.93552968e-4f); // 0x1.9cab92p-12 - p = std::fma(p, t, 3.02698812e-3f); // 0x1.8cc0dep-9 - p = std::fma(p, t, 4.83185798e-3f); // 0x1.3ca920p-8 - p = std::fma(p, t, -2.64646143e-1f); // -0x1.0eff66p-2 - p = std::fma(p, t, 8.40016484e-1f); // 0x1.ae16a4p-1 - } else { // maximum ulp error = 2.35002 - p = 5.43877832e-9f; // 0x1.75c000p-28 - p = std::fma(p, t, 1.43285448e-7f); // 0x1.33b402p-23 - p = std::fma(p, t, 1.22774793e-6f); // 0x1.499232p-20 - p = std::fma(p, t, 1.12963626e-7f); // 0x1.e52cd2p-24 - p = std::fma(p, t, -5.61530760e-5f); // -0x1.d70bd0p-15 - p = std::fma(p, t, -1.47697632e-4f); // -0x1.35be90p-13 - p = std::fma(p, t, 2.31468678e-3f); // 0x1.2f6400p-9 - p = std::fma(p, t, 1.15392581e-2f); // 0x1.7a1e50p-7 - p = std::fma(p, t, -2.32015476e-1f); // -0x1.db2aeep-3 - p = std::fma(p, t, 8.86226892e-1f); // 0x1.c5bf88p-1 - } - return a * p; -} - -struct Abs { - template - T operator()(T x) { - return std::abs(x); - }; - uint8_t operator()(uint8_t x) { - return x; - }; - uint16_t operator()(uint16_t x) { - return x; - }; - uint32_t operator()(uint32_t x) { - return x; - }; - uint64_t operator()(uint64_t x) { - return x; - }; - bool operator()(bool x) { - return x; - }; -}; - -struct ArcCos { - template - T operator()(T x) { - return std::acos(x); - }; -}; - -struct ArcCosh { - template - T operator()(T x) { - return std::acosh(x); - }; -}; - -struct ArcSin { - template - T operator()(T x) { - return std::asin(x); - }; -}; - -struct ArcSinh { - template - T operator()(T x) { - return std::asinh(x); - }; -}; - -struct ArcTan { - template - T operator()(T x) { - return std::atan(x); - }; -}; - -struct ArcTanh { - template - T operator()(T x) { - return std::atanh(x); - }; -}; - -struct Ceil { - template - T operator()(T x) { - return std::ceil(x); - }; - int8_t operator()(int8_t x) { - return x; - }; - int16_t operator()(int16_t x) { - return x; - }; - int32_t operator()(int32_t x) { - return x; - }; - int64_t operator()(int64_t x) { - return x; - }; - uint8_t operator()(uint8_t x) { - return x; - }; - uint16_t operator()(uint16_t x) { - return x; - }; - uint32_t operator()(uint32_t x) { - return x; - }; - uint64_t operator()(uint64_t x) { - return x; - }; - bool operator()(bool x) { - return x; - }; -}; - -struct Cos { - template - T operator()(T x) { - return std::cos(x); - }; -}; - -struct Cosh { - template - T operator()(T x) { - return std::cosh(x); - }; -}; - -struct Erf { - template - T operator()(T x) { - return static_cast(fast_erf(static_cast(x))); - }; -}; - -struct ErfInv { - template - T operator()(T x) { - return static_cast(fast_erfinv(static_cast(x))); - }; -}; - -struct Exp { - template - T operator()(T x) { - return fast_exp(x); - }; -}; - -struct Floor { - template - T operator()(T x) { - return std::floor(x); - }; - int8_t operator()(int8_t x) { - return x; - }; - int16_t operator()(int16_t x) { - return x; - }; - int32_t operator()(int32_t x) { - return x; - }; - int64_t operator()(int64_t x) { - return x; - }; - uint8_t operator()(uint8_t x) { - return x; - }; - uint16_t operator()(uint16_t x) { - return x; - }; - uint32_t operator()(uint32_t x) { - return x; - }; - uint64_t operator()(uint64_t x) { - return x; - }; - bool operator()(bool x) { - return x; - }; -}; - -struct Log { - template - T operator()(T x) { - return std::log(x); - }; -}; - -struct Log2 { - template - T operator()(T x) { - return std::log2(x); - }; -}; - -struct Log10 { - template - T operator()(T x) { - return std::log10(x); - }; -}; - -struct Log1p { - template - T operator()(T x) { - return log1p(x); - }; -}; - -struct LogicalNot { - template - T operator()(T x) { - return !x; - }; -}; - -struct Negative { - template - T operator()(T x) { - return -x; - }; -}; - -struct Round { - template - T operator()(T x) { - return std::rint(x); - } - - std::complex operator()(std::complex x) { - return {std::rint(x.real()), std::rint(x.imag())}; - } -}; - -struct Sigmoid { - template - T operator()(T x) { - auto one = static_cast(1.0); - return one / (one + fast_exp(-x)); - } -}; - -struct Sign { - template - T operator()(T x) { - return (x > T(0)) - (x < T(0)); - } - uint8_t operator()(uint8_t x) { - return x != 0; - } - uint16_t operator()(uint16_t x) { - return x != 0; - } - uint32_t operator()(uint32_t x) { - return x != 0; - } - uint64_t operator()(uint64_t x) { - return x != 0; - } -}; - -struct Sin { - template - T operator()(T x) { - return std::sin(x); - }; -}; - -struct Sinh { - template - T operator()(T x) { - return std::sinh(x); - }; -}; - -struct Square { - template - T operator()(T x) { - return x * x; - }; -}; - -struct Sqrt { - template - T operator()(T x) { - return std::sqrt(x); - }; -}; - -struct Rsqrt { - template - T operator()(T x) { - return static_cast(1.0) / std::sqrt(x); - }; -}; - -struct Tan { - template - T operator()(T x) { - return std::tan(x); - }; -}; - -struct Tanh { - template - T operator()(T x) { - return std::tanh(x); - }; -}; - -struct Add { - template - T operator()(T x, T y) { - return x + y; - } -}; - -struct Divide { - template - T operator()(T x, T y) { - return x / y; - } -}; - -struct Remainder { - template - std::enable_if_t & !std::is_signed_v, T> operator()( - T numerator, - T denominator) { - return numerator % denominator; - } - - template - std::enable_if_t & std::is_signed_v, T> operator()( - T numerator, - T denominator) { - auto r = numerator % denominator; - if (r != 0 && (r < 0 != denominator < 0)) - r += denominator; - return r; - } - - template - std::enable_if_t, T> operator()( - T numerator, - T denominator) { - auto r = std::fmod(numerator, denominator); - if (r != 0 && (r < 0 != denominator < 0)) { - r += denominator; - } - return r; - } - - std::complex operator()( - std::complex a, std::complex b) { - auto real = a.real() - (b.real() * static_cast(a.real() / b.real())); - auto imag = a.imag() - (b.imag() * static_cast(a.imag() / b.imag())); - if (real != 0 && ((real < 0) != (b.real() < 0))) - real += b.real(); - if (imag != 0 && ((imag < 0) != (b.imag() < 0))) - imag += b.imag(); - return {real, imag}; - } -}; - -struct Equal { - template - bool operator()(T x, T y) { - return x == y; - } -}; - -struct NaNEqual { - template - bool operator()(T x, T y) { - return x == y || (std::isnan(x) && std::isnan(y)); - } -}; - -struct Greater { - template - bool operator()(T x, T y) { - return x > y; - } -}; - -struct GreaterEqual { - template - bool operator()(T x, T y) { - return x >= y; - } -}; - -struct Less { - template - bool operator()(T x, T y) { - return x < y; - } -}; - -struct LessEqual { - template - bool operator()(T x, T y) { - return x <= y; - } -}; - -struct LogAddExp { - template - T operator()(T x, T y) { - constexpr float inf = std::numeric_limits::infinity(); - auto maxval = (x > y) ? x : y; - auto minval = (x > y) ? y : x; - return (minval == -inf || maxval == inf) - ? maxval - : static_cast( - maxval + std::log1p(fast_exp(minval - maxval))); - }; -}; - -struct Maximum { - template - std::enable_if_t, T> operator()(T x, T y) { - return (x > y) ? x : y; - } - - template - std::enable_if_t, T> operator()(T x, T y) { - if (std::isnan(x)) { - return x; - } - return (x > y) ? x : y; - } -}; - -struct Minimum { - template - std::enable_if_t, T> operator()(T x, T y) { - return x < y ? x : y; - } - - template - std::enable_if_t, T> operator()(T x, T y) { - if (std::isnan(x)) { - return x; - } - return x < y ? x : y; - } -}; - -struct Multiply { - template - T operator()(T x, T y) { - return x * y; - } -}; - -struct NotEqual { - template - bool operator()(T x, T y) { - return x != y; - } -}; - -struct Power { - template - std::enable_if_t, T> operator()(T base, T exp) { - return std::pow(base, exp); - } - - template - std::enable_if_t, T> operator()(T base, T exp) { - T res = 1; - while (exp) { - if (exp & 1) { - res *= base; - } - exp >>= 1; - base *= base; - } - return res; - } -}; - -struct Subtract { - template - T operator()(T x, T y) { - return x - y; - } -}; - -struct LogicalAnd { - template - T operator()(T x, T y) { - return x && y; - }; -}; - -struct LogicalOr { - template - T operator()(T x, T y) { - return x || y; - }; -}; -)"; +const char* get_kernel_preamble(); diff --git a/mlx/backend/common/erf.h b/mlx/backend/common/erf.h deleted file mode 100644 index a175a0c43..000000000 --- a/mlx/backend/common/erf.h +++ /dev/null @@ -1,11 +0,0 @@ -// Copyright © 2023 Apple Inc. - -namespace mlx::core { - -/* Approximation to the inverse error function. - * Based on code from: - * https://stackoverflow.com/questions/27229371/inverse-error-function-in-c#answer-49743348 - */ -float erfinv(float a); - -} // namespace mlx::core diff --git a/mlx/backend/common/make_compiled_preamble.sh b/mlx/backend/common/make_compiled_preamble.sh new file mode 100644 index 000000000..687f4cfc7 --- /dev/null +++ b/mlx/backend/common/make_compiled_preamble.sh @@ -0,0 +1,34 @@ +#!/bin/bash +# +# This script generates a C++ function that provides the CPU +# code for use with kernel generation. +# +# Copyright © 2023-24 Apple Inc. + + +OUTPUT_FILE=$1 +GCC=$2 +SRCDIR=$3 +CLANG=$4 + +if [ $CLANG = "TRUE" ]; then + read -r -d '' INCLUDES <<- EOM + #include + #include + #include + #include +EOM + +fi + +CONTENT=$($GCC -I $SRCDIR -E $SRCDIR/mlx/backend/common/compiled_preamble.h 2>/dev/null) + +cat << EOF > "$OUTPUT_FILE" +const char* get_kernel_preamble() { +return R"preamble( +$INCLUDES +$CONTENT +using namespace mlx::core::detail; +)preamble"; +} +EOF diff --git a/mlx/backend/common/ops.h b/mlx/backend/common/ops.h new file mode 100644 index 000000000..8b2d7ab58 --- /dev/null +++ b/mlx/backend/common/ops.h @@ -0,0 +1,591 @@ +// Copyright © 2023-2024 Apple Inc. + +#pragma once +#include +#include +#include + +namespace mlx::core::detail { + +typedef union { + int i; + float f; +} IntOrFloat; + +inline float fast_exp(float x) { + if (x == -std::numeric_limits::infinity()) { + return 0.0f; + } else if (x == std::numeric_limits::infinity() || std::isnan(x)) { + return x; + } + x *= 1.442695; // multiply with log_2(e) + float ipart, fpart; + IntOrFloat epart; + x = std::max(-80.f, std::min(x, 80.f)); + ipart = std::floor(x + 0.5); + fpart = x - ipart; + + x = 1.535336188319500e-4f; + x = x * fpart + 1.339887440266574e-3f; + x = x * fpart + 9.618437357674640e-3f; + x = x * fpart + 5.550332471162809e-2f; + x = x * fpart + 2.402264791363012e-1f; + x = x * fpart + 6.931472028550421e-1f; + x = x * fpart + 1.000000000000000f; + + // generate 2**ipart in the floating point representation using integer + // bitshifting + epart.i = (int(ipart) + 127) << 23; + + return epart.f * x; +} + +inline float fast_erf(float a) { + float r, s, t, u; + t = std::abs(a); + s = a * a; + if (t > 0.927734375f) { + // maximum error 0.99527 ulp + r = std::fma( + -1.72853470e-5f, t, 3.83197126e-4f); // -0x1.220000p-16,0x1.91cfb2p-12 + u = std::fma( + -3.88396438e-3f, t, 2.42546219e-2f); // -0x1.fd1438p-9, 0x1.8d6342p-6 + r = std::fma(r, s, u); + r = std::fma(r, t, -1.06777877e-1f); // -0x1.b55cb8p-4 + r = std::fma(r, t, -6.34846687e-1f); // -0x1.450aa0p-1 + r = std::fma(r, t, -1.28717512e-1f); // -0x1.079d0cp-3 + r = std::fma(r, t, -t); + // TODO, replace with expm1 when implemented + r = 1.0f - std::exp(r); + r = std::copysign(r, a); + } else { + // maximum error 0.98929 ulp + r = -5.96761703e-4f; // -0x1.38e000p-11 + r = std::fma(r, s, 4.99119423e-3f); // 0x1.471a58p-8 + r = std::fma(r, s, -2.67681349e-2f); // -0x1.b691b2p-6 + r = std::fma(r, s, 1.12819925e-1f); // 0x1.ce1c44p-4 + r = std::fma(r, s, -3.76125336e-1f); // -0x1.812700p-2 + r = std::fma(r, s, 1.28379166e-1f); // 0x1.06eba8p-3 + r = std::fma(r, a, a); + } + return r; +} + +inline float fast_erfinv(float a) { + auto t = std::fma(a, 0.0f - a, 1.0f); + t = std::log(t); + float p; + if (std::abs(t) > 6.125f) { // maximum ulp error = 2.35793 + p = 3.03697567e-10f; // 0x1.4deb44p-32 + p = std::fma(p, t, 2.93243101e-8f); // 0x1.f7c9aep-26 + p = std::fma(p, t, 1.22150334e-6f); // 0x1.47e512p-20 + p = std::fma(p, t, 2.84108955e-5f); // 0x1.dca7dep-16 + p = std::fma(p, t, 3.93552968e-4f); // 0x1.9cab92p-12 + p = std::fma(p, t, 3.02698812e-3f); // 0x1.8cc0dep-9 + p = std::fma(p, t, 4.83185798e-3f); // 0x1.3ca920p-8 + p = std::fma(p, t, -2.64646143e-1f); // -0x1.0eff66p-2 + p = std::fma(p, t, 8.40016484e-1f); // 0x1.ae16a4p-1 + } else { // maximum ulp error = 2.35002 + p = 5.43877832e-9f; // 0x1.75c000p-28 + p = std::fma(p, t, 1.43285448e-7f); // 0x1.33b402p-23 + p = std::fma(p, t, 1.22774793e-6f); // 0x1.499232p-20 + p = std::fma(p, t, 1.12963626e-7f); // 0x1.e52cd2p-24 + p = std::fma(p, t, -5.61530760e-5f); // -0x1.d70bd0p-15 + p = std::fma(p, t, -1.47697632e-4f); // -0x1.35be90p-13 + p = std::fma(p, t, 2.31468678e-3f); // 0x1.2f6400p-9 + p = std::fma(p, t, 1.15392581e-2f); // 0x1.7a1e50p-7 + p = std::fma(p, t, -2.32015476e-1f); // -0x1.db2aeep-3 + p = std::fma(p, t, 8.86226892e-1f); // 0x1.c5bf88p-1 + } + return a * p; +} + +struct Abs { + template + T operator()(T x) { + return std::abs(x); + }; + uint8_t operator()(uint8_t x) { + return x; + }; + uint16_t operator()(uint16_t x) { + return x; + }; + uint32_t operator()(uint32_t x) { + return x; + }; + uint64_t operator()(uint64_t x) { + return x; + }; + bool operator()(bool x) { + return x; + }; +}; + +struct ArcCos { + template + T operator()(T x) { + return std::acos(x); + }; +}; + +struct ArcCosh { + template + T operator()(T x) { + return std::acosh(x); + }; +}; + +struct ArcSin { + template + T operator()(T x) { + return std::asin(x); + }; +}; + +struct ArcSinh { + template + T operator()(T x) { + return std::asinh(x); + }; +}; + +struct ArcTan { + template + T operator()(T x) { + return std::atan(x); + }; +}; + +struct ArcTanh { + template + T operator()(T x) { + return std::atanh(x); + }; +}; + +struct Ceil { + template + T operator()(T x) { + return std::ceil(x); + }; + int8_t operator()(int8_t x) { + return x; + }; + int16_t operator()(int16_t x) { + return x; + }; + int32_t operator()(int32_t x) { + return x; + }; + int64_t operator()(int64_t x) { + return x; + }; + uint8_t operator()(uint8_t x) { + return x; + }; + uint16_t operator()(uint16_t x) { + return x; + }; + uint32_t operator()(uint32_t x) { + return x; + }; + uint64_t operator()(uint64_t x) { + return x; + }; + bool operator()(bool x) { + return x; + }; +}; + +struct Cos { + template + T operator()(T x) { + return std::cos(x); + }; +}; + +struct Cosh { + template + T operator()(T x) { + return std::cosh(x); + }; +}; + +struct Erf { + template + T operator()(T x) { + return static_cast(fast_erf(static_cast(x))); + }; +}; + +struct ErfInv { + template + T operator()(T x) { + return static_cast(fast_erfinv(static_cast(x))); + }; +}; + +struct Exp { + template + T operator()(T x) { + return fast_exp(x); + }; + + complex64_t operator()(complex64_t x) { + return std::exp(x); + } +}; + +struct Floor { + template + T operator()(T x) { + return std::floor(x); + }; + int8_t operator()(int8_t x) { + return x; + }; + int16_t operator()(int16_t x) { + return x; + }; + int32_t operator()(int32_t x) { + return x; + }; + int64_t operator()(int64_t x) { + return x; + }; + uint8_t operator()(uint8_t x) { + return x; + }; + uint16_t operator()(uint16_t x) { + return x; + }; + uint32_t operator()(uint32_t x) { + return x; + }; + uint64_t operator()(uint64_t x) { + return x; + }; + bool operator()(bool x) { + return x; + }; +}; + +struct Log { + template + T operator()(T x) { + return std::log(x); + }; +}; + +struct Log2 { + template + T operator()(T x) { + return std::log2(x); + }; +}; + +struct Log10 { + template + T operator()(T x) { + return std::log10(x); + }; +}; + +struct Log1p { + template + T operator()(T x) { + return log1p(x); + }; +}; + +struct LogicalNot { + template + T operator()(T x) { + return !x; + }; +}; + +struct Negative { + template + T operator()(T x) { + return -x; + }; +}; + +struct Round { + template + T operator()(T x) { + return std::rint(x); + } + + complex64_t operator()(complex64_t x) { + return {std::rint(x.real()), std::rint(x.imag())}; + } +}; + +struct Sigmoid { + template + T operator()(T x) { + auto one = static_cast(1.0); + return one / (one + fast_exp(-x)); + } +}; + +struct Sign { + template + T operator()(T x) { + return (x > T(0)) - (x < T(0)); + } + uint8_t operator()(uint8_t x) { + return x != 0; + } + uint16_t operator()(uint16_t x) { + return x != 0; + } + uint32_t operator()(uint32_t x) { + return x != 0; + } + uint64_t operator()(uint64_t x) { + return x != 0; + } +}; + +struct Sin { + template + T operator()(T x) { + return std::sin(x); + }; +}; + +struct Sinh { + template + T operator()(T x) { + return std::sinh(x); + }; +}; + +struct Square { + template + T operator()(T x) { + return x * x; + }; +}; + +struct Sqrt { + template + T operator()(T x) { + return std::sqrt(x); + }; +}; + +struct Rsqrt { + template + T operator()(T x) { + return static_cast(1.0) / std::sqrt(x); + }; +}; + +struct Tan { + template + T operator()(T x) { + return std::tan(x); + }; +}; + +struct Tanh { + template + T operator()(T x) { + return std::tanh(x); + }; +}; + +struct Add { + template + T operator()(T x, T y) { + return x + y; + } +}; + +struct Divide { + template + T operator()(T x, T y) { + return x / y; + } +}; + +struct Remainder { + template + std::enable_if_t & !std::is_signed_v, T> operator()( + T numerator, + T denominator) { + return numerator % denominator; + } + + template + std::enable_if_t & std::is_signed_v, T> operator()( + T numerator, + T denominator) { + auto r = numerator % denominator; + if (r != 0 && (r < 0 != denominator < 0)) + r += denominator; + return r; + } + + template + std::enable_if_t, T> operator()( + T numerator, + T denominator) { + auto r = std::fmod(numerator, denominator); + if (r != 0 && (r < 0 != denominator < 0)) { + r += denominator; + } + return r; + } + + complex64_t operator()(complex64_t numerator, complex64_t denominator) { + return numerator % denominator; + } +}; + +struct Equal { + template + bool operator()(T x, T y) { + return x == y; + } +}; + +struct NaNEqual { + template + bool operator()(T x, T y) { + return x == y || (std::isnan(x) && std::isnan(y)); + } +}; + +struct Greater { + template + bool operator()(T x, T y) { + return x > y; + } +}; + +struct GreaterEqual { + template + bool operator()(T x, T y) { + return x >= y; + } +}; + +struct Less { + template + bool operator()(T x, T y) { + return x < y; + } +}; + +struct LessEqual { + template + bool operator()(T x, T y) { + return x <= y; + } +}; + +struct Maximum { + template + std::enable_if_t, T> operator()(T x, T y) { + return (x > y) ? x : y; + } + + template + std::enable_if_t, T> operator()(T x, T y) { + if (std::isnan(x)) { + return x; + } + return (x > y) ? x : y; + } +}; + +struct Minimum { + template + std::enable_if_t, T> operator()(T x, T y) { + return x < y ? x : y; + } + + template + std::enable_if_t, T> operator()(T x, T y) { + if (std::isnan(x)) { + return x; + } + return x < y ? x : y; + } +}; + +struct LogAddExp { + template + T operator()(T x, T y) { + constexpr float inf = std::numeric_limits::infinity(); + auto maxval = Maximum()(x, y); + auto minval = Minimum()(x, y); + return (minval == -inf || maxval == inf) + ? maxval + : static_cast( + maxval + std::log1p(fast_exp(minval - maxval))); + }; +}; + +struct Multiply { + template + T operator()(T x, T y) { + return x * y; + } +}; + +struct NotEqual { + template + bool operator()(T x, T y) { + return x != y; + } +}; + +struct Power { + template + std::enable_if_t, T> operator()(T base, T exp) { + return std::pow(base, exp); + } + + template + std::enable_if_t, T> operator()(T base, T exp) { + T res = 1; + while (exp) { + if (exp & 1) { + res *= base; + } + exp >>= 1; + base *= base; + } + return res; + } +}; + +struct Subtract { + template + T operator()(T x, T y) { + return x - y; + } +}; + +struct LogicalAnd { + template + T operator()(T x, T y) { + return x && y; + }; +}; + +struct LogicalOr { + template + T operator()(T x, T y) { + return x || y; + }; +}; + +} // namespace mlx::core::detail diff --git a/mlx/backend/common/primitives.cpp b/mlx/backend/common/primitives.cpp index 37c61761a..a1e99d7c7 100644 --- a/mlx/backend/common/primitives.cpp +++ b/mlx/backend/common/primitives.cpp @@ -10,7 +10,7 @@ #include "mlx/backend/common/arange.h" #include "mlx/backend/common/binary.h" #include "mlx/backend/common/copy.h" -#include "mlx/backend/common/erf.h" +#include "mlx/backend/common/ops.h" #include "mlx/backend/common/threefry.h" #include "mlx/backend/common/unary.h" #include "mlx/backend/common/utils.h" @@ -26,7 +26,7 @@ void Abs::eval(const std::vector& inputs, array& out) { // No-op for unsigned types out.copy_shared_buffer(in); } else { - unary(in, out, AbsOp()); + unary(in, out, detail::Abs()); } } @@ -38,7 +38,7 @@ void ArcCos::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; if (is_floating_point(out.dtype())) { - unary_fp(in, out, [](auto x) { return std::acos(x); }); + unary_fp(in, out, detail::ArcCos()); } else { throw std::invalid_argument( "[arccos] Cannot compute inverse cosine of elements in array" @@ -50,7 +50,7 @@ void ArcCosh::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; if (is_floating_point(out.dtype())) { - unary_fp(in, out, [](auto x) { return std::acosh(x); }); + unary_fp(in, out, detail::ArcCosh()); } else { throw std::invalid_argument( "[arccosh] Cannot compute inverse hyperbolic cosine of elements in" @@ -62,7 +62,7 @@ void ArcSin::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; if (is_floating_point(out.dtype())) { - unary_fp(in, out, [](auto x) { return std::asin(x); }); + unary_fp(in, out, detail::ArcSin()); } else { throw std::invalid_argument( "[arcsin] Cannot compute inverse sine of elements in array" @@ -74,7 +74,7 @@ void ArcSinh::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; if (is_floating_point(out.dtype())) { - unary_fp(in, out, [](auto x) { return std::asinh(x); }); + unary_fp(in, out, detail::ArcSinh()); } else { throw std::invalid_argument( "[arcsinh] Cannot compute inverse hyperbolic sine of elements in" @@ -86,7 +86,7 @@ void ArcTan::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; if (is_floating_point(out.dtype())) { - unary_fp(in, out, [](auto x) { return std::atan(x); }); + unary_fp(in, out, detail::ArcTan()); } else { throw std::invalid_argument( "[arctan] Cannot compute inverse tangent of elements in array" @@ -98,7 +98,7 @@ void ArcTanh::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; if (is_floating_point(out.dtype())) { - unary_fp(in, out, [](auto x) { return std::atanh(x); }); + unary_fp(in, out, detail::ArcTanh()); } else { throw std::invalid_argument( "[arctanh] Cannot compute inverse hyperbolic tangent of elements in" @@ -172,7 +172,7 @@ void Ceil::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& in = inputs[0]; if (not is_integral(in.dtype())) { - unary_fp(in, out, [](auto x) { return std::ceil(x); }); + unary_fp(in, out, detail::Ceil()); } else { // No-op integer types out.copy_shared_buffer(in); @@ -212,7 +212,7 @@ void Cos::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; if (is_floating_point(out.dtype())) { - unary_fp(in, out, [](auto x) { return std::cos(x); }); + unary_fp(in, out, detail::Cos()); } else { throw std::invalid_argument( "[cos] Cannot compute cosine of elements in array" @@ -224,7 +224,7 @@ void Cosh::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; if (is_floating_point(out.dtype())) { - unary_fp(in, out, [](auto x) { return std::cosh(x); }); + unary_fp(in, out, detail::Cosh()); } else { throw std::invalid_argument( "[cosh] Cannot compute hyperbolic cosine of elements in array" @@ -256,17 +256,13 @@ void Erf::eval(const std::vector& inputs, array& out) { const auto& in = inputs[0]; switch (out.dtype()) { case float32: - unary_op(in, out, [](auto x) { return std::erf(x); }); + unary_op(in, out, detail::Erf()); break; case float16: - unary_op(in, out, [](auto x) { - return static_cast(std::erf(static_cast(x))); - }); + unary_op(in, out, detail::Erf()); break; case bfloat16: - unary_op(in, out, [](auto x) { - return static_cast(std::erf(static_cast(x))); - }); + unary_op(in, out, detail::Erf()); break; default: throw std::invalid_argument( @@ -280,17 +276,13 @@ void ErfInv::eval(const std::vector& inputs, array& out) { const auto& in = inputs[0]; switch (out.dtype()) { case float32: - unary_op(in, out, [](auto x) { return erfinv(x); }); + unary_op(in, out, detail::ErfInv()); break; case float16: - unary_op(in, out, [](auto x) { - return static_cast(erfinv(static_cast(x))); - }); + unary_op(in, out, detail::ErfInv()); break; case bfloat16: - unary_op(in, out, [](auto x) { - return static_cast(erfinv(static_cast(x))); - }); + unary_op(in, out, detail::ErfInv()); break; default: throw std::invalid_argument( @@ -302,9 +294,8 @@ void ErfInv::eval(const std::vector& inputs, array& out) { void Exp::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; - if (is_floating_point(out.dtype())) { - unary_fp(in, out, [](auto x) { return std::exp(x); }); + unary_fp(in, out, detail::Exp()); } else { throw std::invalid_argument( "[exp] Cannot exponentiate elements in array" @@ -316,7 +307,7 @@ void Floor::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& in = inputs[0]; if (not is_integral(in.dtype())) { - unary_fp(in, out, [](auto x) { return std::floor(x); }); + unary_fp(in, out, detail::Floor()); } else { // No-op integer types out.copy_shared_buffer(in); @@ -344,13 +335,13 @@ void Log::eval(const std::vector& inputs, array& out) { if (is_floating_point(out.dtype())) { switch (base_) { case Base::e: - unary_fp(in, out, [](auto x) { return std::log(x); }); + unary_fp(in, out, detail::Log()); break; case Base::two: - unary_fp(in, out, [](auto x) { return std::log2(x); }); + unary_fp(in, out, detail::Log2()); break; case Base::ten: - unary_fp(in, out, [](auto x) { return std::log10(x); }); + unary_fp(in, out, detail::Log10()); break; } } else { @@ -364,7 +355,7 @@ void Log1p::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; if (is_floating_point(out.dtype())) { - unary_fp(in, out, [](auto x) { return std::log1p(x); }); + unary_fp(in, out, detail::Log1p()); } else { throw std::invalid_argument( "[log1p] Cannot compute log of elements in array with" @@ -375,27 +366,27 @@ void Log1p::eval(const std::vector& inputs, array& out) { void LogicalNot::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& in = inputs[0]; - unary(in, out, [](auto x) { return !x; }); + unary(in, out, detail::LogicalNot()); } void LogicalAnd::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 2); // LogicalAnd requires two input arrays auto& in1 = inputs[0]; auto& in2 = inputs[1]; - binary(in1, in2, out, [](auto x, auto y) { return x && y; }); + binary(in1, in2, out, detail::LogicalAnd()); } void LogicalOr::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 2); // LogicalOr requires two input arrays auto& in1 = inputs[0]; auto& in2 = inputs[1]; - binary(in1, in2, out, [](auto x, auto y) { return x || y; }); + binary(in1, in2, out, detail::LogicalOr()); } void Negative::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& in = inputs[0]; - unary(in, out, [](auto x) { return -x; }); + unary(in, out, detail::Negative()); } void Pad::eval(const std::vector& inputs, array& out) { @@ -498,7 +489,7 @@ void Round::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& in = inputs[0]; if (not is_integral(in.dtype())) { - unary_fp(in, out, RoundOp()); + unary_fp(in, out, detail::Round()); } else { // No-op integer types out.copy_shared_buffer(in); @@ -509,11 +500,7 @@ void Sigmoid::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; if (is_floating_point(out.dtype())) { - auto sigmoid_op = [](auto x) { - auto one = static_cast(1.0); - return one / (one + std::exp(-x)); - }; - unary_fp(in, out, sigmoid_op); + unary_fp(in, out, detail::Sigmoid()); } else { throw std::invalid_argument( "[sigmoid] Cannot sigmoid of elements in array with" @@ -527,7 +514,7 @@ void Sign::eval(const std::vector& inputs, array& out) { if (in.dtype() == bool_) { out.copy_shared_buffer(in); } else { - unary(in, out, SignOp()); + unary(in, out, detail::Sign()); } } @@ -535,7 +522,7 @@ void Sin::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; if (is_floating_point(out.dtype())) { - unary_fp(in, out, [](auto x) { return std::sin(x); }); + unary_fp(in, out, detail::Sin()); } else { throw std::invalid_argument( "[sin] Cannot compute sine of elements in array" @@ -547,7 +534,7 @@ void Sinh::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; if (is_floating_point(out.dtype())) { - unary_fp(in, out, [](auto x) { return std::sinh(x); }); + unary_fp(in, out, detail::Sinh()); } else { throw std::invalid_argument( "[sinh] Cannot compute hyperbolic sine of elements in array" @@ -656,18 +643,16 @@ void Split::eval( void Square::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& in = inputs[0]; - unary(in, out, [](auto x) { return x * x; }); + unary(in, out, detail::Square()); } void Sqrt::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& in = inputs[0]; if (recip_) { - unary_fp(in, out, [](auto x) { - return static_cast(1.0) / sqrt(x); - }); + unary_fp(in, out, detail::Rsqrt()); } else { - unary_fp(in, out, [](auto x) { return sqrt(x); }); + unary_fp(in, out, detail::Sqrt()); } } @@ -680,7 +665,7 @@ void Tan::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; if (is_floating_point(out.dtype())) { - unary_fp(in, out, [](auto x) { return std::tan(x); }); + unary_fp(in, out, detail::Tan()); } else { throw std::invalid_argument( "[tan] Cannot compute tangent of elements in array" @@ -692,7 +677,7 @@ void Tanh::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; if (is_floating_point(out.dtype())) { - unary_fp(in, out, [](auto x) { return std::tanh(x); }); + unary_fp(in, out, detail::Tanh()); } else { throw std::invalid_argument( "[tanh] Cannot compute hyperbolic tangent of elements in array" diff --git a/mlx/backend/common/unary.h b/mlx/backend/common/unary.h index 7fdcbeb77..28c4f0f4a 100644 --- a/mlx/backend/common/unary.h +++ b/mlx/backend/common/unary.h @@ -11,59 +11,6 @@ namespace mlx::core { namespace { -struct AbsOp { - template - T operator()(T x) { - return std::abs(x); - } - uint8_t operator()(uint8_t x) { - return x; - } - uint16_t operator()(uint16_t x) { - return x; - } - uint32_t operator()(uint32_t x) { - return x; - } - uint64_t operator()(uint64_t x) { - return x; - } - bool operator()(bool x) { - return x; - } -}; - -struct SignOp { - template - T operator()(T x) { - return (x > T(0)) - (x < T(0)); - } - - uint8_t operator()(uint8_t x) { - return x != 0; - } - uint16_t operator()(uint16_t x) { - return x != 0; - } - uint32_t operator()(uint32_t x) { - return x != 0; - } - uint64_t operator()(uint64_t x) { - return x != 0; - } -}; - -struct RoundOp { - template - T operator()(T x) { - return std::rint(x); - } - - complex64_t operator()(complex64_t x) { - return {std::rint(x.real()), std::rint(x.imag())}; - } -}; - void set_unary_output_data(const array& in, array& out) { if (in.is_donatable() && in.itemsize() == out.itemsize()) { out.copy_shared_buffer(in); diff --git a/mlx/types/complex.h b/mlx/types/complex.h index 46f4310f9..f8a607766 100644 --- a/mlx/types/complex.h +++ b/mlx/types/complex.h @@ -38,9 +38,9 @@ inline bool operator>(const complex64_t& a, const complex64_t& b) { inline complex64_t operator%(complex64_t a, complex64_t b) { auto real = a.real() - (b.real() * static_cast(a.real() / b.real())); auto imag = a.imag() - (b.imag() * static_cast(a.imag() / b.imag())); - if (real != 0 && (real < 0 != b.real() < 0)) + if (real != 0 && ((real < 0) != (b.real() < 0))) real += b.real(); - if (imag != 0 && (imag < 0 != b.imag() < 0)) + if (imag != 0 && ((imag < 0) != (b.imag() < 0))) imag += b.imag(); return {real, imag}; } diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index e52c1294f..41db064be 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -1002,7 +1002,7 @@ TEST_CASE("test arithmetic unary ops") { CHECK_EQ(exp(x).item(), 1.0); x = array(2.0); - CHECK_EQ(exp(x).item(), std::exp(2.0f)); + CHECK_EQ(exp(x).item(), doctest::Approx(std::exp(2.0f))); CHECK(array_equal(exp(array({})), array({})).item()); @@ -1012,7 +1012,7 @@ TEST_CASE("test arithmetic unary ops") { // Integer input type x = array(2); CHECK_EQ(x.dtype(), int32); - CHECK_EQ(exp(x).item(), std::exp(2.0f)); + CHECK_EQ(exp(x).item(), doctest::Approx(std::exp(2.0f))); // Input is irregularly strided x = broadcast_to(array(1.0f), {2, 2, 2}); @@ -1020,7 +1020,7 @@ TEST_CASE("test arithmetic unary ops") { x = split(array({0.0f, 1.0f, 2.0f, 3.0f}, {2, 2}), 2, 1)[0]; auto expected = array({std::exp(0.0f), std::exp(2.0f)}, {2, 1}); - CHECK(array_equal(exp(x), expected).item()); + CHECK(allclose(exp(x), expected).item()); } // Test sine From e1bdf6a8d9ba140352f04dd5f714e419ae38c18f Mon Sep 17 00:00:00 2001 From: Diogo Date: Mon, 19 Feb 2024 10:03:56 -0500 Subject: [PATCH 39/42] discover doctests in cmake (#703) --- tests/CMakeLists.txt | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index be9570f7e..34b69f233 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -14,6 +14,8 @@ if (MLX_BUILD_METAL) ) endif() +include(${doctest_SOURCE_DIR}/scripts/cmake/doctest.cmake) + target_sources(tests PRIVATE allocator_tests.cpp array_tests.cpp @@ -37,4 +39,5 @@ target_sources(tests PRIVATE ) target_link_libraries(tests PRIVATE mlx doctest) +doctest_discover_tests(tests) add_test(NAME tests COMMAND tests) From f883fcede082c7198f32b4793f60d26d3195c2fb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hinrik=20Sn=C3=A6r=20Gu=C3=B0mundsson?= Date: Mon, 19 Feb 2024 12:40:52 -0500 Subject: [PATCH 40/42] Added support for atleast_1d, atleast_2d, atleast_3d (#694) --- ACKNOWLEDGMENTS.md | 3 +- docs/src/python/ops.rst | 3 ++ mlx/ops.cpp | 30 ++++++++++++++ mlx/ops.h | 5 +++ python/src/ops.cpp | 60 +++++++++++++++++++++++++++ python/tests/test_ops.py | 90 ++++++++++++++++++++++++++++++++++++++++ tests/ops_tests.cpp | 51 +++++++++++++++++++++++ 7 files changed, 241 insertions(+), 1 deletion(-) diff --git a/ACKNOWLEDGMENTS.md b/ACKNOWLEDGMENTS.md index 36aedc77a..c2cad615e 100644 --- a/ACKNOWLEDGMENTS.md +++ b/ACKNOWLEDGMENTS.md @@ -10,8 +10,9 @@ MLX was developed with contributions from the following individuals: - Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. Added `LogicalAnd` and `LogicalOR` ops. - Juarez Bochi: Fixed bug in cross attention. - Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example. -- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile`, `StreamContext`, `stream` and safetensor support +- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile`, `StreamContext`, `stream` and safetensor support. - Gabrijel Boduljak: Added `mlx.core.linalg`, implemented `norm` method and `InstanceNorm` layer. Implemented ``MaxPool1d``, ``MaxPool2d``, ``AvgPool1d``, ``AvgPool2d``. +- Hinrik Snær Guðmundsson: Added `atleast_1d`, `atleast_2d`, `atleast_3d` ops. diff --git a/docs/src/python/ops.rst b/docs/src/python/ops.rst index 09e2d5f71..7ec7defc9 100644 --- a/docs/src/python/ops.rst +++ b/docs/src/python/ops.rst @@ -25,6 +25,9 @@ Operations argpartition argsort array_equal + atleast_1d + atleast_2d + atleast_3d broadcast_to ceil clip diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 32af8a078..97d4a3a2d 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -3381,4 +3381,34 @@ std::vector depends( shapes, dtypes, std::make_shared(to_stream(s)), all_inputs); } +array atleast_1d(const array& a, StreamOrDevice s /* = {} */) { + if (a.ndim() == 0) { + return reshape(a, {1}, s); + } + return a; +} + +array atleast_2d(const array& a, StreamOrDevice s /* = {} */) { + switch (a.ndim()) { + case 0: + return reshape(a, {1, 1}, s); + case 1: + return reshape(a, {1, static_cast(a.size())}, s); + default: + return a; + } +} + +array atleast_3d(const array& a, StreamOrDevice s /* = {} */) { + switch (a.ndim()) { + case 0: + return reshape(a, {1, 1, 1}, s); + case 1: + return reshape(a, {1, static_cast(a.size()), 1}, s); + case 2: + return reshape(a, {a.shape(0), a.shape(1), 1}, s); + default: + return a; + } +} } // namespace mlx::core diff --git a/mlx/ops.h b/mlx/ops.h index f7036b8c6..b61224d65 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -1121,4 +1121,9 @@ std::vector depends( const std::vector& inputs, const std::vector& dependencies); +/** convert an array to an atleast ndim array */ +array atleast_1d(const array& a, StreamOrDevice s = {}); +array atleast_2d(const array& a, StreamOrDevice s = {}); +array atleast_3d(const array& a, StreamOrDevice s = {}); + } // namespace mlx::core diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 8e08e6ca9..2c2dcecfd 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -3636,4 +3636,64 @@ void init_ops(py::module_& m) { Returns: array: The extracted diagonal or the constructed diagonal matrix. )pbdoc"); + m.def( + "atleast_1d", + &atleast_1d, + "a"_a, + py::pos_only(), + py::kw_only(), + "stream"_a = none, + R"pbdoc( + atleast_1d(a: array, stream: Union[None, Stream, Device] = None) -> array + + Convert array to have at least one dimension. + + args: + a (array): Input array + stream (Union[None, Stream, Device], optional): The stream to execute the operation on. + + Returns: + array: An array with at least one dimension. + + )pbdoc"); + m.def( + "atleast_2d", + &atleast_2d, + "a"_a, + py::pos_only(), + py::kw_only(), + "stream"_a = none, + R"pbdoc( + atleast_2d(a: array, stream: Union[None, Stream, Device] = None) -> array + + Convert array to have at least two dimensions. + + args: + a (array): Input array + stream (Union[None, Stream, Device], optional): The stream to execute the operation on. + + Returns: + array: An array with at least two dimensions. + + )pbdoc"); + m.def( + "atleast_3d", + &atleast_3d, + "a"_a, + py::pos_only(), + py::kw_only(), + "stream"_a = none, + R"pbdoc( + atleast_3d(a: array, stream: Union[None, Stream, Device] = None) -> array + + Convert array to have at least three dimensions. + + args: + a (array): Input array + stream (Union[None, Stream, Device], optional): The stream to execute the operation on. + + Returns: + array: An array with at least three dimensions. + + )pbdoc"); } diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 66e683303..3401338f8 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -1883,6 +1883,96 @@ class TestOps(mlx_tests.MLXTestCase): expected = mx.array(np.diag(x, k=-1)) self.assertTrue(mx.array_equal(result, expected)) + def test_atleast_1d(self): + def compare_nested_lists(x, y): + if isinstance(x, list) and isinstance(y, list): + if len(x) != len(y): + return False + for i in range(len(x)): + if not compare_nested_lists(x[i], y[i]): + return False + return True + else: + return x == y + + # Test 1D input + arrays = [ + [1], + [1, 2, 3], + [1, 2, 3, 4], + [[1], [2], [3]], + [[1, 2], [3, 4]], + [[1, 2, 3], [4, 5, 6]], + [[[[1]], [[2]], [[3]]]], + ] + + for array in arrays: + mx_res = mx.atleast_1d(mx.array(array)) + np_res = np.atleast_1d(np.array(array)) + self.assertTrue(compare_nested_lists(mx_res.tolist(), np_res.tolist())) + self.assertEqual(mx_res.shape, np_res.shape) + self.assertEqual(mx_res.ndim, np_res.ndim) + + def test_atleast_2d(self): + def compare_nested_lists(x, y): + if isinstance(x, list) and isinstance(y, list): + if len(x) != len(y): + return False + for i in range(len(x)): + if not compare_nested_lists(x[i], y[i]): + return False + return True + else: + return x == y + + # Test 1D input + arrays = [ + [1], + [1, 2, 3], + [1, 2, 3, 4], + [[1], [2], [3]], + [[1, 2], [3, 4]], + [[1, 2, 3], [4, 5, 6]], + [[[[1]], [[2]], [[3]]]], + ] + + for array in arrays: + mx_res = mx.atleast_2d(mx.array(array)) + np_res = np.atleast_2d(np.array(array)) + self.assertTrue(compare_nested_lists(mx_res.tolist(), np_res.tolist())) + self.assertEqual(mx_res.shape, np_res.shape) + self.assertEqual(mx_res.ndim, np_res.ndim) + + def test_atleast_3d(self): + def compare_nested_lists(x, y): + if isinstance(x, list) and isinstance(y, list): + if len(x) != len(y): + return False + for i in range(len(x)): + if not compare_nested_lists(x[i], y[i]): + return False + return True + else: + return x == y + + # Test 1D input + arrays = [ + [1], + [1, 2, 3], + [1, 2, 3, 4], + [[1], [2], [3]], + [[1, 2], [3, 4]], + [[1, 2, 3], [4, 5, 6]], + [[[[1]], [[2]], [[3]]]], + ] + + for array in arrays: + mx_res = mx.atleast_3d(mx.array(array)) + np_res = np.atleast_3d(np.array(array)) + self.assertTrue(compare_nested_lists(mx_res.tolist(), np_res.tolist())) + self.assertEqual(mx_res.shape, np_res.shape) + self.assertEqual(mx_res.ndim, np_res.ndim) + if __name__ == "__main__": unittest.main() diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index 41db064be..ba4ab552f 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -2716,3 +2716,54 @@ TEST_CASE("test diag") { out = diag(x, -1); CHECK(array_equal(out, array({3, 7}, {2})).item()); } + +TEST_CASE("test atleast_1d") { + auto x = array(1); + auto out = atleast_1d(x); + CHECK_EQ(out.ndim(), 1); + CHECK_EQ(out.shape(), std::vector{1}); + + x = array({1, 2, 3}, {3}); + out = atleast_1d(x); + CHECK_EQ(out.ndim(), 1); + CHECK_EQ(out.shape(), std::vector{3}); + + x = array({1, 2, 3}, {3, 1}); + out = atleast_1d(x); + CHECK_EQ(out.ndim(), 2); + CHECK_EQ(out.shape(), std::vector{3, 1}); +} + +TEST_CASE("test atleast_2d") { + auto x = array(1); + auto out = atleast_2d(x); + CHECK_EQ(out.ndim(), 2); + CHECK_EQ(out.shape(), std::vector{1, 1}); + + x = array({1, 2, 3}, {3}); + out = atleast_2d(x); + CHECK_EQ(out.ndim(), 2); + CHECK_EQ(out.shape(), std::vector{1, 3}); + + x = array({1, 2, 3}, {3, 1}); + out = atleast_2d(x); + CHECK_EQ(out.ndim(), 2); + CHECK_EQ(out.shape(), std::vector{3, 1}); +} + +TEST_CASE("test atleast_3d") { + auto x = array(1); + auto out = atleast_3d(x); + CHECK_EQ(out.ndim(), 3); + CHECK_EQ(out.shape(), std::vector{1, 1, 1}); + + x = array({1, 2, 3}, {3}); + out = atleast_3d(x); + CHECK_EQ(out.ndim(), 3); + CHECK_EQ(out.shape(), std::vector{1, 3, 1}); + + x = array({1, 2, 3}, {3, 1}); + out = atleast_3d(x); + CHECK_EQ(out.ndim(), 3); + CHECK_EQ(out.shape(), std::vector{3, 1, 1}); +} \ No newline at end of file From d0fda82595fc238c1ed62b2505cd1bbc551689c0 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 19 Feb 2024 09:44:27 -0800 Subject: [PATCH 41/42] fix tolist for half types (#702) --- python/src/array.cpp | 10 +++++----- python/tests/test_array.py | 8 ++++++++ 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/python/src/array.cpp b/python/src/array.cpp index 57b867dbc..6dd2f290b 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -23,15 +23,15 @@ enum PyScalarT { pycomplex = 3, }; -template +template py::list to_list(array& a, size_t index, int dim) { py::list pl; auto stride = a.strides()[dim]; for (int i = 0; i < a.shape(dim); ++i) { if (dim == a.ndim() - 1) { - pl.append((a.data()[index])); + pl.append(static_cast(a.data()[index])); } else { - pl.append(to_list(a, index, dim + 1)); + pl.append(to_list(a, index, dim + 1)); } index += stride; } @@ -102,11 +102,11 @@ py::object tolist(array& a) { case int64: return to_list(a, 0, 0); case float16: - return to_list(a, 0, 0); + return to_list(a, 0, 0); case float32: return to_list(a, 0, 0); case bfloat16: - return to_list(a, 0, 0); + return to_list(a, 0, 0); case complex64: return to_list>(a, 0, 0); } diff --git a/python/tests/test_array.py b/python/tests/test_array.py index 507675d6e..7812642d3 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -431,6 +431,14 @@ class TestArray(mlx_tests.MLXTestCase): x = mx.array(vals) self.assertEqual(x.tolist(), vals) + # Half types + vals = [1.0, 2.0, 3.0, 4.0, 5.0] + x = mx.array(vals, dtype=mx.float16) + self.assertEqual(x.tolist(), vals) + + x = mx.array(vals, dtype=mx.bfloat16) + self.assertEqual(x.tolist(), vals) + def test_array_np_conversion(self): # Shape test a = np.array([]) From 5798256fcf6487c80653524f4bd9abdc49ef5eb3 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 19 Feb 2024 21:43:54 -0800 Subject: [PATCH 42/42] Shapeless compilation for some graphs (#687) * shapeless compilation for some graphs * update compile benchmark * default compile a few activations * buffer donation * bugfix * shapeless fix * update tests to work for cpu and gpu fusion * test kwargs * add kwargs to compile * Recompile when python arguments change * no compile for tanh * some constant tests --------- Co-authored-by: Angelos Katharopoulos --- benchmarks/python/compile_bench.py | 109 +++++++++++++++++++ benchmarks/python/time_utils.py | 6 +- mlx/backend/common/compiled.cpp | 12 +-- mlx/backend/common/compiled.h | 4 + mlx/backend/metal/compiled.cpp | 12 +-- mlx/compile.cpp | 118 ++++++++++++++++----- mlx/compile.h | 5 +- mlx/primitives.cpp | 25 +++++ mlx/primitives.h | 80 +++++++++++++- mlx/transforms_impl.h | 4 +- python/mlx/nn/layers/activations.py | 83 +++++++++------ python/src/transforms.cpp | 78 ++++++++++++-- python/tests/test_compile.py | 158 ++++++++++++++++++++++++++++ tests/compile_tests.cpp | 64 +++++++---- 14 files changed, 645 insertions(+), 113 deletions(-) create mode 100644 benchmarks/python/compile_bench.py diff --git a/benchmarks/python/compile_bench.py b/benchmarks/python/compile_bench.py new file mode 100644 index 000000000..0d5d9f61d --- /dev/null +++ b/benchmarks/python/compile_bench.py @@ -0,0 +1,109 @@ +# Copyright © 2023-2024 Apple Inc. + +import argparse +import math +import random + +import mlx.core as mx +from time_utils import time_fn + + +def bench_gelu(): + + def gelu(x): + return x * (1 + mx.erf(x / math.sqrt(2))) / 2 + + x = mx.random.uniform(shape=(1000, 1024)) + + def gen_fun(fun): + def bench_fun(x): + for _ in range(10): + x = fun(x) + return x + + return bench_fun + + time_fn(gen_fun(gelu), x, msg="fixed gelu") + time_fn(gen_fun(mx.compile(gelu)), x, msg="compiled fixed gelu") + + def randint(): + return random.randint(1, x.shape[0]) + + def gen_fun(fun): + def bench_fun(x, y): + x = x[: randint()] + for _ in range(10): + x = fun(x) + y = fun(y) + return x, y + + return bench_fun + + y = mx.random.uniform(shape=(1000, 1024)) + time_fn(gen_fun(gelu), x, y, msg="variable gelu") + time_fn(gen_fun(mx.compile(gelu)), x, y, msg="compiled variable gelu") + time_fn( + gen_fun(mx.compile(gelu, shapeless=True)), + x, + y, + msg="shapeless variable gelu", + ) + + +def bench_layernorm(): + + weight = mx.random.uniform(shape=(4096,)).astype(mx.float16) + bias = mx.random.uniform(shape=(4096,)).astype(mx.float16) + mx.eval(weight, bias) + + def layernorm(x): + x = x.astype(mx.float32) + means = mx.mean(x, axis=-1, keepdims=True) + var = mx.var(x, axis=-1, keepdims=True) + x = (x - means) * mx.rsqrt(var + 1e-4) + x = x.astype(mx.float16) + return weight * x + bias + + x = mx.random.uniform(shape=(1000, 4096)).astype(mx.float16) + + def gen_fun(fun): + def bench_fun(x): + for _ in range(10): + x = fun(x) + return x + + return bench_fun + + time_fn(gen_fun(layernorm), x, msg="fixed layernorm") + time_fn(gen_fun(mx.compile(layernorm)), x, msg="compiled fixed layernorm") + + def randint(): + return random.randint(1, x.shape[0]) + + def gen_fun(fun): + def bench_fun(x): + x = x[: randint()] + for _ in range(10): + x = fun(x) + return x + + return bench_fun + + random.seed(0) + time_fn(gen_fun(layernorm), x, msg="variable layernorm") + random.seed(0) + time_fn(gen_fun(mx.compile(layernorm)), x, msg="compiled variable layernorm") + random.seed(0) + time_fn( + gen_fun(mx.compile(layernorm, shapeless=True)), + x, + msg="shapeless variable layernorm", + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("Compile benchmarks.") + args = parser.parse_args() + + bench_gelu() + bench_layernorm() diff --git a/benchmarks/python/time_utils.py b/benchmarks/python/time_utils.py index f10635ec9..2903c3293 100644 --- a/benchmarks/python/time_utils.py +++ b/benchmarks/python/time_utils.py @@ -6,7 +6,11 @@ import mlx.core as mx def time_fn(fn, *args, **kwargs): - print(f"Timing {fn.__name__} ...", end=" ") + msg = kwargs.pop("msg", None) + if msg: + print(f"Timing {msg} ...", end=" ") + else: + print(f"Timing {fn.__name__} ...", end=" ") # warmup for _ in range(5): diff --git a/mlx/backend/common/compiled.cpp b/mlx/backend/common/compiled.cpp index 52bcac4fa..529ad2fa5 100644 --- a/mlx/backend/common/compiled.cpp +++ b/mlx/backend/common/compiled.cpp @@ -37,7 +37,7 @@ std::string build_lib_name( os << "C"; print_constant(constant_hasher, x); } else { - os << ((x.size() == 1) ? "S" : "V"); + os << (is_scalar(x) ? "S" : "V"); } } os << "_"; @@ -122,10 +122,6 @@ std::string get_type_string(Dtype d) { } } -inline bool is_scalar(const array& x) { - return x.size() == 1; -}; - // Return a pointer to a compiled function void* compile( const std::string& kernel_name, @@ -358,7 +354,7 @@ void Compiled::eval_cpu( bool all_col_contig = true; int non_scalar_inputs = 0; for (auto& x : inputs) { - if (x.size() == 1) { + if (is_scalar(x)) { continue; } non_scalar_inputs++; @@ -385,7 +381,7 @@ void Compiled::eval_cpu( auto& x = inputs[i]; args.push_back((void*)x.data()); - if (contiguous || x.size() <= 1) { + if (contiguous || is_scalar(x)) { continue; } @@ -458,7 +454,7 @@ void Compiled::eval_cpu( // - Donatable // - Correct size // - Not a constant - if (in.flags().contiguous && in.size() > 1 && in.is_donatable() && + if (in.flags().contiguous && !is_scalar(in) && in.is_donatable() && constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) { outputs[o++].copy_shared_buffer(in); } diff --git a/mlx/backend/common/compiled.h b/mlx/backend/common/compiled.h index adbd5399c..d01fe4fdc 100644 --- a/mlx/backend/common/compiled.h +++ b/mlx/backend/common/compiled.h @@ -49,4 +49,8 @@ void print_complex_constant(std::ostream& os, const array& x) { void print_constant(std::ostream& os, const array& x); +inline bool is_scalar(const array& x) { + return x.ndim() == 0; +} + } // namespace mlx::core diff --git a/mlx/backend/metal/compiled.cpp b/mlx/backend/metal/compiled.cpp index 681d635ba..3b1ee116a 100644 --- a/mlx/backend/metal/compiled.cpp +++ b/mlx/backend/metal/compiled.cpp @@ -31,9 +31,6 @@ inline void build_kernel( return constant_ids.find(x.id()) != constant_ids.end(); }; - // For scalar we shouldn't do the indexing things, just read at 0 - auto is_scalar = [](const array& x) { return x.size() == 1; }; - NodeNamer namer; bool add_indices = false; int cnt = 0; @@ -226,8 +223,7 @@ void Compiled::eval_gpu( /* ndim = */ 0, /* dynamic_dims = */ true); - kernel_source_ = kernel.str(); - lib = d.get_library(kernel_lib_, kernel_source_); + lib = d.get_library(kernel_lib_, kernel.str()); } // Figure out which kernel we are using @@ -235,7 +231,7 @@ void Compiled::eval_gpu( bool contiguous = true; for (auto& x : inputs) { if ((!x.flags().row_contiguous || x.shape() != output_shape) && - x.size() > 1) { + !is_scalar(x)) { contiguous = false; break; } @@ -256,7 +252,7 @@ void Compiled::eval_gpu( auto& x = inputs[i]; // Skip scalar inputs. - if (x.size() <= 1) { + if (is_scalar(x)) { continue; } @@ -311,7 +307,7 @@ void Compiled::eval_gpu( } auto& x = inputs[i]; set_array_buffer(compute_encoder, x, cnt++); - if (!contiguous && x.size() > 1) { + if (!contiguous && !is_scalar(x)) { compute_encoder->setBytes( strides[stride_idx].data(), strides[stride_idx].size() * sizeof(size_t), diff --git a/mlx/compile.cpp b/mlx/compile.cpp index a648d191f..700c07ced 100644 --- a/mlx/compile.cpp +++ b/mlx/compile.cpp @@ -13,7 +13,7 @@ namespace mlx::core { -constexpr int max_compile_depth = 10; +constexpr int max_compile_depth = 11; bool is_unary(const Primitive& p) { return ( @@ -55,19 +55,20 @@ bool is_noop(const Primitive& p) { return typeid(p) == typeid(Copy) || typeid(p) == typeid(StopGradient); } +bool is_reduction(const Primitive& p) { + return typeid(p) == typeid(Reduce) || typeid(p) == typeid(ArgReduce); +} + bool is_fusable(const Primitive& p) { return is_unary(p) || is_binary(p) || is_broadcast(p) || is_noop(p); } -namespace detail { - -std::vector compile_replace( - const std::vector& tape, - const std::vector& trace_inputs, - const std::vector& trace_outputs, - const std::vector& inputs); - -} // namespace detail +bool allows_shapeless(const Primitive& p) { + return typeid(p) == typeid(Compiled) || is_unary(p) || is_binary(p) || + is_noop(p) || is_reduction(p) || typeid(p) == typeid(Softmax) || + typeid(p) == typeid(Sort) || typeid(p) == typeid(ArgSort) || + typeid(p) == typeid(ArgPartition) || typeid(p) == typeid(Partition); +} Compiled::Compiled( Stream stream, @@ -123,6 +124,23 @@ void Compiled::print(std::ostream& os) { } } +std::vector> Compiled::output_shapes( + const std::vector& inputs) { + size_t nd = 0; + for (auto& in : inputs) { + nd = std::max(nd, in.ndim()); + } + std::vector out_shape(nd, 0); + for (auto& in : inputs) { + auto dd = nd - in.ndim(); + for (auto i = dd; i < nd; ++i) { + out_shape[i] = std::max(out_shape[i], in.shape()[i - dd]); + } + } + // All outputs have the same shape + return std::vector>(outputs_.size(), out_shape); +} + namespace detail { CompileMode& compile_mode() { @@ -180,21 +198,30 @@ struct CompilerCache { std::vector outputs; std::vector tape; bool empty{true}; + std::vector constants; }; // Returns a reference to a CacheEntry which can be updated // by the caller to avoid copying large tapes / inputs / outputs - CacheEntry& find(size_t fun_id, const std::vector& inputs) { + CacheEntry& find( + size_t fun_id, + const std::vector& inputs, + bool shapeless, + const std::vector& constants) { // Try to find the entry auto [entry_it, inserted] = cache_.insert({fun_id, {}}); auto& entries = entry_it->second; - auto is_match = [](const std::vector& in1, - const std::vector& in2) { + auto is_match = [shapeless]( + const std::vector& in1, + const std::vector& in2) { if (in1.size() != in2.size()) { return false; } for (int i = 0; i < in1.size(); ++i) { - if (in1[i].shape() != in2[i].shape()) { + if (in1[i].ndim() != in2[i].ndim()) { + return false; + } + if (!shapeless && in1[i].shape() != in2[i].shape()) { return false; } if (in1[i].dtype() != in2[i].dtype()) { @@ -210,7 +237,7 @@ struct CompilerCache { // more easily searchable structure. for (auto& entry : entries) { // Check the inputs match and return if so - if (is_match(inputs, entry.inputs)) { + if (is_match(inputs, entry.inputs) && constants == entry.constants) { return entry; } } @@ -651,7 +678,8 @@ std::vector compile_replace( const std::vector& tape, const std::vector& trace_inputs, const std::vector& trace_outputs, - const std::vector& inputs) { + const std::vector& inputs, + bool shapeless) { std::unordered_map trace_to_real; for (int i = 0; i < inputs.size(); ++i) { trace_to_real.insert({trace_inputs[i].id(), inputs[i]}); @@ -669,18 +697,29 @@ std::vector compile_replace( real_inputs.push_back(trace_to_real.at(in.id())); } if (a.siblings().empty()) { + auto shape = + shapeless ? a.primitive().output_shapes(real_inputs)[0] : a.shape(); auto real_a = array( - a.shape(), a.dtype(), a.primitive_ptr(), std::move(real_inputs)); + std::move(shape), + a.dtype(), + a.primitive_ptr(), + std::move(real_inputs)); trace_to_real.insert({a.id(), std::move(real_a)}); } else { // Ensure the order is correct for multi-output primitives - std::vector> shapes; std::vector types; auto trace_out = a.outputs(); for (auto& o : trace_out) { - shapes.push_back(o.shape()); types.push_back(o.dtype()); } + std::vector> shapes; + if (shapeless) { + shapes = a.primitive().output_shapes(real_inputs); + } else { + for (auto& o : trace_out) { + shapes.push_back(o.shape()); + } + } auto real_out = array::make_arrays(shapes, types, a.primitive_ptr(), real_inputs); for (int i = 0; i < trace_out.size(); ++i) { @@ -697,13 +736,34 @@ std::vector compile_replace( return outputs; } +void compile_validate_shapeless(const std::vector& tape) { + for (auto& t : tape) { + if (!t.has_primitive()) { + continue; + } + auto& p = t.primitive(); + if (allows_shapeless(p)) { + continue; + } + + std::ostringstream msg; + msg << "[compile] Cannot compile primitive "; + p.print(msg); + msg << " with shapeless enabled."; + throw std::invalid_argument(msg.str()); + } +} + std::function(const std::vector&)> compile( const std::function(const std::vector&)>& fun, - size_t fun_id) { + size_t fun_id, + bool shapeless /* = false */, + std::vector constants /* = {} */) { if (compile_mode() == CompileMode::disabled) { return fun; } - return [fun, fun_id](const std::vector& inputs) { + return [fun, fun_id, shapeless, constants = std::move(constants)]( + const std::vector& inputs) { // If the inputs are tracers, trace the original graph if (std::any_of(inputs.begin(), inputs.end(), [](auto& in) { return in.is_tracer(); @@ -712,12 +772,14 @@ std::function(const std::vector&)> compile( } // Find a cache entry with the correct inputs - auto& entry = compiler_cache().find(fun_id, inputs); + auto& entry = compiler_cache().find(fun_id, inputs, shapeless, constants); // No matching cache entry existed, so compile if (entry.empty) { // Mark the entry as not empty since we are about to fill it entry.empty = false; + // Set the constants + entry.constants = std::move(constants); // Trace to build the graph std::tie(entry.inputs, entry.outputs) = compile_trace(fun, inputs); @@ -739,11 +801,16 @@ std::function(const std::vector&)> compile( if (compile_mode() != CompileMode::no_fuse) { compile_fuse(entry.tape, parents_map, entry.inputs, entry.outputs); } + + if (shapeless) { + compile_validate_shapeless(entry.tape); + } } // At this point we must have a tape, now replace the placeholders // with real arrays that can be evaluated - return compile_replace(entry.tape, entry.inputs, entry.outputs, inputs); + return compile_replace( + entry.tape, entry.inputs, entry.outputs, inputs, shapeless); }; } @@ -754,12 +821,13 @@ void compile_erase(size_t fun_id) { } // namespace detail std::function(const std::vector&)> compile( - const std::function(const std::vector&)>& fun) { + const std::function(const std::vector&)>& fun, + bool shapeless /* false */) { if (detail::compile_mode() == CompileMode::disabled) { return fun; } auto fun_id = detail::getAddress(fun); - return detail::compile(fun, fun_id); + return detail::compile(fun, fun_id, shapeless); } void disable_compile() { diff --git a/mlx/compile.h b/mlx/compile.h index fb3115d61..1134c20dc 100644 --- a/mlx/compile.h +++ b/mlx/compile.h @@ -8,9 +8,10 @@ namespace mlx::core { enum class CompileMode { disabled, no_simplify, no_fuse, enabled }; -// Compile takes a function and returns a new function +/** Compile takes a function and returns a compiled function. */ std::function(const std::vector&)> compile( - const std::function(const std::vector&)>& fun); + const std::function(const std::vector&)>& fun, + bool shapeless = false); /** Globally disable compilation. * Setting the environment variable ``MLX_DISABLE_COMPILE`` can also diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index b78f8d405..b2daaa2f8 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -71,6 +71,15 @@ std::pair, std::vector> Primitive::vmap( throw std::invalid_argument("Primitive's vmap not implemented."); }; +std::vector> Primitive::output_shapes( + const std::vector&) { + std::ostringstream msg; + msg << "[Primitive::output_shapes] "; + this->print(msg); + msg << " cannot infer output shapes."; + throw std::invalid_argument(msg.str()); +}; + std::vector Abs::vjp( const std::vector& primals, const std::vector& cotangents, @@ -383,6 +392,13 @@ std::pair, std::vector> ArgSort::vmap( return {{argsort(inputs[0], axis_ + (axes[0] <= axis_), stream())}, axes}; } +std::vector> ArgReduce::output_shapes( + const std::vector& inputs) { + auto out_shape = inputs[0].shape(); + out_shape[axis_] = 1; + return {out_shape}; +} + bool ArgSort::is_equivalent(const Primitive& other) const { const ArgSort& r_other = static_cast(other); return axis_ == r_other.axis_; @@ -2202,6 +2218,15 @@ bool Reduce::is_equivalent(const Primitive& other) const { return reduce_type_ == r_other.reduce_type_ && axes_ == r_other.axes_; } +std::vector> Reduce::output_shapes( + const std::vector& inputs) { + std::vector out_shape = inputs[0].shape(); + for (auto i : axes_) { + out_shape[i] = 1; + } + return {out_shape}; +} + std::vector Round::vjp( const std::vector& primals, const std::vector& cotangents, diff --git a/mlx/primitives.h b/mlx/primitives.h index 9d0a9181c..73e4394a5 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -36,6 +36,12 @@ return true; \ } +#define DEFINE_INPUT_OUTPUT_SHAPE() \ + std::vector> output_shapes( \ + const std::vector& inputs) override { \ + return {inputs[0].shape()}; \ + }; + namespace mlx::core { // Abstract base class @@ -102,6 +108,11 @@ class Primitive { return false; } + /** Get the output shapes of the primitive. This is not required to be + * implemented by derived classes, in which case it will throw. */ + virtual std::vector> output_shapes( + const std::vector& inputs); + virtual ~Primitive() = default; Primitive(const Primitive& other) = delete; Primitive(Primitive&& other) = delete; @@ -152,6 +163,7 @@ class Abs : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(Abs) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -168,6 +180,7 @@ class Add : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(Add) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -226,6 +239,7 @@ class ArcCos : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(ArcCos) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -242,6 +256,7 @@ class ArcCosh : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(ArcCosh) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -258,6 +273,7 @@ class ArcSin : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(ArcSin) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -274,6 +290,7 @@ class ArcSinh : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(ArcSinh) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -290,6 +307,7 @@ class ArcTan : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(ArcTan) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -306,6 +324,7 @@ class ArcTanh : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(ArcTanh) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -321,6 +340,7 @@ class ArgPartition : public UnaryPrimitive { DEFINE_VMAP() DEFINE_PRINT(ArgPartition) + DEFINE_INPUT_OUTPUT_SHAPE() bool is_equivalent(const Primitive& other) const override; private: @@ -346,6 +366,8 @@ class ArgReduce : public UnaryPrimitive { DEFINE_VMAP() DEFINE_PRINT(ArgReduce) bool is_equivalent(const Primitive& other) const override; + std::vector> output_shapes( + const std::vector& inputs) override; private: ReduceType reduce_type_; @@ -364,6 +386,7 @@ class ArgSort : public UnaryPrimitive { DEFINE_VMAP() DEFINE_PRINT(ArgSort) + DEFINE_INPUT_OUTPUT_SHAPE() bool is_equivalent(const Primitive& other) const override; private: @@ -383,6 +406,7 @@ class AsType : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() DEFINE_PRINT(AsType) + DEFINE_INPUT_OUTPUT_SHAPE() bool is_equivalent(const Primitive& other) const override; private: @@ -448,6 +472,7 @@ class Ceil : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(Ceil) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -478,15 +503,14 @@ class Compiled : public Primitive { DEFINE_VMAP() DEFINE_GRADS() + std::vector> output_shapes( + const std::vector& inputs) override; void print(std::ostream& os) override; bool is_equivalent(const Primitive& other) const override; - std::string metal_lib_name() const { + std::string lib_name() const { return kernel_lib_; } - std::string metal_lib_source() const { - return kernel_source_; - } private: const std::vector inputs_; @@ -495,7 +519,6 @@ class Compiled : public Primitive { const std::unordered_set constant_ids_; std::string kernel_lib_; - std::string kernel_source_; }; class Concatenate : public UnaryPrimitive { @@ -563,6 +586,7 @@ class Copy : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(Copy) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -579,6 +603,7 @@ class Cos : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(Cos) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -595,6 +620,7 @@ class Cosh : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(Cosh) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -665,6 +691,7 @@ class Divide : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(Divide) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -683,6 +710,10 @@ class DivMod : public Primitive { DEFINE_GRADS() DEFINE_PRINT(DivMod) DEFINE_DEFAULT_IS_EQUIVALENT() + std::vector> output_shapes( + const std::vector& inputs) override { + return std::vector{inputs[0].shape(), inputs[0].shape()}; + }; private: void eval(const std::vector& inputs, std::vector& outputs); @@ -699,6 +730,7 @@ class Remainder : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(Remainder) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -715,6 +747,7 @@ class Equal : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() void print(std::ostream& os) override { if (equal_nan_) { @@ -740,6 +773,7 @@ class Erf : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(Erf) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -756,6 +790,7 @@ class ErfInv : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(ErfInv) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -772,6 +807,7 @@ class Exp : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(Exp) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -814,6 +850,7 @@ class Floor : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(Floor) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -868,6 +905,7 @@ class Greater : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(Greater) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -884,6 +922,7 @@ class GreaterEqual : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(GreaterEqual) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -900,6 +939,7 @@ class Less : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(Less) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -916,6 +956,7 @@ class LessEqual : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(LessEqual) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -958,6 +999,7 @@ class Log : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() void print(std::ostream& os) override { switch (base_) { @@ -988,6 +1030,7 @@ class Log1p : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() DEFINE_PRINT(Log1p) + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -1004,6 +1047,7 @@ class LogicalNot : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(LogicalNot) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -1020,6 +1064,7 @@ class LogicalAnd : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(LogicalAnd) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -1036,6 +1081,7 @@ class LogicalOr : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(LogicalOr) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -1052,6 +1098,7 @@ class LogAddExp : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(LogAddExp) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -1085,6 +1132,7 @@ class Maximum : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(Maximum) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -1101,6 +1149,7 @@ class Minimum : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(Minimum) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -1117,6 +1166,7 @@ class Multiply : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(Multiply) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -1133,6 +1183,7 @@ class Negative : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(Negative) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -1149,6 +1200,7 @@ class NotEqual : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(NotEqual) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -1193,6 +1245,7 @@ class Partition : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() DEFINE_PRINT(Partition) + DEFINE_INPUT_OUTPUT_SHAPE() bool is_equivalent(const Primitive& other) const override; private: @@ -1213,6 +1266,7 @@ class Power : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(Power) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -1305,6 +1359,9 @@ class Reduce : public UnaryPrimitive { const std::vector& argnums, const std::vector& outputs) override; + std::vector> output_shapes( + const std::vector& inputs) override; + void print(std::ostream& os) override { switch (reduce_type_) { case And: @@ -1347,6 +1404,7 @@ class Round : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(Round) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -1455,6 +1513,7 @@ class Sigmoid : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(Sigmoid) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -1471,6 +1530,7 @@ class Sign : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(Sign) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -1487,6 +1547,7 @@ class Sin : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(Sin) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -1503,6 +1564,7 @@ class Sinh : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(Sinh) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -1547,6 +1609,7 @@ class Softmax : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(Softmax) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -1563,6 +1626,7 @@ class Sort : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() DEFINE_PRINT(Sort) + DEFINE_INPUT_OUTPUT_SHAPE() bool is_equivalent(const Primitive& other) const override; private: @@ -1604,6 +1668,7 @@ class Square : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(Square) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -1619,6 +1684,7 @@ class Sqrt : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() + DEFINE_INPUT_OUTPUT_SHAPE() bool is_equivalent(const Primitive& other) const override; void print(std::ostream& os) override { @@ -1644,6 +1710,7 @@ class StopGradient : public UnaryPrimitive { DEFINE_VMAP() DEFINE_PRINT(StopGradient) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -1660,6 +1727,7 @@ class Subtract : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(Subtract) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -1676,6 +1744,7 @@ class Tan : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(Tan) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -1692,6 +1761,7 @@ class Tanh : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(Tanh) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); diff --git a/mlx/transforms_impl.h b/mlx/transforms_impl.h index a1b3461ab..6c7959426 100644 --- a/mlx/transforms_impl.h +++ b/mlx/transforms_impl.h @@ -18,7 +18,9 @@ std::vector vmap_replace( // idea. std::function(const std::vector&)> compile( const std::function(const std::vector&)>& fun, - size_t fun_id); + size_t fun_id, + bool shapeless = false, + std::vector constants = {}); // Erase cached compile functions void compile_erase(size_t fun_id); diff --git a/python/mlx/nn/layers/activations.py b/python/mlx/nn/layers/activations.py index db07ce190..dfd435cfd 100644 --- a/python/mlx/nn/layers/activations.py +++ b/python/mlx/nn/layers/activations.py @@ -1,6 +1,7 @@ # Copyright © 2023 Apple Inc. import math +from functools import partial from typing import Any import mlx.core as mx @@ -9,13 +10,13 @@ from mlx.nn.layers.base import Module def _make_activation_module(f): def decorator(klass): - klass.__doc__ = f.__doc__ - klass.__call__ = lambda self, x: f(x) + klass.__call__ = lambda _, x: f(x) return klass return decorator +@partial(mx.compile, shapeless=True) def sigmoid(x): r"""Applies the element-wise function: @@ -25,6 +26,7 @@ def sigmoid(x): return mx.sigmoid(x) +@partial(mx.compile, shapeless=True) def relu(x): r"""Applies the Rectified Linear Unit. @@ -33,6 +35,7 @@ def relu(x): return mx.maximum(x, 0) +@partial(mx.compile, shapeless=True) def leaky_relu(x, negative_slope=0.01): r"""Applies the Leaky Rectified Linear Unit. @@ -41,6 +44,7 @@ def leaky_relu(x, negative_slope=0.01): return mx.maximum(negative_slope * x, x) +@partial(mx.compile, shapeless=True) def log_softmax(x, axis=-1): r"""Applies the Log Softmax function. @@ -49,6 +53,7 @@ def log_softmax(x, axis=-1): return x - mx.logsumexp(x, axis=axis, keepdims=True) +@partial(mx.compile, shapeless=True) def elu(x, alpha=1.0): r"""Applies the Exponential Linear Unit. @@ -57,6 +62,7 @@ def elu(x, alpha=1.0): return mx.where(x > 0, x, alpha * (mx.exp(x) - 1)) +@partial(mx.compile, shapeless=True) def relu6(x): r"""Applies the Rectified Linear Unit 6. @@ -65,6 +71,7 @@ def relu6(x): return mx.minimum(mx.maximum(x, 0), 6.0) +@partial(mx.compile, shapeless=True) def softmax(x, axis=-1): r"""Applies the Softmax function. @@ -73,6 +80,7 @@ def softmax(x, axis=-1): return mx.softmax(x, axis=axis) +@partial(mx.compile, shapeless=True) def softplus(x): r"""Applies the Softplus function. @@ -81,6 +89,7 @@ def softplus(x): return mx.logaddexp(x, 0) +@partial(mx.compile, shapeless=True) def softsign(x): r"""Applies the Softsign function. @@ -89,6 +98,7 @@ def softsign(x): return mx.divide(x, 1 + mx.abs(x)) +@partial(mx.compile, shapeless=True) def softshrink(x, lambd: float = 0.5): r"""Applies the Softshrink activation function. @@ -102,6 +112,7 @@ def softshrink(x, lambd: float = 0.5): return mx.where(mx.abs(x) > lambd, x - mx.sign(x) * lambd, 0) +@partial(mx.compile, shapeless=True) def celu(x, alpha=1.0): r"""Applies the Continuously Differentiable Exponential Linear Unit. @@ -111,6 +122,7 @@ def celu(x, alpha=1.0): return mx.maximum(x, 0.0) + alpha * (mx.exp(mx.minimum(x, 0.0) / alpha) - 1) +@partial(mx.compile, shapeless=True) def silu(x): r"""Applies the Sigmoid Linear Unit. Also known as Swish. @@ -120,6 +132,7 @@ def silu(x): return x * mx.sigmoid(x) +@partial(mx.compile, shapeless=True) def log_sigmoid(x): r"""Applies the Log Sigmoid function. @@ -128,6 +141,7 @@ def log_sigmoid(x): return -softplus(-x) +@partial(mx.compile, shapeless=True) def gelu(x): r"""Applies the Gaussian Error Linear Units function. @@ -142,6 +156,7 @@ def gelu(x): return x * (1 + mx.erf(x / math.sqrt(2))) / 2 +@partial(mx.compile, shapeless=True) def gelu_approx(x): r"""An approximation to Gaussian Error Linear Unit. @@ -159,6 +174,7 @@ def gelu_approx(x): return x * mx.sigmoid(1.60033 * x * (1 + 0.0433603 * x.square())) +@partial(mx.compile, shapeless=True) def gelu_fast_approx(x): r"""A fast approximation to Gaussian Error Linear Unit. @@ -192,27 +208,7 @@ def glu(x: mx.array, axis: int = -1) -> mx.array: return a * mx.sigmoid(b) -class GLU(Module): - r"""Applies the gated linear unit function. - - This function splits the ``axis`` dimension of the input into two halves - (:math:`a` and :math:`b`) and applies :math:`a * \sigma(b)`. - - .. math:: - textrm{GLU}(x) = a * \sigma(b) - - Args: - axis (int): The dimension to split along. Default: ``-1`` - """ - - def __init__(self, axis: int = -1): - super().__init__() - self.axis = axis - - def __call__(self, x) -> Any: - return glu(x=x, axis=self.axis) - - +@partial(mx.compile, shapeless=True) def step(x: mx.array, threshold: float = 0.0): r"""Applies the Step Activation Function. @@ -232,6 +228,7 @@ def step(x: mx.array, threshold: float = 0.0): return mx.where(x > threshold, 1, 0) +@partial(mx.compile, shapeless=True) def selu(x): r"""Applies the Scaled Exponential Linear Unit. @@ -248,6 +245,7 @@ def selu(x): return elu(x, 1.67326) * 1.0507 +@partial(mx.compile, shapeless=True) def prelu(x: mx.array, alpha: mx.array) -> mx.array: r"""Applies the element-wise parametric ReLU. @@ -259,6 +257,7 @@ def prelu(x: mx.array, alpha: mx.array) -> mx.array: return mx.maximum(0, x) + alpha * mx.minimum(0, x) +@partial(mx.compile, shapeless=True) def mish(x: mx.array) -> mx.array: r"""Applies the Mish function, element-wise. Mish: A Self Regularized Non-Monotonic Neural Activation Function. @@ -272,6 +271,7 @@ def mish(x: mx.array) -> mx.array: return x * mx.tanh(softplus(x)) +@partial(mx.compile, shapeless=True) def hardswish(x): r"""Applies the hardswish function, element-wise. @@ -282,6 +282,35 @@ def hardswish(x): return x * mx.minimum(max_x_3, 6) / 6 +def tanh(x): + """Applies the hyperbolic tangent function. + + Simply ``mx.tanh(x)``. + """ + return mx.tanh(x) + + +class GLU(Module): + r"""Applies the gated linear unit function. + + This function splits the ``axis`` dimension of the input into two halves + (:math:`a` and :math:`b`) and applies :math:`a * \sigma(b)`. + + .. math:: + textrm{GLU}(x) = a * \sigma(b) + + Args: + axis (int): The dimension to split along. Default: ``-1`` + """ + + def __init__(self, axis: int = -1): + super().__init__() + self.axis = axis + + def __call__(self, x) -> Any: + return glu(x=x, axis=self.axis) + + @_make_activation_module(mx.sigmoid) class Sigmoid(Module): r"""Applies the sigmoid function, element-wise. @@ -500,14 +529,6 @@ class GELU(Module): return self._act(x) -def tanh(x): - """Applies the hyperbolic tangent function. - - Simply ``mx.tanh(x)``. - """ - return mx.tanh(x) - - @_make_activation_module(tanh) class Tanh(Module): r"""Applies the hyperbolic tangent function. diff --git a/python/src/transforms.cpp b/python/src/transforms.cpp index f081fdedd..cda1d6316 100644 --- a/python/src/transforms.cpp +++ b/python/src/transforms.cpp @@ -555,13 +555,19 @@ struct PyCompiledFun { size_t fun_id; py::object captured_inputs; py::object captured_outputs; + bool shapeless; size_t num_outputs{0}; - PyCompiledFun(const py::function& fun, py::object inputs, py::object outputs) + PyCompiledFun( + const py::function& fun, + py::object inputs, + py::object outputs, + bool shapeless) : fun(fun), fun_id(reinterpret_cast(fun.ptr())), captured_inputs(inputs), - captured_outputs(outputs) {} + captured_outputs(outputs), + shapeless(shapeless) {} PyCompiledFun(const PyCompiledFun&) = delete; PyCompiledFun& operator=(const PyCompiledFun&) = delete; @@ -571,11 +577,15 @@ struct PyCompiledFun { other.fun_id = 0; captured_inputs = std::move(other.captured_inputs); captured_outputs = std::move(other.captured_outputs); + shapeless = other.shapeless; num_outputs = other.num_outputs; }; - py::object operator()(const py::args& args) { - auto compile_fun = [this, &args](const std::vector& a) { + py::object operator()(const py::args& args, const py::kwargs& kwargs) { + auto inputs = tree_flatten(args, false); + + auto compile_fun = [this, &args, &kwargs, num_args = inputs.size()]( + const std::vector& a) { // Put tracers into captured inputs std::vector flat_in_captures; std::vector trace_captures; @@ -586,8 +596,10 @@ struct PyCompiledFun { tree_fill(captured_inputs, trace_captures); } - auto [outputs, py_outputs] = tree_flatten_with_structure( - std::move(fun(*tree_unflatten(args, a))), false); + auto tree_outputs = + fun(*tree_unflatten(args, a), **tree_unflatten(kwargs, a, num_args)); + auto [outputs, py_outputs] = + tree_flatten_with_structure(std::move(tree_outputs), false); tree_cache().insert({fun_id, py_outputs}); @@ -607,7 +619,14 @@ struct PyCompiledFun { return outputs; }; - auto inputs = tree_flatten(args, false); + { + auto flat_kwargs = tree_flatten(kwargs, false); + inputs.insert( + inputs.end(), + std::make_move_iterator(flat_kwargs.begin()), + std::make_move_iterator(flat_kwargs.end())); + } + if (!py::isinstance(captured_inputs)) { auto flat_in_captures = tree_flatten(captured_inputs, false); inputs.insert( @@ -616,8 +635,39 @@ struct PyCompiledFun { std::make_move_iterator(flat_in_captures.end())); } + // Collect the compilation constants + std::vector constants; + auto value_hash = [](py::handle o) -> std::optional { + // Consider expanding tuples to their contents including start and end + // ids + if (py::isinstance(o) || py::isinstance(o)) { + auto r = py::hash(o); + return *reinterpret_cast(&r); + } else if (py::isinstance(o)) { + auto r = o.cast(); + return *reinterpret_cast(&r); + } else if (py::isinstance(o)) { + auto r = o.cast(); + return *reinterpret_cast(&r); + } else { + return std::nullopt; + } + }; + for (int i = 0; i < args.size(); i++) { + if (auto h = value_hash(args[i]); h.has_value()) { + constants.push_back(*h); + } + } + for (auto& pair : kwargs) { + if (auto h = value_hash(pair.second); h.has_value()) { + constants.push_back(*value_hash(pair.first)); + constants.push_back(*h); + } + } + // Compile and call - auto outputs = detail::compile(compile_fun, fun_id)(inputs); + auto outputs = + detail::compile(compile_fun, fun_id, shapeless, constants)(inputs); if (!py::isinstance(captured_outputs)) { std::vector captures( std::make_move_iterator(outputs.begin() + num_outputs), @@ -965,12 +1015,14 @@ void init_transforms(py::module_& m) { "compile", [](const py::function& fun, const py::object& inputs, - const py::object& outputs) { - return py::cpp_function(PyCompiledFun{fun, inputs, outputs}); + const py::object& outputs, + bool shapeless) { + return py::cpp_function(PyCompiledFun{fun, inputs, outputs, shapeless}); }, "fun"_a, "inputs"_a = std::nullopt, "outputs"_a = std::nullopt, + "shapeless"_a = false, R"pbdoc( compile(fun: function) -> function @@ -990,6 +1042,12 @@ void init_transforms(py::module_& m) { :obj:`list` or a :obj:`dict` containing arbitrarily nested lists, dictionaries, or arrays. Leaf nodes that are not :obj:`array` are ignored. Default: ``None`` + shapeless (bool, optional): A function compiled with the ``shapeless`` + option enabled will not be recompiled when the input shape changes. Not all + functions can be compiled with ``shapeless`` enabled. Attempting to compile + such functions with shapeless enabled will throw. Note, changing the number + of dimensions or type of any input will result in a recompilation even with + ``shapeless`` set to ``True``. Default: ``False`` Returns: function: A compiled function which has the same input arguments diff --git a/python/tests/test_compile.py b/python/tests/test_compile.py index 2e0bb1d7f..e53134482 100644 --- a/python/tests/test_compile.py +++ b/python/tests/test_compile.py @@ -381,6 +381,164 @@ class TestCompile(mlx_tests.MLXTestCase): self.assertFalse(mx.allclose(fun(), fun(), 1e-2, 1e-2)) + def test_compile_kwargs(self): + + @mx.compile + def fun(x, y, z): + return x + y + z + + x = mx.array(1) + y = mx.array(2) + z = mx.array(3) + out = fun(x, y=y, z=z) + self.assertEqual(out.item(), 6) + + def test_shapeless_compile(self): + y = 1 + + @partial(mx.compile, shapeless=True) + def fun(x): + return x + y + + x = mx.array([1, 2]) + self.assertTrue(mx.array_equal(fun(x), mx.array([2, 3]))) + + # The function is not recompiled, so the change + # to y should not be reflected in the output + y = 2 + x = mx.array([1, 2, 3]) + self.assertTrue(mx.array_equal(fun(x), mx.array([2, 3, 4]))) + + # Type change recompiles + x = mx.array([1.0, 2.0, 3.0]) + self.assertTrue(mx.array_equal(fun(x), mx.array([3.0, 4.0, 5.0]))) + fun(x, y=y, z=z) + + def test_shapeless_compile(self): + y = 1 + + @partial(mx.compile, shapeless=True) + def fun(x): + return x + y + + x = mx.array([1, 2]) + self.assertTrue(mx.array_equal(fun(x), mx.array([2, 3]))) + + # The function is not recompiled, so the change + # to y should not be reflected in the output + y = 2 + x = mx.array([1, 2, 3]) + self.assertTrue(mx.array_equal(fun(x), mx.array([2, 3, 4]))) + + # Type change recompiles + x = mx.array([1.0, 2.0, 3.0]) + self.assertTrue(mx.array_equal(fun(x), mx.array([3.0, 4.0, 5.0]))) + + # Dim change recompiles + x = mx.array([[1, 2, 3]]) + self.assertTrue(mx.array_equal(fun(x), mx.array([[3, 4, 5]]))) + + def test_shapeless_compile_with_broadcasts(self): + x = mx.ones((2, 2)) + y = mx.array([2, 2]) + + def fun(x, y): + return x * y + + cfun = mx.compile(fun, shapeless=True) + self.assertTrue(mx.array_equal(cfun(x, y), fun(x, y))) + self.assertTrue(mx.array_equal(cfun(y, x), fun(y, x))) + y = mx.array([[3]]) + self.assertTrue(mx.array_equal(cfun(x, y), fun(x, y))) + self.assertTrue(mx.array_equal(cfun(y, x), fun(y, x))) + + def test_shapeless_compile_with_reduction(self): + # Test shapeless compile with a reduction + z = 1 + + @partial(mx.compile, shapeless=True) + def fun(x, y): + return x + y.sum(0, keepdims=True) + z + + x = mx.ones((2, 2), mx.int32) + y = mx.ones((2, 2), mx.int32) + self.assertTrue(mx.array_equal(fun(x, y), mx.full(shape=(2, 2), vals=4))) + x = mx.ones((3, 3), mx.int32) + y = mx.ones((3, 3), mx.int32) + z = 2 + self.assertTrue(mx.array_equal(fun(x, y), mx.full(shape=(3, 3), vals=5))) + + x1 = mx.array([[1, 2], [3, 4], [5, 6]]) + x2 = mx.array([[1, 2]]) + + def fun(x): + return x * x.sum(-1, keepdims=True) + + cfun = mx.compile(fun, shapeless=True) + mx.eval(cfun(x1)) + self.assertTrue(mx.array_equal(fun(x2), cfun(x2))) + + def test_compile_with_constant(self): + + # Test float + @partial(mx.compile) + def fun(x, y): + return x + y + + z = fun(mx.array(1.0), 1.0) + self.assertEqual(z.item(), 2.0) + + z = fun(mx.array(1.0), 2.0) + self.assertEqual(z.item(), 3.0) + + z = fun(mx.array(1.0), y=1.0) + self.assertEqual(z.item(), 2.0) + + z = fun(mx.array(1.0), y=3.0) + self.assertEqual(z.item(), 4.0) + + # Test tuple + @partial(mx.compile) + def fun(x, y=(1, 2)): + return x + y[0] + y[1] + + z = fun(mx.array(1)) + self.assertEqual(z.item(), 4) + + z = fun(mx.array(1), (2, 2)) + self.assertEqual(z.item(), 5) + + z = fun(mx.array(1), (2, 1)) + self.assertEqual(z.item(), 4) + + # Test bool + @partial(mx.compile) + def fun(x, y): + if y: + return x + 1 + else: + return x + 2 + + z = fun(mx.array(1), True) + self.assertEqual(z.item(), 2) + + z = fun(mx.array(1), False) + self.assertEqual(z.item(), 3) + + # Test string + @partial(mx.compile) + def fun(x, y): + if y == "one": + return x + 1 + else: + return x + 2 + + z = fun(mx.array(1), "one") + self.assertEqual(z.item(), 2) + + z = fun(mx.array(1), "two") + self.assertEqual(z.item(), 3) + if __name__ == "__main__": unittest.main() diff --git a/tests/compile_tests.cpp b/tests/compile_tests.cpp index 8ad67a1ed..569ab0913 100644 --- a/tests/compile_tests.cpp +++ b/tests/compile_tests.cpp @@ -624,31 +624,23 @@ TEST_CASE("test transform compiled function") { CHECK(!outs[0].inputs()[1].has_primitive()); } -TEST_CASE("test metal fusion kernel reuse") { - if (default_device() != Device::gpu) { - return; - } - +TEST_CASE("test fusion kernel reuse") { auto cfun = compile(gelu_1); auto x = array({2.0f, -2.0f}); auto y = cfun({x})[0]; auto p = std::dynamic_pointer_cast(y.primitive_ptr()); eval(y); - std::string lib_name = p->metal_lib_name(); - std::string lib_source = p->metal_lib_source(); + std::string lib_name = p->lib_name(); CHECK(!lib_name.empty()); - CHECK(!lib_source.empty()); x = astype(reshape(arange(10), {2, 5}), float32); auto z = cfun({x})[0]; auto pz = std::dynamic_pointer_cast(z.primitive_ptr()); eval(z); - std::string lib_name_z = pz->metal_lib_name(); - std::string lib_source_z = pz->metal_lib_source(); + std::string lib_name_z = pz->lib_name(); CHECK(!lib_name_z.empty()); - CHECK(lib_source_z.empty()); CHECK_EQ(lib_name, lib_name_z); } @@ -657,29 +649,57 @@ auto add3(const std::vector& xs) { return std::vector{xs[0] + xs[0] + xs[0]}; } -TEST_CASE("test metal fusion types") { - if (default_device() != Device::gpu) { - return; - } - +TEST_CASE("test fusion types") { auto cfun = compile(add3); auto x = array({2.0f, -2.0f}); auto y = cfun({x})[0]; auto p = std::dynamic_pointer_cast(y.primitive_ptr()); eval(y); - std::string lib_name = p->metal_lib_name(); - std::string lib_source = p->metal_lib_source(); + std::string lib_name = p->lib_name(); CHECK(!lib_name.empty()); - CHECK(!lib_source.empty()); x = array({2, -2}, int32); auto z = cfun({x})[0]; auto pz = std::dynamic_pointer_cast(z.primitive_ptr()); eval(z); - std::string lib_name_z = pz->metal_lib_name(); - std::string lib_source_z = pz->metal_lib_source(); + std::string lib_name_z = pz->lib_name(); CHECK(!lib_name_z.empty()); - CHECK(!lib_source_z.empty()); +} + +auto compile_shapeless_not_ok(const std::vector& inputs) { + auto x = reshape(inputs[0], {2, 2}); + return std::vector{x}; +} + +auto compile_shapeless_ok(const std::vector& inputs) { + auto x = inputs[0] + array({2}); + return std::vector{x}; +} + +TEST_CASE("test shapeless compile") { + { + auto cfun = compile(compile_shapeless_not_ok, /* shapeless */ true); + CHECK_THROWS(cfun({array({1, 2, 3, 4})})); + } + + { + auto cfun = compile(compile_shapeless_ok, /* shapeless */ true); + auto out = cfun({array({1, 2})})[0]; + auto out2 = cfun({array({1, 2, 3, 4})})[0]; + + // Not making a new constant array since no recompile, + // hence the ids should be the same + CHECK_EQ(out.inputs()[1].id(), out2.inputs()[1].id()); + CHECK(array_equal(out2, array({3, 4, 5, 6})).item()); + + // Recompile since type changes + out2 = cfun({array({1.0, 2.0})})[0]; + CHECK_NE(out.inputs()[1].id(), out2.inputs()[1].id()); + + // Recompile since ndim changes + out2 = cfun({array({1.0, 2.0}, {1, 2})})[0]; + CHECK_NE(out.inputs()[1].id(), out2.inputs()[1].id()); + } }