diff --git a/.circleci/config.yml b/.circleci/config.yml index 9c8cb31a3..6dc7ec4df 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -24,8 +24,8 @@ jobs: type: boolean default: false macos: - xcode: "15.2.0" - resource_class: macos.m1.medium.gen1 + xcode: "16.2.0" + resource_class: m2pro.medium steps: - checkout - run: @@ -89,15 +89,14 @@ jobs: pip install numpy sudo apt-get update sudo apt-get install libblas-dev liblapack-dev liblapacke-dev + sudo apt-get install openmpi-bin openmpi-common libopenmpi-dev - run: name: Install Python package command: | - CMAKE_ARGS="-DMLX_BUILD_METAL=OFF - CMAKE_COMPILE_WARNING_AS_ERROR=ON" \ + CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" \ CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \ python3 setup.py build_ext --inplace - CMAKE_ARGS="-DMLX_BUILD_METAL=OFF \ - CMAKE_COMPILE_WARNING_AS_ERROR=ON" \ + CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" \ CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \ python3 setup.py develop - run: @@ -110,6 +109,8 @@ jobs: name: Run Python tests command: | python3 -m unittest discover python/tests -v + mpirun --bind-to none -host localhost:8 -np 8 python python/tests/mpi_test_distributed.py + mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py - run: name: Build CPP only command: | @@ -124,10 +125,15 @@ jobs: parameters: xcode_version: type: string - default: "15.2.0" + default: "16.2.0" + macosx_deployment_target: + type: string + default: "" macos: xcode: << parameters.xcode_version >> - resource_class: macos.m1.medium.gen1 + environment: + MACOSX_DEPLOYMENT_TARGET: << parameters.macosx_deployment_target >> + resource_class: m2pro.medium steps: - checkout - run: @@ -149,7 +155,7 @@ jobs: command: | source env/bin/activate DEBUG=1 CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \ - CMAKE_ARGS="CMAKE_COMPILE_WARNING_AS_ERROR=ON" \ + CMAKE_ARGS="-DCMAKE_COMPILE_WARNING_AS_ERROR=ON" \ pip install -e . -v - run: name: Generate package stubs @@ -213,13 +219,18 @@ jobs: default: "3.9" xcode_version: type: string - default: "15.2.0" + default: "16.2.0" build_env: type: string default: "" + macosx_deployment_target: + type: string + default: "" macos: xcode: << parameters.xcode_version >> - resource_class: macos.m1.medium.gen1 + resource_class: m2pro.medium + environment: + MACOSX_DEPLOYMENT_TARGET: << parameters.macosx_deployment_target >> steps: - checkout - run: @@ -240,7 +251,7 @@ jobs: name: Install Python package command: | source env/bin/activate - DEV_RELEASE=1 \ + env -u MACOSX_DEPLOYMENT_TARGET DEV_RELEASE=1 \ CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \ pip install . -v - run: @@ -335,7 +346,7 @@ workflows: - mac_build_and_test: matrix: parameters: - xcode_version: ["15.0.0", "15.2.0", "16.0.0"] + macosx_deployment_target: ["13.5", "14.0"] - linux_build_and_test - build_documentation @@ -355,8 +366,70 @@ workflows: matrix: parameters: python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"] - xcode_version: ["15.0.0", "15.2.0"] + macosx_deployment_target: ["13.5", "14.0", "15.0"] build_env: ["PYPI_RELEASE=1"] + xcode_version: ["16.2.0", "15.0.0"] + exclude: + - macosx_deployment_target: "13.5" + xcode_version: "16.2.0" + python_version: "3.9" + build_env: "PYPI_RELEASE=1" + - macosx_deployment_target: "13.5" + xcode_version: "16.2.0" + python_version: "3.10" + build_env: "PYPI_RELEASE=1" + - macosx_deployment_target: "13.5" + xcode_version: "16.2.0" + python_version: "3.11" + build_env: "PYPI_RELEASE=1" + - macosx_deployment_target: "13.5" + xcode_version: "16.2.0" + python_version: "3.12" + build_env: "PYPI_RELEASE=1" + - macosx_deployment_target: "13.5" + xcode_version: "16.2.0" + python_version: "3.13" + build_env: "PYPI_RELEASE=1" + - macosx_deployment_target: "14.0" + xcode_version: "15.0.0" + python_version: "3.9" + build_env: "PYPI_RELEASE=1" + - macosx_deployment_target: "14.0" + xcode_version: "15.0.0" + python_version: "3.10" + build_env: "PYPI_RELEASE=1" + - macosx_deployment_target: "14.0" + xcode_version: "15.0.0" + python_version: "3.11" + build_env: "PYPI_RELEASE=1" + - macosx_deployment_target: "14.0" + xcode_version: "15.0.0" + python_version: "3.12" + build_env: "PYPI_RELEASE=1" + - macosx_deployment_target: "14.0" + xcode_version: "15.0.0" + python_version: "3.13" + build_env: "PYPI_RELEASE=1" + - macosx_deployment_target: "15.0" + xcode_version: "15.0.0" + python_version: "3.9" + build_env: "PYPI_RELEASE=1" + - macosx_deployment_target: "15.0" + xcode_version: "15.0.0" + python_version: "3.10" + build_env: "PYPI_RELEASE=1" + - macosx_deployment_target: "15.0" + xcode_version: "15.0.0" + python_version: "3.11" + build_env: "PYPI_RELEASE=1" + - macosx_deployment_target: "15.0" + xcode_version: "15.0.0" + python_version: "3.12" + build_env: "PYPI_RELEASE=1" + - macosx_deployment_target: "15.0" + xcode_version: "15.0.0" + python_version: "3.13" + build_env: "PYPI_RELEASE=1" - build_documentation: filters: tags: @@ -379,7 +452,7 @@ workflows: requires: [ hold ] matrix: parameters: - xcode_version: ["15.0.0", "15.2.0", "16.0.0"] + macosx_deployment_target: ["13.5", "14.0"] - linux_build_and_test: requires: [ hold ] nightly_build: @@ -392,7 +465,54 @@ workflows: matrix: parameters: python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"] - xcode_version: ["15.0.0", "15.2.0"] + macosx_deployment_target: ["13.5", "14.0", "15.0"] + xcode_version: ["16.2.0", "15.0.0"] + exclude: + - macosx_deployment_target: "13.5" + xcode_version: "16.2.0" + python_version: "3.9" + - macosx_deployment_target: "13.5" + xcode_version: "16.2.0" + python_version: "3.10" + - macosx_deployment_target: "13.5" + xcode_version: "16.2.0" + python_version: "3.11" + - macosx_deployment_target: "13.5" + xcode_version: "16.2.0" + python_version: "3.12" + - macosx_deployment_target: "13.5" + xcode_version: "16.2.0" + python_version: "3.13" + - macosx_deployment_target: "14.0" + xcode_version: "15.0.0" + python_version: "3.9" + - macosx_deployment_target: "14.0" + xcode_version: "15.0.0" + python_version: "3.10" + - macosx_deployment_target: "14.0" + xcode_version: "15.0.0" + python_version: "3.11" + - macosx_deployment_target: "14.0" + xcode_version: "15.0.0" + python_version: "3.12" + - macosx_deployment_target: "14.0" + xcode_version: "15.0.0" + python_version: "3.13" + - macosx_deployment_target: "15.0" + xcode_version: "15.0.0" + python_version: "3.9" + - macosx_deployment_target: "15.0" + xcode_version: "15.0.0" + python_version: "3.10" + - macosx_deployment_target: "15.0" + xcode_version: "15.0.0" + python_version: "3.11" + - macosx_deployment_target: "15.0" + xcode_version: "15.0.0" + python_version: "3.12" + - macosx_deployment_target: "15.0" + xcode_version: "15.0.0" + python_version: "3.13" weekly_build: when: and: @@ -403,8 +523,70 @@ workflows: matrix: parameters: python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"] - xcode_version: ["15.0.0", "15.2.0", "16.0.0"] + macosx_deployment_target: ["13.5", "14.0", "15.0"] build_env: ["DEV_RELEASE=1"] + xcode_version: ["16.2.0", "15.0.0"] + exclude: + - macosx_deployment_target: "13.5" + xcode_version: "16.2.0" + python_version: "3.9" + build_env: "DEV_RELEASE=1" + - macosx_deployment_target: "13.5" + xcode_version: "16.2.0" + python_version: "3.10" + build_env: "DEV_RELEASE=1" + - macosx_deployment_target: "13.5" + xcode_version: "16.2.0" + python_version: "3.11" + build_env: "DEV_RELEASE=1" + - macosx_deployment_target: "13.5" + xcode_version: "16.2.0" + python_version: "3.12" + build_env: "DEV_RELEASE=1" + - macosx_deployment_target: "13.5" + xcode_version: "16.2.0" + python_version: "3.13" + build_env: "DEV_RELEASE=1" + - macosx_deployment_target: "14.0" + xcode_version: "15.0.0" + python_version: "3.9" + build_env: "DEV_RELEASE=1" + - macosx_deployment_target: "14.0" + xcode_version: "15.0.0" + python_version: "3.10" + build_env: "DEV_RELEASE=1" + - macosx_deployment_target: "14.0" + xcode_version: "15.0.0" + python_version: "3.11" + build_env: "DEV_RELEASE=1" + - macosx_deployment_target: "14.0" + xcode_version: "15.0.0" + python_version: "3.12" + build_env: "DEV_RELEASE=1" + - macosx_deployment_target: "14.0" + xcode_version: "15.0.0" + python_version: "3.13" + build_env: "DEV_RELEASE=1" + - macosx_deployment_target: "15.0" + xcode_version: "15.0.0" + python_version: "3.9" + build_env: "DEV_RELEASE=1" + - macosx_deployment_target: "15.0" + xcode_version: "15.0.0" + python_version: "3.10" + build_env: "DEV_RELEASE=1" + - macosx_deployment_target: "15.0" + xcode_version: "15.0.0" + python_version: "3.11" + build_env: "DEV_RELEASE=1" + - macosx_deployment_target: "15.0" + xcode_version: "15.0.0" + python_version: "3.12" + build_env: "DEV_RELEASE=1" + - macosx_deployment_target: "15.0" + xcode_version: "15.0.0" + python_version: "3.13" + build_env: "DEV_RELEASE=1" linux_test_release: when: and: diff --git a/CMakeLists.txt b/CMakeLists.txt index 672b9810c..e2002fc94 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -212,24 +212,6 @@ else() set(MLX_BUILD_ACCELERATE OFF) endif() -find_package(MPI) -if(MPI_FOUND) - execute_process( - COMMAND zsh "-c" "mpirun --version" - OUTPUT_VARIABLE MPI_VERSION - ERROR_QUIET) - if(${MPI_VERSION} MATCHES ".*Open MPI.*") - target_include_directories(mlx PRIVATE ${MPI_INCLUDE_PATH}) - elseif(MPI_VERSION STREQUAL "") - set(MPI_FOUND FALSE) - message( - WARNING "MPI found but mpirun is not available. Building without MPI.") - else() - set(MPI_FOUND FALSE) - message(WARNING "MPI which is not OpenMPI found. Building without MPI.") - endif() -endif() - message(STATUS "Downloading json") FetchContent_Declare( json diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index f1531bb88..fddb2a974 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -5,26 +5,26 @@ possible. ## Pull Requests -1. Fork and submit pull requests to the repo. +1. Fork and submit pull requests to the repo. 2. If you've added code that should be tested, add tests. 3. If a change is likely to impact efficiency, run some of the benchmarks before and after the change. Examples of benchmarks can be found in `benchmarks/python/`. 4. If you've changed APIs, update the documentation. -5. Every PR should have passing tests and at least one review. +5. Every PR should have passing tests and at least one review. 6. For code formatting install `pre-commit` using something like `pip install pre-commit` and run `pre-commit install`. This should install hooks for running `black` and `clang-format` to ensure consistent style for C++ and python code. - + You can also run the formatters manually as follows: - - ``` - clang-format -i file.cpp - ``` - - ``` - black file.py - ``` - + + ```shell + clang-format -i file.cpp + ``` + + ```shell + black file.py + ``` + or run `pre-commit run --all-files` to check all files in the repo. ## Issues diff --git a/benchmarks/python/gather_mm_bench.py b/benchmarks/python/gather_mm_bench.py new file mode 100644 index 000000000..ffeb73487 --- /dev/null +++ b/benchmarks/python/gather_mm_bench.py @@ -0,0 +1,74 @@ +# Copyright © 2025 Apple Inc. + +import mlx.core as mx +from time_utils import time_fn + +N = 1024 +D = 1024 +M = 1024 +E = 32 +I = 4 + + +def gather_sort(x, indices): + N, M = indices.shape + indices = indices.flatten() + order = mx.argsort(indices) + inv_order = mx.argsort(order) + return x.flatten(0, -3)[order // M], indices[order], inv_order + + +def scatter_unsort(x, inv_order, shape=None): + x = x[inv_order] + if shape is not None: + x = mx.unflatten(x, 0, shape) + return x + + +def gather_mm_simulate(x, w, indices): + x, idx, inv_order = gather_sort(x, indices) + for i in range(2): + y = mx.concatenate([x[i] @ w[j].T for i, j in enumerate(idx.tolist())], axis=0) + x = y[:, None] + x = scatter_unsort(x, inv_order, indices.shape) + return x + + +def time_gather_mm(): + x = mx.random.normal((N, 1, 1, D)) / 1024**0.5 + w1 = mx.random.normal((E, M, D)) / 1024**0.5 + w2 = mx.random.normal((E, D, M)) / 1024**0.5 + indices = (mx.random.uniform(shape=(N, I)) * E).astype(mx.uint32) + sorted_indices = mx.sort(indices.flatten()).reshape(N, I) + mx.eval(x, w1, w2, indices, sorted_indices) + + def gather_mm(x, w1, w2, indices, sort): + idx = indices + inv_order = None + if sort: + x, idx, inv_order = gather_sort(x, indices) + x = mx.gather_mm(x, w1.swapaxes(-1, -2), rhs_indices=idx, sorted_indices=sort) + x = mx.gather_mm(x, w2.swapaxes(-1, -2), rhs_indices=idx, sorted_indices=sort) + if sort: + x = scatter_unsort(x, inv_order, indices.shape) + return x + + time_fn(gather_mm, x, w1, w2, indices, False) + time_fn(gather_mm, x, w1, w2, sorted_indices, False) + time_fn(gather_mm, x, w1, w2, indices, True) + + x = mx.random.normal((N * I, D)) / 1024**0.5 + w1 = mx.random.normal((M, D)) / 1024**0.5 + w2 = mx.random.normal((D, M)) / 1024**0.5 + mx.eval(x, w1, w2) + + def equivalent_matmul(x, w1, w2): + x = x @ w1.T + x = x @ w2.T + return x + + time_fn(equivalent_matmul, x, w1, w2) + + +if __name__ == "__main__": + time_gather_mm() diff --git a/benchmarks/python/gather_qmm_bench.py b/benchmarks/python/gather_qmm_bench.py new file mode 100644 index 000000000..17c06d57d --- /dev/null +++ b/benchmarks/python/gather_qmm_bench.py @@ -0,0 +1,84 @@ +# Copyright © 2025 Apple Inc. + +import mlx.core as mx +from time_utils import time_fn + +N = 1024 +D = 1024 +M = 1024 +E = 32 +I = 4 + + +def gather_sort(x, indices): + N, M = indices.shape + indices = indices.flatten() + order = mx.argsort(indices) + inv_order = mx.argsort(order) + return x.flatten(0, -3)[order // M], indices[order], inv_order + + +def scatter_unsort(x, inv_order, shape=None): + x = x[inv_order] + if shape is not None: + x = mx.unflatten(x, 0, shape) + return x + + +def gather_mm_simulate(x, w, indices): + x, idx, inv_order = gather_sort(x, indices) + for i in range(2): + y = mx.concatenate( + [ + mx.quantized_matmul(x[i], w[0][j], w[1][j], w[2][j], transpose=True) + for i, j in enumerate(idx.tolist()) + ], + axis=0, + ) + x = y[:, None] + x = scatter_unsort(x, inv_order, indices.shape) + return x + + +def time_gather_qmm(): + x = mx.random.normal((N, 1, 1, D)) / 1024**0.5 + w1 = mx.random.normal((E, M, D)) / 1024**0.5 + w2 = mx.random.normal((E, D, M)) / 1024**0.5 + w1 = mx.quantize(w1) + w2 = mx.quantize(w2) + indices = (mx.random.uniform(shape=(N, I)) * E).astype(mx.uint32) + sorted_indices = mx.sort(indices.flatten()).reshape(N, I) + mx.eval(x, w1, w2, indices, sorted_indices) + + def gather_mm(x, w1, w2, indices, sort): + idx = indices + inv_order = None + if sort: + x, idx, inv_order = gather_sort(x, indices) + x = mx.gather_qmm(x, *w1, transpose=True, rhs_indices=idx, sorted_indices=sort) + x = mx.gather_qmm(x, *w2, transpose=True, rhs_indices=idx, sorted_indices=sort) + if sort: + x = scatter_unsort(x, inv_order, indices.shape) + return x + + time_fn(gather_mm, x, w1, w2, indices, False) + time_fn(gather_mm, x, w1, w2, sorted_indices, False) + time_fn(gather_mm, x, w1, w2, indices, True) + + x = mx.random.normal((N * I, D)) / 1024**0.5 + w1 = mx.random.normal((M, D)) / 1024**0.5 + w2 = mx.random.normal((D, M)) / 1024**0.5 + w1 = mx.quantize(w1) + w2 = mx.quantize(w2) + mx.eval(x, w1, w2) + + def equivalent_matmul(x, w1, w2): + x = mx.quantized_matmul(x, *w1, transpose=True) + x = mx.quantized_matmul(x, *w2, transpose=True) + return x + + time_fn(equivalent_matmul, x, w1, w2) + + +if __name__ == "__main__": + time_gather_qmm() diff --git a/docs/Doxyfile b/docs/Doxyfile index 460e6b503..e47712d44 100644 --- a/docs/Doxyfile +++ b/docs/Doxyfile @@ -13,7 +13,7 @@ EXCLUDE_PATTERNS = */private/* CREATE_SUBDIRS = NO FULL_PATH_NAMES = YES RECURSIVE = YES -GENERATE_HTML = YES +GENERATE_HTML = NO GENERATE_LATEX = NO GENERATE_XML = YES XML_PROGRAMLISTING = YES diff --git a/docs/src/dev/extensions.rst b/docs/src/dev/extensions.rst index b8c3a4995..2aef28f99 100644 --- a/docs/src/dev/extensions.rst +++ b/docs/src/dev/extensions.rst @@ -93,9 +93,9 @@ Primitives ^^^^^^^^^^^ A :class:`Primitive` is part of the computation graph of an :class:`array`. It -defines how to create outputs arrays given a input arrays. Further, a +defines how to create output arrays given input arrays. Further, a :class:`Primitive` has methods to run on the CPU or GPU and for function -transformations such as ``vjp`` and ``jvp``. Lets go back to our example to be +transformations such as ``vjp`` and ``jvp``. Let's go back to our example to be more concrete: .. code-block:: C++ @@ -128,7 +128,7 @@ more concrete: /** The vector-Jacobian product. */ std::vector vjp( const std::vector& primals, - const array& cotan, + const std::vector& cotangents, const std::vector& argnums, const std::vector& outputs) override; @@ -469,7 +469,7 @@ one we just defined: const std::vector& tangents, const std::vector& argnums) { // Forward mode diff that pushes along the tangents - // The jvp transform on the primitive can built with ops + // The jvp transform on the primitive can be built with ops // that are scheduled on the same stream as the primitive // If argnums = {0}, we only push along x in which case the @@ -481,7 +481,7 @@ one we just defined: auto scale_arr = array(scale, tangents[0].dtype()); return {multiply(scale_arr, tangents[0], stream())}; } - // If, argnums = {0, 1}, we take contributions from both + // If argnums = {0, 1}, we take contributions from both // which gives us jvp = tangent_x * alpha + tangent_y * beta else { return {axpby(tangents[0], tangents[1], alpha_, beta_, stream())}; @@ -735,7 +735,7 @@ Let's look at a simple script and its results: print(f"c shape: {c.shape}") print(f"c dtype: {c.dtype}") - print(f"c correct: {mx.all(c == 6.0).item()}") + print(f"c is correct: {mx.all(c == 6.0).item()}") Output: @@ -743,7 +743,7 @@ Output: c shape: [3, 4] c dtype: float32 - c correctness: True + c is correct: True Results ^^^^^^^ diff --git a/docs/src/index.rst b/docs/src/index.rst index 075861e88..51e719572 100644 --- a/docs/src/index.rst +++ b/docs/src/index.rst @@ -70,6 +70,7 @@ are the CPU and GPU. python/fft python/linalg python/metal + python/memory_management python/nn python/optimizers python/distributed diff --git a/docs/src/python/array.rst b/docs/src/python/array.rst index 532bb45c9..7e1c3339d 100644 --- a/docs/src/python/array.rst +++ b/docs/src/python/array.rst @@ -38,6 +38,7 @@ Array array.log10 array.log1p array.log2 + array.logcumsumexp array.logsumexp array.max array.mean diff --git a/docs/src/python/linalg.rst b/docs/src/python/linalg.rst index 769f4bbb1..b01f74117 100644 --- a/docs/src/python/linalg.rst +++ b/docs/src/python/linalg.rst @@ -20,5 +20,6 @@ Linear Algebra eigh lu lu_factor + pinv solve solve_triangular diff --git a/docs/src/python/memory_management.rst b/docs/src/python/memory_management.rst new file mode 100644 index 000000000..f708efbfd --- /dev/null +++ b/docs/src/python/memory_management.rst @@ -0,0 +1,16 @@ +Memory Management +================= + +.. currentmodule:: mlx.core + +.. autosummary:: + :toctree: _autosummary + + get_active_memory + get_peak_memory + reset_peak_memory + get_cache_memory + set_memory_limit + set_cache_limit + set_wired_limit + clear_cache diff --git a/docs/src/python/metal.rst b/docs/src/python/metal.rst index 4d6fb91d9..83a363c3b 100644 --- a/docs/src/python/metal.rst +++ b/docs/src/python/metal.rst @@ -8,13 +8,5 @@ Metal is_available device_info - get_active_memory - get_peak_memory - reset_peak_memory - get_cache_memory - set_memory_limit - set_cache_limit - set_wired_limit - clear_cache start_capture stop_capture diff --git a/docs/src/python/ops.rst b/docs/src/python/ops.rst index c0d098b21..55fc1f534 100644 --- a/docs/src/python/ops.rst +++ b/docs/src/python/ops.rst @@ -36,10 +36,12 @@ Operations bitwise_or bitwise_xor block_masked_mm + broadcast_arrays broadcast_to ceil clip concatenate + contiguous conj conjugate convolve @@ -101,6 +103,7 @@ Operations log10 log1p logaddexp + logcumsumexp logical_not logical_and logical_or diff --git a/docs/src/python/optimizers/common_optimizers.rst b/docs/src/python/optimizers/common_optimizers.rst index 41b3fba03..86f800135 100644 --- a/docs/src/python/optimizers/common_optimizers.rst +++ b/docs/src/python/optimizers/common_optimizers.rst @@ -18,3 +18,4 @@ Common Optimizers AdamW Adamax Lion + MultiOptimizer diff --git a/docs/src/python/transforms.rst b/docs/src/python/transforms.rst index fbdfd4f08..23f86720b 100644 --- a/docs/src/python/transforms.rst +++ b/docs/src/python/transforms.rst @@ -9,6 +9,7 @@ Transforms :toctree: _autosummary eval + async_eval compile custom_function disable_compile diff --git a/mlx/CMakeLists.txt b/mlx/CMakeLists.txt index 76fe389d4..abf46a7d5 100644 --- a/mlx/CMakeLists.txt +++ b/mlx/CMakeLists.txt @@ -5,6 +5,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp ${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/dtype_utils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/export.cpp ${CMAKE_CURRENT_SOURCE_DIR}/einsum.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp diff --git a/mlx/allocator.cpp b/mlx/allocator.cpp index 2d97a6db3..fbfca7551 100644 --- a/mlx/allocator.cpp +++ b/mlx/allocator.cpp @@ -4,7 +4,6 @@ #include #include "mlx/allocator.h" -#include "mlx/scheduler.h" namespace mlx::core::allocator { @@ -22,23 +21,4 @@ void free(Buffer buffer) { allocator().free(buffer); } -Buffer CommonAllocator::malloc(size_t size) { - void* ptr = std::malloc(size + sizeof(size_t)); - if (ptr != nullptr) { - *static_cast(ptr) = size; - } - return Buffer{ptr}; -} - -void CommonAllocator::free(Buffer buffer) { - std::free(buffer.ptr()); -} - -size_t CommonAllocator::size(Buffer buffer) const { - if (buffer.ptr() == nullptr) { - return 0; - } - return *static_cast(buffer.ptr()); -} - } // namespace mlx::core::allocator diff --git a/mlx/allocator.h b/mlx/allocator.h index d4e3e1d6e..362f4f08a 100644 --- a/mlx/allocator.h +++ b/mlx/allocator.h @@ -49,16 +49,4 @@ class Allocator { Allocator& allocator(); -class CommonAllocator : public Allocator { - /** A general CPU allocator. */ - public: - virtual Buffer malloc(size_t size) override; - virtual void free(Buffer buffer) override; - virtual size_t size(Buffer buffer) const override; - - private: - CommonAllocator() = default; - friend Allocator& allocator(); -}; - } // namespace mlx::core::allocator diff --git a/mlx/array.h b/mlx/array.h index d690dcd97..66a4702a6 100644 --- a/mlx/array.h +++ b/mlx/array.h @@ -339,11 +339,11 @@ class array { return allocator::allocator().size(buffer()); } - // Return a copy of the shared pointer - // to the array::Data struct - std::shared_ptr data_shared_ptr() const { + // Return the shared pointer to the array::Data struct + const std::shared_ptr& data_shared_ptr() const { return array_desc_->data; } + // Return a raw pointer to the arrays data template T* data() { diff --git a/mlx/backend/common/CMakeLists.txt b/mlx/backend/common/CMakeLists.txt index 82e6eef84..6c4e25067 100644 --- a/mlx/backend/common/CMakeLists.txt +++ b/mlx/backend/common/CMakeLists.txt @@ -1,6 +1,7 @@ target_sources( mlx - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/broadcasting.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp ${CMAKE_CURRENT_SOURCE_DIR}/common.cpp ${CMAKE_CURRENT_SOURCE_DIR}/load.cpp ${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp diff --git a/mlx/backend/common/broadcasting.cpp b/mlx/backend/common/broadcasting.cpp new file mode 100644 index 000000000..49bc75b8f --- /dev/null +++ b/mlx/backend/common/broadcasting.cpp @@ -0,0 +1,24 @@ +// Copyright © 2024 Apple Inc. + +#include "mlx/backend/common/utils.h" + +namespace mlx::core { + +void broadcast(const array& in, array& out) { + if (out.size() == 0) { + out.set_data(nullptr); + return; + } + Strides strides(out.ndim(), 0); + int diff = out.ndim() - in.ndim(); + for (int i = in.ndim() - 1; i >= 0; --i) { + strides[i + diff] = (in.shape()[i] == 1) ? 0 : in.strides()[i]; + } + auto flags = in.flags(); + if (out.size() > in.size()) { + flags.row_contiguous = flags.col_contiguous = false; + } + out.copy_shared_buffer(in, strides, flags, in.data_size()); +} + +} // namespace mlx::core diff --git a/mlx/backend/common/broadcasting.h b/mlx/backend/common/broadcasting.h new file mode 100644 index 000000000..29651e909 --- /dev/null +++ b/mlx/backend/common/broadcasting.h @@ -0,0 +1,11 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include "mlx/array.h" + +namespace mlx::core { + +void broadcast(const array& in, array& out); + +} // namespace mlx::core diff --git a/mlx/backend/common/common.cpp b/mlx/backend/common/common.cpp index 57813e062..2cda88a31 100644 --- a/mlx/backend/common/common.cpp +++ b/mlx/backend/common/common.cpp @@ -1,6 +1,7 @@ // Copyright © 2024 Apple Inc. #include +#include "mlx/backend/common/broadcasting.h" #include "mlx/backend/common/utils.h" #include "mlx/primitives.h" @@ -42,23 +43,6 @@ void AsStrided::eval(const std::vector& inputs, array& out) { return out.copy_shared_buffer(in, strides_, flags, data_size, offset_); } -void broadcast(const array& in, array& out) { - if (out.size() == 0) { - out.set_data(nullptr); - return; - } - Strides strides(out.ndim(), 0); - int diff = out.ndim() - in.ndim(); - for (int i = in.ndim() - 1; i >= 0; --i) { - strides[i + diff] = (in.shape()[i] == 1) ? 0 : in.strides()[i]; - } - auto flags = in.flags(); - if (out.size() > in.size()) { - flags.row_contiguous = flags.col_contiguous = false; - } - out.copy_shared_buffer(in, strides, flags, in.data_size()); -} - void Broadcast::eval(const std::vector& inputs, array& out) { broadcast(inputs[0], out); } diff --git a/mlx/backend/cpu/CMakeLists.txt b/mlx/backend/cpu/CMakeLists.txt index e36e0567a..152f33b17 100644 --- a/mlx/backend/cpu/CMakeLists.txt +++ b/mlx/backend/cpu/CMakeLists.txt @@ -58,6 +58,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp ${CMAKE_CURRENT_SOURCE_DIR}/select.cpp ${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.cpp ${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp ${CMAKE_CURRENT_SOURCE_DIR}/threefry.cpp ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp @@ -73,8 +74,8 @@ target_sources( if(MLX_BUILD_ACCELERATE) target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/bnns.cpp) else() - target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/no_fp16.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/gemms/no_bf16.cpp) + target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/simd_fp16.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/gemms/simd_bf16.cpp) endif() if(IOS) diff --git a/mlx/backend/cpu/distributed.cpp b/mlx/backend/cpu/distributed.cpp index 1afa027a8..a3edf8f49 100644 --- a/mlx/backend/cpu/distributed.cpp +++ b/mlx/backend/cpu/distributed.cpp @@ -46,8 +46,15 @@ void AllReduce::eval_cpu( case Sum: distributed::detail::all_sum(group(), in, outputs[0], stream()); break; + case Max: + distributed::detail::all_max(group(), in, outputs[0], stream()); + break; + case Min: + distributed::detail::all_min(group(), in, outputs[0], stream()); + break; default: - throw std::runtime_error("Only all reduce sum is supported for now"); + throw std::runtime_error( + "Only all reduce sum, min and max are supported for now"); } } diff --git a/mlx/backend/cpu/gemms/no_bf16.cpp b/mlx/backend/cpu/gemms/no_bf16.cpp deleted file mode 100644 index 157c07f46..000000000 --- a/mlx/backend/cpu/gemms/no_bf16.cpp +++ /dev/null @@ -1,27 +0,0 @@ -// Copyright © 2025 Apple Inc. - -#include "mlx/backend/cpu/gemm.h" - -namespace mlx::core { - -template <> -void matmul( - const bfloat16_t*, - const bfloat16_t*, - bfloat16_t*, - bool, - bool, - size_t, - size_t, - size_t, - float, - float, - size_t, - const Shape&, - const Strides&, - const Shape&, - const Strides&) { - throw std::runtime_error("[Matmul::eval_cpu] bfloat16 not supported."); -} - -} // namespace mlx::core diff --git a/mlx/backend/cpu/gemms/no_fp16.cpp b/mlx/backend/cpu/gemms/no_fp16.cpp deleted file mode 100644 index 3f3f41cc5..000000000 --- a/mlx/backend/cpu/gemms/no_fp16.cpp +++ /dev/null @@ -1,27 +0,0 @@ -// Copyright © 2025 Apple Inc. - -#include "mlx/backend/cpu/gemm.h" - -namespace mlx::core { - -template <> -void matmul( - const float16_t*, - const float16_t*, - float16_t*, - bool, - bool, - size_t, - size_t, - size_t, - float, - float, - size_t, - const Shape&, - const Strides&, - const Shape&, - const Strides&) { - throw std::runtime_error("[Matmul::eval_cpu] float16 not supported."); -} - -} // namespace mlx::core diff --git a/mlx/backend/cpu/gemms/simd_bf16.cpp b/mlx/backend/cpu/gemms/simd_bf16.cpp new file mode 100644 index 000000000..58f5964b6 --- /dev/null +++ b/mlx/backend/cpu/gemms/simd_bf16.cpp @@ -0,0 +1,45 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/common/utils.h" +#include "mlx/backend/cpu/gemm.h" +#include "mlx/backend/cpu/gemms/simd_gemm.h" + +namespace mlx::core { + +template <> +void matmul( + const bfloat16_t* a, + const bfloat16_t* b, + bfloat16_t* out, + bool a_transposed, + bool b_transposed, + size_t lda, + size_t ldb, + size_t ldc, + float alpha, + float beta, + size_t batch_size, + const Shape& a_shape, + const Strides& a_strides, + const Shape& b_shape, + const Strides& b_strides) { + auto ndim = a_shape.size(); + size_t M = a_shape[ndim - 2]; + size_t N = b_shape[ndim - 1]; + size_t K = a_shape[ndim - 1]; + for (int i = 0; i < batch_size; ++i) { + simd_gemm( + a + elem_to_loc(M * K * i, a_shape, a_strides), + b + elem_to_loc(K * N * i, b_shape, b_strides), + out + M * N * i, + a_transposed, + b_transposed, + M, + N, + K, + alpha, + beta); + } +} + +} // namespace mlx::core diff --git a/mlx/backend/cpu/gemms/simd_fp16.cpp b/mlx/backend/cpu/gemms/simd_fp16.cpp new file mode 100644 index 000000000..93467da86 --- /dev/null +++ b/mlx/backend/cpu/gemms/simd_fp16.cpp @@ -0,0 +1,45 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/common/utils.h" +#include "mlx/backend/cpu/gemm.h" +#include "mlx/backend/cpu/gemms/simd_gemm.h" + +namespace mlx::core { + +template <> +void matmul( + const float16_t* a, + const float16_t* b, + float16_t* out, + bool a_transposed, + bool b_transposed, + size_t lda, + size_t ldb, + size_t ldc, + float alpha, + float beta, + size_t batch_size, + const Shape& a_shape, + const Strides& a_strides, + const Shape& b_shape, + const Strides& b_strides) { + auto ndim = a_shape.size(); + size_t M = a_shape[ndim - 2]; + size_t N = b_shape[ndim - 1]; + size_t K = a_shape[ndim - 1]; + for (int i = 0; i < batch_size; ++i) { + simd_gemm( + a + elem_to_loc(M * K * i, a_shape, a_strides), + b + elem_to_loc(K * N * i, b_shape, b_strides), + out + M * N * i, + a_transposed, + b_transposed, + M, + N, + K, + alpha, + beta); + } +} + +} // namespace mlx::core diff --git a/mlx/backend/cpu/gemms/simd_gemm.h b/mlx/backend/cpu/gemms/simd_gemm.h new file mode 100644 index 000000000..a23c7dea3 --- /dev/null +++ b/mlx/backend/cpu/gemms/simd_gemm.h @@ -0,0 +1,139 @@ +// Copyright © 2025 Apple Inc. +#pragma once + +#include "mlx/backend/cpu/simd/simd.h" + +namespace mlx::core { + +inline int ceildiv(int a, int b) { + return (a + b - 1) / b; +} + +template +void load_block( + const T* in, + AccT* out, + int M, + int N, + int i, + int j, + bool transpose) { + if (transpose) { + for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) { + for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) { + out[jj * block_size + ii] = + in[(i * block_size + ii) * N + j * block_size + jj]; + } + } + } else { + for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) { + for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) { + out[ii * block_size + jj] = + in[(i * block_size + ii) * N + j * block_size + jj]; + } + } + } +} + +template +void simd_gemm( + const T* a, + const T* b, + T* c, + bool a_trans, + bool b_trans, + int M, + int N, + int K, + float alpha, + float beta) { + constexpr int block_size = 16; + constexpr int simd_size = simd::max_size; + static_assert( + (block_size % simd_size) == 0, + "Block size must be divisible by SIMD size"); + + int last_k_block_size = K - block_size * (K / block_size); + int last_k_simd_block = (last_k_block_size / simd_size) * simd_size; + for (int i = 0; i < ceildiv(M, block_size); i++) { + for (int j = 0; j < ceildiv(N, block_size); j++) { + AccT c_block[block_size * block_size] = {0.0}; + AccT a_block[block_size * block_size]; + AccT b_block[block_size * block_size]; + + int k = 0; + for (; k < K / block_size; k++) { + // Load a and b blocks + if (a_trans) { + load_block(a, a_block, K, M, k, i, true); + } else { + load_block(a, a_block, M, K, i, k, false); + } + if (b_trans) { + load_block(b, b_block, N, K, j, k, false); + } else { + load_block(b, b_block, K, N, k, j, true); + } + + // Multiply and accumulate + for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) { + for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) { + for (int kk = 0; kk < block_size; kk += simd_size) { + auto av = + simd::load(a_block + ii * block_size + kk); + auto bv = + simd::load(b_block + jj * block_size + kk); + c_block[ii * block_size + jj] += simd::sum(av * bv); + } + } + } + } + if (last_k_block_size) { + // Load a and b blocks + if (a_trans) { + load_block(a, a_block, K, M, k, i, true); + } else { + load_block(a, a_block, M, K, i, k, false); + } + if (b_trans) { + load_block(b, b_block, N, K, j, k, false); + } else { + load_block(b, b_block, K, N, k, j, true); + } + + // Multiply and accumulate + for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) { + for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) { + int kk = 0; + for (; kk < last_k_simd_block; kk += simd_size) { + auto av = + simd::load(a_block + ii * block_size + kk); + auto bv = + simd::load(b_block + jj * block_size + kk); + c_block[ii * block_size + jj] += simd::sum(av * bv); + } + for (; kk < last_k_block_size; ++kk) { + c_block[ii * block_size + jj] += + a_block[ii * block_size + kk] * b_block[jj * block_size + kk]; + } + } + } + } + + // Store + for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) { + for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) { + auto c_idx = (i * block_size + ii) * N + j * block_size + jj; + if (beta != 0) { + c[c_idx] = static_cast( + alpha * c_block[ii * block_size + jj] + beta * c[c_idx]); + } else { + c[c_idx] = static_cast(alpha * c_block[ii * block_size + jj]); + } + } + } + } + } +} + +} // namespace mlx::core diff --git a/mlx/backend/cpu/logsumexp.cpp b/mlx/backend/cpu/logsumexp.cpp new file mode 100644 index 000000000..56f0dab9f --- /dev/null +++ b/mlx/backend/cpu/logsumexp.cpp @@ -0,0 +1,140 @@ +// Copyright © 2023-2024 Apple Inc. + +#include +#include + +#include "mlx/backend/cpu/copy.h" +#include "mlx/backend/cpu/encoder.h" +#include "mlx/backend/cpu/simd/simd.h" +#include "mlx/primitives.h" +#include "mlx/types/limits.h" + +namespace mlx::core { + +namespace { + +using namespace mlx::core::simd; + +template +void logsumexp(const array& in, array& out, Stream stream) { + auto& encoder = cpu::get_command_encoder(stream); + encoder.set_input_array(in); + encoder.set_output_array(out); + + const T* in_ptr = in.data(); + T* out_ptr = out.data(); + + int M = in.shape().back(); + int L = in.data_size() / M; + + encoder.dispatch([in_ptr, out_ptr, M, L]() mutable { + constexpr int N = std::min(max_size, max_size); + + const T* current_in_ptr; + + for (int i = 0; i < L; i++, in_ptr += M, out_ptr += 1) { + // Find the maximum + current_in_ptr = in_ptr; + Simd vmaximum(-numeric_limits::infinity()); + size_t s = M; + while (s >= N) { + Simd vals = load(current_in_ptr); + vmaximum = maximum(vals, vmaximum); + current_in_ptr += N; + s -= N; + } + + AccT maximum = max(vmaximum); + while (s-- > 0) { + maximum = std::max(maximum, static_cast(*current_in_ptr)); + current_in_ptr++; + } + + // Compute the normalizer and the exponentials + Simd vnormalizer(0.0); + current_in_ptr = in_ptr; + s = M; + while (s >= N) { + Simd vexp = load(current_in_ptr); + vexp = exp(vexp - maximum); + vnormalizer = vnormalizer + vexp; + current_in_ptr += N; + s -= N; + } + AccT normalizer = sum(vnormalizer); + while (s-- > 0) { + AccT _exp = std::exp(*current_in_ptr - maximum); + normalizer += _exp; + current_in_ptr++; + } + // Normalize + *out_ptr = std::isinf(maximum) + ? static_cast(maximum) + : static_cast(std::log(normalizer) + maximum); + } + }); +} + +} // namespace + +void LogSumExp::eval_cpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + + // Make sure that the last dimension is contiguous + auto s = stream(); + auto& encoder = cpu::get_command_encoder(s); + auto ensure_contiguous = [&s, &encoder](const array& x) { + if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) { + return x; + } else { + auto x_copy = array(x.shape(), x.dtype(), nullptr, {}); + copy(x, x_copy, CopyType::General, s); + encoder.add_temporary(x_copy); + return x_copy; + } + }; + + auto in = ensure_contiguous(inputs[0]); + if (in.flags().row_contiguous) { + out.set_data(allocator::malloc(out.nbytes())); + } else { + auto n = in.shape(-1); + auto flags = in.flags(); + auto strides = in.strides(); + for (auto& s : strides) { + s /= n; + } + bool col_contig = strides[0] == 1; + for (int i = 1; col_contig && i < strides.size(); ++i) { + col_contig &= + (out.shape(i) == 1 || strides[i - 1] == out.shape(i) * strides[i]); + } + flags.col_contiguous = col_contig; + out.set_data( + allocator::malloc(in.nbytes() / n), + in.data_size() / n, + std::move(strides), + flags); + } + + switch (in.dtype()) { + case float32: + logsumexp(in, out, stream()); + break; + case float16: + logsumexp(in, out, stream()); + break; + case bfloat16: + logsumexp(in, out, stream()); + break; + case float64: + logsumexp(in, out, stream()); + break; + default: + throw std::runtime_error( + "[logsumexp] only supports floating point types"); + break; + } +} + +} // namespace mlx::core diff --git a/mlx/backend/cpu/primitives.cpp b/mlx/backend/cpu/primitives.cpp index 1dfae8524..2a612a2d9 100644 --- a/mlx/backend/cpu/primitives.cpp +++ b/mlx/backend/cpu/primitives.cpp @@ -205,8 +205,10 @@ void Concatenate::eval_cpu(const std::vector& inputs, array& out) { void Contiguous::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& in = inputs[0]; - if (in.flags().row_contiguous || - (allow_col_major_ && in.flags().col_contiguous)) { + constexpr size_t extra_bytes = 16384; + if (in.buffer_size() <= out.nbytes() + extra_bytes && + (in.flags().row_contiguous || + (allow_col_major_ && in.flags().col_contiguous))) { out.copy_shared_buffer(in); } else { copy(in, out, CopyType::General, stream()); diff --git a/mlx/backend/cpu/scan.cpp b/mlx/backend/cpu/scan.cpp index 1a44ebd39..199dbab35 100644 --- a/mlx/backend/cpu/scan.cpp +++ b/mlx/backend/cpu/scan.cpp @@ -3,6 +3,7 @@ #include #include "mlx/backend/common/utils.h" +#include "mlx/backend/cpu/binary_ops.h" #include "mlx/backend/cpu/copy.h" #include "mlx/backend/cpu/encoder.h" #include "mlx/backend/cpu/simd/simd.h" @@ -226,6 +227,16 @@ void scan_dispatch( scan_op(in, out, axis, reverse, inclusive, op, init); break; } + case Scan::LogAddExp: { + auto op = [](U a, T b) { + return detail::LogAddExp{}(a, static_cast(b)); + }; + auto init = (issubdtype(in.dtype(), floating)) + ? static_cast(-std::numeric_limits::infinity()) + : std::numeric_limits::min(); + scan_op(in, out, axis, reverse, inclusive, op, init); + break; + } } } diff --git a/mlx/backend/cpu/simd/accelerate_fp16_simd.h b/mlx/backend/cpu/simd/accelerate_fp16_simd.h index 1f21d2e18..950544895 100644 --- a/mlx/backend/cpu/simd/accelerate_fp16_simd.h +++ b/mlx/backend/cpu/simd/accelerate_fp16_simd.h @@ -17,7 +17,7 @@ struct ScalarT { #endif template <> -static constexpr int max_size = N; +inline constexpr int max_size = N; #define SIMD_FP16_DEFAULT_UNARY(op) \ template <> \ diff --git a/mlx/backend/cpu/simd/accelerate_simd.h b/mlx/backend/cpu/simd/accelerate_simd.h index a14d99103..37b3cdbd8 100644 --- a/mlx/backend/cpu/simd/accelerate_simd.h +++ b/mlx/backend/cpu/simd/accelerate_simd.h @@ -83,25 +83,25 @@ struct Simd { // Values chosen based on benchmarks on M3 Max // TODO: consider choosing these more optimally template <> -static constexpr int max_size = 16; +inline constexpr int max_size = 16; template <> -static constexpr int max_size = 16; +inline constexpr int max_size = 16; template <> -static constexpr int max_size = 8; +inline constexpr int max_size = 8; template <> -static constexpr int max_size = 4; +inline constexpr int max_size = 4; template <> -static constexpr int max_size = 16; +inline constexpr int max_size = 16; template <> -static constexpr int max_size = 16; +inline constexpr int max_size = 16; template <> -static constexpr int max_size = 8; +inline constexpr int max_size = 8; template <> -static constexpr int max_size = 4; +inline constexpr int max_size = 4; template <> -static constexpr int max_size = 8; +inline constexpr int max_size = 8; template <> -static constexpr int max_size = 4; +inline constexpr int max_size = 4; #define SIMD_DEFAULT_UNARY(name, op) \ template \ diff --git a/mlx/backend/cpu/simd/base_simd.h b/mlx/backend/cpu/simd/base_simd.h index bc416fc22..7e82a4d56 100644 --- a/mlx/backend/cpu/simd/base_simd.h +++ b/mlx/backend/cpu/simd/base_simd.h @@ -87,7 +87,6 @@ DEFAULT_UNARY(cosh, std::cosh) DEFAULT_UNARY(expm1, std::expm1) DEFAULT_UNARY(floor, std::floor) DEFAULT_UNARY(log, std::log) -DEFAULT_UNARY(log2, std::log2) DEFAULT_UNARY(log10, std::log10) DEFAULT_UNARY(log1p, std::log1p) DEFAULT_UNARY(sinh, std::sinh) @@ -95,6 +94,17 @@ DEFAULT_UNARY(sqrt, std::sqrt) DEFAULT_UNARY(tan, std::tan) DEFAULT_UNARY(tanh, std::tanh) +template +Simd log2(Simd in) { + if constexpr (is_complex) { + auto out = std::log(in.value); + auto scale = decltype(out.real())(M_LN2); + return Simd{T{out.real() / scale, out.imag() / scale}}; + } else { + return Simd{std::log2(in.value)}; + } +} + template Simd operator~(Simd in) { return ~in.value; diff --git a/mlx/backend/cpu/softmax.cpp b/mlx/backend/cpu/softmax.cpp index 78e4a3e68..41d14f556 100644 --- a/mlx/backend/cpu/softmax.cpp +++ b/mlx/backend/cpu/softmax.cpp @@ -119,12 +119,7 @@ void Softmax::eval_cpu(const std::vector& inputs, array& out) { // Make sure that the last dimension is contiguous auto set_output = [s = stream(), &out](const array& x) { - bool no_copy = x.strides()[x.ndim() - 1] == 1; - if (x.ndim() > 1) { - auto s = x.strides()[x.ndim() - 2]; - no_copy &= (s == 0 || s == x.shape().back()); - } - if (no_copy) { + if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) { if (x.is_donatable()) { out.copy_shared_buffer(x); } else { @@ -146,18 +141,6 @@ void Softmax::eval_cpu(const std::vector& inputs, array& out) { auto in = set_output(inputs[0]); switch (in.dtype()) { - case bool_: - case uint8: - case uint16: - case uint32: - case uint64: - case int8: - case int16: - case int32: - case int64: - throw std::runtime_error( - "Softmax is defined only for floating point types"); - break; case float32: softmax(in, out, stream()); break; @@ -178,9 +161,9 @@ void Softmax::eval_cpu(const std::vector& inputs, array& out) { case float64: softmax(in, out, stream()); break; - case complex64: - throw std::invalid_argument( - "[Softmax] Not yet implemented for complex64"); + default: + throw std::runtime_error( + "[softmax] Only defined for floating point types."); break; } } diff --git a/mlx/backend/cpu/unary.cpp b/mlx/backend/cpu/unary.cpp index 89d1cafb3..eafe98866 100644 --- a/mlx/backend/cpu/unary.cpp +++ b/mlx/backend/cpu/unary.cpp @@ -1,5 +1,8 @@ // Copyright © 2024 Apple Inc. +// Required for using M_LN2 in MSVC. +#define _USE_MATH_DEFINES + #include #include "mlx/backend/cpu/unary.h" diff --git a/mlx/backend/cpu/unary_ops.h b/mlx/backend/cpu/unary_ops.h index aab1e0de2..633230658 100644 --- a/mlx/backend/cpu/unary_ops.h +++ b/mlx/backend/cpu/unary_ops.h @@ -86,13 +86,14 @@ struct Sign { template Simd operator()(Simd x) { auto z = Simd{0}; + auto o = Simd{1}; + auto m = Simd{-1}; if constexpr (std::is_unsigned_v) { - return x != z; + return simd::select(x == z, z, o); } else if constexpr (std::is_same_v) { return simd::select(x == z, x, Simd(x / simd::abs(x))); } else { - return simd::select( - x < z, Simd{-1}, simd::select(x > z, Simd{1}, z)); + return simd::select(x < z, m, simd::select(x > z, o, z)); } } SINGLE() diff --git a/mlx/backend/metal/CMakeLists.txt b/mlx/backend/metal/CMakeLists.txt index e49201277..332c560f8 100644 --- a/mlx/backend/metal/CMakeLists.txt +++ b/mlx/backend/metal/CMakeLists.txt @@ -47,6 +47,7 @@ if(MLX_METAL_JIT) make_jit_source(binary) make_jit_source(binary_two) make_jit_source(fft kernels/fft/radix.h kernels/fft/readwrite.h) + make_jit_source(logsumexp) make_jit_source(ternary) make_jit_source(softmax) make_jit_source(scan) @@ -60,6 +61,7 @@ if(MLX_METAL_JIT) kernels/steel/gemm/transforms.h) make_jit_source(steel/gemm/kernels/steel_gemm_fused) make_jit_source(steel/gemm/kernels/steel_gemm_masked kernels/steel/defines.h) + make_jit_source(steel/gemm/kernels/steel_gemm_gather) make_jit_source(steel/gemm/kernels/steel_gemm_splitk) make_jit_source( steel/conv/conv @@ -95,6 +97,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp ${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.cpp ${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp ${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cpp ${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp diff --git a/mlx/backend/metal/allocator.cpp b/mlx/backend/metal/allocator.cpp index d7b84a165..0a69dd261 100644 --- a/mlx/backend/metal/allocator.cpp +++ b/mlx/backend/metal/allocator.cpp @@ -3,6 +3,7 @@ #include "mlx/backend/metal/metal.h" #include "mlx/backend/metal/metal_impl.h" #include "mlx/backend/metal/resident.h" +#include "mlx/memory.h" #include #include @@ -32,8 +33,11 @@ namespace metal { namespace { -BufferCache::BufferCache(MTL::Device* device) - : device_(device), head_(nullptr), tail_(nullptr), pool_size_(0) {} +BufferCache::BufferCache(ResidencySet& residency_set) + : head_(nullptr), + tail_(nullptr), + pool_size_(0), + residency_set_(residency_set) {} BufferCache::~BufferCache() { auto pool = metal::new_scoped_memory_pool(); @@ -44,6 +48,9 @@ int BufferCache::clear() { int n_release = 0; for (auto& [size, holder] : buffer_pool_) { if (holder->buf) { + if (!holder->buf->heap()) { + residency_set_.erase(holder->buf); + } holder->buf->release(); n_release++; } @@ -101,6 +108,9 @@ int BufferCache::release_cached_buffers(size_t min_bytes_to_free) { while (tail_ && (total_bytes_freed < min_bytes_to_free)) { if (tail_->buf) { total_bytes_freed += tail_->buf->length(); + if (!tail_->buf->heap()) { + residency_set_.erase(tail_->buf); + } tail_->buf->release(); tail_->buf = nullptr; n_release++; @@ -155,7 +165,7 @@ void BufferCache::remove_from_list(BufferCache::BufferHolder* to_remove) { MetalAllocator::MetalAllocator() : device_(device(mlx::core::Device::gpu).mtl_device()), residency_set_(device_), - buffer_cache_(device_) { + buffer_cache_(residency_set_) { auto pool = metal::new_scoped_memory_pool(); auto memsize = std::get(device_info().at("memory_size")); auto max_rec_size = @@ -262,9 +272,13 @@ Buffer MetalAllocator::malloc(size_t size) { if (!buf) { buf = device_->newBuffer(size, resource_options); } + if (!buf) { + return Buffer{nullptr}; + } lk.lock(); - if (buf) { - num_resources_++; + num_resources_++; + if (!buf->heap()) { + residency_set_.insert(buf); } } @@ -278,10 +292,6 @@ Buffer MetalAllocator::malloc(size_t size) { get_cache_memory() - max_pool_size_); } - if (!buf->heap()) { - residency_set_.insert(buf); - } - return Buffer{static_cast(buf)}; } @@ -297,14 +307,14 @@ void MetalAllocator::free(Buffer buffer) { return; } std::unique_lock lk(mutex_); - if (!buf->heap()) { - residency_set_.erase(buf); - } active_memory_ -= buf->length(); if (get_cache_memory() < max_pool_size_) { buffer_cache_.recycle_to_cache(buf); } else { num_resources_--; + if (!buf->heap()) { + residency_set_.erase(buf); + } lk.unlock(); auto pool = metal::new_scoped_memory_pool(); buf->release(); @@ -323,40 +333,40 @@ MetalAllocator& allocator() { return *allocator_; } +} // namespace metal + size_t set_cache_limit(size_t limit) { - return allocator().set_cache_limit(limit); + return metal::allocator().set_cache_limit(limit); } size_t set_memory_limit(size_t limit) { - return allocator().set_memory_limit(limit); + return metal::allocator().set_memory_limit(limit); } size_t get_memory_limit() { - return allocator().get_memory_limit(); + return metal::allocator().get_memory_limit(); } size_t set_wired_limit(size_t limit) { - if (limit > - std::get(device_info().at("max_recommended_working_set_size"))) { + if (limit > std::get(metal::device_info().at( + "max_recommended_working_set_size"))) { throw std::invalid_argument( "[metal::set_wired_limit] Setting a wired limit larger than " "the maximum working set size is not allowed."); } - return allocator().set_wired_limit(limit); + return metal::allocator().set_wired_limit(limit); } size_t get_active_memory() { - return allocator().get_active_memory(); + return metal::allocator().get_active_memory(); } size_t get_peak_memory() { - return allocator().get_peak_memory(); + return metal::allocator().get_peak_memory(); } void reset_peak_memory() { - allocator().reset_peak_memory(); + metal::allocator().reset_peak_memory(); } size_t get_cache_memory() { - return allocator().get_cache_memory(); + return metal::allocator().get_cache_memory(); } void clear_cache() { - return allocator().clear_cache(); + return metal::allocator().clear_cache(); } -} // namespace metal - } // namespace mlx::core diff --git a/mlx/backend/metal/allocator.h b/mlx/backend/metal/allocator.h index 8b77ff6c1..227b09e91 100644 --- a/mlx/backend/metal/allocator.h +++ b/mlx/backend/metal/allocator.h @@ -18,7 +18,7 @@ namespace { class BufferCache { public: - BufferCache(MTL::Device* device); + BufferCache(ResidencySet& residency_set); ~BufferCache(); MTL::Buffer* reuse_from_cache(size_t size); @@ -42,13 +42,11 @@ class BufferCache { void add_at_head(BufferHolder* to_add); void remove_from_list(BufferHolder* to_remove); - MTL::Device* device_; - MTL::Heap* heap_{nullptr}; - std::multimap buffer_pool_; BufferHolder* head_; BufferHolder* tail_; size_t pool_size_; + ResidencySet& residency_set_; }; } // namespace diff --git a/mlx/backend/metal/conv.cpp b/mlx/backend/metal/conv.cpp index c4803a380..9075ea4c5 100644 --- a/mlx/backend/metal/conv.cpp +++ b/mlx/backend/metal/conv.cpp @@ -712,6 +712,65 @@ void winograd_conv_2D_gpu( } } +void depthwise_conv_2D_gpu( + const Stream& s, + metal::Device& d, + const array& in, + const array& wt, + array out, + const MLXConvParams<2>& conv_params) { + std::ostringstream kname; + kname << "depthwise_conv_2d_" << type_to_name(out); + std::string base_name = kname.str(); + + const int N = conv_params.N; + const int ker_h = conv_params.wS[0]; + const int ker_w = conv_params.wS[1]; + const int str_h = conv_params.str[0]; + const int str_w = conv_params.str[1]; + const int tc = 8; + const int tw = 8; + const int th = 4; + const bool do_flip = conv_params.flip; + + metal::MTLFCList func_consts = { + {&ker_h, MTL::DataType::DataTypeInt, 00}, + {&ker_w, MTL::DataType::DataTypeInt, 01}, + {&str_h, MTL::DataType::DataTypeInt, 10}, + {&str_w, MTL::DataType::DataTypeInt, 11}, + {&th, MTL::DataType::DataTypeInt, 100}, + {&tw, MTL::DataType::DataTypeInt, 101}, + {&do_flip, MTL::DataType::DataTypeBool, 200}, + }; + + // clang-format off + kname << "_ker_h_" << ker_h + << "_ker_w_" << ker_w + << "_str_h_" << str_h + << "_str_w_" << str_w + << "_tgp_h_" << th + << "_tgp_w_" << tw + << "_do_flip_" << (do_flip ? 't' : 'n'); // clang-format on + + std::string hash_name = kname.str(); + + auto& compute_encoder = d.get_command_encoder(s.index); + auto kernel = d.get_kernel(base_name, "mlx", hash_name, func_consts); + compute_encoder.set_compute_pipeline_state(kernel); + + compute_encoder.set_input_array(in, 0); + compute_encoder.set_input_array(wt, 1); + compute_encoder.set_output_array(out, 2); + + compute_encoder.set_bytes(conv_params, 3); + + MTL::Size group_dims = MTL::Size(tc, tw, th); + MTL::Size grid_dims = MTL::Size( + conv_params.C / tc, conv_params.oS[1] / tw, (conv_params.oS[0] / th) * N); + + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); +} + void conv_2D_gpu( const Stream& s, metal::Device& d, @@ -754,11 +813,20 @@ void conv_2D_gpu( bool is_kdil_one = conv_params.kdil[0] == 1 && conv_params.kdil[1] == 1; bool is_idil_one = conv_params.idil[0] == 1 && conv_params.idil[1] == 1; - if (groups > 1) { + if (is_idil_one && groups > 1) { const int C_per_group = conv_params.C / groups; const int O_per_group = conv_params.O / groups; - if (is_idil_one && (C_per_group <= 4 || C_per_group % 16 == 0) && + if (C_per_group == 1 && O_per_group == 1 && is_kdil_one && + conv_params.wS[0] <= 7 && conv_params.wS[1] <= 7 && + conv_params.str[0] <= 2 && conv_params.str[1] <= 2 && + conv_params.oS[0] % 8 == 0 && conv_params.oS[1] % 8 == 0 && + conv_params.wt_strides[1] == conv_params.wS[1] && + conv_params.C % 16 == 0 && conv_params.C == conv_params.O) { + return depthwise_conv_2D_gpu(s, d, in, wt, out, conv_params); + } + + if ((C_per_group <= 4 || C_per_group % 16 == 0) && (O_per_group <= 16 || O_per_group % 16 == 0)) { return implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params); } else { diff --git a/mlx/backend/metal/device.cpp b/mlx/backend/metal/device.cpp index 930e570e2..95aeb1cc9 100644 --- a/mlx/backend/metal/device.cpp +++ b/mlx/backend/metal/device.cpp @@ -55,7 +55,10 @@ std::pair load_library_from_path( } #ifdef SWIFTPM_BUNDLE -MTL::Library* try_load_bundle(MTL::Device* device, NS::URL* url) { +MTL::Library* try_load_bundle( + MTL::Device* device, + NS::URL* url, + const std::string& lib_name) { std::string bundle_path = std::string(url->fileSystemRepresentation()) + "/" + SWIFTPM_BUNDLE + ".bundle"; auto bundle = NS::Bundle::alloc()->init( @@ -63,8 +66,8 @@ MTL::Library* try_load_bundle(MTL::Device* device, NS::URL* url) { if (bundle != nullptr) { std::string resource_path = std::string(bundle->resourceURL()->fileSystemRepresentation()) + "/" + - "default.metallib"; - auto [lib, error] = load_library_from_path(device, resource_path.c_str()); + lib_name + ".metallib" auto [lib, error] = + load_library_from_path(device, resource_path.c_str()); if (lib) { return lib; } @@ -73,51 +76,124 @@ MTL::Library* try_load_bundle(MTL::Device* device, NS::URL* url) { } #endif +// Firstly, search for the metallib in the same path as this binary +std::pair load_colocated_library( + MTL::Device* device, + const std::string& lib_name) { + std::string lib_path = get_colocated_mtllib_path(lib_name); + if (lib_path.size() != 0) { + return load_library_from_path(device, lib_path.c_str()); + } + return {nullptr, nullptr}; +} + +std::pair load_swiftpm_library( + MTL::Device* device, + const std::string& lib_name) { +#ifdef SWIFTPM_BUNDLE + MTL::Library* library = + try_load_bundle(device, NS::Bundle::mainBundle()->bundleURL(), lib_name); + if (library != nullptr) { + return {library, nullptr}; + } + auto bundles = NS::Bundle::allBundles(); + for (int i = 0, c = (int)bundles->count(); i < c; i++) { + auto bundle = reinterpret_cast(bundles->object(i)); + library = try_load_bundle(device, bundle->resourceURL()); + if (library != nullptr) { + return {library, nullptr}; + } + } +#endif + return {nullptr, nullptr}; +} + +MTL::Library* load_default_library(MTL::Device* device) { + NS::Error *error1, *error2, *error3; + MTL::Library* lib; + // First try the colocated mlx.metallib + std::tie(lib, error1) = load_colocated_library(device, "mlx"); + if (lib) { + return lib; + } + + // Then try default.metallib in a SwiftPM bundle if we have one + std::tie(lib, error2) = load_swiftpm_library(device, "default"); + if (lib) { + return lib; + } + + // Finally try default_mtllib_path + std::tie(lib, error3) = load_library_from_path(device, default_mtllib_path); + if (!lib) { + std::ostringstream msg; + msg << "Failed to load the default metallib. "; + if (error1 != nullptr) { + msg << error1->localizedDescription()->utf8String() << " "; + } + if (error2 != nullptr) { + msg << error2->localizedDescription()->utf8String() << " "; + } + if (error3 != nullptr) { + msg << error3->localizedDescription()->utf8String() << " "; + } + throw std::runtime_error(msg.str()); + } + return lib; +} + MTL::Library* load_library( MTL::Device* device, - const std::string& lib_name = "mlx", - const char* lib_path = default_mtllib_path) { - // Firstly, search for the metallib in the same path as this binary - std::string first_path = get_colocated_mtllib_path(lib_name); - if (first_path.size() != 0) { - auto [lib, error] = load_library_from_path(device, first_path.c_str()); + const std::string& lib_name, + const std::string& lib_path) { + // We have been given a path that ends in metallib so try to load it + if (lib_path.size() > 9 && + std::equal(lib_path.end() - 9, lib_path.end(), ".metallib")) { + auto [lib, error] = load_library_from_path(device, lib_path.c_str()); + if (!lib) { + std::ostringstream msg; + msg << "Failed to load the metallib from <" << lib_path << "> with error " + << error->localizedDescription()->utf8String(); + throw std::runtime_error(msg.str()); + } + } + + // We have been given a path so try to load from lib_path / lib_name.metallib + if (lib_path.size() > 0) { + std::string full_path = lib_path + "/" + lib_name + ".metallib"; + auto [lib, error] = load_library_from_path(device, full_path.c_str()); + if (!lib) { + std::ostringstream msg; + msg << "Failed to load the metallib from <" << full_path + << "> with error " << error->localizedDescription()->utf8String(); + throw std::runtime_error(msg.str()); + } + } + + // Try to load the colocated library + { + auto [lib, error] = load_colocated_library(device, lib_name); if (lib) { return lib; } } -#ifdef SWIFTPM_BUNDLE - // try to load from a swiftpm resource bundle -- scan the available bundles to - // find one that contains the named bundle + // Try to load the library from swiftpm { - MTL::Library* library = - try_load_bundle(device, NS::Bundle::mainBundle()->bundleURL()); - if (library != nullptr) { - return library; - } - auto bundles = NS::Bundle::allBundles(); - for (int i = 0, c = (int)bundles->count(); i < c; i++) { - auto bundle = reinterpret_cast(bundles->object(i)); - library = try_load_bundle(device, bundle->resourceURL()); - if (library != nullptr) { - return library; - } + auto [lib, error] = load_swiftpm_library(device, lib_name); + if (lib) { + return lib; } } -#endif - // Couldn't find it so let's load it from default_mtllib_path - { - auto [lib, error] = load_library_from_path(device, lib_path); - if (!lib) { - std::ostringstream msg; - msg << error->localizedDescription()->utf8String() << "\n" - << "Failed to load device library from <" << lib_path << ">" - << " or <" << first_path << ">."; - throw std::runtime_error(msg.str()); - } - return lib; - } + std::ostringstream msg; + msg << "Failed to load the metallib " << lib_name << ".metallib. " + << "We attempted to load it from <" << get_colocated_mtllib_path(lib_name) + << ">"; +#ifdef SWIFTPM_BUNDLE + msg << " and from the Swift PM bundle."; +#endif + throw std::runtime_error(msg.str()); } } // namespace @@ -210,7 +286,7 @@ void CommandEncoder::barrier() { Device::Device() { auto pool = new_scoped_memory_pool(); device_ = load_device(); - library_map_ = {{"mlx", load_library(device_)}}; + library_map_ = {{"mlx", load_default_library(device_)}}; arch_ = std::string(device_->architecture()->name()->utf8String()); auto arch = arch_.back(); switch (arch) { diff --git a/mlx/backend/metal/device.h b/mlx/backend/metal/device.h index 1fe7cf76f..bb0e93147 100644 --- a/mlx/backend/metal/device.h +++ b/mlx/backend/metal/device.h @@ -189,15 +189,7 @@ class Device { void register_library( const std::string& lib_name, - const std::string& lib_path); - - // Note, this should remain in the header so that it is not dynamically - // linked - void register_library(const std::string& lib_name) { - if (auto it = library_map_.find(lib_name); it == library_map_.end()) { - register_library(lib_name, get_colocated_mtllib_path(lib_name)); - } - } + const std::string& lib_path = ""); MTL::Library* get_library( const std::string& name, diff --git a/mlx/backend/metal/event.cpp b/mlx/backend/metal/event.cpp index 58d1521c9..246d6bcc5 100644 --- a/mlx/backend/metal/event.cpp +++ b/mlx/backend/metal/event.cpp @@ -24,10 +24,6 @@ void Event::wait() { } } -void Event::signal() { - static_cast(event_.get())->setSignaledValue(value()); -} - void Event::wait(Stream stream) { if (stream.device == Device::cpu) { scheduler::enqueue(stream, [*this]() mutable { wait(); }); @@ -42,7 +38,9 @@ void Event::wait(Stream stream) { void Event::signal(Stream stream) { if (stream.device == Device::cpu) { - scheduler::enqueue(stream, [*this]() mutable { signal(); }); + scheduler::enqueue(stream, [*this]() mutable { + static_cast(event_.get())->setSignaledValue(value()); + }); } else { auto& d = metal::device(stream.device); d.end_encoding(stream.index); diff --git a/mlx/backend/metal/fft.cpp b/mlx/backend/metal/fft.cpp index 95678279e..153c62c02 100644 --- a/mlx/backend/metal/fft.cpp +++ b/mlx/backend/metal/fft.cpp @@ -356,20 +356,14 @@ void multi_upload_bluestein_fft( bool inverse, bool real, FFTPlan& plan, - std::vector copies, + std::vector& copies, const Stream& s) { // TODO(alexbarron) Implement fused kernels for mutli upload bluestein's // algorithm int n = inverse ? out.shape(axis) : in.shape(axis); auto [w_k, w_q] = compute_bluestein_constants(n, plan.bluestein_n); - - // Broadcast w_q and w_k to the batch size - Strides b_strides(in.ndim(), 0); - b_strides[axis] = 1; - array w_k_broadcast({}, complex64, nullptr, {}); - array w_q_broadcast({}, complex64, nullptr, {}); - w_k_broadcast.copy_shared_buffer(w_k, b_strides, {}, w_k.data_size()); - w_q_broadcast.copy_shared_buffer(w_q, b_strides, {}, w_q.data_size()); + copies.push_back(w_k); + copies.push_back(w_q); auto temp_shape = inverse ? out.shape() : in.shape(); array temp(temp_shape, complex64, nullptr, {}); @@ -378,13 +372,13 @@ void multi_upload_bluestein_fft( if (real && !inverse) { // Convert float32->complex64 copy_gpu(in, temp, CopyType::General, s); + copies.push_back(temp); } else if (real && inverse) { int back_offset = n % 2 == 0 ? 2 : 1; auto slice_shape = in.shape(); slice_shape[axis] -= back_offset; array slice_temp(slice_shape, complex64, nullptr, {}); array conj_temp(in.shape(), complex64, nullptr, {}); - copies.push_back(slice_temp); copies.push_back(conj_temp); Shape rstarts(in.ndim(), 0); @@ -394,19 +388,28 @@ void multi_upload_bluestein_fft( unary_op_gpu({in}, conj_temp, "Conjugate", s); slice_gpu(in, slice_temp, rstarts, rstrides, s); concatenate_gpu({conj_temp, slice_temp}, temp, (int)axis, s); + copies.push_back(temp); } else if (inverse) { unary_op_gpu({in}, temp, "Conjugate", s); + copies.push_back(temp); } else { temp.copy_shared_buffer(in); } + Strides b_strides(in.ndim(), 0); + b_strides[axis] = 1; + array w_k_broadcast(temp.shape(), complex64, nullptr, {}); + w_k_broadcast.copy_shared_buffer(w_k, b_strides, {}, w_k.data_size()); binary_op_gpu({temp, w_k_broadcast}, temp1, "Multiply", s); std::vector> pads; auto padded_shape = out.shape(); padded_shape[axis] = plan.bluestein_n; array pad_temp(padded_shape, complex64, nullptr, {}); - pad_gpu(temp1, array(complex64_t{0.0f, 0.0f}), pad_temp, {(int)axis}, {0}, s); + auto zero = array(complex64_t{0.0f, 0.0f}); + copies.push_back(zero); + pad_gpu(temp1, zero, pad_temp, {(int)axis}, {0}, s); + copies.push_back(pad_temp); array pad_temp1(padded_shape, complex64, nullptr, {}); fft_op( @@ -418,7 +421,10 @@ void multi_upload_bluestein_fft( FourStepParams(), /*inplace=*/false, s); + copies.push_back(pad_temp1); + array w_q_broadcast(pad_temp1.shape(), complex64, nullptr, {}); + w_q_broadcast.copy_shared_buffer(w_q, b_strides, {}, w_q.data_size()); binary_op_gpu_inplace({pad_temp1, w_q_broadcast}, pad_temp, "Multiply", s); fft_op( @@ -435,9 +441,11 @@ void multi_upload_bluestein_fft( Shape starts(in.ndim(), 0); Shape strides(in.ndim(), 1); starts[axis] = plan.bluestein_n - offset - n; - slice_gpu(pad_temp1, temp, starts, strides, s); - binary_op_gpu_inplace({temp, w_k_broadcast}, temp1, "Multiply", s); + array temp2(temp_shape, complex64, nullptr, {}); + slice_gpu(pad_temp1, temp2, starts, strides, s); + + binary_op_gpu_inplace({temp2, w_k_broadcast}, temp1, "Multiply", s); if (real && !inverse) { Shape rstarts(in.ndim(), 0); @@ -449,26 +457,21 @@ void multi_upload_bluestein_fft( array temp_float(out.shape(), out.dtype(), nullptr, {}); copies.push_back(temp_float); copies.push_back(inv_n); + copies.push_back(temp1); copy_gpu(temp1, temp_float, CopyType::General, s); binary_op_gpu({temp_float, inv_n}, out, "Multiply", s); } else if (inverse) { auto inv_n = array({1.0f / n}, {1}, complex64); - unary_op_gpu({temp1}, temp, "Conjugate", s); - binary_op_gpu({temp, inv_n}, out, "Multiply", s); + array temp3(temp_shape, complex64, nullptr, {}); + unary_op_gpu({temp1}, temp3, "Conjugate", s); + binary_op_gpu({temp3, inv_n}, out, "Multiply", s); copies.push_back(inv_n); + copies.push_back(temp1); + copies.push_back(temp3); } else { out.copy_shared_buffer(temp1); } - - copies.push_back(w_k); - copies.push_back(w_q); - copies.push_back(w_k_broadcast); - copies.push_back(w_q_broadcast); - copies.push_back(temp); - copies.push_back(temp1); - copies.push_back(pad_temp); - copies.push_back(pad_temp1); } void four_step_fft( @@ -478,8 +481,9 @@ void four_step_fft( bool inverse, bool real, FFTPlan& plan, - std::vector copies, - const Stream& s) { + std::vector& copies, + const Stream& s, + bool in_place) { auto& d = metal::device(s.device); if (plan.bluestein_n == -1) { @@ -492,7 +496,14 @@ void four_step_fft( in, temp, axis, inverse, real, four_step_params, /*inplace=*/false, s); four_step_params.first_step = false; fft_op( - temp, out, axis, inverse, real, four_step_params, /*inplace=*/false, s); + temp, + out, + axis, + inverse, + real, + four_step_params, + /*inplace=*/in_place, + s); copies.push_back(temp); } else { multi_upload_bluestein_fft(in, out, axis, inverse, real, plan, copies, s); @@ -574,7 +585,7 @@ void fft_op( auto plan = plan_fft(n); if (plan.four_step) { - four_step_fft(in, out, axis, inverse, real, plan, copies, s); + four_step_fft(in, out, axis, inverse, real, plan, copies, s, inplace); d.add_temporaries(std::move(copies), s.index); return; } diff --git a/mlx/backend/metal/jit/arange.h b/mlx/backend/metal/jit/arange.h deleted file mode 100644 index 0c224dca4..000000000 --- a/mlx/backend/metal/jit/arange.h +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright © 2024 Apple Inc. - -constexpr std::string_view arange_kernels = R"( -template [[host_name("{0}")]] [[kernel]] void arange<{1}>( - constant const {1}& start, - constant const {1}& step, - device {1}* out, - uint index [[thread_position_in_grid]]); -)"; diff --git a/mlx/backend/metal/jit/includes.h b/mlx/backend/metal/jit/includes.h index b14aa567b..27ae22d05 100644 --- a/mlx/backend/metal/jit/includes.h +++ b/mlx/backend/metal/jit/includes.h @@ -20,6 +20,7 @@ const char* copy(); const char* fft(); const char* gather_axis(); const char* hadamard(); +const char* logsumexp(); const char* quantized(); const char* ternary(); const char* scan(); @@ -32,6 +33,7 @@ const char* gemm(); const char* steel_gemm_fused(); const char* steel_gemm_masked(); const char* steel_gemm_splitk(); +const char* steel_gemm_gather(); const char* conv(); const char* steel_conv(); const char* steel_conv_general(); diff --git a/mlx/backend/metal/jit/softmax.h b/mlx/backend/metal/jit/softmax.h deleted file mode 100644 index a9672a050..000000000 --- a/mlx/backend/metal/jit/softmax.h +++ /dev/null @@ -1,23 +0,0 @@ -// Copyright © 2024 Apple Inc. - -constexpr std::string_view softmax_kernels = R"( -template [[host_name("block_{0}")]] [[kernel]] void -softmax_single_row<{1}, {2}>( - const device {1}* in, - device {1}* out, - constant int& axis_size, - uint gid [[thread_position_in_grid]], - uint _lid [[thread_position_in_threadgroup]], - uint simd_lane_id [[thread_index_in_simdgroup]], - uint simd_group_id [[simdgroup_index_in_threadgroup]]); -template [[host_name("looped_{0}")]] [[kernel]] void -softmax_looped<{1}, {2}>( - const device {1}* in, - device {1}* out, - constant int& axis_size, - uint gid [[threadgroup_position_in_grid]], - uint lid [[thread_position_in_threadgroup]], - uint lsize [[threads_per_threadgroup]], - uint simd_lane_id [[thread_index_in_simdgroup]], - uint simd_group_id [[simdgroup_index_in_threadgroup]]); -)"; diff --git a/mlx/backend/metal/jit_kernels.cpp b/mlx/backend/metal/jit_kernels.cpp index a9cc267e1..5206c9b54 100644 --- a/mlx/backend/metal/jit_kernels.cpp +++ b/mlx/backend/metal/jit_kernels.cpp @@ -1,8 +1,6 @@ // Copyright © 2024 Apple Inc. #include "mlx/backend/common/compiled.h" -#include "mlx/backend/metal/jit/arange.h" #include "mlx/backend/metal/jit/includes.h" -#include "mlx/backend/metal/jit/softmax.h" #include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/utils.h" @@ -21,13 +19,11 @@ MTL::ComputePipelineState* get_arange_kernel( const std::string& kernel_name, const array& out) { auto lib = d.get_library(kernel_name, [&]() { - std::ostringstream kernel_source; - kernel_source << metal::utils() << metal::arange() - << fmt::format( - arange_kernels, - kernel_name, - get_type_string(out.dtype())); - return kernel_source.str(); + std::string kernel_source = metal::utils(); + kernel_source += metal::arange(); + kernel_source += get_template_definition( + kernel_name, "arange", get_type_string(out.dtype())); + return kernel_source; }); return d.get_kernel(kernel_name, lib); } @@ -259,14 +255,34 @@ MTL::ComputePipelineState* get_softmax_kernel( const array& out) { std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1); auto lib = d.get_library(lib_name, [&] { - std::ostringstream kernel_source; - kernel_source << metal::utils() << metal::softmax() - << fmt::format( - softmax_kernels, - lib_name, - get_type_string(out.dtype()), - get_type_string(precise ? float32 : out.dtype())); - return kernel_source.str(); + std::string kernel_source = metal::utils(); + auto in_type = get_type_string(out.dtype()); + auto acc_type = get_type_string(precise ? float32 : out.dtype()); + kernel_source += metal::softmax(); + kernel_source += get_template_definition( + "block_" + lib_name, "softmax_single_row", in_type, acc_type); + kernel_source += get_template_definition( + "looped_" + lib_name, "softmax_looped", in_type, acc_type); + return kernel_source; + }); + return d.get_kernel(kernel_name, lib); +} + +MTL::ComputePipelineState* get_logsumexp_kernel( + metal::Device& d, + const std::string& kernel_name, + const array& out) { + std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1); + auto lib = d.get_library(lib_name, [&] { + auto t_str = get_type_string(out.dtype()); + std::string kernel_source; + kernel_source = metal::utils(); + kernel_source += metal::logsumexp(); + kernel_source += + get_template_definition("block_" + lib_name, "logsumexp", t_str); + kernel_source += get_template_definition( + "looped_" + lib_name, "logsumexp_looped", t_str); + return kernel_source; }); return d.get_kernel(kernel_name, lib); } @@ -568,6 +584,44 @@ MTL::ComputePipelineState* get_steel_gemm_masked_kernel( return d.get_kernel(kernel_name, lib); } +MTL::ComputePipelineState* get_steel_gemm_gather_kernel( + metal::Device& d, + const std::string& kernel_name, + const std::string& hash_name, + const metal::MTLFCList& func_consts, + const array& out, + bool transpose_a, + bool transpose_b, + int bm, + int bn, + int bk, + int wm, + int wn, + bool rhs) { + const auto& lib_name = kernel_name; + auto lib = d.get_library(lib_name, [&]() { + std::string kernel_source; + concatenate( + kernel_source, + metal::utils(), + metal::gemm(), + metal::steel_gemm_gather(), + get_template_definition( + lib_name, + rhs ? "gather_mm_rhs" : "gather_mm", + get_type_string(out.dtype()), + bm, + bn, + bk, + wm, + wn, + transpose_a, + transpose_b)); + return kernel_source; + }); + return d.get_kernel(kernel_name, lib, hash_name, func_consts); +} + MTL::ComputePipelineState* get_gemv_masked_kernel( metal::Device& d, const std::string& kernel_name, @@ -698,4 +752,43 @@ MTL::ComputePipelineState* get_quantized_kernel( return d.get_kernel(kernel_name, lib); } +MTL::ComputePipelineState* get_gather_qmm_kernel( + metal::Device& d, + const std::string& kernel_name, + const std::string& hash_name, + const metal::MTLFCList& func_consts, + const array& x, + int group_size, + int bits, + int bm, + int bn, + int bk, + int wm, + int wn, + bool transpose) { + const auto& lib_name = kernel_name; + auto lib = d.get_library(lib_name, [&]() { + std::string kernel_source; + concatenate( + kernel_source, + metal::utils(), + metal::gemm(), + metal::quantized(), + get_template_definition( + lib_name, + "gather_qmm_rhs", + get_type_string(x.dtype()), + group_size, + bits, + bm, + bn, + bk, + wm, + wn, + transpose)); + return kernel_source; + }); + return d.get_kernel(kernel_name, lib, hash_name, func_consts); +} + } // namespace mlx::core diff --git a/mlx/backend/metal/kernels.h b/mlx/backend/metal/kernels.h index 63d17f959..6d8864385 100644 --- a/mlx/backend/metal/kernels.h +++ b/mlx/backend/metal/kernels.h @@ -59,6 +59,11 @@ MTL::ComputePipelineState* get_softmax_kernel( bool precise, const array& out); +MTL::ComputePipelineState* get_logsumexp_kernel( + metal::Device& d, + const std::string& kernel_name, + const array& out); + MTL::ComputePipelineState* get_scan_kernel( metal::Device& d, const std::string& kernel_name, @@ -155,6 +160,21 @@ MTL::ComputePipelineState* get_steel_gemm_masked_kernel( bool mn_aligned, bool k_aligned); +MTL::ComputePipelineState* get_steel_gemm_gather_kernel( + metal::Device& d, + const std::string& kernel_name, + const std::string& hash_name, + const metal::MTLFCList& func_consts, + const array& out, + bool transpose_a, + bool transpose_b, + int bm, + int bn, + int bk, + int wm, + int wn, + bool rhs); + MTL::ComputePipelineState* get_steel_conv_kernel( metal::Device& d, const std::string& kernel_name, @@ -204,6 +224,21 @@ MTL::ComputePipelineState* get_quantized_kernel( const std::string& kernel_name, const std::string& template_def); +MTL::ComputePipelineState* get_gather_qmm_kernel( + metal::Device& d, + const std::string& kernel_name, + const std::string& hash_name, + const metal::MTLFCList& func_consts, + const array& x, + int group_size, + int bits, + int bm, + int bn, + int bk, + int wm, + int wn, + bool transpose); + // Create a GPU kernel template definition for JIT compilation template std::string diff --git a/mlx/backend/metal/kernels/CMakeLists.txt b/mlx/backend/metal/kernels/CMakeLists.txt index f7dae3121..3ee88ca46 100644 --- a/mlx/backend/metal/kernels/CMakeLists.txt +++ b/mlx/backend/metal/kernels/CMakeLists.txt @@ -13,6 +13,10 @@ function(build_kernel_base TARGET SRCFILE DEPS) if(MLX_METAL_DEBUG) set(METAL_FLAGS ${METAL_FLAGS} -gline-tables-only -frecord-sources) endif() + if(NOT CMAKE_OSX_DEPLOYMENT_TARGET STREQUAL "") + set(METAL_FLAGS ${METAL_FLAGS} + "-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}") + endif() if(MLX_METAL_VERSION GREATER_EQUAL 310) set(VERSION_INCLUDES ${PROJECT_SOURCE_DIR}/mlx/backend/metal/kernels/metal_3_1) @@ -65,6 +69,7 @@ set(STEEL_HEADERS steel/gemm/loader.h steel/gemm/transforms.h steel/gemm/kernels/steel_gemm_fused.h + steel/gemm/kernels/steel_gemm_gather.h steel/gemm/kernels/steel_gemm_masked.h steel/gemm/kernels/steel_gemm_splitk.h steel/utils/type_traits.h @@ -105,12 +110,14 @@ if(NOT MLX_METAL_JIT) build_kernel(quantized quantized.h ${STEEL_HEADERS}) build_kernel(scan scan.h) build_kernel(softmax softmax.h) + build_kernel(logsumexp logsumexp.h) build_kernel(sort sort.h) build_kernel(ternary ternary.h ternary_ops.h) build_kernel(unary unary.h unary_ops.h) build_kernel(steel/conv/kernels/steel_conv ${STEEL_HEADERS}) build_kernel(steel/conv/kernels/steel_conv_general ${STEEL_HEADERS}) build_kernel(steel/gemm/kernels/steel_gemm_fused ${STEEL_HEADERS}) + build_kernel(steel/gemm/kernels/steel_gemm_gather ${STEEL_HEADERS}) build_kernel(steel/gemm/kernels/steel_gemm_masked ${STEEL_HEADERS}) build_kernel(steel/gemm/kernels/steel_gemm_splitk ${STEEL_HEADERS}) build_kernel(gemv_masked steel/utils.h) diff --git a/mlx/backend/metal/kernels/arange.metal b/mlx/backend/metal/kernels/arange.metal index c2e325697..fb56c1c5f 100644 --- a/mlx/backend/metal/kernels/arange.metal +++ b/mlx/backend/metal/kernels/arange.metal @@ -5,11 +5,7 @@ #include "mlx/backend/metal/kernels/arange.h" #define instantiate_arange(tname, type) \ - template [[host_name("arange" #tname)]] [[kernel]] void arange( \ - constant const type& start, \ - constant const type& step, \ - device type* out, \ - uint index [[thread_position_in_grid]]); + instantiate_kernel("arange" #tname, arange, type) instantiate_arange(uint8, uint8_t) instantiate_arange(uint16, uint16_t) diff --git a/mlx/backend/metal/kernels/conv.metal b/mlx/backend/metal/kernels/conv.metal index 13ee239dc..620352144 100644 --- a/mlx/backend/metal/kernels/conv.metal +++ b/mlx/backend/metal/kernels/conv.metal @@ -275,6 +275,128 @@ instantiate_naive_conv_2d_blocks(float32, float); instantiate_naive_conv_2d_blocks(float16, half); instantiate_naive_conv_2d_blocks(bfloat16, bfloat16_t); +/////////////////////////////////////////////////////////////////////////////// +/// Depthwise convolution kernels +/////////////////////////////////////////////////////////////////////////////// + +constant int ker_h [[function_constant(00)]]; +constant int ker_w [[function_constant(01)]]; +constant int str_h [[function_constant(10)]]; +constant int str_w [[function_constant(11)]]; +constant int tgp_h [[function_constant(100)]]; +constant int tgp_w [[function_constant(101)]]; +constant bool do_flip [[function_constant(200)]]; + +constant int span_h = tgp_h * str_h + ker_h - 1; +constant int span_w = tgp_w * str_w + ker_w - 1; +constant int span_hw = span_h * span_w; + +template +[[kernel]] void depthwise_conv_2d( + const device T* in [[buffer(0)]], + const device T* wt [[buffer(1)]], + device T* out [[buffer(2)]], + const constant MLXConvParams<2>& params [[buffer(3)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint3 gid [[thread_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + constexpr int tc = 8; + constexpr int tw = 8; + constexpr int th = 4; + + constexpr int c_per_thr = 8; + + constexpr int TGH = th * 2 + 6; + constexpr int TGW = tw * 2 + 6; + constexpr int TGC = tc; + + threadgroup T ins[TGH * TGW * TGC]; + + const int n_tgblocks_h = params.oS[0] / th; + const int n = tid.z / n_tgblocks_h; + const int tghid = tid.z % n_tgblocks_h; + const int oh = tghid * th + lid.z; + const int ow = gid.y; + const int c = gid.x; + + in += n * params.in_strides[0]; + + // Load in + { + constexpr int n_threads = th * tw * tc; + const int tg_oh = (tghid * th) * str_h - params.pad[0]; + const int tg_ow = (tid.y * tw) * str_w - params.pad[1]; + const int tg_c = tid.x * tc; + + const int thread_idx = simd_gid * 32 + simd_lid; + constexpr int thr_per_hw = tc / c_per_thr; + constexpr int hw_per_group = n_threads / thr_per_hw; + + const int thr_c = thread_idx % thr_per_hw; + const int thr_hw = thread_idx / thr_per_hw; + + for (int hw = thr_hw; hw < span_hw; hw += hw_per_group) { + const int h = hw / span_w; + const int w = hw % span_w; + + const int ih = tg_oh + h; + const int iw = tg_ow + w; + + const int in_s_offset = h * span_w * TGC + w * TGC; + + if (ih >= 0 && ih < params.iS[0] && iw >= 0 && iw < params.iS[1]) { + const auto in_load = + in + ih * params.in_strides[1] + iw * params.in_strides[2] + tg_c; + + MLX_MTL_PRAGMA_UNROLL + for (int cc = 0; cc < c_per_thr; ++cc) { + ins[in_s_offset + c_per_thr * thr_c + cc] = + in_load[c_per_thr * thr_c + cc]; + } + } else { + MLX_MTL_PRAGMA_UNROLL + for (int cc = 0; cc < c_per_thr; ++cc) { + ins[in_s_offset + c_per_thr * thr_c + cc] = T(0); + } + } + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + wt += c * params.wt_strides[0]; + + const auto ins_ptr = + &ins[lid.z * str_h * span_w * TGC + lid.y * str_w * TGC + lid.x]; + float o = 0.; + for (int h = 0; h < ker_h; ++h) { + for (int w = 0; w < ker_w; ++w) { + int wt_h = h; + int wt_w = w; + if (do_flip) { + wt_h = ker_h - h - 1; + wt_w = ker_w - w - 1; + } + auto inv = ins_ptr[h * span_w * TGC + w * TGC]; + auto wtv = wt[wt_h * ker_w + wt_w]; + o += inv * wtv; + } + } + threadgroup_barrier(mem_flags::mem_none); + + out += n * params.out_strides[0] + oh * params.out_strides[1] + + ow * params.out_strides[2]; + out[c] = static_cast(o); +} + +#define instantiate_depthconv2d(iname, itype) \ + instantiate_kernel("depthwise_conv_2d_" #iname, depthwise_conv_2d, itype) + +instantiate_depthconv2d(float32, float); +instantiate_depthconv2d(float16, half); +instantiate_depthconv2d(bfloat16, bfloat16_t); + /////////////////////////////////////////////////////////////////////////////// /// Winograd kernels /////////////////////////////////////////////////////////////////////////////// diff --git a/mlx/backend/metal/kernels/fft.h b/mlx/backend/metal/kernels/fft.h index a4869a2ac..e478a85b6 100644 --- a/mlx/backend/metal/kernels/fft.h +++ b/mlx/backend/metal/kernels/fft.h @@ -483,4 +483,4 @@ template < perform_fft(fft_idx, &p, m, n, buf); read_writer.write_strided(stride, overall_n); -} \ No newline at end of file +} diff --git a/mlx/backend/metal/kernels/gemv.metal b/mlx/backend/metal/kernels/gemv.metal index f21c35d97..baaf84f2d 100644 --- a/mlx/backend/metal/kernels/gemv.metal +++ b/mlx/backend/metal/kernels/gemv.metal @@ -341,7 +341,7 @@ struct GEMVTKernel { MLX_MTL_PRAGMA_UNROLL for (int tm = 0; tm < TM; tm++) { - auto vc = float(v_coeff[tm]); + auto vc = static_cast(v_coeff[tm]); for (int tn = 0; tn < TN; tn++) { inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn]; } diff --git a/mlx/backend/metal/kernels/layer_norm.metal b/mlx/backend/metal/kernels/layer_norm.metal index 4674a4228..51570e48d 100644 --- a/mlx/backend/metal/kernels/layer_norm.metal +++ b/mlx/backend/metal/kernels/layer_norm.metal @@ -493,71 +493,11 @@ template } // clang-format off -#define instantiate_layer_norm_single_row(name, itype) \ - template [[host_name("layer_norm" #name)]] [[kernel]] void \ - layer_norm_single_row( \ - const device itype* x, \ - const device itype* w, \ - const device itype* b, \ - device itype* out, \ - constant float& eps, \ - constant uint& axis_size, \ - constant uint& w_stride, \ - constant uint& b_stride, \ - uint gid [[thread_position_in_grid]], \ - uint lid [[thread_position_in_threadgroup]], \ - uint simd_lane_id [[thread_index_in_simdgroup]], \ - uint simd_group_id [[simdgroup_index_in_threadgroup]]); \ - template [[host_name("vjp_layer_norm" #name)]] [[kernel]] void \ - vjp_layer_norm_single_row( \ - const device itype* x, \ - const device itype* w, \ - const device itype* g, \ - device itype* gx, \ - device itype* gw, \ - constant float& eps, \ - constant uint& axis_size, \ - constant uint& w_stride, \ - uint gid [[thread_position_in_grid]], \ - uint lid [[thread_position_in_threadgroup]], \ - uint simd_lane_id [[thread_index_in_simdgroup]], \ - uint simd_group_id [[simdgroup_index_in_threadgroup]]); - -#define instantiate_layer_norm_looped(name, itype) \ - template [[host_name("layer_norm_looped" #name)]] [[kernel]] void \ - layer_norm_looped( \ - const device itype* x, \ - const device itype* w, \ - const device itype* b, \ - device itype* out, \ - constant float& eps, \ - constant uint& axis_size, \ - constant uint& w_stride, \ - constant uint& b_stride, \ - uint gid [[thread_position_in_grid]], \ - uint lid [[thread_position_in_threadgroup]], \ - uint lsize [[threads_per_threadgroup]], \ - uint simd_lane_id [[thread_index_in_simdgroup]], \ - uint simd_group_id [[simdgroup_index_in_threadgroup]]); \ - template [[host_name("vjp_layer_norm_looped" #name)]] [[kernel]] void \ - vjp_layer_norm_looped( \ - const device itype* x, \ - const device itype* w, \ - const device itype* g, \ - device itype* gx, \ - device itype* gb, \ - constant float& eps, \ - constant uint& axis_size, \ - constant uint& w_stride, \ - uint gid [[thread_position_in_grid]], \ - uint lid [[thread_position_in_threadgroup]], \ - uint lsize [[threads_per_threadgroup]], \ - uint simd_lane_id [[thread_index_in_simdgroup]], \ - uint simd_group_id [[simdgroup_index_in_threadgroup]]); - -#define instantiate_layer_norm(name, itype) \ - instantiate_layer_norm_single_row(name, itype) \ - instantiate_layer_norm_looped(name, itype) +#define instantiate_layer_norm(name, itype) \ + instantiate_kernel("layer_norm" #name, layer_norm_single_row, itype) \ + instantiate_kernel("vjp_layer_norm" #name, vjp_layer_norm_single_row, itype) \ + instantiate_kernel("layer_norm_looped" #name, layer_norm_looped, itype) \ + instantiate_kernel("vjp_layer_norm_looped" #name, vjp_layer_norm_looped, itype) instantiate_layer_norm(float32, float) instantiate_layer_norm(float16, half) diff --git a/mlx/backend/metal/kernels/logsumexp.h b/mlx/backend/metal/kernels/logsumexp.h new file mode 100644 index 000000000..b6898e31e --- /dev/null +++ b/mlx/backend/metal/kernels/logsumexp.h @@ -0,0 +1,143 @@ +// Copyright © 2025 Apple Inc. + +template +[[kernel]] void logsumexp( + const device T* in, + device T* out, + constant int& axis_size, + uint gid [[threadgroup_position_in_grid]], + uint _lid [[thread_position_in_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]]) { + int lid = _lid; + + constexpr int SIMD_SIZE = 32; + + threadgroup AccT local_max[SIMD_SIZE]; + threadgroup AccT local_normalizer[SIMD_SIZE]; + + AccT ld[N_READS]; + + in += gid * size_t(axis_size) + lid * N_READS; + if (lid * N_READS + N_READS <= axis_size) { + for (int i = 0; i < N_READS; i++) { + ld[i] = AccT(in[i]); + } + } else { + for (int i = 0; i < N_READS; i++) { + ld[i] = + ((lid * N_READS + i) < axis_size) ? AccT(in[i]) : Limits::min; + } + } + if (simd_group_id == 0) { + local_max[simd_lane_id] = Limits::min; + local_normalizer[simd_lane_id] = 0; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Get the max + AccT maxval = Limits::finite_min; + for (int i = 0; i < N_READS; i++) { + maxval = (maxval < ld[i]) ? ld[i] : maxval; + } + maxval = simd_max(maxval); + if (simd_lane_id == 0) { + local_max[simd_group_id] = maxval; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + if (simd_group_id == 0) { + maxval = simd_max(local_max[simd_lane_id]); + if (simd_lane_id == 0) { + local_max[0] = maxval; + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + maxval = local_max[0]; + + // Compute exp(x_i - maxval) and store the partial sums in local_normalizer + AccT normalizer = 0; + for (int i = 0; i < N_READS; i++) { + normalizer += fast::exp(ld[i] - maxval); + } + normalizer = simd_sum(normalizer); + if (simd_lane_id == 0) { + local_normalizer[simd_group_id] = normalizer; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + if (simd_group_id == 0) { + normalizer = simd_sum(local_normalizer[simd_lane_id]); + if (simd_lane_id == 0) { + out[gid] = isinf(maxval) ? T(maxval) : T(log(normalizer) + maxval); + } + } +} + +template +[[kernel]] void logsumexp_looped( + const device T* in, + device T* out, + constant int& axis_size, + uint gid [[threadgroup_position_in_grid]], + uint lid [[thread_position_in_threadgroup]], + uint lsize [[threads_per_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]]) { + in += gid * size_t(axis_size); + + constexpr int SIMD_SIZE = 32; + + threadgroup AccT local_max[SIMD_SIZE]; + threadgroup AccT local_normalizer[SIMD_SIZE]; + + // Get the max and the normalizer in one go + AccT prevmax; + AccT maxval = Limits::finite_min; + AccT normalizer = 0; + for (int r = 0; r < static_cast(ceildiv(axis_size, N_READS * lsize)); + r++) { + int offset = r * lsize * N_READS + lid * N_READS; + AccT vals[N_READS]; + if (offset + N_READS <= axis_size) { + for (int i = 0; i < N_READS; i++) { + vals[i] = AccT(in[offset + i]); + } + } else { + for (int i = 0; i < N_READS; i++) { + vals[i] = (offset + i < axis_size) ? AccT(in[offset + i]) + : Limits::finite_min; + } + } + prevmax = maxval; + for (int i = 0; i < N_READS; i++) { + maxval = (maxval < vals[i]) ? vals[i] : maxval; + } + normalizer *= fast::exp(prevmax - maxval); + for (int i = 0; i < N_READS; i++) { + normalizer += fast::exp(vals[i] - maxval); + } + } + prevmax = maxval; + maxval = simd_max(maxval); + normalizer *= fast::exp(prevmax - maxval); + normalizer = simd_sum(normalizer); + + prevmax = maxval; + if (simd_lane_id == 0) { + local_max[simd_group_id] = maxval; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + maxval = simd_max(local_max[simd_lane_id]); + normalizer *= fast::exp(prevmax - maxval); + if (simd_lane_id == 0) { + local_normalizer[simd_group_id] = normalizer; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + normalizer = simd_sum(local_normalizer[simd_lane_id]); + + if (simd_group_id == 0) { + normalizer = simd_sum(local_normalizer[simd_lane_id]); + if (simd_lane_id == 0) { + out[gid] = isinf(maxval) ? T(maxval) : T(log(normalizer) + maxval); + } + } +} diff --git a/mlx/backend/metal/kernels/logsumexp.metal b/mlx/backend/metal/kernels/logsumexp.metal new file mode 100644 index 000000000..eb76436cf --- /dev/null +++ b/mlx/backend/metal/kernels/logsumexp.metal @@ -0,0 +1,18 @@ +// Copyright © 2023-2024 Apple Inc. + +#include +#include + +using namespace metal; + +// clang-format off +#include "mlx/backend/metal/kernels/utils.h" +#include "mlx/backend/metal/kernels/logsumexp.h" + +#define instantiate_logsumexp(name, itype) \ + instantiate_kernel("block_logsumexp_" #name, logsumexp, itype) \ + instantiate_kernel("looped_logsumexp_" #name, logsumexp_looped, itype) \ + +instantiate_logsumexp(float32, float) +instantiate_logsumexp(float16, half) +instantiate_logsumexp(bfloat16, bfloat16_t) // clang-format on diff --git a/mlx/backend/metal/kernels/quantized.h b/mlx/backend/metal/kernels/quantized.h index 3af3c971f..b2b0d8d8f 100644 --- a/mlx/backend/metal/kernels/quantized.h +++ b/mlx/backend/metal/kernels/quantized.h @@ -3,6 +3,10 @@ #include #include +constant bool align_M [[function_constant(200)]]; +constant bool align_N [[function_constant(201)]]; +constant bool align_K [[function_constant(202)]]; + using namespace metal; #define MLX_MTL_CONST static constant constexpr const @@ -586,13 +590,13 @@ METAL_FUNC void qmv_quad_impl( // Adjust positions const int in_vec_size_w = in_vec_size / pack_factor; const int in_vec_size_g = in_vec_size / group_size; - const int out_row = tid.x * quads_per_simd * results_per_quadgroup + quad_gid; + const int out_row = tid.y * quads_per_simd * results_per_quadgroup + quad_gid; w += out_row * in_vec_size_w + quad_lid * packs_per_thread; scales += out_row * in_vec_size_g + quad_lid / scale_step_per_thread; biases += out_row * in_vec_size_g + quad_lid / scale_step_per_thread; - x += tid.y * in_vec_size + quad_lid * values_per_thread; - y += tid.y * out_vec_size + out_row; + x += tid.x * in_vec_size + quad_lid * values_per_thread; + y += tid.x * out_vec_size + out_row; U sum = load_vector(x, x_thread); @@ -1686,26 +1690,26 @@ template < } template -[[kernel]] void bs_qmv_fast( +[[kernel]] void gather_qmv_fast( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], const device T* biases [[buffer(2)]], const device T* x [[buffer(3)]], - device T* y [[buffer(4)]], - const constant int& in_vec_size [[buffer(5)]], - const constant int& out_vec_size [[buffer(6)]], - const constant int& x_batch_ndims [[buffer(7)]], - const constant int* x_shape [[buffer(8)]], - const constant int64_t* x_strides [[buffer(9)]], - const constant int& w_batch_ndims [[buffer(10)]], - const constant int* w_shape [[buffer(11)]], - const constant int64_t* w_strides [[buffer(12)]], - const constant int64_t* s_strides [[buffer(13)]], - const constant int64_t* b_strides [[buffer(14)]], - const constant int& batch_ndims [[buffer(15)]], - const constant int* batch_shape [[buffer(16)]], - const device uint32_t* lhs_indices [[buffer(17)]], - const device uint32_t* rhs_indices [[buffer(18)]], + const device uint32_t* lhs_indices [[buffer(4)]], + const device uint32_t* rhs_indices [[buffer(5)]], + device T* y [[buffer(6)]], + const constant int& in_vec_size [[buffer(7)]], + const constant int& out_vec_size [[buffer(8)]], + const constant int& x_batch_ndims [[buffer(9)]], + const constant int* x_shape [[buffer(10)]], + const constant int64_t* x_strides [[buffer(11)]], + const constant int& w_batch_ndims [[buffer(12)]], + const constant int* w_shape [[buffer(13)]], + const constant int64_t* w_strides [[buffer(14)]], + const constant int64_t* s_strides [[buffer(15)]], + const constant int64_t* b_strides [[buffer(16)]], + const constant int& batch_ndims [[buffer(17)]], + const constant int* batch_shape [[buffer(18)]], const constant int64_t* lhs_strides [[buffer(19)]], const constant int64_t* rhs_strides [[buffer(20)]], uint3 tid [[threadgroup_position_in_grid]], @@ -1748,26 +1752,26 @@ template } template -[[kernel]] void bs_qmv( +[[kernel]] void gather_qmv( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], const device T* biases [[buffer(2)]], const device T* x [[buffer(3)]], - device T* y [[buffer(4)]], - const constant int& in_vec_size [[buffer(5)]], - const constant int& out_vec_size [[buffer(6)]], - const constant int& x_batch_ndims [[buffer(7)]], - const constant int* x_shape [[buffer(8)]], - const constant int64_t* x_strides [[buffer(9)]], - const constant int& w_batch_ndims [[buffer(10)]], - const constant int* w_shape [[buffer(11)]], - const constant int64_t* w_strides [[buffer(12)]], - const constant int64_t* s_strides [[buffer(13)]], - const constant int64_t* b_strides [[buffer(14)]], - const constant int& batch_ndims [[buffer(15)]], - const constant int* batch_shape [[buffer(16)]], - const device uint32_t* lhs_indices [[buffer(17)]], - const device uint32_t* rhs_indices [[buffer(18)]], + const device uint32_t* lhs_indices [[buffer(4)]], + const device uint32_t* rhs_indices [[buffer(5)]], + device T* y [[buffer(6)]], + const constant int& in_vec_size [[buffer(7)]], + const constant int& out_vec_size [[buffer(8)]], + const constant int& x_batch_ndims [[buffer(9)]], + const constant int* x_shape [[buffer(10)]], + const constant int64_t* x_strides [[buffer(11)]], + const constant int& w_batch_ndims [[buffer(12)]], + const constant int* w_shape [[buffer(13)]], + const constant int64_t* w_strides [[buffer(14)]], + const constant int64_t* s_strides [[buffer(15)]], + const constant int64_t* b_strides [[buffer(16)]], + const constant int& batch_ndims [[buffer(17)]], + const constant int* batch_shape [[buffer(18)]], const constant int64_t* lhs_strides [[buffer(19)]], const constant int64_t* rhs_strides [[buffer(20)]], uint3 tid [[threadgroup_position_in_grid]], @@ -1810,26 +1814,26 @@ template } template -[[kernel]] void bs_qvm( +[[kernel]] void gather_qvm( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], const device T* biases [[buffer(2)]], const device T* x [[buffer(3)]], - device T* y [[buffer(4)]], - const constant int& in_vec_size [[buffer(5)]], - const constant int& out_vec_size [[buffer(6)]], - const constant int& x_batch_ndims [[buffer(7)]], - const constant int* x_shape [[buffer(8)]], - const constant int64_t* x_strides [[buffer(9)]], - const constant int& w_batch_ndims [[buffer(10)]], - const constant int* w_shape [[buffer(11)]], - const constant int64_t* w_strides [[buffer(12)]], - const constant int64_t* s_strides [[buffer(13)]], - const constant int64_t* b_strides [[buffer(14)]], - const constant int& batch_ndims [[buffer(15)]], - const constant int* batch_shape [[buffer(16)]], - const device uint32_t* lhs_indices [[buffer(17)]], - const device uint32_t* rhs_indices [[buffer(18)]], + const device uint32_t* lhs_indices [[buffer(4)]], + const device uint32_t* rhs_indices [[buffer(5)]], + device T* y [[buffer(6)]], + const constant int& in_vec_size [[buffer(7)]], + const constant int& out_vec_size [[buffer(8)]], + const constant int& x_batch_ndims [[buffer(9)]], + const constant int* x_shape [[buffer(10)]], + const constant int64_t* x_strides [[buffer(11)]], + const constant int& w_batch_ndims [[buffer(12)]], + const constant int* w_shape [[buffer(13)]], + const constant int64_t* w_strides [[buffer(14)]], + const constant int64_t* s_strides [[buffer(15)]], + const constant int64_t* b_strides [[buffer(16)]], + const constant int& batch_ndims [[buffer(17)]], + const constant int* batch_shape [[buffer(18)]], const constant int64_t* lhs_strides [[buffer(19)]], const constant int64_t* rhs_strides [[buffer(20)]], uint3 tid [[threadgroup_position_in_grid]], @@ -1879,27 +1883,27 @@ template < const int BM = 32, const int BK = 32, const int BN = 32> -[[kernel]] void bs_qmm_t( +[[kernel]] void gather_qmm_t( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], const device T* biases [[buffer(2)]], const device T* x [[buffer(3)]], - device T* y [[buffer(4)]], - const constant int& K [[buffer(5)]], - const constant int& N [[buffer(6)]], - const constant int& M [[buffer(7)]], - const constant int& x_batch_ndims [[buffer(8)]], - const constant int* x_shape [[buffer(9)]], - const constant int64_t* x_strides [[buffer(10)]], - const constant int& w_batch_ndims [[buffer(11)]], - const constant int* w_shape [[buffer(12)]], - const constant int64_t* w_strides [[buffer(13)]], - const constant int64_t* s_strides [[buffer(14)]], - const constant int64_t* b_strides [[buffer(15)]], - const constant int& batch_ndims [[buffer(16)]], - const constant int* batch_shape [[buffer(17)]], - const device uint32_t* lhs_indices [[buffer(18)]], - const device uint32_t* rhs_indices [[buffer(19)]], + const device uint32_t* lhs_indices [[buffer(4)]], + const device uint32_t* rhs_indices [[buffer(5)]], + device T* y [[buffer(6)]], + const constant int& K [[buffer(7)]], + const constant int& N [[buffer(8)]], + const constant int& M [[buffer(9)]], + const constant int& x_batch_ndims [[buffer(10)]], + const constant int* x_shape [[buffer(11)]], + const constant int64_t* x_strides [[buffer(12)]], + const constant int& w_batch_ndims [[buffer(13)]], + const constant int* w_shape [[buffer(14)]], + const constant int64_t* w_strides [[buffer(15)]], + const constant int64_t* s_strides [[buffer(16)]], + const constant int64_t* b_strides [[buffer(17)]], + const constant int& batch_ndims [[buffer(18)]], + const constant int* batch_shape [[buffer(19)]], const constant int64_t* lhs_strides [[buffer(20)]], const constant int64_t* rhs_strides [[buffer(21)]], uint3 tid [[threadgroup_position_in_grid]], @@ -1946,27 +1950,27 @@ template < const int BM = 32, const int BK = 32, const int BN = 32> -[[kernel]] void bs_qmm_n( +[[kernel]] void gather_qmm_n( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], const device T* biases [[buffer(2)]], const device T* x [[buffer(3)]], - device T* y [[buffer(4)]], - const constant int& K [[buffer(5)]], - const constant int& N [[buffer(6)]], - const constant int& M [[buffer(7)]], - const constant int& x_batch_ndims [[buffer(8)]], - const constant int* x_shape [[buffer(9)]], - const constant int64_t* x_strides [[buffer(10)]], - const constant int& w_batch_ndims [[buffer(11)]], - const constant int* w_shape [[buffer(12)]], - const constant int64_t* w_strides [[buffer(13)]], - const constant int64_t* s_strides [[buffer(14)]], - const constant int64_t* b_strides [[buffer(15)]], - const constant int& batch_ndims [[buffer(16)]], - const constant int* batch_shape [[buffer(17)]], - const device uint32_t* lhs_indices [[buffer(18)]], - const device uint32_t* rhs_indices [[buffer(19)]], + const device uint32_t* lhs_indices [[buffer(4)]], + const device uint32_t* rhs_indices [[buffer(5)]], + device T* y [[buffer(6)]], + const constant int& K [[buffer(7)]], + const constant int& N [[buffer(8)]], + const constant int& M [[buffer(9)]], + const constant int& x_batch_ndims [[buffer(10)]], + const constant int* x_shape [[buffer(11)]], + const constant int64_t* x_strides [[buffer(12)]], + const constant int& w_batch_ndims [[buffer(13)]], + const constant int* w_shape [[buffer(14)]], + const constant int64_t* w_strides [[buffer(15)]], + const constant int64_t* s_strides [[buffer(16)]], + const constant int64_t* b_strides [[buffer(17)]], + const constant int& batch_ndims [[buffer(18)]], + const constant int* batch_shape [[buffer(19)]], const constant int64_t* lhs_strides [[buffer(20)]], const constant int64_t* rhs_strides [[buffer(21)]], uint3 tid [[threadgroup_position_in_grid]], @@ -2007,6 +2011,289 @@ template < w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid); } +template +METAL_FUNC void gemm_loop_aligned( + threadgroup T* As, + threadgroup T* Bs, + thread mma_t& mma_op, + thread loader_a_t& loader_a, + thread loader_b_t& loader_b, + const int k_iterations) { + for (int k = 0; k < k_iterations; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Load elements into threadgroup memory + loader_a.load_unsafe(); + loader_b.load_unsafe(); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } +} + +template < + bool rows_aligned, + bool cols_aligned, + bool transpose, + typename T, + typename mma_t, + typename loader_a_t, + typename loader_b_t> +METAL_FUNC void gemm_loop_unaligned( + threadgroup T* As, + threadgroup T* Bs, + thread mma_t& mma_op, + thread loader_a_t& loader_a, + thread loader_b_t& loader_b, + const int k_iterations, + const short tgp_bm, + const short tgp_bn, + const short tgp_bk) { + for (int k = 0; k < k_iterations; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Load elements into threadgroup memory + if (rows_aligned) { + loader_a.load_unsafe(); + } else { + loader_a.load_safe(short2(tgp_bk, tgp_bm)); + } + if (cols_aligned) { + loader_b.load_unsafe(); + } else { + loader_b.load_safe( + transpose ? short2(tgp_bk, tgp_bn) : short2(tgp_bn, tgp_bk)); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } +} + +template +METAL_FUNC void gemm_loop_finalize( + threadgroup T* As, + threadgroup T* Bs, + thread mma_t& mma_op, + thread loader_a_t& loader_a, + thread loader_b_t& loader_b, + const short2 tile_a, + const short2 tile_b) { + loader_a.load_safe(tile_a); + loader_b.load_safe(tile_b); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(As, Bs); +} + +template < + typename T, + int group_size, + int bits, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose> +[[kernel]] void gather_qmm_rhs( + const device T* x [[buffer(0)]], + const device uint32_t* w [[buffer(1)]], + const device T* scales [[buffer(2)]], + const device T* biases [[buffer(3)]], + const device uint32_t* indices [[buffer(4)]], + device T* y [[buffer(5)]], + const constant int& M [[buffer(6)]], + const constant int& N [[buffer(7)]], + const constant int& K [[buffer(8)]], + uint3 tid [[threadgroup_position_in_grid]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]]) { + constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits; + constexpr int BK_padded = (BK + 16 / sizeof(T)); + constexpr int BN_padded = (BN + 16 / sizeof(T)); + constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; + constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3; + + using mma_t = mlx::steel::BlockMMA< + T, + T, + BM, + BN, + BK, + WM, + WN, + false, + transpose, + BK_padded, + transpose ? BK_padded : BN_padded>; + using loader_x_t = + mlx::steel::BlockLoader; + using loader_w_t = QuantizedBlockLoader< + T, + transpose ? BN : BK, + transpose ? BK : BN, + transpose ? BK_padded : BN_padded, + transpose, + WM * WN * SIMD_SIZE, + group_size, + bits>; + + threadgroup T Xs[BM * BK_padded]; + threadgroup T Ws[transpose ? BN * BK_padded : BK * BN_padded]; + + // Compute the block + const int K_w = K * bytes_per_pack / pack_factor; + const int K_g = K / group_size; + const int N_w = N * bytes_per_pack / pack_factor; + const int N_g = N / group_size; + const int K_it = K / BK; + const size_t stride_w = transpose ? N * K_w : K * N_w; + const size_t stride_s = transpose ? N * K_g : K * N_g; + const int y_row = tid.y * BM; + const int y_col = tid.x * BN; + const size_t y_row_long = size_t(y_row); + const size_t y_col_long = size_t(y_col); + + // Prepare threadgroup bounds + const short tgp_bm = align_M ? BM : short(min(BM, M - y_row)); + const short tgp_bn = align_N ? BN : short(min(BN, N - y_col)); + + // Calculate the final tiles in the case that K is not aligned + const int k_remain = K - K_it * BK; + const short2 tile_x = short2(k_remain, tgp_bm); + const short2 tile_w = + transpose ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain); + + // Move x and output to the correct block + auto wl = (const device uint8_t*)w; + x += y_row_long * K; + y += y_row_long * N + y_col_long; + wl += transpose ? y_col_long * K_w : y_col * bytes_per_pack / pack_factor; + scales += transpose ? y_col_long * K_g : y_col / group_size; + biases += transpose ? y_col_long * K_g : y_col / group_size; + + // Do as many matmuls as necessary + uint32_t index; + short offset; + uint32_t index_next = indices[y_row]; + short offset_next = 0; + int n = 0; + while (n < tgp_bm) { + n++; + offset = offset_next; + index = index_next; + offset_next = tgp_bm; + for (; n < tgp_bm; n++) { + if (indices[y_row + n] != index) { + offset_next = n; + index_next = indices[y_row + n]; + break; + } + } + threadgroup_barrier(mem_flags::mem_none); + + // Prepare threadgroup mma operation + thread mma_t mma_op(simd_group_id, simd_lane_id); + + // Prepare threadgroup loading operations + thread loader_x_t loader_x(x, K, Xs, simd_group_id, simd_lane_id); + thread loader_w_t loader_w( + wl + index * stride_w, + scales + index * stride_s, + biases + index * stride_s, + transpose ? K : N, + Ws, + simd_group_id, + simd_lane_id); + + // Matrices are all aligned check nothing + if (align_M && align_N) { + gemm_loop_aligned(Xs, Ws, mma_op, loader_x, loader_w, K_it); + if (!align_K) { + threadgroup_barrier(mem_flags::mem_threadgroup); + gemm_loop_finalize(Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w); + } + + // Store results to device memory + if (offset_next - offset == BM) { + mma_op.store_result(y, N); + } else { + mma_op.store_result_slice( + y, N, short2(0, offset), short2(BN, offset_next)); + } + } else { + // Tile aligned so check outside of the hot loop + if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) { + gemm_loop_aligned(Xs, Ws, mma_op, loader_x, loader_w, K_it); + if (!align_K) { + threadgroup_barrier(mem_flags::mem_threadgroup); + gemm_loop_finalize( + Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w); + } + + // Store results to device memory + if (offset_next - offset == BM) { + mma_op.store_result(y, N); + } else { + mma_op.store_result_slice( + y, N, short2(0, offset), short2(BN, offset_next)); + } + } + + // Tile partially aligned check rows + else if (align_N || tgp_bn == BN) { + gemm_loop_unaligned( + Xs, Ws, mma_op, loader_x, loader_w, K_it, tgp_bm, tgp_bn, BK); + if (!align_K) { + threadgroup_barrier(mem_flags::mem_threadgroup); + gemm_loop_finalize( + Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w); + } + mma_op.store_result_slice( + y, N, short2(0, offset), short2(BN, offset_next)); + } + + // Tile partially aligned check cols + else if (align_M || tgp_bm == BM) { + gemm_loop_unaligned( + Xs, Ws, mma_op, loader_x, loader_w, K_it, tgp_bm, tgp_bn, BK); + if (!align_K) { + threadgroup_barrier(mem_flags::mem_threadgroup); + gemm_loop_finalize( + Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w); + } + mma_op.store_result_slice( + y, N, short2(0, offset), short2(tgp_bn, offset_next)); + } + + // Nothing aligned so check both rows and cols + else { + gemm_loop_unaligned( + Xs, Ws, mma_op, loader_x, loader_w, K_it, tgp_bm, tgp_bn, BK); + if (!align_K) { + threadgroup_barrier(mem_flags::mem_threadgroup); + gemm_loop_finalize( + Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w); + } + mma_op.store_result_slice( + y, N, short2(0, offset), short2(tgp_bn, offset_next)); + } + } + } +} + template [[kernel]] void affine_quantize( const device T* w [[buffer(0)]], diff --git a/mlx/backend/metal/kernels/quantized.metal b/mlx/backend/metal/kernels/quantized.metal index 7af554437..11cd8421b 100644 --- a/mlx/backend/metal/kernels/quantized.metal +++ b/mlx/backend/metal/kernels/quantized.metal @@ -60,6 +60,20 @@ bits, \ split_k) +#define instantiate_gather_qmm_rhs(func, name, type, group_size, bits, bm, bn, bk, wm, wn, transpose) \ + instantiate_kernel( \ + #name "_" #type "_gs_" #group_size "_b_" #bits "_bm_" #bm "_bn_" #bn "_bk_" #bk "_wm_" #wm "_wn_" #wn, \ + func, \ + type, \ + group_size, \ + bits, \ + bm, \ + bn, \ + bk, \ + wm, \ + wn, \ + transpose) + #define instantiate_quantized_batched_wrap(name, type, group_size, bits) \ instantiate_quantized_batched(name, type, group_size, bits, 1) \ instantiate_quantized_batched(name, type, group_size, bits, 0) @@ -73,14 +87,14 @@ #define instantiate_quantized_all_single(type, group_size, bits) \ instantiate_quantized(affine_quantize, type, group_size, bits) \ instantiate_quantized(affine_dequantize, type, group_size, bits) \ - instantiate_quantized(bs_qmv_fast, type, group_size, bits) \ - instantiate_quantized(bs_qmv, type, group_size, bits) \ - instantiate_quantized(bs_qvm, type, group_size, bits) \ - instantiate_quantized(bs_qmm_n, type, group_size, bits) + instantiate_quantized(gather_qmv_fast, type, group_size, bits) \ + instantiate_quantized(gather_qmv, type, group_size, bits) \ + instantiate_quantized(gather_qvm, type, group_size, bits) \ + instantiate_quantized(gather_qmm_n, type, group_size, bits) #define instantiate_quantized_all_aligned(type, group_size, bits) \ - instantiate_quantized_aligned(bs_qmm_t, type, group_size, bits, true) \ - instantiate_quantized_aligned(bs_qmm_t, type, group_size, bits, false) \ + instantiate_quantized_aligned(gather_qmm_t, type, group_size, bits, true) \ + instantiate_quantized_aligned(gather_qmm_t, type, group_size, bits, false) \ instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, true, 1) \ instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, true, 0) \ instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, false, 1) \ @@ -96,12 +110,17 @@ instantiate_quantized_split_k(qvm_split_k, type, group_size, bits, 8) \ instantiate_quantized_split_k(qvm_split_k, type, group_size, bits, 32) +#define instantiate_quantized_all_rhs(type, group_size, bits) \ + instantiate_gather_qmm_rhs(gather_qmm_rhs, gather_qmm_rhs_nt, type, group_size, bits, 16, 32, 32, 1, 2, true) \ + instantiate_gather_qmm_rhs(gather_qmm_rhs, gather_qmm_rhs_nn, type, group_size, bits, 16, 32, 32, 1, 2, false) + #define instantiate_quantized_funcs(type, group_size, bits) \ instantiate_quantized_all_single(type, group_size, bits) \ instantiate_quantized_all_batched(type, group_size, bits) \ instantiate_quantized_all_aligned(type, group_size, bits) \ instantiate_quantized_all_quad(type, group_size, bits) \ - instantiate_quantized_all_splitk(type, group_size, bits) + instantiate_quantized_all_splitk(type, group_size, bits) \ + instantiate_quantized_all_rhs(type, group_size, bits) #define instantiate_quantized_types(group_size, bits) \ instantiate_quantized_funcs(float, group_size, bits) \ diff --git a/mlx/backend/metal/kernels/rms_norm.metal b/mlx/backend/metal/kernels/rms_norm.metal index f4c1536de..62f2457b7 100644 --- a/mlx/backend/metal/kernels/rms_norm.metal +++ b/mlx/backend/metal/kernels/rms_norm.metal @@ -380,69 +380,11 @@ template } // clang-format off -#define instantiate_rms_single_row(name, itype) \ - template [[host_name("rms" #name)]] [[kernel]] void \ - rms_single_row( \ - const device itype* x, \ - const device itype* w, \ - device itype* out, \ - constant float& eps, \ - constant uint& axis_size, \ - constant uint& w_stride, \ - uint gid [[thread_position_in_grid]], \ - uint lid [[thread_position_in_threadgroup]], \ - uint simd_lane_id [[thread_index_in_simdgroup]], \ - uint simd_group_id [[simdgroup_index_in_threadgroup]]); \ - \ - template [[host_name("vjp_rms" #name)]] [[kernel]] void \ - vjp_rms_single_row( \ - const device itype* x, \ - const device itype* w, \ - const device itype* g, \ - device itype* gx, \ - device itype* gw, \ - constant float& eps, \ - constant uint& axis_size, \ - constant uint& w_stride, \ - uint gid [[thread_position_in_grid]], \ - uint lid [[thread_position_in_threadgroup]], \ - uint simd_lane_id [[thread_index_in_simdgroup]], \ - uint simd_group_id [[simdgroup_index_in_threadgroup]]); - -#define instantiate_rms_looped(name, itype) \ - template [[host_name("rms_looped" #name)]] [[kernel]] void \ - rms_looped( \ - const device itype* x, \ - const device itype* w, \ - device itype* out, \ - constant float& eps, \ - constant uint& axis_size, \ - constant uint& w_stride, \ - uint gid [[thread_position_in_grid]], \ - uint lid [[thread_position_in_threadgroup]], \ - uint lsize [[threads_per_threadgroup]], \ - uint simd_lane_id [[thread_index_in_simdgroup]], \ - uint simd_group_id [[simdgroup_index_in_threadgroup]]); \ - \ - template [[host_name("vjp_rms_looped" #name)]] [[kernel]] void \ - vjp_rms_looped( \ - const device itype* x, \ - const device itype* w, \ - const device itype* g, \ - device itype* gx, \ - device itype* gw, \ - constant float& eps, \ - constant uint& axis_size, \ - constant uint& w_stride, \ - uint gid [[thread_position_in_grid]], \ - uint lid [[thread_position_in_threadgroup]], \ - uint lsize [[threads_per_threadgroup]], \ - uint simd_lane_id [[thread_index_in_simdgroup]], \ - uint simd_group_id [[simdgroup_index_in_threadgroup]]); - -#define instantiate_rms(name, itype) \ - instantiate_rms_single_row(name, itype) \ - instantiate_rms_looped(name, itype) +#define instantiate_rms(name, itype) \ + instantiate_kernel("rms" #name, rms_single_row, itype) \ + instantiate_kernel("vjp_rms" #name, vjp_rms_single_row, itype) \ + instantiate_kernel("rms_looped" #name, rms_looped, itype) \ + instantiate_kernel("vjp_rms_looped" #name, vjp_rms_looped, itype) instantiate_rms(float32, float) instantiate_rms(float16, half) diff --git a/mlx/backend/metal/kernels/scaled_dot_product_attention.metal b/mlx/backend/metal/kernels/scaled_dot_product_attention.metal index ea80396df..c668d9d8c 100644 --- a/mlx/backend/metal/kernels/scaled_dot_product_attention.metal +++ b/mlx/backend/metal/kernels/scaled_dot_product_attention.metal @@ -1,11 +1,11 @@ #include -#include "mlx/backend/metal/kernels/sdpa_vector.h" +// clang-format off #include "mlx/backend/metal/kernels/utils.h" +#include "mlx/backend/metal/kernels/sdpa_vector.h" using namespace metal; -// clang-format off // SDPA vector instantiations #define instantiate_sdpa_vector_aggregation(type, value_dim) \ instantiate_kernel( \ @@ -32,9 +32,11 @@ using namespace metal; instantiate_sdpa_vector(type, 64, 64) \ instantiate_sdpa_vector(type, 96, 96) \ instantiate_sdpa_vector(type, 128, 128) \ + instantiate_sdpa_vector(type, 256, 256) \ instantiate_sdpa_vector_aggregation(type, 64) \ instantiate_sdpa_vector_aggregation(type, 96) \ - instantiate_sdpa_vector_aggregation(type, 128) + instantiate_sdpa_vector_aggregation(type, 128) \ + instantiate_sdpa_vector_aggregation(type, 256) instantiate_sdpa_vector_heads(float) instantiate_sdpa_vector_heads(bfloat16_t) diff --git a/mlx/backend/metal/kernels/scan.h b/mlx/backend/metal/kernels/scan.h index cfa84c04c..cb5147558 100644 --- a/mlx/backend/metal/kernels/scan.h +++ b/mlx/backend/metal/kernels/scan.h @@ -2,6 +2,8 @@ #pragma once +#include "mlx/backend/metal/kernels/binary_ops.h" + #define DEFINE_SIMD_SCAN() \ template = true> \ T simd_scan(T val) { \ @@ -139,6 +141,29 @@ struct CumMin { } }; +template +struct CumLogaddexp { + static constexpr constant U init = Limits::min; + + template + U operator()(U a, T b) { + return LogAddExp{}(a, static_cast(b)); + } + + U simd_scan(U x) { + for (int i = 1; i <= 16; i *= 2) { + U other = simd_shuffle_and_fill_up(x, init, i); + x = LogAddExp{}(x, other); + } + return x; + } + + U simd_exclusive_scan(U x) { + x = simd_scan(x); + return simd_shuffle_and_fill_up(x, init, 1); + } +}; + template inline void load_unsafe(U values[N_READS], const device T* input) { if (reverse) { diff --git a/mlx/backend/metal/kernels/scan.metal b/mlx/backend/metal/kernels/scan.metal index 6aa36f5a3..8fcd7f61b 100644 --- a/mlx/backend/metal/kernels/scan.metal +++ b/mlx/backend/metal/kernels/scan.metal @@ -101,4 +101,7 @@ instantiate_scan_helper(min_int64_int64, int64_t, int64_t, CumMi instantiate_scan_helper(min_float16_float16, half, half, CumMin, 4) instantiate_scan_helper(min_float32_float32, float, float, CumMin, 4) instantiate_scan_helper(min_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumMin, 4) -instantiate_scan_helper(min_complex64_complex64, complex64_t, complex64_t, CumMin, 2) // clang-format on +instantiate_scan_helper(min_complex64_complex64, complex64_t, complex64_t, CumMin, 2) +instantiate_scan_helper(logaddexp_float16_float16, half, half, CumLogaddexp, 4) +instantiate_scan_helper(logaddexp_float32_float32, float, float, CumLogaddexp, 4) +instantiate_scan_helper(logaddexp_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumLogaddexp, 4) // clang-format on diff --git a/mlx/backend/metal/kernels/sdpa_vector.h b/mlx/backend/metal/kernels/sdpa_vector.h index 88a109b88..c4c0f6456 100644 --- a/mlx/backend/metal/kernels/sdpa_vector.h +++ b/mlx/backend/metal/kernels/sdpa_vector.h @@ -6,6 +6,9 @@ using namespace metal; constant bool has_mask [[function_constant(20)]]; constant bool query_transposed [[function_constant(21)]]; +constant bool do_causal [[function_constant(22)]]; +constant bool bool_mask [[function_constant(23)]]; +constant bool float_mask [[function_constant(24)]]; template [[kernel]] void sdpa_vector( @@ -13,17 +16,21 @@ template const device T* keys [[buffer(1)]], const device T* values [[buffer(2)]], device T* out [[buffer(3)]], - const constant int& gqa_factor, - const constant int& N, - const constant size_t& k_head_stride, - const constant size_t& k_seq_stride, - const constant size_t& v_head_stride, - const constant size_t& v_seq_stride, - const constant float& scale, - const device bool* mask [[function_constant(has_mask)]], - const constant int& mask_kv_seq_stride [[function_constant(has_mask)]], - const constant int& mask_q_seq_stride [[function_constant(has_mask)]], - const constant int& mask_head_stride [[function_constant(has_mask)]], + const constant int& gqa_factor [[buffer(4)]], + const constant int& N [[buffer(5)]], + const constant size_t& k_head_stride [[buffer(6)]], + const constant size_t& k_seq_stride [[buffer(7)]], + const constant size_t& v_head_stride [[buffer(8)]], + const constant size_t& v_seq_stride [[buffer(9)]], + const constant float& scale [[buffer(10)]], + const device bool* bmask [[buffer(11), function_constant(bool_mask)]], + const device T* fmask [[buffer(12), function_constant(float_mask)]], + const constant int& mask_kv_seq_stride + [[buffer(13), function_constant(has_mask)]], + const constant int& mask_q_seq_stride + [[buffer(14), function_constant(has_mask)]], + const constant int& mask_head_stride + [[buffer(15), function_constant(has_mask)]], uint3 tid [[threadgroup_position_in_grid]], uint3 tpg [[threadgroups_per_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], @@ -57,8 +64,12 @@ template simd_lid * qk_per_thread; values += kv_head_idx * v_head_stride + simd_gid * v_seq_stride + simd_lid * v_per_thread; - if (has_mask) { - mask += head_idx * mask_head_stride + simd_gid * mask_kv_seq_stride + + if (bool_mask) { + bmask += head_idx * mask_head_stride + simd_gid * mask_kv_seq_stride + + q_seq_idx * mask_q_seq_stride; + } + if (float_mask) { + fmask += head_idx * mask_head_stride + simd_gid * mask_kv_seq_stride + q_seq_idx * mask_q_seq_stride; } @@ -77,7 +88,13 @@ template // For each key for (int i = simd_gid; i < N; i += BN) { - if (!has_mask || mask[0]) { + bool use_key = true; + if (do_causal) { + use_key = i <= (N - int(tpg.y) + int(q_seq_idx)); + } else if (bool_mask) { + use_key = bmask[0]; + } + if (use_key) { // Read the key for (int j = 0; j < qk_per_thread; j++) { k[j] = keys[j]; @@ -89,6 +106,9 @@ template score += q[j] * k[j]; } score = simd_sum(score); + if (float_mask) { + score += max(Limits::finite_min, static_cast(fmask[0])); + } // Update the accumulators U new_max = max(max_score, score); @@ -107,8 +127,11 @@ template // Move the pointers to the next kv keys += inner_k_stride; values += inner_v_stride; - if (has_mask) { - mask += BN * mask_kv_seq_stride; + if (bool_mask) { + bmask += BN * mask_kv_seq_stride; + } + if (float_mask) { + fmask += BN * mask_kv_seq_stride; } } @@ -149,17 +172,21 @@ template device float* out [[buffer(3)]], device float* sums [[buffer(4)]], device float* maxs [[buffer(5)]], - const constant int& gqa_factor, - const constant int& N, - const constant size_t& k_head_stride, - const constant size_t& k_seq_stride, - const constant size_t& v_head_stride, - const constant size_t& v_seq_stride, - const constant float& scale, - const device bool* mask [[function_constant(has_mask)]], - const constant int& mask_kv_seq_stride [[function_constant(has_mask)]], - const constant int& mask_q_seq_stride [[function_constant(has_mask)]], - const constant int& mask_head_stride [[function_constant(has_mask)]], + const constant int& gqa_factor [[buffer(6)]], + const constant int& N [[buffer(7)]], + const constant size_t& k_head_stride [[buffer(8)]], + const constant size_t& k_seq_stride [[buffer(9)]], + const constant size_t& v_head_stride [[buffer(10)]], + const constant size_t& v_seq_stride [[buffer(11)]], + const constant float& scale [[buffer(12)]], + const device bool* bmask [[buffer(13), function_constant(bool_mask)]], + const device T* fmask [[buffer(14), function_constant(float_mask)]], + const constant int& mask_kv_seq_stride + [[buffer(15), function_constant(has_mask)]], + const constant int& mask_q_seq_stride + [[buffer(16), function_constant(has_mask)]], + const constant int& mask_head_stride + [[buffer(17), function_constant(has_mask)]], uint3 tid [[threadgroup_position_in_grid]], uint3 tpg [[threadgroups_per_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], @@ -197,8 +224,13 @@ template values += kv_head_idx * v_head_stride + (block_idx * BN + simd_gid) * v_seq_stride + simd_lid * v_per_thread; out += o_offset * blocks * V + block_idx * V + simd_lid * v_per_thread; - if (has_mask) { - mask += head_idx * mask_head_stride + + if (bool_mask) { + bmask += head_idx * mask_head_stride + + (block_idx * BN + simd_gid) * mask_kv_seq_stride + + q_seq_idx * mask_q_seq_stride; + } + if (float_mask) { + fmask += head_idx * mask_head_stride + (block_idx * BN + simd_gid) * mask_kv_seq_stride + q_seq_idx * mask_q_seq_stride; } @@ -218,7 +250,13 @@ template // For each key for (int i = block_idx * BN + simd_gid; i < N; i += blocks * BN) { - if (!has_mask || mask[0]) { + bool use_key = true; + if (do_causal) { + use_key = i <= (N - int(tpg.y) + int(q_seq_idx)); + } else if (bool_mask) { + use_key = bmask[0]; + } + if (use_key) { // Read the key for (int i = 0; i < qk_per_thread; i++) { k[i] = keys[i]; @@ -230,6 +268,9 @@ template score += q[i] * k[i]; } score = simd_sum(score); + if (float_mask) { + score += fmask[0]; + } // Update the accumulators U new_max = max(max_score, score); @@ -248,8 +289,11 @@ template // Move the pointers to the next kv keys += blocks * inner_k_stride; values += blocks * inner_v_stride; - if (has_mask) { - mask += BN * blocks * mask_kv_seq_stride; + if (bool_mask) { + bmask += BN * blocks * mask_kv_seq_stride; + } + if (float_mask) { + fmask += BN * blocks * mask_kv_seq_stride; } } diff --git a/mlx/backend/metal/kernels/softmax.metal b/mlx/backend/metal/kernels/softmax.metal index 1b64d59a1..79d5d3fca 100644 --- a/mlx/backend/metal/kernels/softmax.metal +++ b/mlx/backend/metal/kernels/softmax.metal @@ -9,47 +9,13 @@ using namespace metal; #include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/softmax.h" -#define instantiate_softmax(name, itype) \ - template [[host_name("block_softmax_" #name)]] [[kernel]] void \ - softmax_single_row( \ - const device itype* in, \ - device itype* out, \ - constant int& axis_size, \ - uint gid [[thread_position_in_grid]], \ - uint _lid [[thread_position_in_threadgroup]], \ - uint simd_lane_id [[thread_index_in_simdgroup]], \ - uint simd_group_id [[simdgroup_index_in_threadgroup]]); \ - template [[host_name("looped_softmax_" #name)]] [[kernel]] void \ - softmax_looped( \ - const device itype* in, \ - device itype* out, \ - constant int& axis_size, \ - uint gid [[threadgroup_position_in_grid]], \ - uint lid [[thread_position_in_threadgroup]], \ - uint lsize [[threads_per_threadgroup]], \ - uint simd_lane_id [[thread_index_in_simdgroup]], \ - uint simd_group_id [[simdgroup_index_in_threadgroup]]); +#define instantiate_softmax(name, itype) \ + instantiate_kernel("block_softmax_" #name, softmax_single_row, itype) \ + instantiate_kernel("looped_softmax_" #name, softmax_looped, itype) -#define instantiate_softmax_precise(name, itype) \ - template [[host_name("block_softmax_precise_" #name)]] [[kernel]] void \ - softmax_single_row( \ - const device itype* in, \ - device itype* out, \ - constant int& axis_size, \ - uint gid [[thread_position_in_grid]], \ - uint _lid [[thread_position_in_threadgroup]], \ - uint simd_lane_id [[thread_index_in_simdgroup]], \ - uint simd_group_id [[simdgroup_index_in_threadgroup]]); \ - template [[host_name("looped_softmax_precise_" #name)]] [[kernel]] void \ - softmax_looped( \ - const device itype* in, \ - device itype* out, \ - constant int& axis_size, \ - uint gid [[threadgroup_position_in_grid]], \ - uint lid [[thread_position_in_threadgroup]], \ - uint lsize [[threads_per_threadgroup]], \ - uint simd_lane_id [[thread_index_in_simdgroup]], \ - uint simd_group_id [[simdgroup_index_in_threadgroup]]); +#define instantiate_softmax_precise(name, itype) \ + instantiate_kernel("block_softmax_precise_" #name, softmax_single_row, itype, float) \ + instantiate_kernel("looped_softmax_precise_" #name, softmax_looped, itype, float) instantiate_softmax(float32, float) instantiate_softmax(float16, half) diff --git a/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h b/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h index a8469e0ff..2e27ea06f 100644 --- a/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +++ b/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h @@ -229,7 +229,7 @@ template < // Init to -Inf STEEL_PRAGMA_UNROLL for (short i = 0; i < kRowsPT; ++i) { - max_score[i] = Limits::min; + max_score[i] = Limits::finite_min; } int kb_lim = params->NK; @@ -237,6 +237,7 @@ template < if (do_causal) { int q_max = (tid.x + 1) * BQ + params->qL_off; kb_lim = (q_max + BK - 1) / BK; + kb_lim = min(params->NK, kb_lim); } // Loop over KV seq length @@ -272,7 +273,7 @@ template < if (!align_K && kb == (params->NK_aligned)) { using stile_t = decltype(Stile); using selem_t = typename stile_t::elem_type; - constexpr auto neg_inf = -metal::numeric_limits::infinity(); + constexpr auto neg_inf = Limits::finite_min; STEEL_PRAGMA_UNROLL for (short i = 0; i < stile_t::kTileRows; i++) { @@ -290,10 +291,10 @@ template < } // Mask out if causal - if (do_causal && kb >= (kb_lim - (BQ + BK - 1) / BK - int(!align_K))) { + if (do_causal && kb >= (kb_lim - ((BQ + BK - 1) / BK) - int(!align_K))) { using stile_t = decltype(Stile); using selem_t = typename stile_t::elem_type; - constexpr auto neg_inf = -metal::numeric_limits::infinity(); + constexpr auto neg_inf = Limits::finite_min; STEEL_PRAGMA_UNROLL for (short i = 0; i < stile_t::kTileRows; i++) { @@ -316,7 +317,7 @@ template < if (has_mask) { using stile_t = decltype(Stile); using selem_t = typename stile_t::elem_type; - constexpr auto neg_inf = -metal::numeric_limits::infinity(); + constexpr auto neg_inf = Limits::finite_min; constexpr bool is_bool = is_same_v; using melem_t = typename metal::conditional_t; diff --git a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h index bcc585bbe..add495d93 100644 --- a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +++ b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h @@ -15,10 +15,6 @@ constant bool align_M [[function_constant(200)]]; constant bool align_N [[function_constant(201)]]; constant bool align_K [[function_constant(202)]]; -constant bool do_gather [[function_constant(300)]]; - -constant bool gather_bias = do_gather && use_out_source; - // clang-format off template < typename T, @@ -39,12 +35,6 @@ template < const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]], const constant int* batch_shape [[buffer(6)]], const constant int64_t* batch_strides [[buffer(7)]], - const constant uint32_t* lhs_indices [[buffer(10), function_constant(do_gather)]], - const constant uint32_t* rhs_indices [[buffer(11), function_constant(do_gather)]], - const constant uint32_t* C_indices [[buffer(12), function_constant(gather_bias)]], - const constant int* operand_shape [[buffer(13), function_constant(do_gather)]], - const constant int64_t* operand_strides [[buffer(14), function_constant(do_gather)]], - const constant packed_int3& operand_batch_ndim [[buffer(15), function_constant(do_gather)]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]], uint3 tid [[threadgroup_position_in_grid]], @@ -81,84 +71,26 @@ template < } // Adjust for batch + if (has_batch) { + const constant auto* A_bstrides = batch_strides; + const constant auto* B_bstrides = batch_strides + params->batch_ndim; - // Handle gather - if (do_gather) { - // Read indices - uint32_t indx_A, indx_B, indx_C; + ulong2 batch_offsets = elem_to_loc_broadcast( + tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim); - if (has_batch) { - const constant auto* indx_A_bstrides = batch_strides; - const constant auto* indx_B_bstrides = batch_strides + params->batch_ndim; - - ulong2 indx_offsets = elem_to_loc_broadcast( - tid.z, - batch_shape, - indx_A_bstrides, - indx_B_bstrides, - params->batch_ndim); - indx_A = lhs_indices[indx_offsets.x]; - indx_B = rhs_indices[indx_offsets.y]; - - if (use_out_source) { - const constant auto* indx_C_bstrides = - indx_B_bstrides + params->batch_ndim; - auto indx_offset_C = elem_to_loc( - tid.z, batch_shape, indx_C_bstrides, params->batch_ndim); - indx_C = C_indices[indx_offset_C]; - } - } else { - indx_A = lhs_indices[params->batch_stride_a * tid.z]; - indx_B = rhs_indices[params->batch_stride_b * tid.z]; - - if (use_out_source) { - indx_C = C_indices[addmm_params->batch_stride_c * tid.z]; - } - } - - // Translate indices to offsets - int batch_ndim_A = operand_batch_ndim.x; - const constant int* batch_shape_A = operand_shape; - const constant auto* batch_strides_A = operand_strides; - A += elem_to_loc(indx_A, batch_shape_A, batch_strides_A, batch_ndim_A); - - int batch_ndim_B = operand_batch_ndim.y; - const constant int* batch_shape_B = batch_shape_A + batch_ndim_A; - const constant auto* batch_strides_B = batch_strides_A + batch_ndim_A; - B += elem_to_loc(indx_B, batch_shape_B, batch_strides_B, batch_ndim_B); + A += batch_offsets.x; + B += batch_offsets.y; if (use_out_source) { - int batch_ndim_C = operand_batch_ndim.z; - const constant int* batch_shape_C = batch_shape_B + batch_ndim_B; - const constant auto* batch_strides_C = batch_strides_B + batch_ndim_B; - C += elem_to_loc(indx_C, batch_shape_C, batch_strides_C, batch_ndim_C); + const constant auto* C_bstrides = B_bstrides + params->batch_ndim; + C += elem_to_loc(tid.z, batch_shape, C_bstrides, params->batch_ndim); } + } else { + A += params->batch_stride_a * tid.z; + B += params->batch_stride_b * tid.z; - } - - // Handle regular batch - else { - if (has_batch) { - const constant auto* A_bstrides = batch_strides; - const constant auto* B_bstrides = batch_strides + params->batch_ndim; - - ulong2 batch_offsets = elem_to_loc_broadcast( - tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim); - - A += batch_offsets.x; - B += batch_offsets.y; - - if (use_out_source) { - const constant auto* C_bstrides = B_bstrides + params->batch_ndim; - C += elem_to_loc(tid.z, batch_shape, C_bstrides, params->batch_ndim); - } - } else { - A += params->batch_stride_a * tid.z; - B += params->batch_stride_b * tid.z; - - if (use_out_source) { - C += addmm_params->batch_stride_c * tid.z; - } + if (use_out_source) { + C += addmm_params->batch_stride_c * tid.z; } } diff --git a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h new file mode 100644 index 000000000..4493375c1 --- /dev/null +++ b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h @@ -0,0 +1,459 @@ +// Copyright © 2024 Apple Inc. + +using namespace mlx::steel; + +constant bool has_batch [[function_constant(10)]]; +constant bool align_M [[function_constant(200)]]; +constant bool align_N [[function_constant(201)]]; +constant bool align_K [[function_constant(202)]]; + +template < + typename T, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose_a, + bool transpose_b, + typename AccumType = float> +[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void gather_mm_rhs( + const device T* A [[buffer(0)]], + const device T* B [[buffer(1)]], + const device uint32_t* rhs_indices [[buffer(2)]], + device T* C [[buffer(3)]], + const constant GEMMParams* params [[buffer(4)]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint3 tid [[threadgroup_position_in_grid]]) { + using gemm_kernel = GEMMKernel< + T, + T, + BM, + BN, + BK, + WM, + WN, + transpose_a, + transpose_b, + true, + true, + AccumType>; + + using loader_a_t = typename gemm_kernel::loader_a_t; + using loader_b_t = typename gemm_kernel::loader_b_t; + using mma_t = typename gemm_kernel::mma_t; + + if (params->tiles_n <= static_cast(tid.x) || + params->tiles_m <= static_cast(tid.y)) { + return; + } + + // Prepare threadgroup memory + threadgroup T As[gemm_kernel::tgp_mem_size_a]; + threadgroup T Bs[gemm_kernel::tgp_mem_size_b]; + + // Find the block in A, B, C + const int c_row = tid.y * BM; + const int c_col = tid.x * BN; + const size_t c_row_long = size_t(c_row); + const size_t c_col_long = size_t(c_col); + + // Prepare threadgroup bounds + const short tgp_bm = align_M ? BM : short(min(BM, params->M - c_row)); + const short tgp_bn = align_N ? BN : short(min(BN, params->N - c_col)); + + A += transpose_a ? c_row_long : c_row_long * params->lda; + B += transpose_b ? c_col_long * params->ldb : c_col_long; + C += c_row_long * params->ldd + c_col_long; + + // Do as many matmuls as necessary + uint32_t index; + short offset; + uint32_t index_next = rhs_indices[c_row]; + short offset_next = 0; + int n = 0; + while (n < tgp_bm) { + n++; + offset = offset_next; + index = index_next; + offset_next = tgp_bm; + for (; n < tgp_bm; n++) { + if (rhs_indices[c_row + n] != index) { + offset_next = n; + index_next = rhs_indices[c_row + n]; + break; + } + } + threadgroup_barrier(mem_flags::mem_none); + + // Prepare threadgroup mma operation + thread mma_t mma_op(simd_group_id, simd_lane_id); + + // Prepare threadgroup loading operations + thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id); + thread loader_b_t loader_b( + B + index * params->batch_stride_b, + params->ldb, + Bs, + simd_group_id, + simd_lane_id); + + // Prepare iterations + const int gemm_k_iterations = params->gemm_k_iterations_aligned; + + // Do unaligned K iterations first + if (!align_K) { + const int k_last = params->gemm_k_iterations_aligned * BK; + const int k_remain = params->K - k_last; + const size_t k_jump_a = + transpose_a ? params->lda * size_t(k_last) : size_t(k_last); + const size_t k_jump_b = + transpose_b ? size_t(k_last) : params->ldb * size_t(k_last); + + // Move loader source ahead to end + loader_a.src += k_jump_a; + loader_b.src += k_jump_b; + + // Load tile + const short2 tile_dims_A = + transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm); + const short2 tile_dims_B = + transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain); + + loader_a.load_safe(tile_dims_A); + loader_b.load_safe(tile_dims_B); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Do matmul + mma_op.mma(As, Bs); + + // Reset source back to start + loader_a.src -= k_jump_a; + loader_b.src -= k_jump_b; + } + + // Matrix level aligned never check + if (align_M && align_N) { + for (int k = 0; k < gemm_k_iterations; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Load elements into threadgroup + loader_a.load_unsafe(); + loader_b.load_unsafe(); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + + // Store results to device memory + if (offset_next - offset == BM) { + mma_op.store_result(C, params->ldd); + } else { + mma_op.store_result_slice( + C, params->ldd, short2(0, offset), short2(BN, offset_next)); + } + } else { + const short lbk = 0; + + // Tile aligned don't check + if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) { + gemm_kernel::gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + lbk, + LoopAlignment{}); + if (offset_next - offset == BM) { + mma_op.store_result(C, params->ldd); + } else { + mma_op.store_result_slice( + C, params->ldd, short2(0, offset), short2(BN, offset_next)); + } + } + + // Tile partially aligned check rows + else if (align_N || tgp_bn == BN) { + gemm_kernel::gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + lbk, + LoopAlignment{}); + mma_op.store_result_slice( + C, params->ldd, short2(0, offset), short2(BN, offset_next)); + } + + // Tile partially aligned check cols + else if (align_M || tgp_bm == BM) { + gemm_kernel::gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + lbk, + LoopAlignment{}); + mma_op.store_result_slice( + C, params->ldd, short2(0, offset), short2(tgp_bn, offset_next)); + } + + // Nothing aligned so check both rows and cols + else { + gemm_kernel::gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + lbk, + LoopAlignment{}); + mma_op.store_result_slice( + C, params->ldd, short2(0, offset), short2(tgp_bn, offset_next)); + } + } + } +} + +template < + typename T, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose_a, + bool transpose_b, + typename AccumType = float> +[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void gather_mm( + const device T* A [[buffer(0)]], + const device T* B [[buffer(1)]], + const device uint32_t* lhs_indices [[buffer(2)]], + const device uint32_t* rhs_indices [[buffer(3)]], + device T* C [[buffer(4)]], + const constant GEMMParams* params [[buffer(5)]], + const constant int* indices_shape [[buffer(6)]], + const constant int64_t* lhs_strides [[buffer(7)]], + const constant int64_t* rhs_strides [[buffer(8)]], + const constant int& batch_ndim_a [[buffer(9)]], + const constant int* batch_shape_a [[buffer(10)]], + const constant int64_t* batch_strides_a [[buffer(11)]], + const constant int& batch_ndim_b [[buffer(12)]], + const constant int* batch_shape_b [[buffer(13)]], + const constant int64_t* batch_strides_b [[buffer(14)]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint3 tid [[threadgroup_position_in_grid]]) { + using gemm_kernel = GEMMKernel< + T, + T, + BM, + BN, + BK, + WM, + WN, + transpose_a, + transpose_b, + true, + true, + AccumType>; + + using loader_a_t = typename gemm_kernel::loader_a_t; + using loader_b_t = typename gemm_kernel::loader_b_t; + using mma_t = typename gemm_kernel::mma_t; + + if (params->tiles_n <= static_cast(tid.x) || + params->tiles_m <= static_cast(tid.y)) { + return; + } + + // Move A and B to the locations pointed by lhs_indices and rhs_indices. + uint32_t indx_A, indx_B; + if (has_batch) { + ulong2 indices_offsets = elem_to_loc_broadcast( + tid.z, indices_shape, lhs_strides, rhs_strides, params->batch_ndim); + indx_A = lhs_indices[indices_offsets.x]; + indx_B = rhs_indices[indices_offsets.y]; + } else { + indx_A = lhs_indices[params->batch_stride_a * tid.z]; + indx_B = rhs_indices[params->batch_stride_b * tid.z]; + } + A += elem_to_loc(indx_A, batch_shape_a, batch_strides_a, batch_ndim_a); + B += elem_to_loc(indx_B, batch_shape_b, batch_strides_b, batch_ndim_b); + C += params->batch_stride_d * tid.z; + + // Prepare threadgroup memory + threadgroup T As[gemm_kernel::tgp_mem_size_a]; + threadgroup T Bs[gemm_kernel::tgp_mem_size_b]; + + // Just make sure everybody's finished with the indexing math above. + threadgroup_barrier(mem_flags::mem_none); + + // Find block in A, B, C + const int c_row = tid.y * BM; + const int c_col = tid.x * BN; + const size_t c_row_long = size_t(c_row); + const size_t c_col_long = size_t(c_col); + + A += transpose_a ? c_row_long : c_row_long * params->lda; + B += transpose_b ? c_col_long * params->ldb : c_col_long; + C += c_row_long * params->ldd + c_col_long; + + // Prepare threadgroup mma operation + thread mma_t mma_op(simd_group_id, simd_lane_id); + + // Prepare threadgroup loading operations + thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id); + thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id); + + // Prepare threadgroup bounds + const short tgp_bm = align_M ? BM : short(min(BM, params->M - c_row)); + const short tgp_bn = align_N ? BN : short(min(BN, params->N - c_col)); + + // Prepare iterations + int gemm_k_iterations = params->gemm_k_iterations_aligned; + + // Do unaligned K iterations first + if (!align_K) { + const int k_last = params->gemm_k_iterations_aligned * BK; + const int k_remain = params->K - k_last; + const size_t k_jump_a = + transpose_a ? params->lda * size_t(k_last) : size_t(k_last); + const size_t k_jump_b = + transpose_b ? size_t(k_last) : params->ldb * size_t(k_last); + + // Move loader source ahead to end + loader_a.src += k_jump_a; + loader_b.src += k_jump_b; + + // Load tile + const short2 tile_dims_A = + transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm); + const short2 tile_dims_B = + transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain); + + loader_a.load_safe(tile_dims_A); + loader_b.load_safe(tile_dims_B); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Do matmul + mma_op.mma(As, Bs); + + // Reset source back to start + loader_a.src -= k_jump_a; + loader_b.src -= k_jump_b; + } + + // Matrix level aligned never check + if (align_M && align_N) { + for (int k = 0; k < gemm_k_iterations; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Load elements into threadgroup + loader_a.load_unsafe(); + loader_b.load_unsafe(); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + + // Store results to device memory + mma_op.store_result(C, params->ldd); + } else { + const short lbk = 0; + + // Tile aligned don't check + if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) { + gemm_kernel::gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + lbk, + LoopAlignment{}); + mma_op.store_result(C, params->ldd); + } + + // Tile partially aligned check rows + else if (align_N || tgp_bn == BN) { + gemm_kernel::gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + lbk, + LoopAlignment{}); + mma_op.store_result_safe(C, params->ldd, short2(tgp_bn, tgp_bm)); + } + + // Tile partially aligned check cols + else if (align_M || tgp_bm == BM) { + gemm_kernel::gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + lbk, + LoopAlignment{}); + mma_op.store_result_safe(C, params->ldd, short2(tgp_bn, tgp_bm)); + } + + // Nothing aligned so check both rows and cols + else { + gemm_kernel::gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + lbk, + LoopAlignment{}); + mma_op.store_result_safe(C, params->ldd, short2(tgp_bn, tgp_bm)); + } + } +} diff --git a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.metal b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.metal new file mode 100644 index 000000000..f8e5a2a37 --- /dev/null +++ b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.metal @@ -0,0 +1,59 @@ +// Copyright © 2024 Apple Inc. + +// clang-format off +#include "mlx/backend/metal/kernels/utils.h" +#include "mlx/backend/metal/kernels/steel/gemm/gemm.h" +#include "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h" + +#define instantiate_gather_mm_rhs(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_kernel( \ + "steel_gather_mm_rhs_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn \ + "_bk" #bk "_wm" #wm "_wn" #wn, \ + gather_mm_rhs, \ + itype, \ + bm, \ + bn, \ + bk, \ + wm, \ + wn, \ + trans_a, \ + trans_b, \ + float) + +#define instantiate_gather_mm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_kernel( \ + "steel_gather_mm_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn \ + "_bk" #bk "_wm" #wm "_wn" #wn, \ + gather_mm, \ + itype, \ + bm, \ + bn, \ + bk, \ + wm, \ + wn, \ + trans_a, \ + trans_b, \ + float) + +#define instantiate_gather_mm_rhs_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_gather_mm_rhs(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_gather_mm_rhs(nt, false, true, iname, itype, oname, otype, bm, bn, bk, wm, wn) + +#define instantiate_gather_mm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_gather_mm(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_gather_mm(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_gather_mm(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_gather_mm(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn) + +#define instantiate_gather_mm_shapes_helper(iname, itype, oname, otype) \ + instantiate_gather_mm_rhs_transpose_helper(iname, itype, oname, otype, 16, 64, 16, 1, 2) \ + instantiate_gather_mm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) \ + instantiate_gather_mm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 1, 2) \ + instantiate_gather_mm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \ + instantiate_gather_mm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 1, 2) \ + instantiate_gather_mm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) +// clang-format on + +instantiate_gather_mm_shapes_helper(float16, half, float16, half); +instantiate_gather_mm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t); +instantiate_gather_mm_shapes_helper(float32, float, float32, float); diff --git a/mlx/backend/metal/kernels/steel/gemm/mma.h b/mlx/backend/metal/kernels/steel/gemm/mma.h index aea235abb..64b87655e 100644 --- a/mlx/backend/metal/kernels/steel/gemm/mma.h +++ b/mlx/backend/metal/kernels/steel/gemm/mma.h @@ -142,6 +142,42 @@ struct BaseMMAFrag { } } + template < + typename DstPtrType, + typename StrX, + typename StrY, + typename StartX, + typename StopX, + typename StartY, + typename StopY, + typename OffX, + typename OffY> + METAL_FUNC static constexpr void store_slice( + const thread frag_type& src, + DstPtrType dst, + StrX str_x, + StrY str_y, + StartX start_x, + StopX stop_x, + StartY start_y, + StopY stop_y, + OffX off_x = Int<0>{}, + OffY off_y = Int<0>{}) { + using U = pointer_element_t; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + if ((off_x + i) < stop_x && (off_x + i) >= start_x && + (off_y + j) < stop_y && (off_y + j) >= start_y) { + dst[(off_x + i) * str_x + (off_y + j) * str_y] = + static_cast(src[i * kElemCols + j]); + } + } + } + } + METAL_FUNC static constexpr void mma( thread frag_type& D, thread frag_type& A, @@ -335,6 +371,31 @@ struct MMATile { } } } + + template + METAL_FUNC void store_slice( + device U* dst, + const int ld, + const short2 start, + const short2 stop) const { + STEEL_PRAGMA_UNROLL + for (int i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < kTileCols; ++j) { + MMAFrag_t::store_slice( + frag_at(i, j), + dst, + ld, + Int<1>{}, + start.y, + stop.y, + start.x, + stop.x, + (i * kFragRows) * w_x, + (j * kFragCols) * w_y); + } + } + } }; template @@ -474,6 +535,26 @@ struct BlockMMA { Ctile.template store(D, ldd); } + METAL_FUNC void + store_result_slice(device U* D, const int ldd, short2 start, short2 stop) { + // Apply epilogue + STEEL_PRAGMA_UNROLL + for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) { + Ctile.elems()[i] = Epilogue::apply(Ctile.elems()[i]); + } + + D += sm * ldd + sn; + start -= short2(sn, sm); + stop -= short2(sn, sm); + + // TODO: Check the start as well + if (stop.y <= 0 || stop.x <= 0) { + return; + } + + Ctile.template store_slice(D, ldd, start, stop); + } + METAL_FUNC void store_result_safe(device U* D, const int ldd, short2 dst_tile_dims) { // Apply epilogue diff --git a/mlx/backend/metal/kernels/unary.metal b/mlx/backend/metal/kernels/unary.metal index 82692c8e5..2209b0665 100644 --- a/mlx/backend/metal/kernels/unary.metal +++ b/mlx/backend/metal/kernels/unary.metal @@ -73,6 +73,9 @@ instantiate_unary_all_same(Conjugate, complex64, complex64_t) instantiate_unary_all_same(Cos, complex64, complex64_t) instantiate_unary_all_same(Cosh, complex64, complex64_t) instantiate_unary_all_same(Exp, complex64, complex64_t) +instantiate_unary_all_same(Log, complex64, complex64_t) +instantiate_unary_all_same(Log2, complex64, complex64_t) +instantiate_unary_all_same(Log10, complex64, complex64_t) instantiate_unary_all_same(Negative, complex64, complex64_t) instantiate_unary_all_same(Sign, complex64, complex64_t) instantiate_unary_all_same(Sin, complex64, complex64_t) diff --git a/mlx/backend/metal/kernels/unary_ops.h b/mlx/backend/metal/kernels/unary_ops.h index ceed3efe5..52e126b40 100644 --- a/mlx/backend/metal/kernels/unary_ops.h +++ b/mlx/backend/metal/kernels/unary_ops.h @@ -257,6 +257,13 @@ struct Log { T operator()(T x) { return metal::precise::log(x); }; + + template <> + complex64_t operator()(complex64_t x) { + auto r = metal::precise::log(Abs{}(x).real); + auto i = metal::precise::atan2(x.imag, x.real); + return {r, i}; + }; }; struct Log2 { @@ -264,6 +271,12 @@ struct Log2 { T operator()(T x) { return metal::precise::log2(x); }; + + template <> + complex64_t operator()(complex64_t x) { + auto y = Log{}(x); + return {y.real / M_LN2_F, y.imag / M_LN2_F}; + }; }; struct Log10 { @@ -271,6 +284,12 @@ struct Log10 { T operator()(T x) { return metal::precise::log10(x); }; + + template <> + complex64_t operator()(complex64_t x) { + auto y = Log{}(x); + return {y.real / M_LN10_F, y.imag / M_LN10_F}; + }; }; struct Log1p { diff --git a/mlx/backend/metal/logsumexp.cpp b/mlx/backend/metal/logsumexp.cpp new file mode 100644 index 000000000..4901190e1 --- /dev/null +++ b/mlx/backend/metal/logsumexp.cpp @@ -0,0 +1,96 @@ +// Copyright © 2023-2024 Apple Inc. +#include + +#include "mlx/backend/metal/copy.h" +#include "mlx/backend/metal/device.h" +#include "mlx/backend/metal/kernels.h" +#include "mlx/backend/metal/utils.h" +#include "mlx/primitives.h" + +namespace mlx::core { + +constexpr int LOGSUMEXP_LOOPED_LIMIT = 4096; + +void LogSumExp::eval_gpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + if (!issubdtype(out.dtype(), floating)) { + throw std::runtime_error( + "[logsumexp] Does not support non-floating point types."); + } + auto& s = stream(); + auto& d = metal::device(s.device); + + // Make sure that the last dimension is contiguous + auto ensure_contiguous = [&s, &d](const array& x) { + if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) { + return x; + } else { + auto x_copy = array(x.shape(), x.dtype(), nullptr, {}); + copy_gpu(x, x_copy, CopyType::General, s); + d.add_temporary(x_copy, s.index); + return x_copy; + } + }; + + auto in = ensure_contiguous(inputs[0]); + if (in.flags().row_contiguous) { + out.set_data(allocator::malloc(out.nbytes())); + } else { + auto n = in.shape(-1); + auto flags = in.flags(); + auto strides = in.strides(); + for (auto& s : strides) { + s /= n; + } + bool col_contig = strides[0] == 1; + for (int i = 1; col_contig && i < strides.size(); ++i) { + col_contig &= + (out.shape(i) == 1 || strides[i - 1] == out.shape(i) * strides[i]); + } + flags.col_contiguous = col_contig; + out.set_data( + allocator::malloc(in.nbytes() / n), + in.data_size() / n, + std::move(strides), + flags); + } + + int axis_size = in.shape().back(); + int n_rows = in.data_size() / axis_size; + + const int simd_size = 32; + const int n_reads = 4; + const int looped_limit = LOGSUMEXP_LOOPED_LIMIT; + + std::string kernel_name = (axis_size > looped_limit) ? "looped_" : "block_"; + kernel_name += "logsumexp_"; + kernel_name += type_to_name(out); + + auto kernel = get_logsumexp_kernel(d, kernel_name, out); + auto& compute_encoder = d.get_command_encoder(s.index); + { + MTL::Size grid_dims, group_dims; + if (axis_size <= looped_limit) { + size_t threadgroup_needed = (axis_size + n_reads - 1) / n_reads; + size_t simds_needed = (threadgroup_needed + simd_size - 1) / simd_size; + size_t threadgroup_size = simd_size * simds_needed; + assert(threadgroup_size <= kernel->maxTotalThreadsPerThreadgroup()); + size_t n_threads = n_rows * threadgroup_size; + grid_dims = MTL::Size(n_threads, 1, 1); + group_dims = MTL::Size(threadgroup_size, 1, 1); + } else { + size_t threadgroup_size = kernel->maxTotalThreadsPerThreadgroup(); + size_t n_threads = n_rows * threadgroup_size; + grid_dims = MTL::Size(n_threads, 1, 1); + group_dims = MTL::Size(threadgroup_size, 1, 1); + } + + compute_encoder.set_compute_pipeline_state(kernel); + compute_encoder.set_input_array(in, 0); + compute_encoder.set_output_array(out, 1); + compute_encoder.set_bytes(axis_size, 2); + compute_encoder.dispatch_threads(grid_dims, group_dims); + } +} + +} // namespace mlx::core diff --git a/mlx/backend/metal/matmul.cpp b/mlx/backend/metal/matmul.cpp index 3f736505f..f55d20c9f 100644 --- a/mlx/backend/metal/matmul.cpp +++ b/mlx/backend/metal/matmul.cpp @@ -5,6 +5,7 @@ #include #include +#include "mlx/backend/common/broadcasting.h" #include "mlx/backend/common/utils.h" #include "mlx/backend/metal/copy.h" #include "mlx/backend/metal/device.h" @@ -102,6 +103,47 @@ std::tuple check_transpose( } }; +inline array +ensure_row_contiguous(const array& x, metal::Device& d, const Stream& s) { + if (!x.flags().row_contiguous) { + array x_copy(x.shape(), x.dtype(), nullptr, {}); + copy_gpu(x, x_copy, CopyType::General, s); + d.add_temporary(x_copy, s.index); + return x_copy; + } else { + return x; + } +} + +inline std::tuple +ensure_batch_contiguous(const array& x, metal::Device& d, const Stream& s) { + if (x.flags().row_contiguous) { + return std::make_tuple(false, x.strides()[x.ndim() - 2], x); + } + + bool rc = true; + for (int i = 0; i < x.ndim() - 3; i++) { + rc &= x.strides()[i + 1] * x.shape(i) == x.strides()[i]; + } + if (rc) { + auto stx = x.strides()[x.ndim() - 2]; + auto sty = x.strides()[x.ndim() - 1]; + auto K = x.shape(-2); + auto N = x.shape(-1); + if (sty == 1 && (N != 1 || stx == N)) { + return std::make_tuple(false, stx, x); + } + if (stx == 1 && (N != 1 || sty == K)) { + return std::make_tuple(true, sty, x); + } + } + + array x_copy(x.shape(), x.dtype(), nullptr, {}); + copy_gpu(x, x_copy, CopyType::General, s); + d.add_temporary(x_copy, s.index); + return std::make_tuple(false, x_copy.strides()[x_copy.ndim() - 2], x_copy); +} + } // namespace /////////////////////////////////////////////////////////////////////////////// @@ -230,7 +272,6 @@ void steel_matmul_regular( const bool align_M = (M % bm) == 0; const bool align_N = (N % bn) == 0; const bool align_K = (K % bk) == 0; - const bool do_gather = false; metal::MTLFCList func_consts = { {&has_batch, MTL::DataType::DataTypeBool, 10}, @@ -239,7 +280,6 @@ void steel_matmul_regular( {&align_M, MTL::DataType::DataTypeBool, 200}, {&align_N, MTL::DataType::DataTypeBool, 201}, {&align_K, MTL::DataType::DataTypeBool, 202}, - {&do_gather, MTL::DataType::DataTypeBool, 300}, }; // clang-format off @@ -248,8 +288,7 @@ void steel_matmul_regular( << "_do_axpby_" << (do_axpby ? 't' : 'n') << "_align_M_" << (align_M ? 't' : 'n') << "_align_N_" << (align_N ? 't' : 'n') - << "_align_K_" << (align_K ? 't' : 'n') - << "_do_gather_" << (do_gather ? 't' : 'n'); // clang-format on + << "_align_K_" << (align_K ? 't' : 'n'); // clang-format on std::string hash_name = kname.str(); @@ -975,7 +1014,6 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { const bool align_M = (M % bm) == 0; const bool align_N = (N % bn) == 0; const bool align_K = (K % bk) == 0; - const bool do_gather = false; metal::MTLFCList func_consts = { {&has_batch, MTL::DataType::DataTypeBool, 10}, @@ -984,7 +1022,6 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { {&align_M, MTL::DataType::DataTypeBool, 200}, {&align_N, MTL::DataType::DataTypeBool, 201}, {&align_K, MTL::DataType::DataTypeBool, 202}, - {&do_gather, MTL::DataType::DataTypeBool, 300}, }; // clang-format off @@ -993,8 +1030,7 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { << "_do_axpby_" << (do_axpby ? 't' : 'n') << "_align_M_" << (align_M ? 't' : 'n') << "_align_N_" << (align_N ? 't' : 'n') - << "_align_K_" << (align_K ? 't' : 'n') - << "_do_gather_" << (do_gather ? 't' : 'n'); // clang-format on + << "_align_K_" << (align_K ? 't' : 'n'); // clang-format on std::string hash_name = kname.str(); @@ -1464,267 +1500,337 @@ void BlockMaskedMM::eval_gpu(const std::vector& inputs, array& out) { d.add_temporaries(std::move(copies), s.index); } -void GatherMM::eval_gpu(const std::vector& inputs, array& out) { - using namespace mlx::steel; - // assert(inputs.size() == 2); - if (!issubdtype(out.dtype(), floating)) { - throw std::runtime_error( - "[GatherMM] Does not yet support non-floating point types."); - } - auto& s = stream(); - auto& d = metal::device(s.device); +void gather_mm_rhs( + const array& a_, + const array& b_, + const array& indices_, + array& out, + metal::Device& d, + const Stream& s) { + array indices = ensure_row_contiguous(indices_, d, s); + auto [transpose_b, ldb, b] = ensure_batch_contiguous(b_, d, s); - auto& a_pre = inputs[0]; - auto& b_pre = inputs[1]; - // Return 0s if either input is empty - if (a_pre.size() == 0 || b_pre.size() == 0) { - array zero = array(0, a_pre.dtype()); - fill_gpu(zero, out, s); - d.add_temporary(std::move(zero), s.index); - return; - } + // Broadcast a with indices. If we are here that means lhs_indices were not + // provided so the lhs_indices are implied to be the shape of a broadcasted + // with rhs_indices. We need only broadcast a and copy it as if applying the + // lhs_indices. + auto broadcast_with_indices = [&d, &s, &indices](const array& x) { + if (x.size() / x.shape(-2) / x.shape(-1) == indices.size()) { + return ensure_row_contiguous(x, d, s); + } - out.set_data(allocator::malloc(out.nbytes())); + auto x_shape = indices.shape(); + x_shape.push_back(x.shape(-2)); + x_shape.push_back(x.shape(-1)); + array new_x(std::move(x_shape), x.dtype(), nullptr, {}); + broadcast(x, new_x); + return ensure_row_contiguous(new_x, d, s); + }; + array a = broadcast_with_indices(a_); - ///////////////////////////////////////////////////////////////////////////// - // Init checks and prep + // Extract the matmul shapes + int K = a.shape(-1); + int M = a.size() / K; + int N = b.shape(-1); + int lda = a.strides()[a.ndim() - 2]; // should be K - int M = a_pre.shape(-2); - int N = b_pre.shape(-1); - int K = a_pre.shape(-1); + // Define the dispatch blocks + int bm = 16, bn = 64, bk = 16; + int wm = 1, wn = 2; - // Keep a vector with copies to be cleared in the completed buffer to release - // the arrays - std::vector copies; - auto [transpose_a, a_cols, a] = check_transpose(copies, s, a_pre, M == 1); - auto [transpose_b, b_cols, b] = check_transpose(copies, s, b_pre, N == 1); + const bool align_M = (M % bm) == 0; + const bool align_N = (N % bn) == 0; + const bool align_K = (K % bk) == 0; - int lda = a_cols; - int ldb = b_cols; + // Define the kernel name + std::string base_name; + base_name.reserve(64); + concatenate( + base_name, + "steel_gather_mm_rhs_n", + transpose_b ? 't' : 'n', + '_', + type_to_name(a), + '_', + type_to_name(out), + "_bm", + bm, + "_bn", + bn, + "_bk", + bk, + "_wm", + wm, + "_wn", + wn); - ///////////////////////////////////////////////////////////////////////////// - // Check and collapse batch dimensions - - auto get_batch_dims = [](const auto& v) { - return decltype(v){v.begin(), v.end() - 2}; + metal::MTLFCList func_consts = { + {&align_M, MTL::DataType::DataTypeBool, 200}, + {&align_N, MTL::DataType::DataTypeBool, 201}, + {&align_K, MTL::DataType::DataTypeBool, 202}, }; - auto& lhs_indices = inputs[2]; - auto& rhs_indices = inputs[3]; + // And the kernel hash that includes the function constants + std::string hash_name; + hash_name.reserve(128); + concatenate( + hash_name, + base_name, + "_align_M_", + align_M ? 't' : 'n', + "_align_N_", + align_N ? 't' : 'n', + "_align_K_", + align_K ? 't' : 'n'); - Shape batch_shape = get_batch_dims(out.shape()); - Strides batch_strides; + // Get and set the kernel + auto& compute_encoder = d.get_command_encoder(s.index); + auto kernel = get_steel_gemm_gather_kernel( + d, + base_name, + hash_name, + func_consts, + out, + false, + transpose_b, + bm, + bn, + bk, + wm, + wn, + true); + compute_encoder.set_compute_pipeline_state(kernel); - batch_strides.insert( - batch_strides.end(), - lhs_indices.strides().begin(), - lhs_indices.strides().end()); - auto lhs_indices_str = batch_strides.empty() ? 0 : batch_strides.back(); + // Prepare the matmul params + auto batch_stride_b = b.ndim() > 2 ? b.strides()[b.ndim() - 3] : b.size(); + steel::GEMMParams params{ + /* const int M = */ M, + /* const int N = */ N, + /* const int K = */ K, + /* const int lda = */ lda, + /* const int ldb = */ static_cast(ldb), + /* const int ldd = */ N, + /* const int tiles_n = */ (N + bn - 1) / bn, + /* const int tiles_m = */ (M + bm - 1) / bm, + /* const int64_t batch_stride_a = */ 0, + /* const int64_t batch_stride_b = */ static_cast(batch_stride_b), + /* const int64_t batch_stride_d = */ 0, + /* const int swizzle_log = */ 0, + /* const int gemm_k_iterations_aligned = */ (K / bk), + /* const int batch_ndim = */ 0}; - batch_strides.insert( - batch_strides.end(), - rhs_indices.strides().begin(), - rhs_indices.strides().end()); - auto rhs_indices_str = batch_strides.empty() ? 0 : batch_strides.back(); + // Prepare the grid + MTL::Size group_dims = MTL::Size(32, wn, wm); + MTL::Size grid_dims = MTL::Size(params.tiles_n, params.tiles_m, 1); - int batch_ndim = batch_shape.size(); + // Launch kernel + compute_encoder.set_input_array(a, 0); + compute_encoder.set_input_array(b, 1); + compute_encoder.set_input_array(indices, 2); + compute_encoder.set_output_array(out, 3); + compute_encoder.set_bytes(params, 4); - if (batch_ndim == 0) { - batch_shape = {1}; - batch_strides = {0}; - } + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); +} - int batch_ndim_A = a.ndim() - 2; - int batch_ndim_B = b.ndim() - 2; - std::vector operand_batch_ndim = {batch_ndim_A, batch_ndim_B}; +void gather_mv( + const array& mat_, + const array& vec_, + const array& mat_indices_, + const array& vec_indices_, + array& out, + int N, + int K, + bool is_mv, + metal::Device& d, + const Stream& s) { + // Copy if needed + std::vector copies; + auto [transpose_mat, mat_cols, mat] = + check_transpose(copies, s, mat_, N == 1); + auto [transpose_vec, vec_cols, vec] = check_transpose(copies, s, vec_, true); + d.add_temporaries(std::move(copies), s.index); - Shape batch_shape_A = get_batch_dims(a.shape()); - Strides batch_strides_A = get_batch_dims(a.strides()); - Shape batch_shape_B = get_batch_dims(b.shape()); - Strides batch_strides_B = get_batch_dims(b.strides()); + // If we are doing vector matrix instead of matrix vector we need to flip the + // matrix transposition. Basically m @ v = v @ m.T assuming that v is treated + // as a one dimensional array. + transpose_mat = (!is_mv) ^ transpose_mat; - if (batch_ndim_A == 0) { - batch_shape_A = {1}; - batch_strides_A = {0}; - } + // Define some shapes + int in_vector_len = K; + int out_vector_len = N; + int mat_ld = mat_cols; - if (batch_ndim_B == 0) { - batch_shape_B = {1}; - batch_strides_B = {0}; - } + int batch_size_out = out.size() / N; + int batch_ndim = out.ndim() - 2; + int batch_ndim_mat = mat.ndim() - 2; + int batch_ndim_vec = vec.ndim() - 2; + Strides index_strides = vec_indices_.strides(); + index_strides.insert( + index_strides.end(), + mat_indices_.strides().begin(), + mat_indices_.strides().end()); - auto matrix_stride_out = static_cast(M) * N; - auto batch_size_out = out.size() / matrix_stride_out; - - ///////////////////////////////////////////////////////////////////////////// - // Gemv specialization - - // Route to gemv if needed - if (std::min(M, N) == 1) { - // Collect problem info - bool is_b_matrix = N != 1; - - auto& mat = is_b_matrix ? b : a; - auto& vec = is_b_matrix ? a : b; - bool transpose_mat = is_b_matrix ? !transpose_b : transpose_a; - int in_vector_len = K; - int out_vector_len = is_b_matrix ? N : M; - - int mat_cols = transpose_mat ? out_vector_len : in_vector_len; - int mat_rows = transpose_mat ? in_vector_len : out_vector_len; - int mat_ld = is_b_matrix ? b_cols : a_cols; - - auto batch_strides_mat = is_b_matrix ? batch_strides_B : batch_strides_A; - auto batch_strides_vec = is_b_matrix ? batch_strides_A : batch_strides_B; - - auto batch_shape_mat = is_b_matrix ? batch_shape_B : batch_shape_A; - auto batch_shape_vec = is_b_matrix ? batch_shape_A : batch_shape_B; - - if (!is_b_matrix) { - batch_strides = rhs_indices.strides(); - batch_strides.insert( - batch_strides.end(), - lhs_indices.strides().begin(), - lhs_indices.strides().end()); - } - - int batch_ndim = batch_shape.size(); - - // Determine dispatch kernel - int tm = 4, tn = 4; - int sm = 1, sn = 32; - int bm = 1, bn = 1; - int n_out_per_tgp; - std::ostringstream kname; - - if (transpose_mat) { - if (in_vector_len >= 8192 && out_vector_len >= 2048) { - sm = 4; - sn = 8; - } else { - sm = 8; - sn = 4; - } - - if (out_vector_len >= 2048) { - bn = 16; - } else if (out_vector_len >= 512) { - bn = 4; - } else { - bn = 2; - } - - // Specialized kernel for very small outputs - tn = out_vector_len < tn ? 1 : tn; - - n_out_per_tgp = bn * sn * tn; - kname << "gemv_t_gather_" << type_to_name(out); + // Determine dispatch kernel + int tm = 4, tn = 4; + int sm = 1, sn = 32; + int bm = 1, bn = 1; + int n_out_per_tgp; + std::ostringstream kname; + if (transpose_mat) { + if (in_vector_len >= 8192 && out_vector_len >= 2048) { + sm = 4; + sn = 8; } else { - bm = out_vector_len >= 4096 ? 8 : 4; - sn = 32; - - // Specialized kernel for very small outputs - tm = out_vector_len < tm ? 1 : tm; - - n_out_per_tgp = bm * sm * tm; - kname << "gemv_gather_" << type_to_name(out); + sm = 8; + sn = 4; } - kname << "_bm" << bm << "_bn" << bn << "_sm" << sm << "_sn" << sn << "_tm" - << tm << "_tn" << tn; + if (out_vector_len >= 2048) { + bn = 16; + } else if (out_vector_len >= 512) { + bn = 4; + } else { + bn = 2; + } - // Encode and dispatch kernel - auto& compute_encoder = d.get_command_encoder(s.index); - auto kernel = d.get_kernel(kname.str()); - compute_encoder.set_compute_pipeline_state(kernel); + // Specialized kernel for very small outputs + tn = out_vector_len < tn ? 1 : tn; - int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp; - MTL::Size group_dims = MTL::Size(32, bn, bm); - MTL::Size grid_dims = MTL::Size(n_tgp, 1, batch_size_out); + n_out_per_tgp = bn * sn * tn; + kname << "gemv_t_gather_" << type_to_name(out); - compute_encoder.set_input_array(mat, 0); - compute_encoder.set_input_array(vec, 1); - compute_encoder.set_output_array(out, 3); + } else { + bm = out_vector_len >= 4096 ? 8 : 4; + sn = 32; - compute_encoder.set_bytes(in_vector_len, 4); - compute_encoder.set_bytes(out_vector_len, 5); - compute_encoder.set_bytes(mat_ld, 6); + // Specialized kernel for very small outputs + tm = out_vector_len < tm ? 1 : tm; - compute_encoder.set_bytes(batch_ndim, 9); - compute_encoder.set_vector_bytes(batch_shape, 10); - compute_encoder.set_vector_bytes(batch_strides, 11); - - int batch_ndim_vec = batch_shape_vec.size(); - compute_encoder.set_bytes(batch_ndim_vec, 12); - compute_encoder.set_vector_bytes(batch_shape_vec, 13); - compute_encoder.set_vector_bytes(batch_strides_vec, 14); - - int batch_ndim_mat = batch_shape_mat.size(); - compute_encoder.set_bytes(batch_ndim_mat, 15); - compute_encoder.set_vector_bytes(batch_shape_mat, 16); - compute_encoder.set_vector_bytes(batch_strides_mat, 17); - - compute_encoder.set_input_array(lhs_indices, 18 + int(!is_b_matrix)); - compute_encoder.set_input_array(rhs_indices, 18 + int(is_b_matrix)); - - compute_encoder.dispatch_threadgroups(grid_dims, group_dims); - - d.add_temporaries(std::move(copies), s.index); - return; + n_out_per_tgp = bm * sm * tm; + kname << "gemv_gather_" << type_to_name(out); } - ///////////////////////////////////////////////////////////////////////////// - // Regular kernel dispatch + kname << "_bm" << bm << "_bn" << bn << "_sm" << sm << "_sn" << sn << "_tm" + << tm << "_tn" << tn; + + // Encode and dispatch kernel + auto& compute_encoder = d.get_command_encoder(s.index); + auto kernel = d.get_kernel(kname.str()); + compute_encoder.set_compute_pipeline_state(kernel); + + int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp; + MTL::Size group_dims = MTL::Size(32, bn, bm); + MTL::Size grid_dims = MTL::Size(n_tgp, 1, batch_size_out); + + compute_encoder.set_input_array(mat, 0); + compute_encoder.set_input_array(vec, 1); + compute_encoder.set_output_array(out, 3); + + compute_encoder.set_bytes(in_vector_len, 4); + compute_encoder.set_bytes(out_vector_len, 5); + compute_encoder.set_bytes(mat_ld, 6); + + compute_encoder.set_bytes(batch_ndim, 9); + compute_encoder.set_vector_bytes(out.shape(), 10); + compute_encoder.set_vector_bytes(index_strides, 11); + + compute_encoder.set_bytes(batch_ndim_vec, 12); + compute_encoder.set_vector_bytes(vec.shape(), 13); + compute_encoder.set_vector_bytes(vec.strides(), 14); + + compute_encoder.set_bytes(batch_ndim_mat, 15); + compute_encoder.set_vector_bytes(mat.shape(), 16); + compute_encoder.set_vector_bytes(mat.strides(), 17); + + compute_encoder.set_input_array(vec_indices_, 18); + compute_encoder.set_input_array(mat_indices_, 19); + + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); +} + +void gather_mm( + const array& a_, + const array& b_, + const array& lhs_indices, + const array& rhs_indices, + array& out, + int M, + int N, + int K, + metal::Device& d, + const Stream& s) { + // Copy if needed + std::vector copies; + auto [transpose_a, lda, a] = check_transpose(copies, s, a_, false); + auto [transpose_b, ldb, b] = check_transpose(copies, s, b_, false); + d.add_temporaries(std::move(copies), s.index); // Determine dispatch kernel int bm = 64, bn = 64, bk = 16; int wm = 2, wn = 2; + size_t batch_size_out = out.size() / M / N; + int batch_ndim = out.ndim() - 2; + int batch_ndim_a = a.ndim() - 2; + int batch_ndim_b = b.ndim() - 2; char devc = d.get_architecture().back(); GEMM_TPARAM_MACRO(devc) - // Prepare kernel name - std::ostringstream kname; - kname << "steel_gemm_fused_" << (transpose_a ? 't' : 'n') - << (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_" - << type_to_name(out) << "_bm" << bm << "_bn" << bn << "_bk" << bk - << "_wm" << wm << "_wn" << wn; - - std::string base_name = kname.str(); - const bool has_batch = batch_ndim > 1; - const bool use_out_source = false; - const bool do_axpby = false; const bool align_M = (M % bm) == 0; const bool align_N = (N % bn) == 0; const bool align_K = (K % bk) == 0; - const bool do_gather = true; + + // Define the kernel name + std::string base_name; + base_name.reserve(128); + concatenate( + base_name, + "steel_gather_mm_", + transpose_a ? 't' : 'n', + transpose_b ? 't' : 'n', + "_", + type_to_name(a), + "_", + type_to_name(out), + "_bm", + bm, + "_bn", + bn, + "_bk", + bk, + "_wm", + wm, + "_wn", + wn); metal::MTLFCList func_consts = { {&has_batch, MTL::DataType::DataTypeBool, 10}, - {&use_out_source, MTL::DataType::DataTypeBool, 100}, - {&do_axpby, MTL::DataType::DataTypeBool, 110}, {&align_M, MTL::DataType::DataTypeBool, 200}, {&align_N, MTL::DataType::DataTypeBool, 201}, {&align_K, MTL::DataType::DataTypeBool, 202}, - {&do_gather, MTL::DataType::DataTypeBool, 300}, }; - // clang-format off - kname << "_has_batch_" << (has_batch ? 't' : 'n') - << "_use_out_source_" << (use_out_source ? 't' : 'n') - << "_do_axpby_" << (do_axpby ? 't' : 'n') - << "_align_M_" << (align_M ? 't' : 'n') - << "_align_N_" << (align_N ? 't' : 'n') - << "_align_K_" << (align_K ? 't' : 'n') - << "_do_gather_" << (do_gather ? 't' : 'n'); // clang-format on + // And the kernel hash that includes the function constants + std::string hash_name; + hash_name.reserve(128); + concatenate( + hash_name, + base_name, + "_has_batch_", + has_batch ? 't' : 'n', + "_align_M_", + align_M ? 't' : 'n', + "_align_N_", + align_N ? 't' : 'n', + "_align_K_", + align_K ? 't' : 'n'); - std::string hash_name = kname.str(); - - // Encode and dispatch kernel + // Get and set the kernel auto& compute_encoder = d.get_command_encoder(s.index); - auto kernel = get_steel_gemm_fused_kernel( + auto kernel = get_steel_gemm_gather_kernel( d, base_name, hash_name, @@ -1736,72 +1842,96 @@ void GatherMM::eval_gpu(const std::vector& inputs, array& out) { bn, bk, wm, - wn); - + wn, + false); compute_encoder.set_compute_pipeline_state(kernel); - // Use problem size to determine threadblock swizzle - int tn = (N + bn - 1) / bn; - int tm = (M + bm - 1) / bm; - - // TODO: Explore device-based tuning for swizzle - int swizzle_log = 0; // tm >= 6 ? 3 : (tm <= 3 ? 0 : 2); - - // Prepare steel matmul params - GEMMParams params{ + // Prepare the matmul params + steel::GEMMParams params{ /* const int M = */ M, /* const int N = */ N, /* const int K = */ K, - /* const int lda = */ lda, - /* const int ldb = */ ldb, + /* const int lda = */ static_cast(lda), + /* const int ldb = */ static_cast(ldb), /* const int ldd = */ N, - /* const int tiles_n = */ tn, - /* const int tiles_m = */ tm, - /* const int64_t batch_stride_a = */ lhs_indices_str, - /* const int64_t batch_stride_b = */ rhs_indices_str, - /* const int64_t batch_stride_d = */ matrix_stride_out, - /* const int swizzle_log = */ swizzle_log, + /* const int tiles_n = */ (N + bn - 1) / bn, + /* const int tiles_m = */ (M + bm - 1) / bm, + /* const int64_t batch_stride_a = */ + (batch_ndim > 0) ? lhs_indices.strides()[0] : 0, + /* const int64_t batch_stride_b = */ + (batch_ndim > 0) ? rhs_indices.strides()[0] : 0, + /* const int64_t batch_stride_d = */ M * N, + /* const int swizzle_log = */ 0, /* const int gemm_k_iterations_aligned = */ (K / bk), /* const int batch_ndim = */ batch_ndim}; - // Prepare launch grid params - int tile = 1 << swizzle_log; - tm = (tm + tile - 1) / tile; - tn = tn * tile; - + // Prepare the grid MTL::Size group_dims = MTL::Size(32, wn, wm); - MTL::Size grid_dims = MTL::Size(tn, tm, batch_size_out); + MTL::Size grid_dims = + MTL::Size(params.tiles_n, params.tiles_m, batch_size_out); // Launch kernel compute_encoder.set_input_array(a, 0); compute_encoder.set_input_array(b, 1); - compute_encoder.set_output_array(out, 3); - - compute_encoder.set_bytes(params, 4); - - compute_encoder.set_vector_bytes(batch_shape, 6); - compute_encoder.set_vector_bytes(batch_strides, 7); - - compute_encoder.set_input_array(lhs_indices, 10); - compute_encoder.set_input_array(rhs_indices, 11); - - std::vector operand_shape = batch_shape_A; - operand_shape.insert( - operand_shape.end(), batch_shape_B.begin(), batch_shape_B.end()); - - std::vector operand_strides = batch_strides_A; - operand_strides.insert( - operand_strides.end(), batch_strides_B.begin(), batch_strides_B.end()); - - operand_batch_ndim.push_back(0); - - compute_encoder.set_vector_bytes(operand_shape, 13); - compute_encoder.set_vector_bytes(operand_strides, 14); - compute_encoder.set_vector_bytes(operand_batch_ndim, 15); - + compute_encoder.set_input_array(lhs_indices, 2); + compute_encoder.set_input_array(rhs_indices, 3); + compute_encoder.set_output_array(out, 4); + compute_encoder.set_bytes(params, 5); + compute_encoder.set_vector_bytes(lhs_indices.shape(), 6); + compute_encoder.set_vector_bytes(lhs_indices.strides(), 7); + compute_encoder.set_vector_bytes(rhs_indices.strides(), 8); + compute_encoder.set_bytes(batch_ndim_a, 9); + compute_encoder.set_vector_bytes(a.shape(), 10); + compute_encoder.set_vector_bytes(a.strides(), 11); + compute_encoder.set_bytes(batch_ndim_b, 12); + compute_encoder.set_vector_bytes(b.shape(), 13); + compute_encoder.set_vector_bytes(b.strides(), 14); compute_encoder.dispatch_threadgroups(grid_dims, group_dims); +} - d.add_temporaries(std::move(copies), s.index); +void GatherMM::eval_gpu(const std::vector& inputs, array& out) { + auto& s = stream(); + auto& d = metal::device(s.device); + + auto& a = inputs[0]; + auto& b = inputs[1]; + auto& lhs_indices = inputs[2]; + auto& rhs_indices = inputs[3]; + + // Return 0s if either input is empty + if (a.size() == 0 || b.size() == 0) { + array zero = array(0, a.dtype()); + fill_gpu(zero, out, s); + d.add_temporary(std::move(zero), s.index); + return; + } + + out.set_data(allocator::malloc(out.nbytes())); + + // Extract shapes from inputs. + int M = a.shape(-2); + int N = b.shape(-1); + int K = a.shape(-1); + + // We are walking a in order and b is also in order so we can batch up the + // matmuls and reuse reading a and b. + if (M == 1 && right_sorted_ == true) { + gather_mm_rhs(a, b, rhs_indices, out, d, s); + return; + } + + // Route to gather gemv if any of a or b are vectors + if (M == 1) { + gather_mv(b, a, rhs_indices, lhs_indices, out, N, K, false, d, s); + return; + } + if (N == 1) { + gather_mv(a, b, lhs_indices, rhs_indices, out, M, K, true, d, s); + return; + } + + // Route to non specialized gather mm + gather_mm(a, b, lhs_indices, rhs_indices, out, M, N, K, d, s); } } // namespace mlx::core diff --git a/mlx/backend/metal/metal.h b/mlx/backend/metal/metal.h index 82151c538..d162007d1 100644 --- a/mlx/backend/metal/metal.h +++ b/mlx/backend/metal/metal.h @@ -12,74 +12,6 @@ namespace mlx::core::metal { /* Check if the Metal backend is available. */ bool is_available(); -/* Get the actively used memory in bytes. - * - * Note, this will not always match memory use reported by the system because - * it does not include cached memory buffers. - * */ -size_t get_active_memory(); - -/* Get the peak amount of used memory in bytes. - * - * The maximum memory used recorded from the beginning of the program - * execution or since the last call to reset_peak_memory. - * */ -size_t get_peak_memory(); - -/* Reset the peak memory to zero. - * */ -void reset_peak_memory(); - -/* Get the cache size in bytes. - * - * The cache includes memory not currently used that has not been returned - * to the system allocator. - * */ -size_t get_cache_memory(); - -/* Set the memory limit. - * The memory limit is a guideline for the maximum amount of memory to use - * during graph evaluation. If the memory limit is exceeded and there is no - * more RAM (including swap when available) allocations will result in an - * exception. - * - * When metal is available the memory limit defaults to 1.5 times the maximum - * recommended working set size reported by the device. - * - * Returns the previous memory limit. - * */ -size_t set_memory_limit(size_t limit); - -/* Get the current memory limit. */ -size_t get_memory_limit(); - -/* Set the free cache limit. - * If using more than the given limit, free memory will be reclaimed - * from the cache on the next allocation. To disable the cache, - * set the limit to 0. - * - * The cache limit defaults to the memory limit. - * - * Returns the previous cache limit. - * */ -size_t set_cache_limit(size_t limit); - -/* Clear the memory cache. */ -void clear_cache(); - -/* Set the wired size limit. - * - * Note, this function is only useful for macOS 15.0 or higher. - * - * The wired limit is the total size in bytes of memory that will be kept - * resident. The default value is ``0``. - * - * Setting a wired limit larger than system wired limit is an error. - * - * Returns the previous wired limit. - * */ -size_t set_wired_limit(size_t limit); - /** Capture a GPU trace, saving it to an absolute file `path` */ void start_capture(std::string path = ""); void stop_capture(); diff --git a/mlx/backend/metal/nojit_kernels.cpp b/mlx/backend/metal/nojit_kernels.cpp index ff561374d..8da147971 100644 --- a/mlx/backend/metal/nojit_kernels.cpp +++ b/mlx/backend/metal/nojit_kernels.cpp @@ -72,6 +72,13 @@ MTL::ComputePipelineState* get_softmax_kernel( return d.get_kernel(kernel_name); } +MTL::ComputePipelineState* get_logsumexp_kernel( + metal::Device& d, + const std::string& kernel_name, + const array&) { + return d.get_kernel(kernel_name); +} + MTL::ComputePipelineState* get_scan_kernel( metal::Device& d, const std::string& kernel_name, @@ -186,6 +193,23 @@ MTL::ComputePipelineState* get_steel_gemm_masked_kernel( return d.get_kernel(kernel_name); } +MTL::ComputePipelineState* get_steel_gemm_gather_kernel( + metal::Device& d, + const std::string& kernel_name, + const std::string& hash_name, + const metal::MTLFCList& func_consts, + const array&, + bool, + bool, + int, + int, + int, + int, + int, + bool) { + return d.get_kernel(kernel_name, "mlx", hash_name, func_consts); +} + MTL::ComputePipelineState* get_gemv_masked_kernel( metal::Device& d, const std::string& kernel_name, @@ -245,4 +269,21 @@ MTL::ComputePipelineState* get_quantized_kernel( return d.get_kernel(kernel_name); } +MTL::ComputePipelineState* get_gather_qmm_kernel( + metal::Device& d, + const std::string& kernel_name, + const std::string& hash_name, + const metal::MTLFCList& func_consts, + const array&, + int, + int, + int, + int, + int, + int, + int, + bool) { + return d.get_kernel(kernel_name, "mlx", hash_name, func_consts); +} + } // namespace mlx::core diff --git a/mlx/backend/metal/primitives.cpp b/mlx/backend/metal/primitives.cpp index 67576f03f..6946ffb9e 100644 --- a/mlx/backend/metal/primitives.cpp +++ b/mlx/backend/metal/primitives.cpp @@ -251,8 +251,10 @@ void Concatenate::eval_gpu(const std::vector& inputs, array& out) { void Contiguous::eval_gpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& in = inputs[0]; - if (in.flags().row_contiguous || - (allow_col_major_ && in.flags().col_contiguous)) { + constexpr size_t extra_bytes = 16384; + if (in.buffer_size() <= out.nbytes() + extra_bytes && + (in.flags().row_contiguous || + (allow_col_major_ && in.flags().col_contiguous))) { out.copy_shared_buffer(in); } else { copy_gpu(in, out, CopyType::General); diff --git a/mlx/backend/metal/quantized.cpp b/mlx/backend/metal/quantized.cpp index cc32797eb..6f5807543 100644 --- a/mlx/backend/metal/quantized.cpp +++ b/mlx/backend/metal/quantized.cpp @@ -2,6 +2,7 @@ #include +#include "mlx/backend/common/broadcasting.h" #include "mlx/backend/common/compiled.h" #include "mlx/backend/metal/copy.h" #include "mlx/backend/metal/device.h" @@ -14,93 +15,168 @@ namespace mlx::core { -void launch_qmm( - std::string name, - const std::vector& inputs, +namespace { + +inline array +ensure_row_contiguous(const array& x, metal::Device& d, const Stream& s) { + if (!x.flags().row_contiguous) { + array x_copy(x.shape(), x.dtype(), nullptr, {}); + copy_gpu(x, x_copy, CopyType::General, s); + d.add_temporary(x_copy, s.index); + return x_copy; + } else { + return x; + } +} + +inline array ensure_row_contiguous_matrix( + const array& x, + metal::Device& d, + const Stream& s) { + auto stride_0 = x.strides()[x.ndim() - 2]; + auto stride_1 = x.strides()[x.ndim() - 1]; + if (stride_0 == x.shape(-1) && stride_1 == 1) { + return x; + } else { + array x_copy(x.shape(), x.dtype(), nullptr, {}); + copy_gpu(x, x_copy, CopyType::General, s); + d.add_temporary(x_copy, s.index); + return x_copy; + } +} + +inline int get_qmv_batch_limit(int D, int O, metal::Device& d) { + auto arch = d.get_architecture(); + auto arch_size = arch.back(); + auto arch_gen = arch.substr(arch.size() - 3, 2); + if (arch_gen == "13" || arch_gen == "14") { + switch (arch_size) { + case 'd': + if (D <= 2048 && O <= 2048) { + return 32; + } else if (D <= 4096 && O <= 4096) { + return 18; + } else { + return 12; + } + default: + if (D <= 2048 && O <= 2048) { + return 14; + } else if (D <= 4096 && O <= 4096) { + return 10; + } else { + return 6; + } + } + } else { + switch (arch_size) { + case 'd': + if (D <= 2048 && O <= 2048) { + return 32; + } else if (D <= 4096 && O <= 4096) { + return 18; + } else { + return 12; + } + default: + if (D <= 2048 && O <= 2048) { + return 18; + } else if (D <= 4096 && O <= 4096) { + return 12; + } else { + return 10; + } + } + } +} + +inline int add_strides_and_shapes( + CommandEncoder& compute_encoder, + bool skip, + const array& x, + const array& w, + const array& scales, + const array& biases, + int offset) { + if (skip) { + return 0; + } + + // TODO: Collapse batch dimensions + + int x_batch_ndims = x.ndim() - 2; + int w_batch_ndims = w.ndim() - 2; + compute_encoder.set_bytes(x_batch_ndims, offset); + compute_encoder.set_vector_bytes(x.shape(), offset + 1); + compute_encoder.set_vector_bytes(x.strides(), offset + 2); + compute_encoder.set_bytes(w_batch_ndims, offset + 3); + compute_encoder.set_vector_bytes(w.shape(), offset + 4); + compute_encoder.set_vector_bytes(w.strides(), offset + 5); + compute_encoder.set_vector_bytes(scales.strides(), offset + 6); + compute_encoder.set_vector_bytes(biases.strides(), offset + 7); + + return 8; +} + +inline int add_gather_strides_and_shapes( + CommandEncoder& compute_encoder, + const array& lhs_indices, + const array& rhs_indices, + int offset) { + auto [shape, strides] = collapse_contiguous_dims( + lhs_indices.shape(), {lhs_indices.strides(), rhs_indices.strides()}); + int ndims = shape.size(); + + compute_encoder.set_bytes(ndims, offset); + compute_encoder.set_vector_bytes(shape, offset + 1); + compute_encoder.set_vector_bytes(strides[0], offset + 2); + compute_encoder.set_vector_bytes(strides[1], offset + 3); + + return 4; +} + +} // namespace + +void qmv_quad( + const array& x, + const array& w, + const array& scales, + const array& biases, array& out, int group_size, int bits, - int D, - int O, - int B, + int M, int N, - MTL::Size& group_dims, - MTL::Size& grid_dims, - bool batched, - bool matrix, - bool gather, - bool aligned, - bool quad, + int K, + metal::Device& d, const Stream& s) { - auto& x_pre = inputs[0]; - auto& w_pre = inputs[1]; - auto& scales_pre = inputs[2]; - auto& biases_pre = inputs[3]; + int B = out.size() / M / N; - // Ensure that the last two dims are row contiguous. - // TODO: Check if we really need this for x as well... - std::vector copies; - auto ensure_row_contiguous_last_dims = [&copies, &s](const array& arr) { - auto stride_0 = arr.strides()[arr.ndim() - 2]; - auto stride_1 = arr.strides()[arr.ndim() - 1]; - if (stride_0 == arr.shape(-1) && stride_1 == 1) { - return arr; - } else { - array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); - copy_gpu(arr, arr_copy, CopyType::General, s); - copies.push_back(arr_copy); - return arr_copy; - } - }; - auto x = ensure_row_contiguous_last_dims(x_pre); - auto w = ensure_row_contiguous_last_dims(w_pre); - auto scales = ensure_row_contiguous_last_dims(scales_pre); - auto biases = ensure_row_contiguous_last_dims(biases_pre); + constexpr int quads_per_simd = 8; + constexpr int results_per_quadgroup = 8; + int bn = quads_per_simd * results_per_quadgroup; + int simdgroup_size = 32; + MTL::Size group_dims(simdgroup_size, 1, 1); + MTL::Size grid_dims(M, (N + bn - 1) / bn, B); - int x_batch_ndims = x.ndim() - 2; - auto& x_shape = x.shape(); - auto& x_strides = x.strides(); - int w_batch_ndims = w.ndim() - 2; - auto& w_shape = w.shape(); - auto& w_strides = w.strides(); - auto& s_strides = scales.strides(); - auto& b_strides = biases.strides(); + std::string kname; + kname.reserve(64); + std::string type_string = get_type_string(x.dtype()); + concatenate( + kname, + "qmv_quad_", + type_string, + "_gs_", + group_size, + "_b_", + bits, + "_d_", + K, + B > 1 ? "_batch_1" : "_batch_0"); + auto template_def = get_template_definition( + kname, "qmv_quad", type_string, group_size, bits, K, B > 1); - std::string aligned_n = (O % 32) == 0 ? "true" : "false"; - - std::ostringstream kname; - auto type_string = get_type_string(x.dtype()); - kname << name << "_" << type_string << "_gs_" << group_size << "_b_" << bits; - if (quad) { - kname << "_d_" << D; - } - if (aligned) { - kname << "_alN_" << aligned_n; - } - if (!gather) { - kname << "_batch_" << batched; - } - - // Encode and dispatch kernel - std::string template_def; - if (quad) { - template_def = get_template_definition( - kname.str(), name, type_string, group_size, bits, D, batched); - } else if (aligned && !gather) { - template_def = get_template_definition( - kname.str(), name, type_string, group_size, bits, aligned_n, batched); - } else if (!gather && !aligned) { - template_def = get_template_definition( - kname.str(), name, type_string, group_size, bits, batched); - } else if (aligned && gather) { - template_def = get_template_definition( - kname.str(), name, type_string, group_size, bits, aligned_n); - } else { - template_def = get_template_definition( - kname.str(), name, type_string, group_size, bits); - } - auto& d = metal::device(s.device); - auto kernel = get_quantized_kernel(d, kname.str(), template_def); + auto kernel = get_quantized_kernel(d, kname, template_def); auto& compute_encoder = d.get_command_encoder(s.index); compute_encoder.set_compute_pipeline_state(kernel); @@ -109,90 +185,87 @@ void launch_qmm( compute_encoder.set_input_array(biases, 2); compute_encoder.set_input_array(x, 3); compute_encoder.set_output_array(out, 4); - compute_encoder.set_bytes(D, 5); - compute_encoder.set_bytes(O, 6); - - int offset = 7; - if (matrix) { - compute_encoder.set_bytes(B, 7); - offset += 1; - } - - if (batched || gather) { - compute_encoder.set_bytes(x_batch_ndims, offset); - compute_encoder.set_vector_bytes(x_shape, offset + 1); - compute_encoder.set_vector_bytes(x_strides, offset + 2); - compute_encoder.set_bytes(w_batch_ndims, offset + 3); - compute_encoder.set_vector_bytes(w_shape, offset + 4); - compute_encoder.set_vector_bytes(w_strides, offset + 5); - compute_encoder.set_vector_bytes(s_strides, offset + 6); - compute_encoder.set_vector_bytes(b_strides, offset + 7); - } - if (gather) { - auto& lhs_indices = inputs[4]; - auto& rhs_indices = inputs[5]; - - // TODO: collapse batch dims - auto& batch_shape = lhs_indices.shape(); - int batch_ndims = batch_shape.size(); - auto& lhs_strides = lhs_indices.strides(); - auto& rhs_strides = rhs_indices.strides(); - - compute_encoder.set_bytes(batch_ndims, offset + 8); - compute_encoder.set_vector_bytes(batch_shape, offset + 9); - compute_encoder.set_input_array(lhs_indices, offset + 10); - compute_encoder.set_input_array(rhs_indices, offset + 11); - compute_encoder.set_vector_bytes(lhs_strides, offset + 12); - compute_encoder.set_vector_bytes(rhs_strides, offset + 13); - } + compute_encoder.set_bytes(K, 5); + compute_encoder.set_bytes(N, 6); + add_strides_and_shapes(compute_encoder, B <= 1, x, w, scales, biases, 7); compute_encoder.dispatch_threadgroups(grid_dims, group_dims); - d.add_temporaries(std::move(copies), s.index); } -void qvm_split_k( - const std::vector& inputs, +void qmv( + const array& x, + const array& w, + const array& scales, + const array& biases, array& out, int group_size, int bits, - int D, - int O, - int B, + int M, int N, + int K, + metal::Device& d, const Stream& s) { - int split_k = D > 8192 ? 32 : 8; - int split_D = (D + split_k - 1) / split_k; - N *= split_k; + int B = out.size() / M / N; - int bo = 64; - int bd = 32; - MTL::Size group_dims = MTL::Size(bd, 2, 1); - MTL::Size grid_dims = MTL::Size(B, O / bo, N); + int bn = 8; + int bk = 32; + MTL::Size group_dims(bk, 2, 1); + MTL::Size grid_dims(M, (N + bn - 1) / bn, B); - auto& x_pre = inputs[0]; - auto& w_pre = inputs[1]; - auto& scales_pre = inputs[2]; - auto& biases_pre = inputs[3]; + std::string kname; + kname.reserve(64); + std::string type_string = get_type_string(x.dtype()); + bool fast = N % bn == 0 && K % 512 == 0; + concatenate( + kname, + fast ? "qmv_fast_" : "qmv_", + type_string, + "_gs_", + group_size, + "_b_", + bits, + B > 1 ? "_batch_1" : "_batch_0"); + auto template_def = get_template_definition( + kname, fast ? "qmv_fast" : "qmv", type_string, group_size, bits, B > 1); - // Ensure that the last two dims are row contiguous. - // TODO: Check if we really need this for x as well... - std::vector copies; - auto ensure_row_contiguous_last_dims = [&copies, &s](const array& arr) { - auto stride_0 = arr.strides()[arr.ndim() - 2]; - auto stride_1 = arr.strides()[arr.ndim() - 1]; - if (stride_0 == arr.shape(-1) && stride_1 == 1) { - return arr; - } else { - array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); - copy_gpu(arr, arr_copy, CopyType::General, s); - copies.push_back(arr_copy); - return arr_copy; - } - }; - auto x = ensure_row_contiguous_last_dims(x_pre); - auto w = ensure_row_contiguous_last_dims(w_pre); - auto scales = ensure_row_contiguous_last_dims(scales_pre); - auto biases = ensure_row_contiguous_last_dims(biases_pre); + auto kernel = get_quantized_kernel(d, kname, template_def); + auto& compute_encoder = d.get_command_encoder(s.index); + compute_encoder.set_compute_pipeline_state(kernel); + + compute_encoder.set_input_array(w, 0); + compute_encoder.set_input_array(scales, 1); + compute_encoder.set_input_array(biases, 2); + compute_encoder.set_input_array(x, 3); + compute_encoder.set_output_array(out, 4); + compute_encoder.set_bytes(K, 5); + compute_encoder.set_bytes(N, 6); + add_strides_and_shapes(compute_encoder, B <= 1, x, w, scales, biases, 7); + + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); +} + +void qvm_split_k( + const array& x, + const array& w, + const array& scales, + const array& biases, + array& out, + int group_size, + int bits, + int M, + int N, + int K, + metal::Device& d, + const Stream& s) { + int split_k = K > 8192 ? 32 : 8; + int split_D = (K + split_k - 1) / split_k; + int B = out.size() / M / N; + B *= split_k; + + int bn = 64; + int bk = 32; + MTL::Size group_dims = MTL::Size(bk, 2, 1); + MTL::Size grid_dims = MTL::Size(M, N / bn, B); int x_batch_ndims = x.ndim() - 2; auto x_shape = x.shape(); @@ -217,9 +290,7 @@ void qvm_split_k( s_strides.insert(s_strides.end() - 2, split_D * scales.shape(-1)); b_strides.insert(b_strides.end() - 2, split_D * biases.shape(-1)); - int final_block_size = D - (split_k - 1) * split_D; - - auto& d = metal::device(s.device); + int final_block_size = K - (split_k - 1) * split_D; auto temp_shape = out.shape(); temp_shape.insert(temp_shape.end() - 2, split_k); @@ -227,15 +298,24 @@ void qvm_split_k( intermediate.set_data(allocator::malloc(intermediate.nbytes())); d.add_temporary(intermediate, s.index); - std::ostringstream kname; - auto type_string = get_type_string(x.dtype()); - kname << "qvm_split_k" << "_" << type_string << "_gs_" << group_size << "_b_" - << bits << "_spk_" << split_k; + std::string type_string = get_type_string(x.dtype()); + std::string kname; + kname.reserve(64); + concatenate( + kname, + "qvm_split_k_", + type_string, + "_gs_", + group_size, + "_b_", + bits, + "_spk_", + split_k); auto template_def = get_template_definition( - kname.str(), "qvm_split_k", type_string, group_size, bits, split_k); + kname, "qvm_split_k", type_string, group_size, bits, split_k); // Encode and dispatch kernel - auto kernel = get_quantized_kernel(d, kname.str(), template_def); + auto kernel = get_quantized_kernel(d, kname, template_def); auto& compute_encoder = d.get_command_encoder(s.index); compute_encoder.set_compute_pipeline_state(kernel); @@ -245,7 +325,7 @@ void qvm_split_k( compute_encoder.set_input_array(x, 3); compute_encoder.set_output_array(intermediate, 4); compute_encoder.set_bytes(split_D, 5); - compute_encoder.set_bytes(O, 6); + compute_encoder.set_bytes(N, 6); compute_encoder.set_bytes(x_batch_ndims, 7); compute_encoder.set_vector_bytes(x_shape, 8); @@ -258,7 +338,6 @@ void qvm_split_k( compute_encoder.set_bytes(final_block_size, 15); compute_encoder.dispatch_threadgroups(grid_dims, group_dims); - d.add_temporaries(std::move(copies), s.index); int axis = intermediate.ndim() - 3; ReductionPlan plan( @@ -269,124 +348,589 @@ void qvm_split_k( intermediate, out, "sum", plan, {axis}, compute_encoder, d, s); } -void qmm_op( - const std::vector& inputs, +void qvm( + const array& x, + const array& w, + const array& scales, + const array& biases, + array& out, + int group_size, + int bits, + int M, + int N, + int K, + metal::Device& d, + const Stream& s) { + int B = out.size() / M / N; + + int bn = 64; + int bk = 32; + MTL::Size group_dims(bk, 2, 1); + MTL::Size grid_dims(M, (N + bn - 1) / bn, B); + + std::string kname; + kname.reserve(64); + std::string type_string = get_type_string(x.dtype()); + concatenate( + kname, + "qvm_", + type_string, + "_gs_", + group_size, + "_b_", + bits, + B > 1 ? "_batch_1" : "_batch_0"); + auto template_def = get_template_definition( + kname, "qvm", type_string, group_size, bits, B > 1); + + auto kernel = get_quantized_kernel(d, kname, template_def); + auto& compute_encoder = d.get_command_encoder(s.index); + compute_encoder.set_compute_pipeline_state(kernel); + + compute_encoder.set_input_array(w, 0); + compute_encoder.set_input_array(scales, 1); + compute_encoder.set_input_array(biases, 2); + compute_encoder.set_input_array(x, 3); + compute_encoder.set_output_array(out, 4); + compute_encoder.set_bytes(K, 5); + compute_encoder.set_bytes(N, 6); + add_strides_and_shapes(compute_encoder, B <= 1, x, w, scales, biases, 7); + + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); +} + +void qmm( + const array& x, + const array& w, + const array& scales, + const array& biases, array& out, bool transpose, int group_size, int bits, - bool gather, + int M, + int N, + int K, + metal::Device& d, const Stream& s) { - out.set_data(allocator::malloc(out.nbytes())); + int B = out.size() / M / N; - MTL::Size group_dims; - MTL::Size grid_dims; - - auto& x = inputs[0]; - auto& w = inputs[1]; - bool batched = !gather && (w.ndim() > 2 || !x.flags().row_contiguous); - - int D = x.shape(-1); - int O = out.shape(-1); - // For the unbatched W case, avoid `adjust_matrix_offsets` - // for a small performance gain. - int B = (batched || gather) ? x.shape(-2) : x.size() / D; - int N = (batched || gather) ? out.size() / B / O : 1; - - std::string name = gather ? "bs_" : ""; - bool matrix = false; - bool aligned = false; - bool quad = false; + int wm = 2; + int wn = 2; + int bm = 32; + int bn = 32; + MTL::Size group_dims(32, wn, wm); + MTL::Size grid_dims((N + bn - 1) / bn, (M + bm - 1) / bm, B); + std::string kname; + kname.reserve(64); + bool aligned = N % 32 == 0; + bool batched = B > 1; + std::string type_string = get_type_string(x.dtype()); + concatenate( + kname, + transpose ? "qmm_t_" : "qmm_n_", + type_string, + "_gs_", + group_size, + "_b_", + bits, + transpose ? (aligned ? "_alN_true" : "_alN_false") : "", + batched ? "_batch_1" : "_batch_0"); + std::string template_def; if (transpose) { - if (B < 6 && (D == 128 || D == 64) && is_power_of_2(bits)) { - name += "qmv_quad"; - constexpr int quads_per_simd = 8; - constexpr int results_per_quadgroup = 8; - int bo = quads_per_simd * results_per_quadgroup; - int simdgroup_size = 32; - group_dims = MTL::Size(simdgroup_size, 1, 1); - grid_dims = MTL::Size((O + bo - 1) / bo, B, N); - quad = true; - } else if (B < 6 && O % 8 == 0 && D % 512 == 0 && D >= 512) { - name += "qmv_fast"; - int bo = 8; - int bd = 32; - group_dims = MTL::Size(bd, 2, 1); - grid_dims = MTL::Size(B, O / bo, N); - } else if (B < 6) { - name += "qmv"; - int bo = 8; - int bd = 32; - group_dims = MTL::Size(bd, 2, 1); - grid_dims = MTL::Size(B, (O + bo - 1) / bo, N); - } else { - int wn = 2; - int wm = 2; - int bm = 32; - int bn = 32; - group_dims = MTL::Size(32, wn, wm); - grid_dims = MTL::Size((O + bn - 1) / bn, (B + bm - 1) / bm, N); - name += "qmm_t"; - matrix = true; - aligned = true; - } + template_def = get_template_definition( + kname, "qmm_t", type_string, group_size, bits, aligned, batched); } else { - if (B < 4 && D >= 1024 && !gather) { - return qvm_split_k(inputs, out, group_size, bits, D, O, B, N, s); - } else if (B < 4) { - name += "qvm"; - int bo = 64; - int bd = 32; - group_dims = MTL::Size(bd, 2, 1); - grid_dims = MTL::Size(B, O / bo, N); - } else { - name += "qmm_n"; - int wn = 2; - int wm = 2; - int bm = 32; - int bn = 32; - group_dims = MTL::Size(32, wn, wm); - grid_dims = MTL::Size(O / bn, (B + bm - 1) / bm, N); - matrix = true; - if ((O % bn) != 0) { - std::ostringstream msg; - msg << "[quantized_matmul] The output size should be divisible by " - << bn << " but received " << O << "."; - throw std::runtime_error(msg.str()); - } - } + template_def = get_template_definition( + kname, "qmm_n", type_string, group_size, bits, batched); } - launch_qmm( - name, - inputs, - out, + + auto kernel = get_quantized_kernel(d, kname, template_def); + auto& compute_encoder = d.get_command_encoder(s.index); + compute_encoder.set_compute_pipeline_state(kernel); + + compute_encoder.set_input_array(w, 0); + compute_encoder.set_input_array(scales, 1); + compute_encoder.set_input_array(biases, 2); + compute_encoder.set_input_array(x, 3); + compute_encoder.set_output_array(out, 4); + compute_encoder.set_bytes(K, 5); + compute_encoder.set_bytes(N, 6); + compute_encoder.set_bytes(M, 7); + add_strides_and_shapes(compute_encoder, B <= 1, x, w, scales, biases, 8); + + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); +} + +void gather_qmm( + const array& x, + const array& w, + const array& scales, + const array& biases, + const array& lhs_indices, + const array& rhs_indices, + array& out, + bool transpose, + int group_size, + int bits, + int M, + int N, + int K, + metal::Device& d, + const Stream& s) { + int B = out.size() / M / N; + + int wm = 2; + int wn = 2; + int bm = 32; + int bn = 32; + MTL::Size group_dims(32, wn, wm); + MTL::Size grid_dims((N + bn - 1) / bn, (M + bm - 1) / bm, B); + + std::string kname; + kname.reserve(64); + bool aligned = N % 32 == 0; + bool batched = B > 1; + std::string type_string = get_type_string(x.dtype()); + concatenate( + kname, + transpose ? "gather_qmm_t_" : "gather_qmm_n_", + type_string, + "_gs_", + group_size, + "_b_", + bits, + transpose ? (aligned ? "_alN_true" : "_alN_false") : ""); + std::string template_def; + if (transpose) { + template_def = get_template_definition( + kname, "gather_qmm_t", type_string, group_size, bits, aligned); + } else { + template_def = get_template_definition( + kname, "gather_qmm_n", type_string, group_size, bits); + } + + auto kernel = get_quantized_kernel(d, kname, template_def); + auto& compute_encoder = d.get_command_encoder(s.index); + compute_encoder.set_compute_pipeline_state(kernel); + + compute_encoder.set_input_array(w, 0); + compute_encoder.set_input_array(scales, 1); + compute_encoder.set_input_array(biases, 2); + compute_encoder.set_input_array(x, 3); + compute_encoder.set_input_array(lhs_indices, 4); + compute_encoder.set_input_array(rhs_indices, 5); + compute_encoder.set_output_array(out, 6); + compute_encoder.set_bytes(K, 7); + compute_encoder.set_bytes(N, 8); + compute_encoder.set_bytes(M, 9); + int n = + add_strides_and_shapes(compute_encoder, false, x, w, scales, biases, 10); + add_gather_strides_and_shapes( + compute_encoder, lhs_indices, rhs_indices, 10 + n); + + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); +} + +void gather_qmv( + const array& x, + const array& w, + const array& scales, + const array& biases, + const array& lhs_indices, + const array& rhs_indices, + array& out, + int group_size, + int bits, + int M, + int N, + int K, + metal::Device& d, + const Stream& s) { + int B = out.size() / M / N; + + int bn = 8; + int bk = 32; + MTL::Size group_dims(bk, 2, 1); + MTL::Size grid_dims(M, (N + bn - 1) / bn, B); + + std::string kname; + kname.reserve(64); + std::string type_string = get_type_string(x.dtype()); + bool fast = N % bn == 0 && K % 512 == 0; + concatenate( + kname, + fast ? "gather_qmv_fast_" : "gather_qmv_", + type_string, + "_gs_", + group_size, + "_b_", + bits); + auto template_def = get_template_definition( + kname, + fast ? "gather_qmv_fast" : "gather_qmv", + type_string, + group_size, + bits); + + auto kernel = get_quantized_kernel(d, kname, template_def); + auto& compute_encoder = d.get_command_encoder(s.index); + compute_encoder.set_compute_pipeline_state(kernel); + + compute_encoder.set_input_array(w, 0); + compute_encoder.set_input_array(scales, 1); + compute_encoder.set_input_array(biases, 2); + compute_encoder.set_input_array(x, 3); + compute_encoder.set_input_array(lhs_indices, 4); + compute_encoder.set_input_array(rhs_indices, 5); + compute_encoder.set_output_array(out, 6); + compute_encoder.set_bytes(K, 7); + compute_encoder.set_bytes(N, 8); + int n = + add_strides_and_shapes(compute_encoder, false, x, w, scales, biases, 9); + add_gather_strides_and_shapes( + compute_encoder, lhs_indices, rhs_indices, 9 + n); + + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); +} + +void gather_qvm( + const array& x, + const array& w, + const array& scales, + const array& biases, + const array& lhs_indices, + const array& rhs_indices, + array& out, + int group_size, + int bits, + int M, + int N, + int K, + metal::Device& d, + const Stream& s) { + int B = out.size() / M / N; + + int bn = 64; + int bk = 32; + MTL::Size group_dims(bk, 2, 1); + MTL::Size grid_dims(M, (N + bn - 1) / bn, B); + + std::string kname; + kname.reserve(64); + std::string type_string = get_type_string(x.dtype()); + concatenate( + kname, "gather_qvm_", type_string, "_gs_", group_size, "_b_", bits); + auto template_def = get_template_definition( + kname, "gather_qvm", type_string, group_size, bits); + + auto kernel = get_quantized_kernel(d, kname, template_def); + auto& compute_encoder = d.get_command_encoder(s.index); + compute_encoder.set_compute_pipeline_state(kernel); + + compute_encoder.set_input_array(w, 0); + compute_encoder.set_input_array(scales, 1); + compute_encoder.set_input_array(biases, 2); + compute_encoder.set_input_array(x, 3); + compute_encoder.set_input_array(lhs_indices, 4); + compute_encoder.set_input_array(rhs_indices, 5); + compute_encoder.set_output_array(out, 6); + compute_encoder.set_bytes(K, 7); + compute_encoder.set_bytes(N, 8); + int n = + add_strides_and_shapes(compute_encoder, false, x, w, scales, biases, 9); + add_gather_strides_and_shapes( + compute_encoder, lhs_indices, rhs_indices, 9 + n); + + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); +} + +void gather_qmm_rhs( + const array& x_, + const array& w_, + const array& scales_, + const array& biases_, + const array& indices_, + array& out, + bool transpose, + int group_size, + int bits, + int M, + int N, + int K, + metal::Device& d, + const Stream& s) { + // Start by normalizing the indices + array indices = ensure_row_contiguous(indices_, d, s); + + // Broadcast x with indices. If we are here that means lhs_indices were not + // provided so the lhs_indices are implied to be the shape of x broadcasted + // with rhs_indices. We need only broadcast x and copy it as if applying the + // lhs_indices. + auto broadcast_with_indices = [&d, &s, &indices](const array& x) { + if (x.size() / x.shape(-2) / x.shape(-1) == indices.size()) { + return ensure_row_contiguous(x, d, s); + } + + auto x_shape = indices.shape(); + x_shape.push_back(x.shape(-2)); + x_shape.push_back(x.shape(-1)); + array new_x(std::move(x_shape), x.dtype(), nullptr, {}); + broadcast(x, new_x); + return ensure_row_contiguous(new_x, d, s); + }; + + // Normalize the input arrays + array x = broadcast_with_indices(x_); + array w = ensure_row_contiguous(w_, d, s); + array scales = ensure_row_contiguous(scales_, d, s); + array biases = ensure_row_contiguous(biases_, d, s); + + // TODO: Tune the block sizes + int bm = 16, bn = 32, bk = 32; + int wm = 1, wn = 2; + + const bool align_M = (M % bm) == 0; + const bool align_N = (N % bn) == 0; + const bool align_K = (K % bk) == 0; + + // Make the kernel name + std::string kname; + kname.reserve(64); + std::string type_string = get_type_string(x.dtype()); + concatenate( + kname, + transpose ? "gather_qmm_rhs_nt_" : "gather_qmm_rhs_nn_", + type_string, + "_gs_", + group_size, + "_b_", + bits, + "_bm_", + bm, + "_bn_", + bn, + "_bk_", + bk, + "_wm_", + wm, + "_wn_", + wn); + + metal::MTLFCList func_consts = { + {&align_M, MTL::DataType::DataTypeBool, 200}, + {&align_N, MTL::DataType::DataTypeBool, 201}, + {&align_K, MTL::DataType::DataTypeBool, 202}, + }; + + // And the kernel hash that includes the function constants + std::string hash_name; + hash_name.reserve(128); + concatenate( + hash_name, + kname, + "_align_M_", + align_M ? 't' : 'n', + "_align_N_", + align_N ? 't' : 'n', + "_align_K_", + align_K ? 't' : 'n'); + + // Get and set the kernel + auto& compute_encoder = d.get_command_encoder(s.index); + auto kernel = get_gather_qmm_kernel( + d, + kname, + hash_name, + func_consts, + x, group_size, bits, - D, - O, - B, - N, - group_dims, - grid_dims, - batched, - matrix, - gather, - aligned, - quad, - s); + bm, + bn, + bk, + wm, + wn, + transpose); + compute_encoder.set_compute_pipeline_state(kernel); + + MTL::Size group_dims(32, wn, wm); + MTL::Size grid_dims((N + bn - 1) / bn, (M + bm - 1) / bm, 1); + + compute_encoder.set_input_array(x, 0); + compute_encoder.set_input_array(w, 1); + compute_encoder.set_input_array(scales, 2); + compute_encoder.set_input_array(biases, 3); + compute_encoder.set_input_array(indices, 4); + compute_encoder.set_output_array(out, 5); + compute_encoder.set_bytes(M, 6); + compute_encoder.set_bytes(N, 7); + compute_encoder.set_bytes(K, 8); + + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); } void QuantizedMatmul::eval_gpu(const std::vector& inputs, array& out) { - assert(inputs.size() == 4); - qmm_op( - inputs, out, transpose_, group_size_, bits_, /*gather=*/false, stream()); + auto& s = stream(); + auto& d = metal::device(s.device); + + out.set_data(allocator::malloc(out.nbytes())); + + // Make sure the last two dims of x and w, s, b are contiguous. This should + // be relaxed for x. + array x = ensure_row_contiguous_matrix(inputs[0], d, s); + array w = ensure_row_contiguous_matrix(inputs[1], d, s); + array scales = ensure_row_contiguous_matrix(inputs[2], d, s); + array biases = ensure_row_contiguous_matrix(inputs[3], d, s); + + // Extract the matmul shapes + bool non_batched = w.ndim() == 2 && x.flags().row_contiguous; + int K = x.shape(-1); + int M = non_batched ? x.size() / K : x.shape(-2); + int N = out.shape(-1); + + int vector_limit = transpose_ ? get_qmv_batch_limit(K, N, d) : 4; + + // It is a matrix matrix product. + if (M >= vector_limit) { + qmm(x, + w, + scales, + biases, + out, + transpose_, + group_size_, + bits_, + M, + N, + K, + d, + s); + return; + } + + // It is a qmv with a small inner dimension so route to qmv_quad kernel + if (transpose_ && (K == 128 || K == 64) && is_power_of_2(bits_)) { + qmv_quad(x, w, scales, biases, out, group_size_, bits_, M, N, K, d, s); + return; + } + + // Run of the mill qmv + if (transpose_) { + qmv(x, w, scales, biases, out, group_size_, bits_, M, N, K, d, s); + return; + } + + // Run of the mill qvm + if (K < 1024) { + qvm(x, w, scales, biases, out, group_size_, bits_, M, N, K, d, s); + return; + } + + // Qvm with large dimension so route to a split K kernel for more parallelism + qvm_split_k(x, w, scales, biases, out, group_size_, bits_, M, N, K, d, s); + return; } void GatherQMM::eval_gpu(const std::vector& inputs, array& out) { - assert(inputs.size() == 6); - qmm_op( - inputs, out, transpose_, group_size_, bits_, /*gather=*/true, stream()); + auto& s = stream(); + auto& d = metal::device(s.device); + + out.set_data(allocator::malloc(out.nbytes())); + + array x = ensure_row_contiguous_matrix(inputs[0], d, s); + array w = ensure_row_contiguous_matrix(inputs[1], d, s); + array scales = ensure_row_contiguous_matrix(inputs[2], d, s); + array biases = ensure_row_contiguous_matrix(inputs[3], d, s); + const array& lhs_indices = inputs[4]; + const array& rhs_indices = inputs[5]; + + int K = x.shape(-1); + int M = x.shape(-2); + int N = out.shape(-1); + int B = out.size() / M / N; + int E = w.size() / w.shape(-1) / w.shape(-2); + int vector_limit = transpose_ ? get_qmv_batch_limit(K, N, d) : 4; + + // We are walking x in order and w is also in order so we can batch up the + // matmuls and reuse reading x and w. + // + // TODO: Tune 16 and 8 here a bit better. + if (M == 1 && B >= 16 && right_sorted_ == true && B / E >= 8) { + gather_qmm_rhs( + x, + w, + scales, + biases, + rhs_indices, + out, + transpose_, + group_size_, + bits_, + x.size() / K, + N, + K, + d, + s); + return; + } + + // It is a matrix matrix product + if (M >= vector_limit) { + gather_qmm( + x, + w, + scales, + biases, + lhs_indices, + rhs_indices, + out, + transpose_, + group_size_, + bits_, + M, + N, + K, + d, + s); + return; + } + + if (transpose_) { + gather_qmv( + x, + w, + scales, + biases, + lhs_indices, + rhs_indices, + out, + group_size_, + bits_, + M, + N, + K, + d, + s); + return; + } + + gather_qvm( + x, + w, + scales, + biases, + lhs_indices, + rhs_indices, + out, + group_size_, + bits_, + M, + N, + K, + d, + s); } void fast::AffineQuantize::eval_gpu( @@ -398,27 +942,13 @@ void fast::AffineQuantize::eval_gpu( auto& s = stream(); auto& d = metal::device(s.device); - - std::vector copies; - auto ensure_row_contiguous = [&copies, &s](const array& arr) { - if (arr.flags().row_contiguous) { - return arr; - } else { - array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); - copy_gpu(arr, arr_copy, CopyType::General, s); - copies.push_back(arr_copy); - return arr_copy; - } - }; - auto w = ensure_row_contiguous(w_pre); - auto& compute_encoder = d.get_command_encoder(s.index); + + auto w = ensure_row_contiguous(w_pre, d, s); compute_encoder.set_input_array(w, 0); if (dequantize_) { - auto& scales_pre = inputs[1]; - auto& biases_pre = inputs[2]; - auto scales = ensure_row_contiguous(scales_pre); - auto biases = ensure_row_contiguous(biases_pre); + auto scales = ensure_row_contiguous(inputs[1], d, s); + auto biases = ensure_row_contiguous(inputs[2], d, s); compute_encoder.set_input_array(scales, 1); compute_encoder.set_input_array(biases, 2); compute_encoder.set_output_array(out, 3); @@ -466,8 +996,6 @@ void fast::AffineQuantize::eval_gpu( MTL::Size grid_dims = use_2d ? get_2d_grid_dims(grid_shape, w.strides()) : MTL::Size(nthreads, 1, 1); compute_encoder.dispatch_threads(grid_dims, group_dims); - - d.add_temporaries(std::move(copies), s.index); } } // namespace mlx::core diff --git a/mlx/backend/metal/resident.cpp b/mlx/backend/metal/resident.cpp index 545f67e49..0a9e1b861 100644 --- a/mlx/backend/metal/resident.cpp +++ b/mlx/backend/metal/resident.cpp @@ -22,6 +22,7 @@ ResidencySet::ResidencySet(MTL::Device* d) { } throw std::runtime_error(msg.str()); } + wired_set_->requestResidency(); } } @@ -32,7 +33,6 @@ void ResidencySet::insert(MTL::Allocation* buf) { if (wired_set_->allocatedSize() + buf->allocatedSize() <= capacity_) { wired_set_->addAllocation(buf); wired_set_->commit(); - wired_set_->requestResidency(); } else { unwired_set_.insert(buf); } @@ -76,7 +76,6 @@ void ResidencySet::resize(size_t size) { } } wired_set_->commit(); - wired_set_->requestResidency(); } else if (current_size > size) { auto pool = new_scoped_memory_pool(); // Remove wired allocations until under capacity diff --git a/mlx/backend/metal/scaled_dot_product_attention.cpp b/mlx/backend/metal/scaled_dot_product_attention.cpp index f64d057ce..845962d01 100644 --- a/mlx/backend/metal/scaled_dot_product_attention.cpp +++ b/mlx/backend/metal/scaled_dot_product_attention.cpp @@ -138,6 +138,7 @@ void sdpa_vector( const array& v, array& out, float scale, + bool do_causal, const std::optional& mask) { // Set the kernel name std::string kname; @@ -162,14 +163,20 @@ void sdpa_vector( MTL::Size grid_dims(B, q.shape(2), 1); bool has_mask = mask.has_value(); + bool bool_mask = has_mask && (*mask).dtype() == bool_; + bool float_mask = has_mask && !bool_mask; bool query_transposed = !q.flags().row_contiguous; metal::MTLFCList func_consts = { {&has_mask, MTL::DataType::DataTypeBool, 20}, {&query_transposed, MTL::DataType::DataTypeBool, 21}, + {&do_causal, MTL::DataType::DataTypeBool, 22}, + {&bool_mask, MTL::DataType::DataTypeBool, 23}, + {&float_mask, MTL::DataType::DataTypeBool, 24}, }; std::string hash_name = kname; - hash_name += has_mask ? "_mask" : "_nomask"; + hash_name += has_mask ? (bool_mask ? "_boolmask" : "_floatmask") : "_nomask"; hash_name += query_transposed ? "_qt" : "_qnt"; + hash_name += do_causal ? "_c" : "_nc"; // Get the kernel auto& compute_encoder = d.get_command_encoder(s.index); @@ -191,15 +198,15 @@ void sdpa_vector( compute_encoder.set_bytes(scale, 10); if (has_mask) { auto& m = *mask; - compute_encoder.set_input_array(m, 11); + compute_encoder.set_input_array(m, 11 + float_mask); auto nd = m.ndim(); int32_t kv_seq_stride = nd >= 1 && m.shape(-1) > 1 ? m.strides()[nd - 1] : 0; int32_t q_seq_stride = nd >= 2 && m.shape(-2) > 1 ? m.strides()[nd - 2] : 0; int32_t head_stride = nd >= 3 && m.shape(-3) > 1 ? m.strides()[nd - 3] : 0; - compute_encoder.set_bytes(kv_seq_stride, 12); - compute_encoder.set_bytes(q_seq_stride, 13); - compute_encoder.set_bytes(head_stride, 14); + compute_encoder.set_bytes(kv_seq_stride, 13); + compute_encoder.set_bytes(q_seq_stride, 14); + compute_encoder.set_bytes(head_stride, 15); } // Launch @@ -214,6 +221,7 @@ void sdpa_vector_2pass( const array& v, array& out, float scale, + bool do_causal, const std::optional& mask) { // Set the kernel name std::string kname; @@ -256,14 +264,20 @@ void sdpa_vector_2pass( d.add_temporary(maxs, s.index); bool has_mask = mask.has_value(); + bool bool_mask = has_mask && (*mask).dtype() == bool_; + bool float_mask = has_mask && !bool_mask; bool query_transposed = !q.flags().row_contiguous; metal::MTLFCList func_consts = { {&has_mask, MTL::DataType::DataTypeBool, 20}, {&query_transposed, MTL::DataType::DataTypeBool, 21}, + {&do_causal, MTL::DataType::DataTypeBool, 22}, + {&bool_mask, MTL::DataType::DataTypeBool, 23}, + {&float_mask, MTL::DataType::DataTypeBool, 24}, }; std::string hash_name = kname; - hash_name += has_mask ? "_mask" : "_nomask"; + hash_name += has_mask ? (bool_mask ? "_boolmask" : "_floatmask") : "_nomask"; hash_name += query_transposed ? "_qt" : "_qnt"; + hash_name += do_causal ? "_c" : "_nc"; // Get the kernel auto& compute_encoder = d.get_command_encoder(s.index); @@ -287,15 +301,15 @@ void sdpa_vector_2pass( compute_encoder.set_bytes(scale, 12); if (has_mask) { auto& m = *mask; - compute_encoder.set_input_array(m, 13); + compute_encoder.set_input_array(m, 13 + float_mask); auto nd = m.ndim(); int32_t kv_seq_stride = nd >= 1 && m.shape(-1) > 1 ? m.strides()[nd - 1] : 0; int32_t q_seq_stride = nd >= 2 && m.shape(-2) > 1 ? m.strides()[nd - 2] : 0; int32_t head_stride = nd >= 3 && m.shape(-3) > 1 ? m.strides()[nd - 3] : 0; - compute_encoder.set_bytes(kv_seq_stride, 14); - compute_encoder.set_bytes(q_seq_stride, 15); - compute_encoder.set_bytes(head_stride, 16); + compute_encoder.set_bytes(kv_seq_stride, 15); + compute_encoder.set_bytes(q_seq_stride, 16); + compute_encoder.set_bytes(head_stride, 17); } // Launch @@ -401,12 +415,13 @@ void ScaledDotProductAttention::eval_gpu( // We route to the 2 pass fused attention if // - The device is large and the sequence length long // - The sequence length is even longer and we have gqa + bool do_causal = do_causal_ && q.shape(2) > 1; char devc = d.get_architecture().back(); if ((devc == 'd' && k.shape(2) >= 1024) || (k.shape(1) < q.shape(1) && k.shape(2) >= 4096)) { - sdpa_vector_2pass(s, d, q, k, v, o, scale_, mask); + sdpa_vector_2pass(s, d, q, k, v, o, scale_, do_causal, mask); } else { - sdpa_vector(s, d, q, k, v, o, scale_, mask); + sdpa_vector(s, d, q, k, v, o, scale_, do_causal, mask); } } diff --git a/mlx/backend/metal/scan.cpp b/mlx/backend/metal/scan.cpp index c7e0087b7..b1800fea9 100644 --- a/mlx/backend/metal/scan.cpp +++ b/mlx/backend/metal/scan.cpp @@ -60,6 +60,9 @@ void Scan::eval_gpu(const std::vector& inputs, array& out) { case Scan::Min: reduce_type = "min"; break; + case Scan::LogAddExp: + reduce_type = "logaddexp"; + break; } kname << reduce_type << "_" << type_to_name(in) << "_" << type_to_name(out); auto kernel = get_scan_kernel( diff --git a/mlx/backend/metal/softmax.cpp b/mlx/backend/metal/softmax.cpp index b089188b8..224721a50 100644 --- a/mlx/backend/metal/softmax.cpp +++ b/mlx/backend/metal/softmax.cpp @@ -23,12 +23,7 @@ void Softmax::eval_gpu(const std::vector& inputs, array& out) { // Make sure that the last dimension is contiguous auto set_output = [&s, &out](const array& x) { - bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1; - if (no_copy && x.ndim() > 1) { - auto s = x.strides()[x.ndim() - 2]; - no_copy &= (s == 0 || s == x.shape().back()); - } - if (no_copy) { + if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) { if (x.is_donatable()) { out.copy_shared_buffer(x); } else { diff --git a/mlx/backend/metal/utils.h b/mlx/backend/metal/utils.h index cc56bab32..079d15f17 100644 --- a/mlx/backend/metal/utils.h +++ b/mlx/backend/metal/utils.h @@ -2,6 +2,8 @@ #pragma once +#include + #include "mlx/array.h" #include "mlx/backend/metal/device.h" #include "mlx/primitives.h" @@ -58,14 +60,27 @@ inline void debug_set_primitive_buffer_label( std::string get_primitive_string(Primitive* primitive); +template +constexpr bool is_numeric_except_char = std::is_arithmetic_v && + !std::is_same_v && !std::is_same_v && + !std::is_same_v && !std::is_same_v; + template void concatenate(std::string& acc, T first) { - acc += first; + if constexpr (is_numeric_except_char) { + acc += std::to_string(first); + } else { + acc += first; + } } template void concatenate(std::string& acc, T first, Args... args) { - acc += first; + if constexpr (is_numeric_except_char) { + acc += std::to_string(first); + } else { + acc += first; + } concatenate(acc, args...); } diff --git a/mlx/backend/no_cpu/primitives.cpp b/mlx/backend/no_cpu/primitives.cpp index 2f1ae566f..84372b096 100644 --- a/mlx/backend/no_cpu/primitives.cpp +++ b/mlx/backend/no_cpu/primitives.cpp @@ -82,6 +82,7 @@ NO_CPU(LogicalNot) NO_CPU(LogicalAnd) NO_CPU(LogicalOr) NO_CPU(LogAddExp) +NO_CPU(LogSumExp) NO_CPU_MULTI(LUF) NO_CPU(Matmul) NO_CPU(Maximum) diff --git a/mlx/backend/no_metal/allocator.cpp b/mlx/backend/no_metal/allocator.cpp index 0429ea53a..a8b260b6b 100644 --- a/mlx/backend/no_metal/allocator.cpp +++ b/mlx/backend/no_metal/allocator.cpp @@ -1,14 +1,72 @@ // Copyright © 2023 Apple Inc. +#include +#include + #include "mlx/allocator.h" -namespace mlx::core::allocator { +#ifdef __APPLE__ +#include "mlx/backend/no_metal/apple_memory.h" +#elif defined(__linux__) +#include "mlx/backend/no_metal/linux_memory.h" +#else +size_t get_memory_size() { + return 0; +} +#endif -Allocator& allocator() { +namespace mlx::core { + +namespace allocator { + +class CommonAllocator : public Allocator { + /** A general CPU allocator. */ + public: + virtual Buffer malloc(size_t size) override; + virtual void free(Buffer buffer) override; + virtual size_t size(Buffer buffer) const override; + size_t get_active_memory() const { + return active_memory_; + }; + size_t get_peak_memory() const { + return peak_memory_; + }; + void reset_peak_memory() { + std::unique_lock lk(mutex_); + peak_memory_ = 0; + }; + size_t get_memory_limit() { + return memory_limit_; + } + size_t set_memory_limit(size_t limit) { + std::unique_lock lk(mutex_); + std::swap(memory_limit_, limit); + return limit; + } + + private: + size_t memory_limit_; + size_t active_memory_{0}; + size_t peak_memory_{0}; + std::mutex mutex_; + CommonAllocator() : memory_limit_(0.8 * get_memory_size()) { + if (memory_limit_ == 0) { + memory_limit_ = 1UL << 33; + } + }; + + friend CommonAllocator& common_allocator(); +}; + +CommonAllocator& common_allocator() { static CommonAllocator allocator_; return allocator_; } +Allocator& allocator() { + return common_allocator(); +} + void* Buffer::raw_ptr() { if (!ptr_) { return nullptr; @@ -16,4 +74,59 @@ void* Buffer::raw_ptr() { return static_cast(ptr_) + 1; } -} // namespace mlx::core::allocator +Buffer CommonAllocator::malloc(size_t size) { + void* ptr = std::malloc(size + sizeof(size_t)); + if (ptr != nullptr) { + *static_cast(ptr) = size; + } + std::unique_lock lk(mutex_); + active_memory_ += size; + peak_memory_ = std::max(active_memory_, peak_memory_); + return Buffer{ptr}; +} + +void CommonAllocator::free(Buffer buffer) { + auto sz = size(buffer); + std::free(buffer.ptr()); + std::unique_lock lk(mutex_); + active_memory_ -= sz; +} + +size_t CommonAllocator::size(Buffer buffer) const { + if (buffer.ptr() == nullptr) { + return 0; + } + return *static_cast(buffer.ptr()); +} + +} // namespace allocator + +size_t get_active_memory() { + return allocator::common_allocator().get_active_memory(); +} +size_t get_peak_memory() { + return allocator::common_allocator().get_peak_memory(); +} +void reset_peak_memory() { + return allocator::common_allocator().reset_peak_memory(); +} +size_t set_memory_limit(size_t limit) { + return allocator::common_allocator().set_memory_limit(limit); +} +size_t get_memory_limit() { + return allocator::common_allocator().get_memory_limit(); +} + +// No-ops for common allocator +size_t get_cache_memory() { + return 0; +} +size_t set_cache_limit(size_t) { + return 0; +} +size_t set_wired_limit(size_t) { + return 0; +} +void clear_cache() {} + +} // namespace mlx::core diff --git a/mlx/backend/no_metal/apple_memory.h b/mlx/backend/no_metal/apple_memory.h new file mode 100644 index 000000000..7fdc53014 --- /dev/null +++ b/mlx/backend/no_metal/apple_memory.h @@ -0,0 +1,16 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include + +namespace { + +size_t get_memory_size() { + size_t memsize = 0; + size_t length = sizeof(memsize); + sysctlbyname("hw.memsize", &memsize, &length, NULL, 0); + return memsize; +} + +} // namespace diff --git a/mlx/backend/no_metal/event.cpp b/mlx/backend/no_metal/event.cpp index 692be20d2..6dde047ab 100644 --- a/mlx/backend/no_metal/event.cpp +++ b/mlx/backend/no_metal/event.cpp @@ -28,21 +28,19 @@ void Event::wait() { ec->cv.wait(lk, [value = value(), ec] { return ec->value >= value; }); } -void Event::signal() { - auto ec = static_cast(event_.get()); - { - std::lock_guard lk(ec->mtx); - ec->value = value(); - } - ec->cv.notify_all(); -} - void Event::wait(Stream stream) { scheduler::enqueue(stream, [*this]() mutable { wait(); }); } void Event::signal(Stream stream) { - scheduler::enqueue(stream, [*this]() mutable { signal(); }); + scheduler::enqueue(stream, [*this]() mutable { + auto ec = static_cast(event_.get()); + { + std::lock_guard lk(ec->mtx); + ec->value = value(); + } + ec->cv.notify_all(); + }); } bool Event::is_signaled() const { diff --git a/mlx/backend/no_metal/linux_memory.h b/mlx/backend/no_metal/linux_memory.h new file mode 100644 index 000000000..f909edcd7 --- /dev/null +++ b/mlx/backend/no_metal/linux_memory.h @@ -0,0 +1,22 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include + +namespace { + +size_t get_memory_size() { + struct sysinfo info; + + if (sysinfo(&info) != 0) { + return 0; + } + + size_t total_ram = info.totalram; + total_ram *= info.mem_unit; + + return total_ram; +} + +} // namespace diff --git a/mlx/backend/no_metal/metal.cpp b/mlx/backend/no_metal/metal.cpp index 03c68c734..ef9af8800 100644 --- a/mlx/backend/no_metal/metal.cpp +++ b/mlx/backend/no_metal/metal.cpp @@ -31,33 +31,8 @@ void synchronize(Stream) { "[metal::synchronize] Cannot synchronize GPU without metal backend"); } -// No-ops when Metal is not available. -size_t get_active_memory() { - return 0; -} -size_t get_peak_memory() { - return 0; -} -void reset_peak_memory() {} -size_t get_cache_memory() { - return 0; -} -size_t set_memory_limit(size_t) { - return 0; -} -size_t get_memory_limit() { - return 0; -} -size_t set_cache_limit(size_t) { - return 0; -} -size_t set_wired_limit(size_t) { - return 0; -} - void start_capture(std::string) {} void stop_capture() {} -void clear_cache() {} const std::unordered_map>& device_info() { diff --git a/mlx/backend/no_metal/primitives.cpp b/mlx/backend/no_metal/primitives.cpp index 6e37a1d2b..6826c97f6 100644 --- a/mlx/backend/no_metal/primitives.cpp +++ b/mlx/backend/no_metal/primitives.cpp @@ -82,6 +82,7 @@ NO_GPU(LogicalNot) NO_GPU(LogicalAnd) NO_GPU(LogicalOr) NO_GPU(LogAddExp) +NO_GPU(LogSumExp) NO_GPU_MULTI(LUF) NO_GPU(Matmul) NO_GPU(Maximum) diff --git a/mlx/distributed/distributed.cpp b/mlx/distributed/distributed.cpp index 33f1d1320..cc01e6090 100644 --- a/mlx/distributed/distributed.cpp +++ b/mlx/distributed/distributed.cpp @@ -15,6 +15,14 @@ void all_sum(Group group, const array& input, array& output, Stream stream) { group.raw_group()->all_sum(input, output, stream); } +void all_max(Group group, const array& input, array& output, Stream stream) { + group.raw_group()->all_max(input, output, stream); +} + +void all_min(Group group, const array& input, array& output, Stream stream) { + group.raw_group()->all_min(input, output, stream); +} + void all_gather(Group group, const array& input, array& output, Stream stream) { group.raw_group()->all_gather(input, output, stream); } @@ -57,6 +65,16 @@ class EmptyGroup : public GroupImpl { throw std::runtime_error( "Communication not implemented in an empty distributed group."); } + + void all_max(const array&, array&, Stream) override { + throw std::runtime_error( + "Communication not implemented in an empty distributed group."); + } + + void all_min(const array&, array&, Stream) override { + throw std::runtime_error( + "Communication not implemented in an empty distributed group."); + } }; } // namespace detail diff --git a/mlx/distributed/distributed_impl.h b/mlx/distributed/distributed_impl.h index 7c06068b2..8b0327131 100644 --- a/mlx/distributed/distributed_impl.h +++ b/mlx/distributed/distributed_impl.h @@ -21,6 +21,8 @@ class GroupImpl { virtual void all_gather(const array& input, array& output, Stream stream) = 0; virtual void send(const array& input, int dst, Stream stream) = 0; virtual void recv(array& out, int src, Stream stream) = 0; + virtual void all_max(const array& input, array& output, Stream stream) = 0; + virtual void all_min(const array& input, array& output, Stream stream) = 0; }; /* Perform an all reduce sum operation */ @@ -35,4 +37,10 @@ void send(Group group, const array& input, int dst, Stream stream); /** Recv an array from the src rank */ void recv(Group group, array& out, int src, Stream stream); +/** Max reduction */ +void all_max(Group group, const array& input, array& output, Stream stream); + +/** Min reduction */ +void all_min(Group group, const array& input, array& output, Stream stream); + } // namespace mlx::core::distributed::detail diff --git a/mlx/distributed/mpi/CMakeLists.txt b/mlx/distributed/mpi/CMakeLists.txt index 7063a101f..842f70b55 100644 --- a/mlx/distributed/mpi/CMakeLists.txt +++ b/mlx/distributed/mpi/CMakeLists.txt @@ -1,4 +1,4 @@ -if(MPI_FOUND AND MLX_BUILD_CPU) +if(MLX_BUILD_CPU) target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/mpi.cpp) else() target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_mpi.cpp) diff --git a/mlx/distributed/mpi/mpi.cpp b/mlx/distributed/mpi/mpi.cpp index b9136f701..e80a1759f 100644 --- a/mlx/distributed/mpi/mpi.cpp +++ b/mlx/distributed/mpi/mpi.cpp @@ -1,12 +1,13 @@ // Copyright © 2024 Apple Inc. #include -#include +#include #include "mlx/backend/cpu/encoder.h" #include "mlx/distributed/distributed.h" #include "mlx/distributed/distributed_impl.h" #include "mlx/distributed/mpi/mpi.h" +#include "mlx/distributed/mpi/mpi_declarations.h" #define LOAD_SYMBOL(symbol, variable) \ { \ @@ -18,6 +19,12 @@ } \ } +#ifdef __APPLE__ +static constexpr const char* libmpi_name = "libmpi.dylib"; +#else +static constexpr const char* libmpi_name = "libmpi.so"; +#endif + namespace mlx::core::distributed::mpi { using GroupImpl = mlx::core::distributed::detail::GroupImpl; @@ -43,15 +50,69 @@ void simple_sum( template void simple_sum(void*, void*, int*, MPI_Datatype*); template void simple_sum(void*, void*, int*, MPI_Datatype*); +template +void simple_max( + void* input, + void* accumulator, + int* len, + MPI_Datatype* datatype) { + T* in = (T*)input; + T* acc = (T*)accumulator; + int N = *len; + + while (N-- > 0) { + *acc = std::max(*acc, *in); + acc++; + in++; + } +} +template void simple_max(void*, void*, int*, MPI_Datatype*); +template void simple_max(void*, void*, int*, MPI_Datatype*); +template void simple_max(void*, void*, int*, MPI_Datatype*); + +template +void simple_min( + void* input, + void* accumulator, + int* len, + MPI_Datatype* datatype) { + T* in = (T*)input; + T* acc = (T*)accumulator; + int N = *len; + + while (N-- > 0) { + *acc = std::min(*acc, *in); + acc++; + in++; + } +} +template void simple_min(void*, void*, int*, MPI_Datatype*); +template void simple_min(void*, void*, int*, MPI_Datatype*); +template void simple_min(void*, void*, int*, MPI_Datatype*); + struct MPIWrapper { MPIWrapper() { initialized_ = false; - libmpi_handle_ = dlopen("libmpi.dylib", RTLD_NOW | RTLD_GLOBAL); + libmpi_handle_ = dlopen(libmpi_name, RTLD_NOW | RTLD_GLOBAL); if (libmpi_handle_ == nullptr) { return; } + // Check library version and warn if it isn't Open MPI + int (*get_version)(char*, int*); + LOAD_SYMBOL(MPI_Get_library_version, get_version); + char version_ptr[MPI_MAX_LIBRARY_VERSION_STRING]; + int version_length = 0; + get_version(version_ptr, &version_length); + std::string_view version(version_ptr, version_length); + if (version.find("Open MPI") == std::string::npos) { + std::cerr << "[mpi] MPI found but it does not appear to be Open MPI." + << "MLX requires Open MPI but this is " << version << std::endl; + libmpi_handle_ = nullptr; + return; + } + // API LOAD_SYMBOL(MPI_Init, init); LOAD_SYMBOL(MPI_Finalize, finalize); @@ -72,6 +133,8 @@ struct MPIWrapper { // Ops LOAD_SYMBOL(ompi_mpi_op_sum, op_sum_); + LOAD_SYMBOL(ompi_mpi_op_max, op_max_); + LOAD_SYMBOL(ompi_mpi_op_min, op_min_); // Datatypes LOAD_SYMBOL(ompi_mpi_c_bool, mpi_bool_); @@ -106,9 +169,15 @@ struct MPIWrapper { mpi_type_contiguous(2, mpi_uint8_, &mpi_bfloat16_); mpi_type_commit(&mpi_bfloat16_); - // Custom sum ops + // Custom reduction ops mpi_op_create(&simple_sum, 1, &op_sum_f16_); mpi_op_create(&simple_sum, 1, &op_sum_bf16_); + mpi_op_create(&simple_max, 1, &op_max_f16_); + mpi_op_create(&simple_max, 1, &op_max_bf16_); + mpi_op_create(&simple_max, 1, &op_max_c64_); + mpi_op_create(&simple_min, 1, &op_min_f16_); + mpi_op_create(&simple_min, 1, &op_min_bf16_); + mpi_op_create(&simple_min, 1, &op_min_c64_); initialized_ = true; } @@ -170,6 +239,32 @@ struct MPIWrapper { } } + MPI_Op op_max(const array& arr) { + switch (arr.dtype()) { + case float16: + return op_max_f16_; + case bfloat16: + return op_max_bf16_; + case complex64: + return op_max_c64_; + default: + return op_max_; + } + } + + MPI_Op op_min(const array& arr) { + switch (arr.dtype()) { + case float16: + return op_min_f16_; + case bfloat16: + return op_min_bf16_; + case complex64: + return op_min_c64_; + default: + return op_min_; + } + } + void* libmpi_handle_; // API @@ -198,6 +293,14 @@ struct MPIWrapper { MPI_Op op_sum_; MPI_Op op_sum_f16_; MPI_Op op_sum_bf16_; + MPI_Op op_max_; + MPI_Op op_max_f16_; + MPI_Op op_max_bf16_; + MPI_Op op_max_c64_; + MPI_Op op_min_; + MPI_Op op_min_f16_; + MPI_Op op_min_bf16_; + MPI_Op op_min_c64_; // Datatypes MPI_Datatype mpi_bool_; @@ -285,6 +388,36 @@ class MPIGroup : public GroupImpl { comm_); } + void all_max(const array& input, array& output, Stream stream) override { + auto& encoder = cpu::get_command_encoder(stream); + encoder.set_input_array(input); + encoder.set_output_array(output); + encoder.dispatch( + mpi().all_reduce, + (input.data() == output.data()) ? MPI_IN_PLACE + : input.data(), + output.data(), + input.size(), + mpi().datatype(input), + mpi().op_max(input), + comm_); + } + + void all_min(const array& input, array& output, Stream stream) override { + auto& encoder = cpu::get_command_encoder(stream); + encoder.set_input_array(input); + encoder.set_output_array(output); + encoder.dispatch( + mpi().all_reduce, + (input.data() == output.data()) ? MPI_IN_PLACE + : input.data(), + output.data(), + input.size(), + mpi().datatype(input), + mpi().op_min(input), + comm_); + } + void all_gather(const array& input, array& output, Stream stream) override { auto& encoder = cpu::get_command_encoder(stream); encoder.set_input_array(input); diff --git a/mlx/distributed/mpi/mpi_declarations.h b/mlx/distributed/mpi/mpi_declarations.h new file mode 100644 index 000000000..99c1a9cbb --- /dev/null +++ b/mlx/distributed/mpi/mpi_declarations.h @@ -0,0 +1,28 @@ +// Copyright © 2024 Apple Inc. + +// Constants + +#define MPI_SUCCESS 0 +#define MPI_ANY_SOURCE -1 +#define MPI_ANY_TAG -1 +#define MPI_IN_PLACE ((void*)1) +#define MPI_MAX_LIBRARY_VERSION_STRING 256 + +// Define all the types that we use so that we don't include which +// causes linker errors on some platforms. +// +// NOTE: We define everything for openmpi. + +typedef void* MPI_Comm; +typedef void* MPI_Datatype; +typedef void* MPI_Op; + +typedef void(MPI_User_function)(void*, void*, int*, MPI_Datatype*); + +typedef struct ompi_status_public_t { + int MPI_SOURCE; + int MPI_TAG; + int MPI_ERROR; + int _cancelled; + size_t _ucount; +} MPI_Status; diff --git a/mlx/distributed/ops.cpp b/mlx/distributed/ops.cpp index 865911ac6..0a5114805 100644 --- a/mlx/distributed/ops.cpp +++ b/mlx/distributed/ops.cpp @@ -36,6 +36,40 @@ array all_sum( {x}); } +array all_max( + const array& x, + std::optional group_ /* = std::nullopt */, + StreamOrDevice s /* = {} */) { + auto group = to_group(group_); + + if (group.size() == 1) { + return x; + } + return array( + x.shape(), + x.dtype(), + std::make_shared( + to_stream(s, Device::cpu), group, AllReduce::Max), + {x}); +} + +array all_min( + const array& x, + std::optional group_ /* = std::nullopt */, + StreamOrDevice s /* = {} */) { + auto group = to_group(group_); + + if (group.size() == 1) { + return x; + } + return array( + x.shape(), + x.dtype(), + std::make_shared( + to_stream(s, Device::cpu), group, AllReduce::Min), + {x}); +} + array all_gather( const array& x, std::optional group_ /* = std::nullopt */, diff --git a/mlx/distributed/ops.h b/mlx/distributed/ops.h index 9430106b1..edd1fc9f4 100644 --- a/mlx/distributed/ops.h +++ b/mlx/distributed/ops.h @@ -38,4 +38,14 @@ array recv_like( std::optional group = std::nullopt, StreamOrDevice s = {}); +array all_max( + const array& x, + std::optional group = std::nullopt, + StreamOrDevice s = {}); + +array all_min( + const array& x, + std::optional group = std::nullopt, + StreamOrDevice s = {}); + } // namespace mlx::core::distributed diff --git a/mlx/distributed/primitives.cpp b/mlx/distributed/primitives.cpp index de28788b3..576424cdd 100644 --- a/mlx/distributed/primitives.cpp +++ b/mlx/distributed/primitives.cpp @@ -15,8 +15,14 @@ std::pair, std::vector> AllReduce::vmap( switch (reduce_type_) { case Sum: return {{all_sum(inputs[0], group(), stream())}, axes}; + case Max: + return {{all_max(inputs[0], group(), stream())}, axes}; + case Min: + return {{all_min(inputs[0], group(), stream())}, axes}; default: - throw std::runtime_error("Only all reduce sum is supported for now"); + + throw std::runtime_error( + "Only all reduce sum, max and min are supported for now"); } } @@ -27,8 +33,13 @@ std::vector AllReduce::jvp( switch (reduce_type_) { case Sum: return {all_sum(tangents[0], group(), stream())}; + case Max: + return {all_max(tangents[0], group(), stream())}; + case Min: + return {all_min(tangents[0], group(), stream())}; default: - throw std::runtime_error("Only all reduce sum is supported for now"); + throw std::runtime_error( + "Only all reduce sum, max and min are supported for now"); } } diff --git a/mlx/distributed/ring/ring.cpp b/mlx/distributed/ring/ring.cpp index d8fe53051..b31274e23 100644 --- a/mlx/distributed/ring/ring.cpp +++ b/mlx/distributed/ring/ring.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #include #include #include @@ -502,15 +503,38 @@ std::vector make_connections( return sockets; } +template +struct SumOp { + void operator()(const T* input, T* output, size_t N) { + while (N-- > 0) { + *output += *input; + input++; + output++; + } + } +}; template -void sum_inplace(const T* input, T* output, size_t N) { - while (N-- > 0) { - *output += *input; - input++; - output++; +struct MaxOp { + void operator()(const T* input, T* output, size_t N) { + while (N-- > 0) { + *output = std::max(*output, *input); + input++; + output++; + } } -} +}; + +template +struct MinOp { + void operator()(const T* input, T* output, size_t N) { + while (N-- > 0) { + *output = std::min(*output, *input); + input++; + output++; + } + } +}; } // namespace @@ -603,8 +627,19 @@ class RingGroup : public GroupImpl { return size_; } - void all_sum(const array& input_, array& output, Stream stream) override { - SWITCH_TYPE(output, all_sum(input_, output, stream)); + void all_sum(const array& input, array& output, Stream stream) override { + SWITCH_TYPE( + output, all_reduce>(input, output, stream, SumOp())); + } + + void all_max(const array& input, array& output, Stream stream) override { + SWITCH_TYPE( + output, all_reduce>(input, output, stream, MaxOp())); + } + + void all_min(const array& input, array& output, Stream stream) override { + SWITCH_TYPE( + output, all_reduce>(input, output, stream, MinOp())); } std::shared_ptr split(int color, int key = -1) override { @@ -612,7 +647,39 @@ class RingGroup : public GroupImpl { } void all_gather(const array& input, array& output, Stream stream) override { - throw std::runtime_error("[ring] All gather not supported."); + auto& encoder = cpu::get_command_encoder(stream); + encoder.set_input_array(input); + encoder.set_output_array(output); + encoder.dispatch([input_ptr = input.data(), + nbytes = input.nbytes(), + output_ptr = output.data(), + this]() { + constexpr size_t min_send_size = 262144; + size_t n_gathers = std::max( + std::min( + sockets_right_.size() + sockets_left_.size(), + nbytes / min_send_size), + size_t(1)); + size_t bytes_per_gather = ceildiv(nbytes, n_gathers); + std::vector> all_gathers; + for (int i = 0; i < n_gathers; i++) { + auto offset = i * bytes_per_gather; + all_gathers.emplace_back(pool_.enqueue(std::bind( + &RingGroup::all_gather_impl, + this, + input_ptr + offset, + output_ptr + offset, + nbytes, + offset + bytes_per_gather > nbytes ? nbytes - offset + : bytes_per_gather, + sockets_right_[i / 2], + sockets_left_[i / 2], + (i % 2) ? -1 : 1))); + } + for (auto& f : all_gathers) { + f.wait(); + } + }); } void send(const array& input, int dst, Stream stream) override { @@ -642,9 +709,8 @@ class RingGroup : public GroupImpl { encoder.dispatch( [out_ptr = out.data(), nbytes = out.nbytes(), src, this]() { // NOTE: We 'll check the sockets with the opposite order of send so - // that - // they work even with 2 nodes where left and right is the same - // neighbor. + // that they work even with 2 nodes where left and right is the same + // neighbor. int right = (rank_ + 1) % size_; int left = (rank_ + size_ - 1) % size_; if (src == left) { @@ -662,13 +728,17 @@ class RingGroup : public GroupImpl { } private: - template - void all_sum(const array& input, array& output, Stream stream) { + template + void all_reduce( + const array& input, + array& output, + Stream stream, + ReduceOp reduce_op) { auto in_ptr = input.data(); auto out_ptr = output.data(); auto& encoder = cpu::get_command_encoder(stream); encoder.set_output_array(output); - encoder.dispatch([in_ptr, out_ptr, size = input.size(), this]() { + encoder.dispatch([in_ptr, out_ptr, size = input.size(), this, reduce_op]() { // If the input data cannot be split into size_ segments then copy it and // all reduce a local buffer prefilled with 0s. size_t nbytes = size * sizeof(T); @@ -685,13 +755,14 @@ class RingGroup : public GroupImpl { char buffer[1024]; std::memset(buffer, 0, size_ * sizeof(T)); std::memcpy(buffer, in_ptr, nbytes); - all_sum_impl( + all_reduce_impl( reinterpret_cast(buffers_.data()), reinterpret_cast(buffer), size_, sockets_right_[0], sockets_left_[0], - -1); + -1, + reduce_op); std::memcpy(out_ptr, buffer, nbytes); return; } @@ -708,13 +779,13 @@ class RingGroup : public GroupImpl { std::min( sockets_right_.size() + sockets_left_.size(), nbytes / (size_ * min_send_size)), - 1UL); + size_t(1)); size_t step = ceildiv(size, n_reduces); std::vector> all_sums; for (int i = 0; i < n_reduces; i++) { all_sums.emplace_back(pool_.enqueue(std::bind( - &RingGroup::all_sum_impl, + &RingGroup::all_reduce_impl, this, reinterpret_cast( buffers_.data() + i * ALL_SUM_SIZE * ALL_SUM_BUFFERS), @@ -722,7 +793,8 @@ class RingGroup : public GroupImpl { std::min(size, (i + 1) * step) - i * step, sockets_right_[i / 2], sockets_left_[i / 2], - (i % 2) ? -1 : 1))); + (i % 2) ? -1 : 1, + reduce_op))); } for (auto& f : all_sums) { f.wait(); @@ -730,14 +802,15 @@ class RingGroup : public GroupImpl { }); } - template - void all_sum_impl( + template + void all_reduce_impl( T* buffer, T* data, size_t data_size, int socket_right, int socket_left, - int direction) { + int direction, + ReduceOp reduce_op) { // Choose which socket we send to and recv from int socket_send = (direction < 0) ? socket_right : socket_left; int socket_recv = (direction < 0) ? socket_left : socket_right; @@ -745,8 +818,8 @@ class RingGroup : public GroupImpl { // We split the data into `size_` segments of size `segment_size` and each // of these in smaller segments of ALL_SUM_SIZE which we 'll call packets. size_t segment_size = ceildiv(data_size, size_); - size_t BUFFER_SIZE = - std::max(32768UL, std::min(ALL_SUM_SIZE / sizeof(T), segment_size / 2)); + size_t BUFFER_SIZE = std::max( + size_t(32768), std::min(ALL_SUM_SIZE / sizeof(T), segment_size / 2)); size_t n_packets = ceildiv(segment_size, BUFFER_SIZE); // Initial segments @@ -814,7 +887,7 @@ class RingGroup : public GroupImpl { sends[b].wait(); recvs[b].wait(); if (2 * j < send_plan.size()) { - sum_inplace( + reduce_op( recv_buffers[j % ALL_SUM_BUFFERS], data + recv_plan[j].first, recv_plan[j].second - recv_plan[j].first); @@ -827,9 +900,46 @@ class RingGroup : public GroupImpl { recvs[b].wait(); } + void all_gather_impl( + const char* input, + char* output, + size_t input_size, + size_t data_size, + int socket_right, + int socket_left, + int direction) { + // Choose which socket we send to and recv from + int socket_send = (direction < 0) ? socket_right : socket_left; + int socket_recv = (direction < 0) ? socket_left : socket_right; + + // Initial segments + int send_segment = rank_; + int recv_segment = (rank_ + direction + size_) % size_; + + // Copy our own segment in the output + std::memcpy(output + rank_ * input_size, input, data_size); + + // Simple send/recv all gather. Possible performance improvement by + // splitting to multiple chunks and allowing send/recv to run a bit ahead. + // See all_sum_impl for an example. + for (int i = 0; i < size_ - 1; i++) { + auto sent = comm_.send( + socket_send, output + send_segment * input_size, data_size); + auto recvd = comm_.recv( + socket_recv, output + recv_segment * input_size, data_size); + + send_segment = (send_segment + size_ + direction) % size_; + recv_segment = (recv_segment + size_ + direction) % size_; + + sent.wait(); + recvd.wait(); + } + } + void send(const std::vector& sockets, const char* data, size_t data_size) { - size_t segment_size = std::max(1024UL, ceildiv(data_size, sockets.size())); + size_t segment_size = + std::max(size_t(1024), ceildiv(data_size, sockets.size())); std::vector> sends; for (int i = 0; i < sockets.size(); i++) { if (i * segment_size >= data_size) { @@ -846,7 +956,8 @@ class RingGroup : public GroupImpl { } void recv(const std::vector& sockets, char* data, size_t data_size) { - size_t segment_size = std::max(1024UL, ceildiv(data_size, sockets.size())); + size_t segment_size = + std::max(size_t(1024), ceildiv(data_size, sockets.size())); std::vector> recvs; for (int i = 0; i < sockets.size(); i++) { if (i * segment_size >= data_size) { diff --git a/mlx/dtype_utils.cpp b/mlx/dtype_utils.cpp new file mode 100644 index 000000000..a4448536d --- /dev/null +++ b/mlx/dtype_utils.cpp @@ -0,0 +1,20 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/dtype_utils.h" + +namespace mlx::core { + +const char* dtype_to_string(Dtype arg) { + if (arg == bool_) { + return "bool"; + } +#define SPECIALIZE_DtypeToString(CPP_TYPE, DTYPE) \ + if (DTYPE == arg) { \ + return #DTYPE; \ + } + MLX_FORALL_DTYPES(SPECIALIZE_DtypeToString) +#undef SPECIALIZE_DtypeToString + return "(unknown)"; +} + +} // namespace mlx::core diff --git a/mlx/dtype_utils.h b/mlx/dtype_utils.h new file mode 100644 index 000000000..55de890f2 --- /dev/null +++ b/mlx/dtype_utils.h @@ -0,0 +1,207 @@ +// Copyright © 2025 Apple Inc. +// Copyright © Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the BSD-style license found in +// https://github.com/pytorch/executorch/blob/main/LICENSE +// +// Forked from +// https://github.com/pytorch/executorch/blob/main/runtime/core/exec_aten/util/scalar_type_util.h + +#pragma once + +#include "mlx/dtype.h" + +#include + +namespace mlx::core { + +// Return string representation of dtype. +const char* dtype_to_string(Dtype arg); + +// Macros that iterate across different subsets of Dtypes. +// +// For all of these macros, the final `_` parameter is the name of another macro +// that takes two parameters: the name of a C type, and the name of the +// corresponding Dtype enumerator. +// +// Note that these macros should use fully-qualified namespaces (starting with +// `::`) to ensure that they can be called safely in any arbitrary namespace. +#define MLX_FORALL_INT_TYPES(_) \ + _(uint8_t, uint8) \ + _(uint16_t, uint16) \ + _(uint32_t, uint32) \ + _(uint64_t, uint64) \ + _(int8_t, int8) \ + _(int16_t, int16) \ + _(int32_t, int32) \ + _(int64_t, int64) + +#define MLX_FORALL_FLOAT_TYPES(_) \ + _(float16_t, float16) \ + _(float, float32) \ + _(double, float64) \ + _(bfloat16_t, bfloat16) + +// Calls the provided macro on every Dtype, providing the C type and the +// Dtype name to each call. +// +// @param _ A macro that takes two parameters: the name of a C type, and the +// name of the corresponding Dtype enumerator. +#define MLX_FORALL_DTYPES(_) \ + MLX_FORALL_INT_TYPES(_) \ + MLX_FORALL_FLOAT_TYPES(_) \ + _(bool, bool_) \ + _(complex64_t, complex64) + +// Maps Dtypes to C++ types. +template +struct DtypeToCppType; + +#define SPECIALIZE_DtypeToCppType(CPP_TYPE, DTYPE) \ + template <> \ + struct DtypeToCppType { \ + using type = CPP_TYPE; \ + }; + +MLX_FORALL_DTYPES(SPECIALIZE_DtypeToCppType) + +#undef SPECIALIZE_DtypeToCppType + +// Maps C++ types to Dtypes. +template +struct CppTypeToDtype; + +#define SPECIALIZE_CppTypeToDtype(CPP_TYPE, DTYPE) \ + template <> \ + struct CppTypeToDtype \ + : std::integral_constant {}; + +MLX_FORALL_DTYPES(SPECIALIZE_CppTypeToDtype) + +#undef SPECIALIZE_CppTypeToDtype + +// Helper macros for switch case macros (see below) +// +// These macros are not meant to be used directly. They provide an easy way to +// generate a switch statement that can handle subsets of Dtypes supported. + +#define MLX_INTERNAL_SWITCH_CASE(enum_type, CTYPE_ALIAS, ...) \ + case enum_type: { \ + using CTYPE_ALIAS = ::mlx::core::DtypeToCppType::type; \ + __VA_ARGS__; \ + break; \ + } + +#define MLX_INTERNAL_SWITCH_CHECKED(TYPE, NAME, ...) \ + switch (TYPE) { \ + __VA_ARGS__ \ + default: \ + throw std::invalid_argument(fmt::format( \ + "Unhandled dtype %s for %s", dtype_to_string(TYPE), NAME)); \ + } + +#define MLX_INTERNAL_SWITCH_CASE_INT_TYPES(CTYPE_ALIAS, ...) \ + MLX_INTERNAL_SWITCH_CASE( \ + ::mlx::core::Dtype::Val::uint8, CTYPE_ALIAS, __VA_ARGS__) \ + MLX_INTERNAL_SWITCH_CASE( \ + ::mlx::core::Dtype::Val::uint16, CTYPE_ALIAS, __VA_ARGS__) \ + MLX_INTERNAL_SWITCH_CASE( \ + ::mlx::core::Dtype::Val::uint32, CTYPE_ALIAS, __VA_ARGS__) \ + MLX_INTERNAL_SWITCH_CASE( \ + ::mlx::core::Dtype::Val::uint64, CTYPE_ALIAS, __VA_ARGS__) \ + MLX_INTERNAL_SWITCH_CASE( \ + ::mlx::core::Dtype::Val::int8, CTYPE_ALIAS, __VA_ARGS__) \ + MLX_INTERNAL_SWITCH_CASE( \ + ::mlx::core::Dtype::Val::int16, CTYPE_ALIAS, __VA_ARGS__) \ + MLX_INTERNAL_SWITCH_CASE( \ + ::mlx::core::Dtype::Val::int32, CTYPE_ALIAS, __VA_ARGS__) \ + MLX_INTERNAL_SWITCH_CASE( \ + ::mlx::core::Dtype::Val::int64, CTYPE_ALIAS, __VA_ARGS__) + +#define MLX_INTERNAL_SWITCH_CASE_FLOAT_TYPES(CTYPE_ALIAS, ...) \ + MLX_INTERNAL_SWITCH_CASE( \ + ::mlx::core::Dtype::Val::float16, CTYPE_ALIAS, __VA_ARGS__) \ + MLX_INTERNAL_SWITCH_CASE( \ + ::mlx::core::Dtype::Val::float32, CTYPE_ALIAS, __VA_ARGS__) \ + MLX_INTERNAL_SWITCH_CASE( \ + ::mlx::core::Dtype::Val::float64, CTYPE_ALIAS, __VA_ARGS__) \ + MLX_INTERNAL_SWITCH_CASE( \ + ::mlx::core::Dtype::Val::bfloat16, CTYPE_ALIAS, __VA_ARGS__) + +#define MLX_INTERNAL_SWITCH_CASE_INT_FLOAT_TYPES(CTYPE_ALIAS, ...) \ + MLX_INTERNAL_SWITCH_CASE_INT_TYPES(CTYPE_ALIAS, __VA_ARGS__) \ + MLX_INTERNAL_SWITCH_CASE_FLOAT_TYPES(CTYPE_ALIAS, __VA_ARGS__) + +#define MLX_INTERNAL_SWITCH_CASE_REAL_TYPES(CTYPE_ALIAS, ...) \ + MLX_INTERNAL_SWITCH_CASE_INT_FLOAT_TYPES(CTYPE_ALIAS, __VA_ARGS__) \ + MLX_INTERNAL_SWITCH_CASE( \ + ::mlx::core::Dtype::Val::bool_, CTYPE_ALIAS, __VA_ARGS__) + +#define MLX_INTERNAL_SWITCH_CASE_COMPLEX_TYPES(CTYPE_ALIAS, ...) \ + MLX_INTERNAL_SWITCH_CASE( \ + ::mlx::core::Dtype::Val::complex64, CTYPE_ALIAS, __VA_ARGS__) + +#define MLX_INTERNAL_SWITCH_CASE_ALL_TYPES(CTYPE_ALIAS, ...) \ + MLX_INTERNAL_SWITCH_CASE_REAL_TYPES(CTYPE_ALIAS, __VA_ARGS__) \ + MLX_INTERNAL_SWITCH_CASE_COMPLEX_TYPES(CTYPE_ALIAS, __VA_ARGS__) + +// Switch case macros +// +// These macros provide an easy way to generate switch statements that apply a +// common lambda function to subsets of Dtypes supported by MLX. +// The lambda function can type specialize to the ctype associated with the +// Dtype being handled through an alias passed as the CTYPE_ALIAS argument. +// +// Arguments: +// - ADDITIONAL: Additional Dtype case to add +// - TYPE: The Dtype to handle through the switch statement +// - NAME: A name for this operation which will be used in error messages +// - CTYPE_ALIAS: A typedef for the ctype associated with the Dtype. +// - ...: A statement to be applied to each Dtype case +// +// An example usage is: +// +// MLX_SWITCH_ALL_TYPES(input.dtype(), CTYPE, { +// output.data[0] = input.data[0]; +// }); +// +// Note that these can be nested as well: +// +// MLX_SWITCH_ALL_TYPES(input.dtype(), CTYPE_IN, { +// MLX_SWITCH_ALL_TYPES(output.dtype(), CTYPE_OUT, { +// output.data[0] = input.data[0]; +// }); +// }); +// +// These macros are adapted from Dispatch.h in the ATen library. The primary +// difference is that the CTYPE_ALIAS argument is exposed to users, which is +// used to alias the ctype associated with the Dtype that is being handled. + +#define MLX_SWITCH_ALL_TYPES(TYPE, CTYPE_ALIAS, ...) \ + switch (TYPE) { MLX_INTERNAL_SWITCH_CASE_ALL_TYPES(CTYPE_ALIAS, __VA_ARGS__) } + +#define MLX_SWITCH_INT_TYPES_CHECKED(TYPE, NAME, CTYPE_ALIAS, ...) \ + MLX_INTERNAL_SWITCH_CHECKED( \ + TYPE, \ + NAME, \ + MLX_INTERNAL_SWITCH_CASE_INT_TYPES(CTYPE_ALIAS, __VA_ARGS__)) + +#define MLX_SWITCH_FLOAT_TYPES_CHECKED(TYPE, NAME, CTYPE_ALIAS, ...) \ + MLX_INTERNAL_SWITCH_CHECKED( \ + TYPE, \ + NAME, \ + MLX_INTERNAL_SWITCH_CASE_FLOAT_TYPES(CTYPE_ALIAS, __VA_ARGS__)) + +#define MLX_SWITCH_INT_FLOAT_TYPES_CHECKED(TYPE, NAME, CTYPE_ALIAS, ...) \ + MLX_INTERNAL_SWITCH_CHECKED( \ + TYPE, \ + NAME, \ + MLX_INTERNAL_SWITCH_CASE_INT_FLOAT_TYPES(CTYPE_ALIAS, __VA_ARGS__)) + +#define MLX_SWITCH_REAL_TYPES_CHECKED(TYPE, NAME, CTYPE_ALIAS, ...) \ + MLX_INTERNAL_SWITCH_CHECKED( \ + TYPE, \ + NAME, \ + MLX_INTERNAL_SWITCH_CASE_REAL_TYPES(CTYPE_ALIAS, __VA_ARGS__)) + +} // namespace mlx::core diff --git a/mlx/event.h b/mlx/event.h index 7054b60c7..937028e2a 100644 --- a/mlx/event.h +++ b/mlx/event.h @@ -16,9 +16,6 @@ class Event { // Wait for the event to be signaled at its current value void wait(); - // Signal the event at its current value - void signal(); - // Wait in the given stream for the event to be signaled at its current value void wait(Stream stream); diff --git a/mlx/export.cpp b/mlx/export.cpp index 4eb3ff99a..effc7a0c1 100644 --- a/mlx/export.cpp +++ b/mlx/export.cpp @@ -1,5 +1,6 @@ // Copyright © 2024 Apple Inc. #include "mlx/export.h" +#include #include "mlx/compile_impl.h" #include "mlx/fast_primitives.h" #include "mlx/primitives.h" @@ -278,6 +279,7 @@ struct PrimitiveFactory { SERIALIZE_PRIMITIVE(LogicalAnd), SERIALIZE_PRIMITIVE(LogicalOr), SERIALIZE_PRIMITIVE(LogAddExp), + SERIALIZE_PRIMITIVE(LogSumExp), SERIALIZE_PRIMITIVE(Matmul), SERIALIZE_PRIMITIVE(Maximum), SERIALIZE_PRIMITIVE(Minimum), @@ -297,7 +299,13 @@ struct PrimitiveFactory { SERIALIZE_PRIMITIVE(Reshape), SERIALIZE_PRIMITIVE(Reduce, "And", "Or", "Sum", "Prod", "Min", "Max"), SERIALIZE_PRIMITIVE(Round), - SERIALIZE_PRIMITIVE(Scan, "CumSum", "CumProd", "CumMin", "CumMax"), + SERIALIZE_PRIMITIVE( + Scan, + "CumSum", + "CumProd", + "CumMin", + "CumMax", + "CumLogaddexp"), SERIALIZE_PRIMITIVE(Scatter), SERIALIZE_PRIMITIVE(Select), SERIALIZE_PRIMITIVE(Sigmoid), @@ -474,7 +482,9 @@ bool FunctionTable::match( return false; } } - for (auto& [_, in] : kwargs) { + auto sorted_kwargs = + std::map(kwargs.begin(), kwargs.end()); + for (auto& [_, in] : sorted_kwargs) { if (!match_inputs(in, fun.inputs[i++])) { return false; } @@ -550,7 +560,9 @@ void FunctionExporter::export_function(const Args& args, const Kwargs& kwargs) { // Flatten the inputs to the function for tracing std::vector kwarg_keys; auto inputs = args; - for (auto& [k, v] : kwargs) { + auto sorted_kwargs = + std::map(kwargs.begin(), kwargs.end()); + for (auto& [k, v] : sorted_kwargs) { kwarg_keys.push_back(k); inputs.push_back(v); } diff --git a/mlx/export.h b/mlx/export.h index da090510b..c6859c6d8 100644 --- a/mlx/export.h +++ b/mlx/export.h @@ -2,14 +2,14 @@ #pragma once -#include #include +#include #include "mlx/array.h" namespace mlx::core { using Args = std::vector; -using Kwargs = std::map; +using Kwargs = std::unordered_map; struct FunctionExporter; diff --git a/mlx/fast.cpp b/mlx/fast.cpp index ac3cfe042..77210f713 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -9,6 +9,7 @@ #include "mlx/fast_primitives.h" #include "mlx/ops.h" #include "mlx/transforms.h" +#include "mlx/transforms_impl.h" namespace mlx::core::fast { @@ -567,9 +568,9 @@ array scaled_dot_product_attention( const array& keys, const array& values, const float scale, - const std::variant& mask /* = {}*/, - const std::optional memory_efficient_threshold, - StreamOrDevice s) { + const std::string& mask_mode /* = "" */, + const std::vector& mask_arrs /* = {} */, + StreamOrDevice s /* = {}*/) { for (const auto& tensor : {queries, keys, values}) { if (tensor.ndim() != 4) { std::ostringstream msg; @@ -578,29 +579,49 @@ array scaled_dot_product_attention( throw std::invalid_argument(msg.str()); } } + // Check valid mask + if (mask_mode != "" && mask_mode != "causal" && mask_mode != "array") { + std::ostringstream msg; + msg << "[scaled_dot_product_attention] Invalid mask_mode " << mask_mode + << ". mask_mode must be 'causal', 'array' or ''."; + throw std::invalid_argument(msg.str()); + } bool do_causal = false; - bool has_mask = !std::holds_alternative(mask); - bool has_str_mask = has_mask && std::holds_alternative(mask); - bool has_arr_mask = has_mask && std::holds_alternative(mask); + bool has_mask = false; + bool has_arr_mask = false; bool has_bool_mask = false; - if (has_str_mask) { - if (std::get(mask) != "causal") { + if (mask_mode == "causal") { + has_mask = true; + do_causal = true; + + if (!mask_arrs.empty()) { std::ostringstream msg; - msg << "[scaled_dot_product_attention] invalid mask option '" - << std::get(mask) << "'. Must be 'causal', or an array."; + msg << "[scaled_dot_product_attention] Invalid mask_arrs for mask_mode " + << "'casusal'. No array masks supported."; throw std::invalid_argument(msg.str()); - } else { - do_causal = true; } } - if (has_arr_mask && (std::get(mask)).ndim() > 4) { + if (mask_mode == "array" || (mask_mode == "" && !mask_arrs.empty())) { + if (mask_arrs.size() != 1) { + std::ostringstream msg; + msg << "[scaled_dot_product_attention] Invalid mask_arrs for mask_mode " + << "'" << mask_mode << "'. Only 1 mask array is supported, got " + << mask_arrs.size() << "arrays."; + throw std::invalid_argument(msg.str()); + } + + has_mask = true; + has_arr_mask = true; + has_bool_mask = mask_arrs[0].dtype() == bool_; + } + + if (has_arr_mask && (mask_arrs[0]).ndim() > 4) { std::ostringstream msg; msg << "[scaled_dot_product_attention] the mask with shape " - << (std::get(mask)).shape() - << " expected to have at most rank 4"; + << mask_arrs[0].shape() << " expected to have at most rank 4."; throw std::invalid_argument(msg.str()); } @@ -654,13 +675,6 @@ array scaled_dot_product_attention( auto k = astype(keys, final_type, s); auto v = astype(values, final_type, s); - /* Generic implementation for use cases that Metal implementation does not - * support. */ - int threshold = 32; // TODO: Fix after dev - if (memory_efficient_threshold.has_value()) { - threshold = std::max(1, memory_efficient_threshold.value()); - } - auto fallback = [scale, final_type, n_q_heads, n_kv_heads, do_causal, s]( const std::vector& inputs) { auto q = multiply(array(scale, inputs[0].dtype()), inputs[0], s); @@ -720,22 +734,20 @@ array scaled_dot_product_attention( const bool sdpa_vector_supported_head_dim = query_head_dim == value_head_dim && - (query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128); + (query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128 || + query_head_dim == 256); const bool sdpa_full_supported_head_dim = query_head_dim == value_head_dim && (query_head_dim == 64 || query_head_dim == 80 || query_head_dim == 128); - const bool sdpa_vector_supported_mask = (!has_mask || has_bool_mask); const bool sdpa_full_supported_mask = !has_mask || has_arr_mask || (query_sequence_length <= key_sequence_length && do_causal); - const bool supports_sdpa_full = query_sequence_length >= threshold && - sdpa_full_supported_mask && sdpa_full_supported_head_dim && - stream.device == Device::gpu; + const bool supports_sdpa_full = sdpa_full_supported_mask && + sdpa_full_supported_head_dim && stream.device == Device::gpu; const bool supports_sdpa_vector = (query_sequence_length <= 8) && (query_sequence_length <= key_sequence_length) && - sdpa_vector_supported_mask && sdpa_vector_supported_head_dim && - stream.device == Device::gpu; + sdpa_vector_supported_head_dim && stream.device == Device::gpu; const bool implementation_supports_use_case = supports_sdpa_full || supports_sdpa_vector; @@ -743,20 +755,22 @@ array scaled_dot_product_attention( std::vector inputs = {q, k, v}; if (has_arr_mask) { // Check type - auto mask_arr = std::get(mask); + auto mask_arr = mask_arrs[0]; has_bool_mask = mask_arr.dtype() == bool_; if (promote_types(mask_arr.dtype(), final_type) != final_type) { std::ostringstream msg; msg << "[scaled_dot_product_attention] Mask type must promote to output type. " << final_type << "."; throw std::invalid_argument(msg.str()); + } else if (!has_bool_mask) { + mask_arr = astype(mask_arr, final_type, stream); } // Broadcast mask auto mask_shape = queries.shape(); mask_shape.back() = keys.shape(-2); inputs.push_back(broadcast_to(mask_arr, mask_shape, stream)); } - if (implementation_supports_use_case) { + if (!detail::in_grad_tracing() && implementation_supports_use_case) { auto out_shape = Shape{q.shape(0), q.shape(1), q.shape(2), v.shape(-1)}; return array( std::move(out_shape), diff --git a/mlx/fast.h b/mlx/fast.h index b9db6d462..7aebe3863 100644 --- a/mlx/fast.h +++ b/mlx/fast.h @@ -48,8 +48,8 @@ array scaled_dot_product_attention( const array& keys, const array& values, const float scale, - const std::variant& mask = {}, - const std::optional memory_efficient_threshold = std::nullopt, + const std::string& mask_mode = "", + const std::vector& mask_arrs = {}, StreamOrDevice s = {}); std::tuple affine_quantize( diff --git a/mlx/memory.h b/mlx/memory.h new file mode 100644 index 000000000..8a264734c --- /dev/null +++ b/mlx/memory.h @@ -0,0 +1,78 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include + +namespace mlx::core { + +/* Get the actively used memory in bytes. + * + * Note, this will not always match memory use reported by the system because + * it does not include cached memory buffers. + * */ +size_t get_active_memory(); + +/* Get the peak amount of used memory in bytes. + * + * The maximum memory used recorded from the beginning of the program + * execution or since the last call to reset_peak_memory. + * */ +size_t get_peak_memory(); + +/* Reset the peak memory to zero. + * */ +void reset_peak_memory(); + +/* Get the cache size in bytes. + * + * The cache includes memory not currently used that has not been returned + * to the system allocator. + * */ +size_t get_cache_memory(); + +/* Set the memory limit. + * The memory limit is a guideline for the maximum amount of memory to use + * during graph evaluation. If the memory limit is exceeded and there is no + * more RAM (including swap when available) allocations will result in an + * exception. + * + * When Metal is available the memory limit defaults to 1.5 times the maximum + * recommended working set size reported by the device. + * + * Returns the previous memory limit. + * */ +size_t set_memory_limit(size_t limit); + +/* Get the current memory limit. */ +size_t get_memory_limit(); + +/* Set the cache limit. + * If using more than the given limit, free memory will be reclaimed + * from the cache on the next allocation. To disable the cache, + * set the limit to 0. + * + * The cache limit defaults to the memory limit. + * + * Returns the previous cache limit. + * */ +size_t set_cache_limit(size_t limit); + +/* Clear the memory cache. */ +void clear_cache(); + +/* Set the wired size limit. + * + * Note, this function is only useful when using the Metal backend with + * macOS 15.0 or higher. + * + * The wired limit is the total size in bytes of memory that will be kept + * resident. The default value is ``0``. + * + * Setting a wired limit larger than system wired limit is an error. + * + * Returns the previous wired limit. + * */ +size_t set_wired_limit(size_t limit); + +} // namespace mlx::core diff --git a/mlx/mlx.h b/mlx/mlx.h index 0fc657ca4..cef8d806d 100644 --- a/mlx/mlx.h +++ b/mlx/mlx.h @@ -14,6 +14,7 @@ #include "mlx/fft.h" #include "mlx/io.h" #include "mlx/linalg.h" +#include "mlx/memory.h" #include "mlx/ops.h" #include "mlx/random.h" #include "mlx/stream.h" diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 4e147487d..54ac62fef 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -993,6 +993,9 @@ array concatenate( throw std::invalid_argument( "[concatenate] No arrays provided for concatenation"); } + if (arrays.size() == 1) { + return arrays[0]; + } auto ax = normalize_axis_index(axis, arrays[0].ndim(), "[concatenate] "); @@ -2356,6 +2359,29 @@ array logsumexp( const std::vector& axes, bool keepdims /* = false */, StreamOrDevice s /* = {}*/) { + if (a.size() == 0) { + throw std::invalid_argument("[logsumexp] Received empty array."); + } + if (a.ndim() == 0 && !axes.empty()) { + throw std::invalid_argument( + "[logsumexp] Received non-empty axes for array with 0 dimensions."); + } + bool is_complex = issubdtype(a.dtype(), complexfloating); + if (!is_complex && axes.size() == 1 && + (a.ndim() == axes[0] + 1 || axes[0] == -1)) { + auto dtype = at_least_float(a.dtype()); + auto out_shape = a.shape(); + out_shape.back() = 1; + auto out = array( + std::move(out_shape), + dtype, + std::make_shared(to_stream(s)), + {astype(a, dtype, s)}); + if (!keepdims) { + out = squeeze(out, -1, s); + } + return out; + } auto maxval = stop_gradient(max(a, axes, true, s), s); auto out = log(sum(exp(subtract(a, maxval, s), s), axes, keepdims, s), s); out = add(out, reshape(maxval, out.shape(), s), s); @@ -2823,6 +2849,19 @@ array matmul( } // Type promotion auto out_type = promote_types(a.dtype(), b.dtype()); + // Complex matmul in terms of real matmuls + if (out_type == complex64) { + auto a_real = real(a, s); + auto b_real = real(b, s); + auto a_imag = imag(a, s); + auto b_imag = imag(b, s); + auto c_real = + subtract(matmul(a_real, b_real, s), matmul(a_imag, b_imag, s), s); + auto c_imag = add(matmul(a_real, b_imag, s), matmul(a_imag, b_real, s), s); + return add( + c_real, multiply(array(complex64_t{0, 1}, complex64), c_imag, s), s); + } + if (!issubdtype(out_type, floating)) { std::ostringstream msg; msg << "[matmul] Only real floating point types are supported but " @@ -3331,8 +3370,14 @@ array softmax( if (a.size() == 0) { return a; } + if (a.ndim() == 0 && !axes.empty()) { + throw std::invalid_argument( + "[softmax] Received non-empty axes for array with 0 dimensions."); + } - if (axes.size() == 1 && (a.ndim() == axes[0] + 1 || axes[0] == -1)) { + bool is_complex = issubdtype(a.dtype(), complexfloating); + if (!is_complex && axes.size() == 1 && + (a.ndim() == axes[0] + 1 || axes[0] == -1)) { auto dtype = at_least_float(a.dtype()); return array( a.shape(), @@ -3341,7 +3386,7 @@ array softmax( {astype(a, dtype, s)}); } else { auto in = a; - if (precise) { + if (precise && !is_complex) { in = astype(a, float32, s); } auto a_max = stop_gradient(max(in, axes, /*keepdims = */ true, s), s); @@ -3459,6 +3504,28 @@ array cummin( {a}); } +array logcumsumexp( + const array& a, + int axis, + bool reverse /* = false*/, + bool inclusive /* = true*/, + StreamOrDevice s /* = {}*/) { + int ndim = a.ndim(); + if (axis >= ndim || axis < -ndim) { + std::ostringstream msg; + msg << "[logcumsumexp] Axis " << axis << " is out of bounds for array with " + << a.ndim() << " dimensions."; + throw std::invalid_argument(msg.str()); + } + axis = (axis + a.ndim()) % a.ndim(); + return array( + a.shape(), + a.dtype(), + std::make_shared( + to_stream(s), Scan::ReduceType::LogAddExp, axis, reverse, inclusive), + {a}); +} + /** Convolution operations */ namespace { @@ -3961,6 +4028,7 @@ array gather_qmm( bool transpose /* = true */, int group_size /* = 64 */, int bits /* = 4 */, + bool sorted_indices /* = false */, StreamOrDevice s /* = {} */) { if (!lhs_indices_ && !rhs_indices_) { return quantized_matmul( @@ -4000,13 +4068,19 @@ array gather_qmm( return array( std::move(out_shape), out_type, - std::make_shared(to_stream(s), group_size, bits, transpose), + std::make_shared( + to_stream(s), + group_size, + bits, + transpose, + sorted_indices && !rhs_indices_, + sorted_indices && !lhs_indices_), {astype(x, out_type, s), - w, + std::move(w), astype(scales, out_type, s), astype(biases, out_type, s), - lhs_indices, - rhs_indices}); + std::move(lhs_indices), + std::move(rhs_indices)}); } array tensordot( @@ -4147,6 +4221,14 @@ array addmm( // Type promotion auto out_type = result_type(a, b, c); + + if (out_type == complex64) { + return add( + multiply(matmul(a, b, s), array(alpha), s), + multiply(array(beta), c, s), + s); + } + if (!issubdtype(out_type, floating)) { std::ostringstream msg; msg << "[addmm] Only real floating point types are supported but " @@ -4424,6 +4506,7 @@ array gather_mm( array b, std::optional lhs_indices_ /* = std::nullopt */, std::optional rhs_indices_ /* = std::nullopt */, + bool sorted_indices /* = false */, StreamOrDevice s /* = {} */) { // If no indices, fall back to full matmul if (!lhs_indices_ && !rhs_indices_) { @@ -4499,12 +4582,18 @@ array gather_mm( out_shape.push_back(M); out_shape.push_back(N); - // Caculate array + // Make the output array auto out = array( std::move(out_shape), out_type, - std::make_shared(to_stream(s)), - {a, b, lhs_indices, rhs_indices}); + std::make_shared( + to_stream(s), + sorted_indices && !rhs_indices_, + sorted_indices && !lhs_indices_), + {std::move(a), + std::move(b), + std::move(lhs_indices), + std::move(rhs_indices)}); // Remove the possibly inserted singleton dimensions std::vector axes; @@ -4826,8 +4915,10 @@ array operator^(const array& a, const array& b) { } array left_shift(const array& a, const array& b, StreamOrDevice s /* = {} */) { - // Bit shift on bool always up-casts to uint8 - auto t = promote_types(result_type(a, b), uint8); + auto t = result_type(a, b); + if (t == bool_) { + t = uint8; + } return bitwise_impl( astype(a, t, s), astype(b, t, s), @@ -4840,8 +4931,10 @@ array operator<<(const array& a, const array& b) { } array right_shift(const array& a, const array& b, StreamOrDevice s /* = {} */) { - // Bit shift on bool always up-casts to uint8 - auto t = promote_types(result_type(a, b), uint8); + auto t = result_type(a, b); + if (t == bool_) { + t = uint8; + } return bitwise_impl( astype(a, t, s), astype(b, t, s), diff --git a/mlx/ops.h b/mlx/ops.h index 02428b974..e79ea235d 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -715,6 +715,14 @@ array topk(const array& a, int k, StreamOrDevice s = {}); /** Returns topk elements of the array along a given axis. */ array topk(const array& a, int k, int axis, StreamOrDevice s = {}); +/** Cumulative logsumexp of an array. */ +array logcumsumexp( + const array& a, + int axis, + bool reverse = false, + bool inclusive = true, + StreamOrDevice s = {}); + /** The logsumexp of all elements of the array. */ array logsumexp(const array& a, bool keepdims, StreamOrDevice s = {}); inline array logsumexp(const array& a, StreamOrDevice s = {}) { @@ -1344,6 +1352,7 @@ array gather_qmm( bool transpose = true, int group_size = 64, int bits = 4, + bool sorted_indices = false, StreamOrDevice s = {}); /** Returns a contraction of a and b over multiple dimensions. */ @@ -1391,6 +1400,7 @@ array gather_mm( array b, std::optional lhs_indices = std::nullopt, std::optional rhs_indices = std::nullopt, + bool sorted_indices = false, StreamOrDevice s = {}); /** Extract a diagonal or construct a diagonal array */ diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index b5e5ec82e..590af60f6 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -2509,6 +2509,49 @@ std::pair, std::vector> LogAddExp::vmap( return {{logaddexp(a, b, stream())}, {to_ax}}; } +std::pair, std::vector> LogSumExp::vmap( + const std::vector& inputs, + const std::vector& axes) { + auto ax = axes[0]; + auto in = inputs[0]; + if (ax == (in.ndim() - 1)) { + in = swapaxes(in, -1, -2, stream()); + ax = in.ndim() - 2; + } + return {{logsumexp(in, -1, true, stream())}, {ax}}; +} + +std::vector LogSumExp::vjp( + const std::vector& primals, + const std::vector& cotangents, + const std::vector& argnums, + const std::vector&) { + assert(primals.size() == 1); + assert(cotangents.size() == 1); + return {multiply( + cotangents[0], + softmax(primals[0], std::vector{-1}, true, stream()), + stream())}; +} + +std::vector LogSumExp::jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) { + assert(primals.size() == 1); + assert(tangents.size() == 1); + return {multiply( + tangents[0], + softmax(primals[0], std::vector{-1}, true, stream()), + stream())}; +} + +std::vector LogSumExp::output_shapes(const std::vector& inputs) { + auto s = inputs[0].shape(); + s.back() = 1; + return {s}; +} + std::vector Matmul::vjp( const std::vector& primals, const std::vector& cotangents, @@ -3037,6 +3080,8 @@ std::vector GatherQMM::vjp( auto& lhs_indices = primals[4]; auto& rhs_indices = primals[5]; + bool sorted = left_sorted_ || right_sorted_; + for (auto arg : argnums) { // gradient wrt to x if (arg == 0) { @@ -3055,6 +3100,7 @@ std::vector GatherQMM::vjp( !transpose_, group_size_, bits_, + sorted, stream()), -3, stream()), @@ -3435,6 +3481,45 @@ std::vector Scan::vjp( if (reduce_type_ == Scan::Sum) { return {cumsum(cotangents[0], axis_, !reverse_, inclusive_, stream())}; + } else if (reduce_type_ == Scan::LogAddExp) { + // Ref: + // https://github.com/tensorflow/tensorflow/blob/2a5910906a0e0f3dbc186ff9db6386d81a63448c/tensorflow/python/ops/math_grad.py#L1832-L1863 + + auto x = primals[0]; + auto grad = cotangents[0]; + auto results = outputs[0]; + + auto zero = zeros({1}, grad.dtype(), stream()); + auto grad_min = array(finfo(grad.dtype()).min, grad.dtype()); + + // Split the incoming gradient into positive and negative part + // in order to take logs. This is required for stable results. + auto log_abs_grad = log(abs(grad, stream()), stream()); + auto log_grad_positive = + where(greater(grad, zero, stream()), log_abs_grad, grad_min, stream()); + auto log_grad_negative = + where(less(grad, zero, stream()), log_abs_grad, grad_min, stream()); + + auto output_pos = exp( + add(logcumsumexp( + subtract(log_grad_positive, results, stream()), + axis_, + !reverse_, + inclusive_, + stream()), + x, + stream())); + auto output_neg = exp( + add(logcumsumexp( + subtract(log_grad_negative, results, stream()), + axis_, + !reverse_, + inclusive_, + stream()), + x, + stream())); + + return {subtract(output_pos, output_neg, stream())}; } else if (reduce_type_ == Scan::Prod) { auto in = primals[0]; // Find the location of the first 0 and set it to 1: @@ -4813,6 +4898,8 @@ std::vector GatherMM::vjp( int N = cotan.shape(-1); int K = primals[0].shape(-1); + bool sorted = left_sorted_ || right_sorted_; + for (auto arg : argnums) { if (arg == 0) { // M X N * (K X N).T -> M X K @@ -4823,7 +4910,8 @@ std::vector GatherMM::vjp( base = reshape(base, {-1, M, K}, stream()); // g : (out_batch_shape) + (M, K) - auto g = gather_mm(cotan, bt, std::nullopt, rhs_indices, stream()); + auto g = + gather_mm(cotan, bt, std::nullopt, rhs_indices, sorted, stream()); g = expand_dims(g, -3, stream()); auto gacc = scatter_add(base, lhs_indices, g, 0, stream()); @@ -4838,7 +4926,8 @@ std::vector GatherMM::vjp( base = reshape(base, {-1, K, N}, stream()); // g : (out_batch_shape) + (K, N) - auto g = gather_mm(at, cotan, lhs_indices, std::nullopt, stream()); + auto g = + gather_mm(at, cotan, lhs_indices, std::nullopt, sorted, stream()); g = expand_dims(g, -3, stream()); auto gacc = scatter_add(base, rhs_indices, g, 0, stream()); @@ -4851,6 +4940,12 @@ std::vector GatherMM::vjp( return vjps; } +bool GatherMM::is_equivalent(const Primitive& other) const { + const GatherMM& g_other = static_cast(other); + return left_sorted_ == g_other.left_sorted_ && + right_sorted_ == g_other.right_sorted_; +} + bool BlockMaskedMM::is_equivalent(const Primitive& other) const { const BlockMaskedMM& a_other = static_cast(other); return (block_size_ == a_other.block_size_); diff --git a/mlx/primitives.h b/mlx/primitives.h index bb0ca8080..997931f30 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -498,7 +498,13 @@ class BlockMaskedMM : public UnaryPrimitive { class GatherMM : public UnaryPrimitive { public: - explicit GatherMM(Stream stream) : UnaryPrimitive(stream) {} + explicit GatherMM( + Stream stream, + bool left_sorted = false, + bool right_sorted = false) + : UnaryPrimitive(stream), + left_sorted_(left_sorted), + right_sorted_(right_sorted) {} void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; @@ -510,7 +516,14 @@ class GatherMM : public UnaryPrimitive { const std::vector& outputs) override; DEFINE_PRINT(GatherMM) - DEFINE_DEFAULT_IS_EQUIVALENT() + bool is_equivalent(const Primitive& other) const override; + auto state() const { + return std::make_pair(left_sorted_, right_sorted_); + } + + private: + bool left_sorted_; + bool right_sorted_; }; class BroadcastAxes : public UnaryPrimitive { @@ -1350,6 +1363,20 @@ class LogAddExp : public UnaryPrimitive { DEFINE_INPUT_OUTPUT_SHAPE() }; +class LogSumExp : public UnaryPrimitive { + public: + explicit LogSumExp(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_PRINT(LogSumExp) + DEFINE_DEFAULT_IS_EQUIVALENT() + std::vector output_shapes(const std::vector& inputs) override; +}; + class Matmul : public UnaryPrimitive { public: explicit Matmul(Stream stream) : UnaryPrimitive(stream) {} @@ -1564,11 +1591,19 @@ class QuantizedMatmul : public UnaryPrimitive { class GatherQMM : public UnaryPrimitive { public: - explicit GatherQMM(Stream stream, int group_size, int bits, bool transpose) + explicit GatherQMM( + Stream stream, + int group_size, + int bits, + bool transpose, + bool left_sorted = false, + bool right_sorted = false) : UnaryPrimitive(stream), group_size_(group_size), bits_(bits), - transpose_(transpose) {} + transpose_(transpose), + left_sorted_(left_sorted), + right_sorted_(right_sorted) {} void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; @@ -1578,13 +1613,16 @@ class GatherQMM : public UnaryPrimitive { DEFINE_PRINT(GatherQMM) bool is_equivalent(const Primitive& other) const override; auto state() const { - return std::make_tuple(group_size_, bits_, transpose_); + return std::make_tuple( + group_size_, bits_, transpose_, left_sorted_, right_sorted_); } private: int group_size_; int bits_; bool transpose_; + bool left_sorted_; + bool right_sorted_; }; class RandomBits : public UnaryPrimitive { @@ -1714,7 +1752,7 @@ class Round : public UnaryPrimitive { class Scan : public UnaryPrimitive { public: - enum ReduceType { Max, Min, Sum, Prod }; + enum ReduceType { Max, Min, Sum, Prod, LogAddExp }; explicit Scan( Stream stream, @@ -1749,6 +1787,9 @@ class Scan : public UnaryPrimitive { case Max: os << "Max"; break; + case LogAddExp: + os << "Logaddexp"; + break; } } bool is_equivalent(const Primitive& other) const override; diff --git a/mlx/scheduler.cpp b/mlx/scheduler.cpp index e3612f351..7bd128c10 100644 --- a/mlx/scheduler.cpp +++ b/mlx/scheduler.cpp @@ -56,8 +56,16 @@ namespace scheduler { /** A singleton scheduler to manage devices, streams, and task execution. */ Scheduler& scheduler() { + // Leak the scheduler on Windows to avoid joining threads on exit, can be + // removed after Visual Studio fixes bug: + // https://developercommunity.visualstudio.com/t/1654756 +#ifdef _WIN32 + static Scheduler* scheduler = new Scheduler; + return *scheduler; +#else static Scheduler scheduler; return scheduler; +#endif } } // namespace scheduler diff --git a/mlx/scheduler.h b/mlx/scheduler.h index bf34b38c0..b2c6b842b 100644 --- a/mlx/scheduler.h +++ b/mlx/scheduler.h @@ -129,7 +129,7 @@ class Scheduler { int n_tasks_old = n_active_tasks(); if (n_tasks_old > 1) { completion_cv.wait(lk, [this, n_tasks_old] { - return this->n_active_tasks() != n_tasks_old; + return this->n_active_tasks() < n_tasks_old; }); } } diff --git a/mlx/transforms.cpp b/mlx/transforms.cpp index 105a0fa28..b305257f0 100644 --- a/mlx/transforms.cpp +++ b/mlx/transforms.cpp @@ -12,6 +12,7 @@ #include "mlx/backend/cpu/eval.h" #include "mlx/backend/metal/metal_impl.h" #include "mlx/fence.h" +#include "mlx/memory.h" #include "mlx/ops.h" #include "mlx/primitives.h" #include "mlx/scheduler.h" @@ -41,7 +42,8 @@ class Synchronizer : public Primitive { // are currently under a function transformation and the retain_graph() // function which returns true if we are forced to retain the graph during // evaluation. -std::vector detail::InTracing::trace_stack{}; +std::vector> detail::InTracing::trace_stack{}; +int detail::InTracing::grad_counter{0}; int detail::RetainGraph::tracing_counter{0}; array eval_impl(std::vector outputs, bool async) { @@ -219,7 +221,7 @@ array eval_impl(std::vector outputs, bool async) { } if (scheduler::n_active_tasks() > MAX_ACTIVE_TASKS || - (metal::get_active_memory() > metal::get_memory_limit() && + (get_active_memory() > get_memory_limit() && scheduler::n_active_tasks() > 0)) { // Commit any open streams for (auto& [_, e] : events) { @@ -228,8 +230,7 @@ array eval_impl(std::vector outputs, bool async) { } } scheduler::wait_for_one(); - // TODO memory api should be moved out of metal - while (metal::get_active_memory() > metal::get_memory_limit() && + while (get_active_memory() > get_memory_limit() && scheduler::n_active_tasks() > 0) { scheduler::wait_for_one(); } @@ -307,7 +308,7 @@ std::pair, std::vector> vjp( const std::vector& cotans, const std::vector& argnums) { // Set the global tracing flag. - detail::InTracing in_tracing; + detail::InTracing in_tracing{false, true}; // Make tracers from given primals std::vector primals_; @@ -505,7 +506,7 @@ std::pair, std::vector> jvp( const std::vector& primals, const std::vector& tangents) { // Set the global tracing flag. - detail::InTracing in_tracing; + detail::InTracing in_tracing{false, true}; if (primals.size() != tangents.size()) { throw std::invalid_argument( diff --git a/mlx/transforms_impl.h b/mlx/transforms_impl.h index 3aa84bde4..7f62c406b 100644 --- a/mlx/transforms_impl.h +++ b/mlx/transforms_impl.h @@ -20,10 +20,12 @@ std::vector vmap_replace( // of the codebase that we are during tracing so evals should not throw away // the graph. struct InTracing { - explicit InTracing(bool dynamic = false) { - trace_stack.push_back(dynamic); + explicit InTracing(bool dynamic = false, bool grad = false) { + grad_counter += grad; + trace_stack.push_back({dynamic, grad}); } ~InTracing() { + grad_counter -= trace_stack.back().second; trace_stack.pop_back(); } @@ -32,11 +34,16 @@ struct InTracing { } static bool in_dynamic_tracing() { // compile is always and only the outer-most transform - return in_tracing() && trace_stack.front(); + return in_tracing() && trace_stack.front().first; + } + + static bool in_grad_tracing() { + return grad_counter > 0; } private: - static std::vector trace_stack; + static int grad_counter; + static std::vector> trace_stack; }; struct RetainGraph { @@ -67,6 +74,11 @@ inline bool in_dynamic_tracing() { return detail::InTracing::in_dynamic_tracing(); } +/** Return true if we are in a gradient trace (vjp, jvp, etc). */ +inline bool in_grad_tracing() { + return detail::InTracing::in_grad_tracing(); +} + inline bool retain_graph() { return detail::RetainGraph::retain_graph(); } diff --git a/mlx/utils.cpp b/mlx/utils.cpp index 9168b34c8..188584174 100644 --- a/mlx/utils.cpp +++ b/mlx/utils.cpp @@ -5,6 +5,7 @@ #include #include +#include "mlx/dtype_utils.h" #include "mlx/types/limits.h" #include "mlx/utils.h" @@ -224,37 +225,7 @@ void print_array(std::ostream& os, const array& a) { } // namespace std::ostream& operator<<(std::ostream& os, const Dtype& dtype) { - switch (dtype) { - case bool_: - return os << "bool"; - case uint8: - return os << "uint8"; - case uint16: - return os << "uint16"; - case uint32: - return os << "uint32"; - case uint64: - return os << "uint64"; - case int8: - return os << "int8"; - case int16: - return os << "int16"; - case int32: - return os << "int32"; - case int64: - return os << "int64"; - case float16: - return os << "float16"; - case float32: - return os << "float32"; - case float64: - return os << "float64"; - case bfloat16: - return os << "bfloat16"; - case complex64: - return os << "complex64"; - } - return os; + return os << dtype_to_string(dtype); } std::ostream& operator<<(std::ostream& os, const Dtype::Kind& k) { @@ -277,50 +248,7 @@ std::ostream& operator<<(std::ostream& os, const Dtype::Kind& k) { std::ostream& operator<<(std::ostream& os, array a) { a.eval(); - switch (a.dtype()) { - case bool_: - print_array(os, a); - break; - case uint8: - print_array(os, a); - break; - case uint16: - print_array(os, a); - break; - case uint32: - print_array(os, a); - break; - case uint64: - print_array(os, a); - break; - case int8: - print_array(os, a); - break; - case int16: - print_array(os, a); - break; - case int32: - print_array(os, a); - break; - case int64: - print_array(os, a); - break; - case float16: - print_array(os, a); - break; - case bfloat16: - print_array(os, a); - break; - case float32: - print_array(os, a); - break; - case float64: - print_array(os, a); - break; - case complex64: - print_array(os, a); - break; - } + MLX_SWITCH_ALL_TYPES(a.dtype(), CTYPE, print_array(os, a)); return os; } @@ -380,4 +308,15 @@ finfo::finfo(Dtype dtype) : dtype(dtype) { } } +template +void set_iinfo_limits(int64_t& min, uint64_t& max) { + min = std::numeric_limits::min(); + max = std::numeric_limits::max(); +} + +iinfo::iinfo(Dtype dtype) : dtype(dtype) { + MLX_SWITCH_INT_TYPES_CHECKED( + dtype, "[iinfo]", CTYPE, set_iinfo_limits(min, max)); +} + } // namespace mlx::core diff --git a/mlx/utils.h b/mlx/utils.h index 0b5ce54a1..19241e4c6 100644 --- a/mlx/utils.h +++ b/mlx/utils.h @@ -67,6 +67,14 @@ struct finfo { double max; }; +/** Holds information about integral types. */ +struct iinfo { + explicit iinfo(Dtype dtype); + Dtype dtype; + int64_t min; + uint64_t max; +}; + /** The type from promoting the arrays' types with one another. */ inline Dtype result_type(const array& a, const array& b) { return promote_types(a.dtype(), b.dtype()); diff --git a/mlx/version.h b/mlx/version.h index 275c74c73..fe47d96cc 100644 --- a/mlx/version.h +++ b/mlx/version.h @@ -3,7 +3,7 @@ #pragma once #define MLX_VERSION_MAJOR 0 -#define MLX_VERSION_MINOR 24 +#define MLX_VERSION_MINOR 25 #define MLX_VERSION_PATCH 0 #define MLX_VERSION_NUMERIC \ (100000 * MLX_VERSION_MAJOR + 1000 * MLX_VERSION_MINOR + MLX_VERSION_PATCH) diff --git a/python/mlx/distributed_run.py b/python/mlx/distributed_run.py index 5d6bc4383..9c946005b 100644 --- a/python/mlx/distributed_run.py +++ b/python/mlx/distributed_run.py @@ -172,7 +172,7 @@ def parse_hostfile(parser, hostfile): for i, h in enumerate(json.load(f)): hosts.append(Host(i, h["ssh"], h.get("ips", []))) return hosts - except e: + except Exception as e: parser.error(f"Failed to parse hostfile {str(hostfile)} ({str(e)})") @@ -761,6 +761,8 @@ def main(): "--cwd", help="Set the working directory on each node to the provided one" ) args, rest = parser.parse_known_args() + if rest[0] == "--": + rest.pop(0) if args.print_python: print(sys.executable) diff --git a/python/mlx/extension.py b/python/mlx/extension.py index ecf3c52e6..8c0d60655 100644 --- a/python/mlx/extension.py +++ b/python/mlx/extension.py @@ -30,10 +30,6 @@ class CMakeBuild(build_ext): debug = int(os.environ.get("DEBUG", 0)) if self.debug is None else self.debug cfg = "Debug" if debug else "Release" - # CMake lets you override the generator - we need to check this. - # Can be set with Conda-Build, for example. - cmake_generator = os.environ.get("CMAKE_GENERATOR", "") - # Set Python_EXECUTABLE instead if you use PYBIND11_FINDPYTHON # EXAMPLE_VERSION_INFO shows you how to pass a value into the C++ code # from Python. diff --git a/python/mlx/nn/layers/__init__.py b/python/mlx/nn/layers/__init__.py index c1d89fed9..26f77917f 100644 --- a/python/mlx/nn/layers/__init__.py +++ b/python/mlx/nn/layers/__init__.py @@ -60,6 +60,12 @@ from mlx.nn.layers.convolution_transpose import ( ConvTranspose2d, ConvTranspose3d, ) +from mlx.nn.layers.distributed import ( + AllToShardedLinear, + QuantizedAllToShardedLinear, + QuantizedShardedToAllLinear, + ShardedToAllLinear, +) from mlx.nn.layers.dropout import Dropout, Dropout2d, Dropout3d from mlx.nn.layers.embedding import Embedding from mlx.nn.layers.linear import Bilinear, Identity, Linear diff --git a/python/mlx/nn/layers/distributed.py b/python/mlx/nn/layers/distributed.py new file mode 100644 index 000000000..92acde8f6 --- /dev/null +++ b/python/mlx/nn/layers/distributed.py @@ -0,0 +1,599 @@ +# Copyright © 2024 Apple Inc. + +import math +from functools import lru_cache +from typing import Callable, Optional, Union + +import mlx.core as mx +from mlx.nn.layers.base import Module +from mlx.nn.layers.linear import Linear +from mlx.nn.layers.quantized import QuantizedLinear +from mlx.utils import tree_map_with_path + + +@lru_cache +def sum_gradients(group): + if group.size() == 1: + return lambda x: x + + @mx.custom_function + def f(x): + return x + + @f.vjp + def f(x, dx, _): + return mx.distributed.all_sum(dx, group=group) + + return f + + +def _split(weight, segments, axis): + """Equivalent to mx.split but allows for fractional segments.""" + if isinstance(segments, int) or isinstance(segments[0], int): + return mx.split(weight, segments, axis=axis) + + N = weight.shape[axis] + indices = [int(s * N) for s in segments] + return mx.split(weight, indices, axis=axis) + + +def _shard( + parameters: dict, + sharding_predicate: Callable, + group: Optional[mx.distributed.Group] = None, +): + """Returns a new parameter tree with the weights sharded according to the + sharding_predicate. + + The sharding predicate should return the sharding axis and optionally also + the segments that comprise the weight. + """ + group = group or mx.distributed.init() + N = group.size() + r = group.rank() + + def _shard_fn(path, weight): + if not isinstance(weight, mx.array): + return weight + + s = sharding_predicate(path, weight) + if s is None: + return weight + + axis = None + segments = 1 + if isinstance(s, int): + axis = s + elif isinstance(s, tuple): + axis, segments = s + else: + raise ValueError( + "The sharding function should return int or tuple[int, list]" + ) + + return mx.contiguous( + mx.concatenate( + [_split(part, N, axis)[r] for part in _split(weight, segments, axis)], + axis=axis, + ) + ) + + return tree_map_with_path(_shard_fn, parameters) + + +def _all_to_sharded(segments): + """Simple predicate to shard fully connected layers such that a common + representation becomes a sharded representation.""" + + def _shard_fn(path, weight): + return max(weight.ndim - 2, 0), segments + + return _shard_fn + + +def _sharded_to_all(segments): + """Simple predicate to shard fully connected layers such that a sharded + representation becomes a common representation.""" + + def _shard_fn(path, weight): + if path.endswith("bias"): + return None + return -1, segments + + return _shard_fn + + +def _check_sharding(sharding): + if sharding not in ("all-to-sharded", "sharded-to-all"): + raise ValueError( + ( + f"Sharding type {sharding=} not supported, " + "choose one of 'all-to-sharded' or 'sharded-to-all'" + ) + ) + + +def shard_inplace( + module: Module, + sharding: Union[str, Callable], + *, + segments: Union[int, list] = 1, + group: Optional[mx.distributed.Group] = None, +): + """Shard a module in-place by updating its parameter dictionary with the + sharded parameter dictionary. + + The ``sharding`` argument can be any callable that given the path and the + weight returns the sharding axis and optionally also the segments that + comprise the unsharded weight. For instance if the weight is a fused QKV + matrix the segments should be 3. + + .. note:: + The module doesn't change so in order for distributed communication to + happen the module needs to natively support it and for it to be enabled. + + Args: + module (mlx.nn.Module): The parameters of this module will be sharded + in-place. + sharding (str or callable): One of "all-to-sharded" and + "sharded-to-all" or a callable that returns the sharding axis and + segments. + segments (int or list): The segments to use if ``sharding`` is a + string. Default: ``1``. + group (mlx.core.distributed.Group): The distributed group to shard + across. If not set, the global group will be used. Default: ``None``. + """ + if isinstance(sharding, str): + _check_sharding(sharding) + sharding = ( + _all_to_sharded(segments) + if sharding == "all-to-sharded" + else _sharded_to_all(segments) + ) + module.update(_shard(module.parameters(), sharding, group)) + + +def shard_linear( + module: Module, + sharding: str, + *, + segments: Union[int, list] = 1, + group: Optional[mx.distributed.Group] = None, +): + """Create a new linear layer that has its parameters sharded and also + performs distributed communication either in the forward or backward + pass. + + .. note:: + Contrary to ``shard_inplace``, the original layer is not changed but a + new layer is returned. + + Args: + module (mlx.nn.Module): The linear layer to be sharded. + sharding (str): One of "all-to-sharded" and + "sharded-to-all" that defines the type of sharding to perform. + segments (int or list): The segments to use. Default: ``1``. + group (mlx.core.distributed.Group): The distributed group to shard + across. If not set, the global group will be used. Default: ``None``. + """ + _check_sharding(sharding) + fns = { + ("all-to-sharded", True): AllToShardedLinear.from_linear, + ("all-to-sharded", False): QuantizedAllToShardedLinear.from_quantized_linear, + ("sharded-to-all", True): ShardedToAllLinear.from_linear, + ("sharded-to-all", False): QuantizedShardedToAllLinear.from_quantized_linear, + } + return fns[sharding, isinstance(module, Linear)]( + module, segments=segments, group=group + ) + + +class AllToShardedLinear(Module): + """Each member of the group applies part of the affine transformation such + that the result is sharded across the group. + + The gradients are automatically aggregated from each member of the group. + + Args: + input_dims (int): The dimensionality of the input features + output_dims (int): The dimensionality of the output features + bias (bool, optional): If set to ``False`` the the layer will not use a + bias. Default is ``True``. + group (mx.distributed.Group, optional): The sharding will happen across + this group. If not set then the global group is used. Default is + ``None``. + """ + + def __init__( + self, + input_dims: int, + output_dims: int, + bias: bool = True, + group: Optional[mx.distributed.Group] = None, + ): + super().__init__() + + # Initialize the parameters + scale = math.sqrt(1.0 / input_dims) + self.group = group or mx.distributed.init() + N = self.group.size() + + if (output_dims % N) != 0: + raise ValueError( + f"Cannot shard the output of size {output_dims} across {N} devices." + ) + + self.weight = mx.random.uniform( + low=-scale, + high=scale, + shape=(output_dims // N, input_dims), + ) + if bias: + self.bias = mx.random.uniform( + low=-scale, + high=scale, + shape=(output_dims // N,), + ) + + def _extra_repr(self) -> str: + out_dims, in_dims = self.weight.shape + N = self.group.size() + out_dims *= N + return f"input_dims={in_dims}, output_dims={out_dims}, bias={'bias' in self}" + + def __call__(self, x: mx.array) -> mx.array: + # Aggregate the gradients coming from each shard + x = sum_gradients(self.group)(x) + + # Compute the affine projection + if "bias" in self: + x = mx.addmm(self["bias"], x, self["weight"].T) + else: + x = x @ self["weight"].T + return x + + @classmethod + def from_linear( + cls, + linear_layer: Module, + *, + segments: Union[int, list] = 1, + group: Optional[mx.distributed.Group] = None, + ): + group = group or mx.distributed.init() + output_dims, input_dims = linear_layer.weight.shape + + sl = cls(input_dims, output_dims, hasattr(linear_layer, "bias"), group) + sl.update(_shard(linear_layer.parameters(), _all_to_sharded(segments), group)) + + return sl + + +class ShardedToAllLinear(Module): + """Each member of the group applies part of the affine transformation and + then aggregates the results. + + All nodes will have the same exact result after this layer. + + :class:`ShardedToAllLinear` provides a classmethod :meth:`from_linear` to + convert linear layers to sharded :obj:`ShardedToAllLinear` layers. + + Args: + input_dims (int): The dimensionality of the input features + output_dims (int): The dimensionality of the output features + bias (bool, optional): If set to ``False`` the the layer will not use a + bias. Default is ``True``. + group (mx.distributed.Group, optional): The sharding will happen across + this group. If not set then the global group is used. Default is + ``None``. + """ + + def __init__( + self, + input_dims: int, + output_dims: int, + bias: bool = True, + group: Optional[mx.distributed.Group] = None, + ): + super().__init__() + + # Initialize the parameters + scale = math.sqrt(1.0 / input_dims) + self.group = group or mx.distributed.init() + N = self.group.size() + + if (input_dims % N) != 0: + raise ValueError( + f"The input of size {input_dims} cannot be sharded across {N} devices." + ) + + self.weight = mx.random.uniform( + low=-scale, + high=scale, + shape=(output_dims, input_dims // N), + ) + if bias: + self.bias = mx.random.uniform( + low=-scale, + high=scale, + shape=(output_dims,), + ) + + def _extra_repr(self) -> str: + N = self.group.size() + out_dims, in_dims = self.weight.shape + in_dims *= N + return f"input_dims={in_dims}, output_dims={out_dims}, bias={'bias' in self}" + + def __call__(self, x: mx.array) -> mx.array: + x = x @ self["weight"].T + + x = mx.distributed.all_sum(x, group=self.group) + + if "bias" in self: + x = x + self["bias"] + + return x + + @classmethod + def from_linear( + cls, + linear_layer: Module, + *, + segments: Union[int, list] = 1, + group: Optional[mx.distributed.Group] = None, + ): + group = group or mx.distributed.init() + output_dims, input_dims = linear_layer.weight.shape + + sl = cls(input_dims, output_dims, hasattr(linear_layer, "bias"), group) + sl.update(_shard(linear_layer.parameters(), _sharded_to_all(segments), group)) + + return sl + + +class QuantizedAllToShardedLinear(Module): + """Each member of the group applies part of the affine transformation with + a quantized matrix such that the result is sharded across the group. + + It is the quantized equivalent of :class:`mlx.nn.AllToShardedLinear`. + Similar to :class:`mlx.nn.QuantizedLinear` its parameters are frozen and + will not be included in any gradient computation. + + Args: + input_dims (int): The dimensionality of the input features. + output_dims (int): The dimensionality of the output features. + bias (bool, optional): If set to ``False`` then the layer will not use + a bias. Default: ``True``. + group_size (int, optional): The group size to use for the quantized + weight. See :func:`~mlx.core.quantize`. Default: ``64``. + bits (int, optional): The bit width to use for the quantized weight. + See :func:`~mlx.core.quantize`. Default: ``4``. + group (mx.distributed.Group, optional): The sharding will happen across + this group. If not set then the global group is used. Default is + ``None``. + """ + + def __init__( + self, + input_dims: int, + output_dims: int, + bias: bool = True, + group_size: int = 64, + bits: int = 4, + group: Optional[mx.distributed.Group] = None, + ): + super().__init__() + + # Quantization config + self.group_size = group_size + self.bits = bits + + # Initialize the quantized weight + scale = math.sqrt(1.0 / input_dims) + self.group = group or mx.distributed.init() + N = self.group.size() + + if (output_dims % N) != 0: + raise ValueError( + f"Cannot shard the output of size {output_dims} across {N} devices." + ) + + weight = mx.random.uniform( + low=-scale, + high=scale, + shape=(output_dims // N, input_dims), + ) + self.weight, self.scales, self.biases = mx.quantize(weight, group_size, bits) + + # And bias if needed + if bias: + self.bias = mx.zeros((output_dims // N,)) + + # Freeze this model's parameters + self.freeze() + + def unfreeze(self, *args, **kwargs): + """Wrap unfreeze so that we unfreeze any layers we might contain but + our parameters will remain frozen.""" + super().unfreeze(*args, **kwargs) + self.freeze(recurse=False) + + def _extra_repr(self) -> str: + out_dims, in_dims = self.weight.shape + in_dims *= 32 // self.bits + out_dims *= self.group.size() + return ( + f"input_dims={in_dims}, output_dims={out_dims}, bias={'bias' in self}, " + f"group_size={self.group_size}, bits={self.bits}" + ) + + def __call__(self, x: mx.array) -> mx.array: + # Aggregate the gradients coming from each shard + x = sum_gradients(self.group)(x) + + x = mx.quantized_matmul( + x, + self["weight"], + scales=self["scales"], + biases=self["biases"], + transpose=True, + group_size=self.group_size, + bits=self.bits, + ) + if "bias" in self: + x = x + self["bias"] + return x + + @classmethod + def from_quantized_linear( + cls, + quantized_linear_layer: Module, + *, + segments: Union[int, list] = 1, + group: Optional[mx.distributed.Group] = None, + ): + group = group or mx.distributed.init() + output_dims, input_dims = quantized_linear_layer.weight.shape + input_dims *= 32 // quantized_linear_layer.bits + + sl = cls( + input_dims, + output_dims, + hasattr(quantized_linear_layer, "bias"), + group_size=quantized_linear_layer.group_size, + bits=quantized_linear_layer.bits, + group=group, + ) + sl.update( + _shard( + quantized_linear_layer.parameters(), + _all_to_sharded(segments), + group, + ) + ) + + return sl + + +class QuantizedShardedToAllLinear(Module): + """Each member of the group applies part of the affine transformation using + the quantized matrix and then aggregates the results. + + All nodes will have the same exact result after this layer. + + It is the quantized equivalent of :class:`mlx.nn.ShardedToAllLinear`. + Similar to :class:`mlx.nn.QuantizedLinear` its parameters are frozen and + will not be included in any gradient computation. + + Args: + input_dims (int): The dimensionality of the input features. + output_dims (int): The dimensionality of the output features. + bias (bool, optional): If set to ``False`` then the layer will not use + a bias. Default: ``True``. + group_size (int, optional): The group size to use for the quantized + weight. See :func:`~mlx.core.quantize`. Default: ``64``. + bits (int, optional): The bit width to use for the quantized weight. + See :func:`~mlx.core.quantize`. Default: ``4``. + group (mx.distributed.Group, optional): The sharding will happen across + this group. If not set then the global group is used. Default is + ``None``. + """ + + def __init__( + self, + input_dims: int, + output_dims: int, + bias: bool = True, + group_size: int = 64, + bits: int = 4, + group: Optional[mx.distributed.Group] = None, + ): + super().__init__() + + # Quantization config + self.group_size = group_size + self.bits = bits + + # Initialize the quantized weight + scale = math.sqrt(1.0 / input_dims) + self.group = group or mx.distributed.init() + N = self.group.size() + + if (input_dims % N) != 0: + raise ValueError( + f"The input of size {input_dims} cannot be sharded across {N} devices." + ) + + weight = mx.random.uniform( + low=-scale, + high=scale, + shape=(output_dims, input_dims // N), + ) + self.weight, self.scales, self.biases = mx.quantize(weight, group_size, bits) + + # And bias if needed + if bias: + self.bias = mx.zeros((output_dims,)) + + # Freeze this model's parameters + self.freeze() + + def unfreeze(self, *args, **kwargs): + """Wrap unfreeze so that we unfreeze any layers we might contain but + our parameters will remain frozen.""" + super().unfreeze(*args, **kwargs) + self.freeze(recurse=False) + + def _extra_repr(self) -> str: + out_dims, in_dims = self.weight.shape + in_dims *= (32 // self.bits) * self.group.size() + return ( + f"input_dims={in_dims}, output_dims={out_dims}, bias={'bias' in self}, " + f"group_size={self.group_size}, bits={self.bits}" + ) + + def __call__(self, x: mx.array) -> mx.array: + x = mx.quantized_matmul( + x, + self["weight"], + scales=self["scales"], + biases=self["biases"], + transpose=True, + group_size=self.group_size, + bits=self.bits, + ) + x = mx.distributed.all_sum(x, group=self.group) + if "bias" in self: + x = x + self["bias"] + return x + + @classmethod + def from_quantized_linear( + cls, + quantized_linear_layer: Module, + *, + segments: Union[int, list] = 1, + group: Optional[mx.distributed.Group] = None, + ): + group = group or mx.distributed.init() + output_dims, input_dims = quantized_linear_layer.weight.shape + input_dims *= 32 // quantized_linear_layer.bits + + sl = cls( + input_dims, + output_dims, + hasattr(quantized_linear_layer, "bias"), + group_size=quantized_linear_layer.group_size, + bits=quantized_linear_layer.bits, + group=group, + ) + sl.update( + _shard( + quantized_linear_layer.parameters(), + _sharded_to_all(segments), + group, + ) + ) + + return sl diff --git a/python/mlx/nn/losses.py b/python/mlx/nn/losses.py index 58232363a..aceb1f98a 100644 --- a/python/mlx/nn/losses.py +++ b/python/mlx/nn/losses.py @@ -352,7 +352,7 @@ def smooth_l1_loss( .. math:: l = \begin{cases} - 0.5 (x - y)^2, & \text{if } (x - y) < \beta \\ + 0.5 (x - y)^2 / \beta, & \text{if } |x - y| < \beta \\ |x - y| - 0.5 \beta, & \text{otherwise} \end{cases} diff --git a/python/mlx/optimizers/optimizers.py b/python/mlx/optimizers/optimizers.py index d36f34891..7931c74fa 100644 --- a/python/mlx/optimizers/optimizers.py +++ b/python/mlx/optimizers/optimizers.py @@ -1,6 +1,5 @@ # Copyright © 2023-2024 Apple Inc. -import math from typing import Callable, List, Optional, Tuple, Union import mlx.core as mx diff --git a/python/src/CMakeLists.txt b/python/src/CMakeLists.txt index caaa478a3..7ea302cf9 100644 --- a/python/src/CMakeLists.txt +++ b/python/src/CMakeLists.txt @@ -17,6 +17,7 @@ nanobind_add_module( ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/load.cpp ${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/memory.cpp ${CMAKE_CURRENT_SOURCE_DIR}/mlx_func.cpp ${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp ${CMAKE_CURRENT_SOURCE_DIR}/stream.cpp diff --git a/python/src/array.cpp b/python/src/array.cpp index 375f9a3ec..467bd0fa5 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -206,6 +206,30 @@ void init_array(nb::module_& m) { return os.str(); }); + nb::class_( + m, + "iinfo", + R"pbdoc( + Get information on integer types. + )pbdoc") + .def(nb::init()) + .def_ro( + "min", + &mx::iinfo::min, + R"pbdoc(The smallest representable number.)pbdoc") + .def_ro( + "max", + &mx::iinfo::max, + R"pbdoc(The largest representable number.)pbdoc") + .def_ro("dtype", &mx::iinfo::dtype, R"pbdoc(The :obj:`Dtype`.)pbdoc") + .def("__repr__", [](const mx::iinfo& i) { + std::ostringstream os; + os << "iinfo(" + << "min=" << i.min << ", max=" << i.max << ", dtype=" << i.dtype + << ")"; + return os.str(); + }); + nb::class_( m, "ArrayAt", @@ -1178,6 +1202,28 @@ void init_array(nb::module_& m) { nb::kw_only(), "stream"_a = nb::none(), "See :func:`max`.") + .def( + "logcumsumexp", + [](const mx::array& a, + std::optional axis, + bool reverse, + bool inclusive, + mx::StreamOrDevice s) { + if (axis) { + return mx::logcumsumexp(a, *axis, reverse, inclusive, s); + } else { + // TODO: Implement that in the C++ API as well. See concatenate + // above. + return mx::logcumsumexp( + mx::reshape(a, {-1}, s), 0, reverse, inclusive, s); + } + }, + "axis"_a = nb::none(), + nb::kw_only(), + "reverse"_a = false, + "inclusive"_a = true, + "stream"_a = nb::none(), + "See :func:`logcumsumexp`.") .def( "logsumexp", [](const mx::array& a, diff --git a/python/src/distributed.cpp b/python/src/distributed.cpp index ff24b8a95..c9acc8583 100644 --- a/python/src/distributed.cpp +++ b/python/src/distributed.cpp @@ -117,7 +117,64 @@ void init_distributed(nb::module_& parent_module) { Returns: array: The sum of all ``x`` arrays. )pbdoc"); + m.def( + "all_max", + [](const ScalarOrArray& x, + std::optional group, + mx::StreamOrDevice s) { + return mx::distributed::all_max(to_array(x), group, s); + }, + "x"_a, + nb::kw_only(), + "group"_a = nb::none(), + "stream"_a = nb::none(), + nb::sig( + "def all_max(x: array, *, group: Optional[Group] = None, stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + All reduce max. + Find the maximum of the ``x`` arrays from all processes in the group. + + Args: + x (array): Input array. + group (Group): The group of processes that will participate in the + reduction. If set to ``None`` the global group is used. Default: + ``None``. + stream (Stream, optional): Stream or device. Defaults to ``None`` + in which case the default stream of the default device is used. + + Returns: + array: The maximum of all ``x`` arrays. + )pbdoc"); + m.def( + "all_min", + [](const ScalarOrArray& x, + std::optional group, + mx::StreamOrDevice s) { + return mx::distributed::all_min(to_array(x), group, s); + }, + "x"_a, + nb::kw_only(), + "group"_a = nb::none(), + "stream"_a = nb::none(), + nb::sig( + "def all_min(x: array, *, group: Optional[Group] = None, stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + All reduce min. + + Find the minimum of the ``x`` arrays from all processes in the group. + + Args: + x (array): Input array. + group (Group): The group of processes that will participate in the + reduction. If set to ``None`` the global group is used. Default: + ``None``. + stream (Stream, optional): Stream or device. Defaults to ``None`` + in which case the default stream of the default device is used. + + Returns: + array: The minimum of all ``x`` arrays. + )pbdoc"); m.def( "all_gather", [](const ScalarOrArray& x, diff --git a/python/src/export.cpp b/python/src/export.cpp index feefeb12c..30062ae37 100644 --- a/python/src/export.cpp +++ b/python/src/export.cpp @@ -1,8 +1,8 @@ // Copyright © 2024 Apple Inc. #include -#include #include #include +#include #include #include @@ -16,8 +16,7 @@ namespace mx = mlx::core; namespace nb = nanobind; using namespace nb::literals; -std::pair, std::map> -validate_and_extract_inputs( +std::pair validate_and_extract_inputs( const nb::args& args, const nb::kwargs& kwargs, const std::string& prefix) { @@ -30,8 +29,8 @@ validate_and_extract_inputs( "and/or dictionary of arrays."); } }; - std::vector args_; - std::map kwargs_; + mx::Args args_; + mx::Kwargs kwargs_; if (args.size() == 0) { // No args so kwargs must be keyword arrays maybe_throw(nb::try_cast(kwargs, kwargs_)); @@ -81,9 +80,7 @@ class PyFunctionExporter { void close() { exporter_.close(); } - void operator()( - const std::vector& args, - const std::map& kwargs) { + void operator()(const mx::Args& args, const mx::Kwargs& kwargs) { exporter_(args, kwargs); } @@ -98,9 +95,12 @@ int py_function_exporter_tp_traverse( PyObject* self, visitproc visit, void* arg) { + Py_VISIT(Py_TYPE(self)); + if (!nb::inst_ready(self)) { + return 0; + } auto* p = nb::inst_ptr(self); Py_VISIT(p->dep_.ptr()); - Py_VISIT(Py_TYPE(self)); return 0; } @@ -109,23 +109,22 @@ PyType_Slot py_function_exporter_slots[] = { {0, 0}}; auto wrap_export_function(nb::callable fun) { - return [fun = std::move(fun)]( - const std::vector& args_, - const std::map& kwargs_) { - auto kwargs = nb::dict(); - kwargs.update(nb::cast(kwargs_)); - auto args = nb::tuple(nb::cast(args_)); - auto outputs = fun(*args, **kwargs); - std::vector outputs_; - if (nb::isinstance(outputs)) { - outputs_.push_back(nb::cast(outputs)); - } else if (!nb::try_cast(outputs, outputs_)) { - throw std::invalid_argument( - "[export_function] Outputs can be either a single array " - "a tuple or list of arrays."); - } - return outputs_; - }; + return + [fun = std::move(fun)](const mx::Args& args_, const mx::Kwargs& kwargs_) { + auto kwargs = nb::dict(); + kwargs.update(nb::cast(kwargs_)); + auto args = nb::tuple(nb::cast(args_)); + auto outputs = fun(*args, **kwargs); + std::vector outputs_; + if (nb::isinstance(outputs)) { + outputs_.push_back(nb::cast(outputs)); + } else if (!nb::try_cast(outputs, outputs_)) { + throw std::invalid_argument( + "[export_function] Outputs can be either a single array " + "a tuple or list of arrays."); + } + return outputs_; + }; } void init_export(nb::module_& m) { diff --git a/python/src/fast.cpp b/python/src/fast.cpp index 95b7dcc9a..c94f99e1a 100644 --- a/python/src/fast.cpp +++ b/python/src/fast.cpp @@ -124,14 +124,45 @@ void init_fast(nb::module_& parent_module) { m.def( "scaled_dot_product_attention", - &mx::fast::scaled_dot_product_attention, + [](const mx::array& queries, + const mx::array& keys, + const mx::array& values, + const float scale, + const std::variant& mask, + mx::StreamOrDevice s) { + bool has_mask = !std::holds_alternative(mask); + bool has_str_mask = + has_mask && std::holds_alternative(mask); + bool has_arr_mask = has_mask && std::holds_alternative(mask); + + if (has_mask) { + if (has_str_mask) { + auto mask_str = std::get(mask); + if (mask_str != "causal") { + std::ostringstream msg; + msg << "[scaled_dot_product_attention] invalid mask option '" + << mask_str << "'. Must be 'causal', or an array."; + throw std::invalid_argument(msg.str()); + } + return mx::fast::scaled_dot_product_attention( + queries, keys, values, scale, mask_str, {}, s); + } else { + auto mask_arr = std::get(mask); + return mx::fast::scaled_dot_product_attention( + queries, keys, values, scale, "", {mask_arr}, s); + } + + } else { + return mx::fast::scaled_dot_product_attention( + queries, keys, values, scale, "", {}, s); + } + }, "q"_a, "k"_a, "v"_a, nb::kw_only(), "scale"_a, "mask"_a = nb::none(), - "memory_efficient_threshold"_a = nb::none(), "stream"_a = nb::none(), nb::sig( "def scaled_dot_product_attention(q: array, k: array, v: array, *, scale: float, mask: Union[None, str, array] = None, stream: Union[None, Stream, Device] = None) -> array"), @@ -164,10 +195,10 @@ void init_fast(nb::module_& parent_module) { k (array): Keys with shape ``[B, N_kv, T_kv, D]``. v (array): Values with shape ``[B, N_kv, T_kv, D]``. scale (float): Scale for queries (typically ``1.0 / sqrt(q.shape(-1)``) - mask (Union[None, str, array], optional): A causal, boolean or additive - mask to apply to the query-key scores. The mask can have at most 4 - dimensions and must be broadcast-compatible with the shape - ``[B, N, T_q, T_kv]``. If an additive mask is given its type must + mask (Union[None, str, array], optional): A causal, boolean or additive + mask to apply to the query-key scores. The mask can have at most 4 + dimensions and must be broadcast-compatible with the shape + ``[B, N, T_q, T_kv]``. If an additive mask is given its type must promote to the promoted type of ``q``, ``k``, and ``v``. Returns: array: The output array. diff --git a/python/src/memory.cpp b/python/src/memory.cpp new file mode 100644 index 000000000..5ce9a765b --- /dev/null +++ b/python/src/memory.cpp @@ -0,0 +1,125 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/memory.h" +#include + +namespace mx = mlx::core; +namespace nb = nanobind; +using namespace nb::literals; + +void init_memory(nb::module_& m) { + m.def( + "get_active_memory", + &mx::get_active_memory, + R"pbdoc( + Get the actively used memory in bytes. + + Note, this will not always match memory use reported by the system because + it does not include cached memory buffers. + )pbdoc"); + m.def( + "get_peak_memory", + &mx::get_peak_memory, + R"pbdoc( + Get the peak amount of used memory in bytes. + + The maximum memory used recorded from the beginning of the program + execution or since the last call to :func:`reset_peak_memory`. + )pbdoc"); + m.def( + "reset_peak_memory", + &mx::reset_peak_memory, + R"pbdoc( + Reset the peak memory to zero. + )pbdoc"); + m.def( + "get_cache_memory", + &mx::get_cache_memory, + R"pbdoc( + Get the cache size in bytes. + + The cache includes memory not currently used that has not been returned + to the system allocator. + )pbdoc"); + m.def( + "set_memory_limit", + &mx::set_memory_limit, + "limit"_a, + R"pbdoc( + Set the memory limit. + + The memory limit is a guideline for the maximum amount of memory to use + during graph evaluation. If the memory limit is exceeded and there is no + more RAM (including swap when available) allocations will result in an + exception. + + When metal is available the memory limit defaults to 1.5 times the + maximum recommended working set size reported by the device. + + Args: + limit (int): Memory limit in bytes. + + Returns: + int: The previous memory limit in bytes. + )pbdoc"); + m.def( + "set_cache_limit", + &mx::set_cache_limit, + "limit"_a, + R"pbdoc( + Set the free cache limit. + + If using more than the given limit, free memory will be reclaimed + from the cache on the next allocation. To disable the cache, set + the limit to ``0``. + + The cache limit defaults to the memory limit. See + :func:`set_memory_limit` for more details. + + Args: + limit (int): The cache limit in bytes. + + Returns: + int: The previous cache limit in bytes. + )pbdoc"); + m.def( + "set_wired_limit", + &mx::set_wired_limit, + "limit"_a, + R"pbdoc( + Set the wired size limit. + + .. note:: + * This function is only useful on macOS 15.0 or higher. + * The wired limit should remain strictly less than the total + memory size. + + The wired limit is the total size in bytes of memory that will be kept + resident. The default value is ``0``. + + Setting a wired limit larger than system wired limit is an error. You can + increase the system wired limit with: + + .. code-block:: + + sudo sysctl iogpu.wired_limit_mb= + + Use :func:`device_info` to query the system wired limit + (``"max_recommended_working_set_size"``) and the total memory size + (``"memory_size"``). + + Args: + limit (int): The wired limit in bytes. + + Returns: + int: The previous wired limit in bytes. + )pbdoc"); + m.def( + "clear_cache", + &mx::clear_cache, + R"pbdoc( + Clear the memory cache. + + After calling this, :func:`get_cache_memory` should return ``0``. + )pbdoc"); +} diff --git a/python/src/metal.cpp b/python/src/metal.cpp index fef856dd9..a13dd2a03 100644 --- a/python/src/metal.cpp +++ b/python/src/metal.cpp @@ -1,17 +1,27 @@ // Copyright © 2023-2024 Apple Inc. +#include -#include "mlx/backend/metal/metal.h" #include #include #include #include #include #include +#include "mlx/backend/metal/metal.h" +#include "mlx/memory.h" namespace mx = mlx::core; namespace nb = nanobind; using namespace nb::literals; +bool DEPRECATE(const std::string& old_fn, const std::string new_fn) { + std::cerr << old_fn << " is deprecated and will be removed in a future " + << "version. Use " << new_fn << " instead." << std::endl; + return true; +} + +#define DEPRECATE(oldfn, newfn) static bool dep = DEPRECATE(oldfn, newfn) + void init_metal(nb::module_& m) { nb::module_ metal = m.def_submodule("metal", "mlx.metal"); metal.def( @@ -20,121 +30,47 @@ void init_metal(nb::module_& m) { R"pbdoc( Check if the Metal back-end is available. )pbdoc"); - metal.def( - "get_active_memory", - &mx::metal::get_active_memory, - R"pbdoc( - Get the actively used memory in bytes. - - Note, this will not always match memory use reported by the system because - it does not include cached memory buffers. - )pbdoc"); - metal.def( - "get_peak_memory", - &mx::metal::get_peak_memory, - R"pbdoc( - Get the peak amount of used memory in bytes. - - The maximum memory used recorded from the beginning of the program - execution or since the last call to :func:`reset_peak_memory`. - )pbdoc"); - metal.def( - "reset_peak_memory", - &mx::metal::reset_peak_memory, - R"pbdoc( - Reset the peak memory to zero. - )pbdoc"); - metal.def( - "get_cache_memory", - &mx::metal::get_cache_memory, - R"pbdoc( - Get the cache size in bytes. - - The cache includes memory not currently used that has not been returned - to the system allocator. - )pbdoc"); + metal.def("get_active_memory", []() { + DEPRECATE("mx.metal.get_active_memory", "mx.get_active_memory"); + return mx::get_active_memory(); + }); + metal.def("get_peak_memory", []() { + DEPRECATE("mx.metal.get_peak_memory", "mx.get_peak_memory"); + return mx::get_peak_memory(); + }); + metal.def("reset_peak_memory", []() { + DEPRECATE("mx.metal.reset_peak_memory", "mx.reset_peak_memory"); + mx::reset_peak_memory(); + }); + metal.def("get_cache_memory", []() { + DEPRECATE("mx.metal.get_cache_memory", "mx.get_cache_memory"); + return mx::get_cache_memory(); + }); metal.def( "set_memory_limit", - &mx::metal::set_memory_limit, - "limit"_a, - R"pbdoc( - Set the memory limit. - - The memory limit is a guideline for the maximum amount of memory to use - during graph evaluation. If the memory limit is exceeded and there is no - more RAM (including swap when available) allocations will result in an - exception. - - When metal is available the memory limit defaults to 1.5 times the - maximum recommended working set size reported by the device. - - Args: - limit (int): Memory limit in bytes. - - Returns: - int: The previous memory limit in bytes. - )pbdoc"); + [](size_t limit) { + DEPRECATE("mx.metal.set_memory_limt", "mx.set_memory_limit"); + return mx::set_memory_limit(limit); + }, + "limit"_a); metal.def( "set_cache_limit", - &mx::metal::set_cache_limit, - "limit"_a, - R"pbdoc( - Set the free cache limit. - - If using more than the given limit, free memory will be reclaimed - from the cache on the next allocation. To disable the cache, set - the limit to ``0``. - - The cache limit defaults to the memory limit. See - :func:`set_memory_limit` for more details. - - Args: - limit (int): The cache limit in bytes. - - Returns: - int: The previous cache limit in bytes. - )pbdoc"); + [](size_t limit) { + DEPRECATE("mx.metal.set_cache_limt", "mx.set_cache_limit"); + return mx::set_cache_limit(limit); + }, + "limit"_a); metal.def( "set_wired_limit", - &mx::metal::set_wired_limit, - "limit"_a, - R"pbdoc( - Set the wired size limit. - - .. note:: - * This function is only useful on macOS 15.0 or higher. - * The wired limit should remain strictly less than the total - memory size. - - The wired limit is the total size in bytes of memory that will be kept - resident. The default value is ``0``. - - Setting a wired limit larger than system wired limit is an error. You can - increase the system wired limit with: - - .. code-block:: - - sudo sysctl iogpu.wired_limit_mb= - - Use :func:`device_info` to query the system wired limit - (``"max_recommended_working_set_size"``) and the total memory size - (``"memory_size"``). - - Args: - limit (int): The wired limit in bytes. - - Returns: - int: The previous wired limit in bytes. - )pbdoc"); - metal.def( - "clear_cache", - &mx::metal::clear_cache, - R"pbdoc( - Clear the memory cache. - - After calling this, :func:`get_cache_memory` should return ``0``. - )pbdoc"); - + [](size_t limit) { + DEPRECATE("mx.metal.set_wired_limt", "mx.set_wired_limit"); + return mx::set_wired_limit(limit); + }, + "limit"_a); + metal.def("clear_cache", []() { + DEPRECATE("mx.metal.clear_cache", "mx.clear_cache"); + mx::clear_cache(); + }); metal.def( "start_capture", &mx::metal::start_capture, diff --git a/python/src/mlx.cpp b/python/src/mlx.cpp index ecf9a3a13..eaddecb26 100644 --- a/python/src/mlx.cpp +++ b/python/src/mlx.cpp @@ -12,6 +12,7 @@ void init_array(nb::module_&); void init_device(nb::module_&); void init_stream(nb::module_&); void init_metal(nb::module_&); +void init_memory(nb::module_&); void init_ops(nb::module_&); void init_transforms(nb::module_&); void init_random(nb::module_&); @@ -34,6 +35,7 @@ NB_MODULE(core, m) { init_stream(m); init_array(m); init_metal(m); + init_memory(m); init_ops(m); init_transforms(m); init_random(m); diff --git a/python/src/mlx_func.cpp b/python/src/mlx_func.cpp index b2eca5f6f..2f0589bb6 100644 --- a/python/src/mlx_func.cpp +++ b/python/src/mlx_func.cpp @@ -16,12 +16,12 @@ struct gc_func { }; int gc_func_tp_traverse(PyObject* self, visitproc visit, void* arg) { + Py_VISIT(Py_TYPE(self)); gc_func* w = (gc_func*)self; Py_VISIT(w->func); for (auto d : w->deps) { Py_VISIT(d); } - Py_VISIT(Py_TYPE(self)); return 0; }; diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 1577cae18..f98aa80aa 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -2382,6 +2382,43 @@ void init_ops(nb::module_& m) { Returns: array: The output array with the corresponding axes reduced. )pbdoc"); + m.def( + "logcumsumexp", + [](const mx::array& a, + std::optional axis, + bool reverse, + bool inclusive, + mx::StreamOrDevice s) { + if (axis) { + return mx::logcumsumexp(a, *axis, reverse, inclusive, s); + } else { + return mx::logcumsumexp( + mx::reshape(a, {-1}, s), 0, reverse, inclusive, s); + } + }, + nb::arg(), + "axis"_a = nb::none(), + nb::kw_only(), + "reverse"_a = false, + "inclusive"_a = true, + "stream"_a = nb::none(), + nb::sig( + "def logcumsumexp(a: array, /, axis: Optional[int] = None, *, reverse: bool = False, inclusive: bool = True, stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + Return the cumulative logsumexp of the elements along the given axis. + + Args: + a (array): Input array + axis (int, optional): Optional axis to compute the cumulative logsumexp + over. If unspecified the cumulative logsumexp of the flattened array is + returned. + reverse (bool): Perform the cumulative logsumexp in reverse. + inclusive (bool): The i-th element of the output includes the i-th + element of the input. + + Returns: + array: The output array. + )pbdoc"); m.def( "logsumexp", [](const mx::array& a, @@ -3924,7 +3961,7 @@ void init_ops(nb::module_& m) { array or dict: A single array if loading from a ``.npy`` file or a dict mapping names to arrays if loading from a ``.npz`` or - ``.safetensors`` file. If ``return_metadata` is ``True`` an + ``.safetensors`` file. If ``return_metadata`` is ``True`` an additional dictionary of metadata will be returned. Warning: @@ -4213,9 +4250,10 @@ void init_ops(nb::module_& m) { "group_size"_a = 64, "bits"_a = 4, nb::kw_only(), + "sorted_indices"_a = false, "stream"_a = nb::none(), nb::sig( - "def gather_qmm(x: array, w: array, /, scales: array, biases: array, lhs_indices: Optional[array] = None, rhs_indices: Optional[array] = None, transpose: bool = True, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array"), + "def gather_qmm(x: array, w: array, /, scales: array, biases: array, lhs_indices: Optional[array] = None, rhs_indices: Optional[array] = None, transpose: bool = True, group_size: int = 64, bits: int = 4, *, sorted_indices: bool = False, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( Perform quantized matrix multiplication with matrix-level gather. @@ -4228,23 +4266,25 @@ void init_ops(nb::module_& m) { as ``w`` since they represent the same quantized matrix. Args: - x (array): Input array - w (array): Quantized matrix packed in unsigned integers - scales (array): The scales to use per ``group_size`` elements of ``w`` - biases (array): The biases to use per ``group_size`` elements of ``w`` - lhs_indices (array, optional): Integer indices for ``x``. Default: ``None``. - rhs_indices (array, optional): Integer indices for ``w``. Default: ``None``. - transpose (bool, optional): Defines whether to multiply with the - transposed ``w`` or not, namely whether we are performing - ``x @ w.T`` or ``x @ w``. Default: ``True``. - group_size (int, optional): The size of the group in ``w`` that - shares a scale and bias. Default: ``64``. - bits (int, optional): The number of bits occupied by each element in - ``w``. Default: ``4``. + x (array): Input array + w (array): Quantized matrix packed in unsigned integers + scales (array): The scales to use per ``group_size`` elements of ``w`` + biases (array): The biases to use per ``group_size`` elements of ``w`` + lhs_indices (array, optional): Integer indices for ``x``. Default: ``None``. + rhs_indices (array, optional): Integer indices for ``w``. Default: ``None``. + transpose (bool, optional): Defines whether to multiply with the + transposed ``w`` or not, namely whether we are performing + ``x @ w.T`` or ``x @ w``. Default: ``True``. + group_size (int, optional): The size of the group in ``w`` that + shares a scale and bias. Default: ``64``. + bits (int, optional): The number of bits occupied by each element in + ``w``. Default: ``4``. + sorted_indices (bool, optional): May allow a faster implementation + if the passed indices are sorted. Default: ``False``. Returns: - array: The result of the multiplication of ``x`` with ``w`` - after gathering using ``lhs_indices`` and ``rhs_indices``. + array: The result of the multiplication of ``x`` with ``w`` + after gathering using ``lhs_indices`` and ``rhs_indices``. )pbdoc"); m.def( "tensordot", @@ -4274,16 +4314,16 @@ void init_ops(nb::module_& m) { Compute the tensor dot product along the specified axes. Args: - a (array): Input array - b (array): Input array - axes (int or list(list(int)), optional): The number of dimensions to - sum over. If an integer is provided, then sum over the last - ``axes`` dimensions of ``a`` and the first ``axes`` dimensions of - ``b``. If a list of lists is provided, then sum over the - corresponding dimensions of ``a`` and ``b``. Default: 2. + a (array): Input array + b (array): Input array + axes (int or list(list(int)), optional): The number of dimensions to + sum over. If an integer is provided, then sum over the last + ``axes`` dimensions of ``a`` and the first ``axes`` dimensions of + ``b``. If a list of lists is provided, then sum over the + corresponding dimensions of ``a`` and ``b``. Default: 2. Returns: - array: The tensor dot product. + array: The tensor dot product. )pbdoc"); m.def( "inner", @@ -4427,9 +4467,10 @@ void init_ops(nb::module_& m) { "lhs_indices"_a = nb::none(), "rhs_indices"_a = nb::none(), nb::kw_only(), + "sorted_indices"_a = false, "stream"_a = nb::none(), nb::sig( - "def gather_mm(a: array, b: array, /, lhs_indices: array, rhs_indices: array, *, stream: Union[None, Stream, Device] = None) -> array"), + "def gather_mm(a: array, b: array, /, lhs_indices: array, rhs_indices: array, *, sorted_indices: bool = False, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( Matrix multiplication with matrix-level gather. @@ -4448,11 +4489,16 @@ void init_ops(nb::module_& m) { For ``b`` with shape ``(B1, B2, ..., BS, M, K)``, ``rhs_indices`` contains indices from the range ``[0, B1 * B2 * ... * BS)`` + If only one index is passed and it is sorted, the ``sorted_indices`` + flag can be passed for a possible faster implementation. + Args: a (array): Input array. b (array): Input array. lhs_indices (array, optional): Integer indices for ``a``. Default: ``None`` rhs_indices (array, optional): Integer indices for ``b``. Default: ``None`` + sorted_indices (bool, optional): May allow a faster implementation + if the passed indices are sorted. Default: ``False``. Returns: array: The output array. @@ -5124,4 +5170,23 @@ void init_ops(nb::module_& m) { [0, 1, 0], [0, 1, 0]], dtype=float32) )pbdoc"); + m.def( + "contiguous", + &mx::contiguous, + nb::arg(), + "allow_col_major"_a = false, + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def contiguous(a: array, /, allow_col_major: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + Force an array to be row contiguous. Copy if necessary. + + Args: + a (array): The input to make contiguous + allow_col_major (bool): Consider column major as contiguous and don't copy + + Returns: + array: The row or col contiguous output. + )pbdoc"); } diff --git a/python/src/random.cpp b/python/src/random.cpp index 4b82f5479..e9c0a87fc 100644 --- a/python/src/random.cpp +++ b/python/src/random.cpp @@ -467,7 +467,7 @@ void init_random(nb::module_& parent_module) { array: The output array of random values. )pbdoc"); m.def( - "permuation", + "permutation", [](const std::variant& x, int axis, const std::optional& key_, diff --git a/python/src/transforms.cpp b/python/src/transforms.cpp index 4a5e2e6ac..c47942b72 100644 --- a/python/src/transforms.cpp +++ b/python/src/transforms.cpp @@ -960,6 +960,11 @@ class PyCustomFunction { }; int py_custom_function_tp_traverse(PyObject* self, visitproc visit, void* arg) { + Py_VISIT(Py_TYPE(self)); + if (!nb::inst_ready(self)) { + return 0; + } + auto* p = nb::inst_ptr(self); nb::handle v = nb::find(p->fun_); Py_VISIT(v.ptr()); @@ -975,7 +980,6 @@ int py_custom_function_tp_traverse(PyObject* self, visitproc visit, void* arg) { nb::handle v = nb::find(*(p->vmap_fun_)); Py_VISIT(v.ptr()); } - Py_VISIT(Py_TYPE(self)); return 0; } int py_custom_function_tp_clear(PyObject* self) { diff --git a/python/src/utils.cpp b/python/src/utils.cpp index e6ca346dc..08f78bdf4 100644 --- a/python/src/utils.cpp +++ b/python/src/utils.cpp @@ -2,6 +2,7 @@ #include "python/src/utils.h" #include "mlx/ops.h" +#include "mlx/utils.h" #include "python/src/convert.h" mx::array to_array( @@ -16,6 +17,16 @@ mx::array to_array( ? mx::int64 : mx::int32; auto out_t = dtype.value_or(default_type); + if (mx::issubdtype(out_t, mx::integer) && out_t.size() < 8) { + auto info = mx::iinfo(out_t); + if (val < info.min || val > static_cast(info.max)) { + std::ostringstream msg; + msg << "Converting " << val << " to " << out_t + << " would result in overflow."; + throw std::invalid_argument(msg.str()); + } + } + // bool_ is an exception and is always promoted return mx::array(val, (out_t == mx::bool_) ? mx::int32 : out_t); } else if (auto pv = std::get_if(&v); pv) { diff --git a/python/tests/mlx_distributed_tests.py b/python/tests/mlx_distributed_tests.py new file mode 100644 index 000000000..5feb51bc9 --- /dev/null +++ b/python/tests/mlx_distributed_tests.py @@ -0,0 +1,250 @@ +# Copyright © 2025 Apple Inc. + +import unittest + +import mlx.core as mx +import mlx.nn as nn +import mlx_tests +from mlx.nn.layers.distributed import shard_inplace, shard_linear +from mlx.nn.utils import average_gradients + + +class MLXDistributedCommonTestCase(mlx_tests.MLXTestCase): + def test_average_gradients(self): + original_all_sum = mx.distributed.all_sum + n_calls = 0 + xtype = None + + def new_all_sum(x, **kwargs): + nonlocal n_calls + nonlocal xtype + + n_calls += 1 + if xtype is not None: + self.assertEqual(xtype, x.dtype) + + return original_all_sum(x, **kwargs) + + mx.distributed.all_sum = new_all_sum + + try: + grads = [mx.ones(10) for i in range(10)] + new_grads = average_gradients(grads) + mx.eval(new_grads) + self.assertEqual(len(new_grads), 10) + self.assertTrue(all(mx.all(g == 1) for g in new_grads)) + self.assertEqual(n_calls, 1) + + n_calls = 0 + new_grads = average_gradients(grads, all_reduce_size=4 * 50) + mx.eval(new_grads) + self.assertEqual(len(new_grads), 10) + self.assertTrue(all(mx.all(g == 1) for g in new_grads)) + self.assertEqual(n_calls, 2) + + n_calls = 0 + new_grads = average_gradients(grads, all_reduce_size=0) + mx.eval(new_grads) + self.assertEqual(len(new_grads), 10) + self.assertTrue(all(mx.all(g == 1) for g in new_grads)) + self.assertEqual(n_calls, 10) + + n_calls = 0 + xtype = mx.float16 + new_grads = average_gradients( + grads, all_reduce_size=2 * 50, communication_type=mx.float16 + ) + mx.eval(new_grads) + self.assertEqual(len(new_grads), 10) + self.assertTrue(all(g.dtype == mx.float32 for g in new_grads)) + self.assertTrue(all(mx.all(g == 1) for g in new_grads)) + self.assertEqual(n_calls, 2) + + finally: + mx.distributed.all_sum = original_all_sum + + def test_donation(self): + x = mx.random.normal((1024,)) + mx.eval(x) + mx.synchronize() + + mx.reset_peak_memory() + scale = mx.array(2.0) + y = mx.distributed.all_sum(x) + mx.eval(y) + mx.synchronize() + all_sum_only = mx.get_peak_memory() + y = mx.distributed.all_sum(x) * scale + mx.eval(y) + mx.synchronize() + all_sum_with_binary = mx.get_peak_memory() + + self.assertEqual(all_sum_only, all_sum_with_binary) + + def test_shard_linear(self): + # Seed the prng to have the same inputs and weights generated everywhere + mx.random.seed(0xF0F0F0F0) + + # Prepare inputs + world = mx.distributed.init() + part = ( + slice(None), + slice( + world.rank() * 1024 // world.size(), + (world.rank() + 1) * 1024 // world.size(), + ), + ) + x = mx.random.normal((4, 1024)) + + # Create and shard some linear layers + lin = nn.Linear(1024, 1024, bias=True) + slin1 = shard_linear(lin, "all-to-sharded") + slin2 = shard_linear(lin, "sharded-to-all") + y = lin(x) + y1 = slin1(x) + y2 = slin2(x[part]) + self.assertTrue(mx.allclose(y, y2, atol=1e-6, rtol=1e-4)) + self.assertTrue(mx.allclose(y[part], y1)) + + # And their quant versions + qlin = lin.to_quantized() + slin1 = shard_linear(qlin, "all-to-sharded") + slin2 = shard_linear(qlin, "sharded-to-all") + y = qlin(x) + y1 = slin1(x) + y2 = slin2(x[part]) + self.assertTrue(mx.allclose(y, y2, atol=1e-6, rtol=1e-4)) + self.assertTrue(mx.allclose(y[part], y1)) + + # Check the backward works as expected + def dummy_loss(model, x, y): + return (model(x) * y).sum() + + mod = nn.Sequential( + nn.Linear(128, 128), + nn.Linear(128, 128), + nn.Linear(128, 128), + nn.Linear(128, 128), + ) + smod = nn.Sequential( + shard_linear(mod.layers[0], "all-to-sharded"), + shard_linear(mod.layers[1], "sharded-to-all"), + shard_linear(mod.layers[2], "all-to-sharded"), + shard_linear(mod.layers[3], "sharded-to-all"), + ) + + grad1 = nn.value_and_grad(mod, dummy_loss) + grad2 = nn.value_and_grad(smod, dummy_loss) + + x = mx.random.normal((4, 128)) + y = mx.random.normal((4, 128)) + + l1, g1 = grad1(mod, x, y) + l2, g2 = grad2(smod, x, y) + mx.eval(l1, g1, l2, g2) + + part = slice( + world.rank() * 128 // world.size(), (world.rank() + 1) * 128 // world.size() + ) + self.assertTrue(mx.allclose(l1, l2)) + self.assertTrue( + mx.allclose( + g1["layers"][0]["weight"][part], + g2["layers"][0]["weight"], + atol=1e-6, + rtol=1e-4, + ) + ) + self.assertTrue( + mx.allclose( + g1["layers"][2]["weight"][part], + g2["layers"][2]["weight"], + atol=1e-6, + rtol=1e-4, + ) + ) + self.assertTrue( + mx.allclose( + g1["layers"][1]["weight"][:, part], + g2["layers"][1]["weight"], + atol=1e-6, + rtol=1e-4, + ) + ) + self.assertTrue( + mx.allclose( + g1["layers"][3]["weight"][:, part], + g2["layers"][3]["weight"], + atol=1e-6, + rtol=1e-4, + ) + ) + self.assertTrue( + mx.allclose( + g1["layers"][0]["bias"][part], + g2["layers"][0]["bias"], + atol=1e-6, + rtol=1e-4, + ) + ) + self.assertTrue( + mx.allclose( + g1["layers"][2]["bias"][part], + g2["layers"][2]["bias"], + atol=1e-6, + rtol=1e-4, + ) + ) + self.assertTrue( + mx.allclose( + g1["layers"][1]["bias"], g2["layers"][1]["bias"], atol=1e-6, rtol=1e-4 + ) + ) + self.assertTrue( + mx.allclose( + g1["layers"][3]["bias"], g2["layers"][3]["bias"], atol=1e-6, rtol=1e-4 + ) + ) + + def test_shard_predicate(self): + mx.random.seed(0xF0F0F0F0) + + class MyConv(nn.Module): + def __init__(self, *args, **kwargs): + super().__init__() + self.aggregate = kwargs.pop("aggregate", False) + self.conv = nn.Conv2d(*args, **kwargs) + + def __call__(self, x): + x = self.conv(x) + if self.aggregate: + x = mx.distributed.all_sum(x) + return x + + def sharding(path, weight): + parts = path.split(".") + even = int(parts[1]) % 2 == 0 + if even: + return 0 + else: + return -1 if parts[-1] != "bias" else None + + mod = nn.Sequential( + MyConv(3, 128, kernel_size=3), + MyConv(128, 128, kernel_size=3), + MyConv(128, 128, kernel_size=3), + MyConv(128, 3, kernel_size=3), + ) + smod = nn.Sequential( + MyConv(3, 128, kernel_size=3), + MyConv(128, 128, kernel_size=3, aggregate=True), + MyConv(128, 128, kernel_size=3), + MyConv(128, 3, kernel_size=3, aggregate=True), + ) + smod.update(mod.parameters()) + shard_inplace(smod, sharding) + + x = mx.random.normal((4, 16, 16, 3)) + y1 = mod(x) + y2 = smod(x) + self.assertTrue(mx.allclose(y1, y2, atol=1e-6, rtol=1e-4)) diff --git a/python/tests/mpi_test_distributed.py b/python/tests/mpi_test_distributed.py index 0d172cee4..26d340dbe 100644 --- a/python/tests/mpi_test_distributed.py +++ b/python/tests/mpi_test_distributed.py @@ -3,11 +3,14 @@ import unittest import mlx.core as mx -import mlx_tests -from mlx.nn.utils import average_gradients +import mlx_distributed_tests -class TestDistributed(mlx_tests.MLXTestCase): +class TestMPIDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase): + @classmethod + def setUpClass(cls): + world = mx.distributed.init(strict=True, backend="mpi") + def test_groups(self): world = mx.distributed.init() self.assertEqual(world.size(), 8) @@ -27,27 +30,51 @@ class TestDistributed(mlx_tests.MLXTestCase): def test_all_reduce(self): world = mx.distributed.init() dtypes = [ - mx.int8, - mx.uint8, - mx.int16, - mx.uint16, - mx.int32, - mx.uint32, - mx.float32, - mx.float16, - mx.bfloat16, - mx.complex64, + (mx.int8, 0), + (mx.uint8, 0), + (mx.int16, 0), + (mx.uint16, 0), + (mx.int32, 0), + (mx.uint32, 0), + (mx.float32, 1e-6), + (mx.float16, 5e-3), + (mx.bfloat16, 1e-1), + (mx.complex64, 1e-6), ] - for dt in dtypes: - x = mx.ones((2, 2, 4), dtype=dt) - y = mx.distributed.all_sum(x) - self.assertTrue(mx.all(y == world.size())) + sizes = [ + (7,), + (10,), + (1024,), + (1024, 1024), + ] + key = mx.random.key(0) + group = world.split(world.rank() % 2) - sub = world.split(world.rank() % 2) - for dt in dtypes: - x = mx.ones((2, 2, 4), dtype=dt) - y = mx.distributed.all_sum(x, group=sub) - self.assertTrue(mx.all(y == sub.size())) + for dt, rtol in dtypes: + for sh in sizes: + for g in [world, group]: + x = ( + mx.random.uniform(shape=(g.size(),) + sh, key=key) * 10 + ).astype(dt) + + # All sum + y = mx.distributed.all_sum(x[g.rank()], group=g) + z = x.sum(0) + maxrelerror = (y - z).abs() + if rtol > 0: + maxrelerror /= z.abs() + maxrelerror = maxrelerror.max() + self.assertLessEqual(maxrelerror, rtol) + + # All max + y = mx.distributed.all_max(x[g.rank()], group=g) + z = x.max(0) + self.assertTrue(mx.all(y == z)) + + # All min + y = mx.distributed.all_min(x[g.rank()], group=g) + z = x.min(0) + self.assertTrue(mx.all(y == z)) def test_all_gather(self): world = mx.distributed.init() @@ -121,77 +148,6 @@ class TestDistributed(mlx_tests.MLXTestCase): x = mx.distributed.recv_like(x, neighbor, group=pairs) mx.eval(y, x) - def test_average_gradients(self): - original_all_sum = mx.distributed.all_sum - n_calls = 0 - xtype = None - - def new_all_sum(x, **kwargs): - nonlocal n_calls - nonlocal xtype - - n_calls += 1 - if xtype is not None: - self.assertEqual(xtype, x.dtype) - - return original_all_sum(x, **kwargs) - - mx.distributed.all_sum = new_all_sum - - try: - grads = [mx.ones(10) for i in range(10)] - new_grads = average_gradients(grads) - mx.eval(new_grads) - self.assertEqual(len(new_grads), 10) - self.assertTrue(all(mx.all(g == 1) for g in new_grads)) - self.assertEqual(n_calls, 1) - - n_calls = 0 - new_grads = average_gradients(grads, all_reduce_size=4 * 50) - mx.eval(new_grads) - self.assertEqual(len(new_grads), 10) - self.assertTrue(all(mx.all(g == 1) for g in new_grads)) - self.assertEqual(n_calls, 2) - - n_calls = 0 - new_grads = average_gradients(grads, all_reduce_size=0) - mx.eval(new_grads) - self.assertEqual(len(new_grads), 10) - self.assertTrue(all(mx.all(g == 1) for g in new_grads)) - self.assertEqual(n_calls, 10) - - n_calls = 0 - xtype = mx.float16 - new_grads = average_gradients( - grads, all_reduce_size=2 * 50, communication_type=mx.float16 - ) - mx.eval(new_grads) - self.assertEqual(len(new_grads), 10) - self.assertTrue(all(g.dtype == mx.float32 for g in new_grads)) - self.assertTrue(all(mx.all(g == 1) for g in new_grads)) - self.assertEqual(n_calls, 2) - - finally: - mx.distributed.all_sum = original_all_sum - - def test_donation(self): - x = mx.random.normal((1024,)) - mx.eval(x) - mx.synchronize(mx.default_stream(mx.default_device())) - - mx.metal.reset_peak_memory() - scale = mx.array(2.0) - y = mx.distributed.all_sum(x) - mx.eval(y) - mx.synchronize(mx.default_stream(mx.default_device())) - all_sum_only = mx.metal.get_peak_memory() - y = mx.distributed.all_sum(x) * scale - mx.eval(y) - mx.synchronize(mx.default_stream(mx.default_device())) - all_sum_with_binary = mx.metal.get_peak_memory() - - self.assertEqual(all_sum_only, all_sum_with_binary) - if __name__ == "__main__": unittest.main() diff --git a/python/tests/ring_test_distributed.py b/python/tests/ring_test_distributed.py index 0c68914bf..77d45dbad 100644 --- a/python/tests/ring_test_distributed.py +++ b/python/tests/ring_test_distributed.py @@ -3,10 +3,10 @@ import unittest import mlx.core as mx -import mlx_tests +import mlx_distributed_tests -class TestRingDistributed(mlx_tests.MLXTestCase): +class TestRingDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase): @classmethod def setUpClass(cls): world = mx.distributed.init(strict=True, backend="ring") @@ -44,18 +44,51 @@ class TestRingDistributed(mlx_tests.MLXTestCase): (1024, 1024), ] key = mx.random.key(0) + reductions = ["min", "max", "sum"] + for dt, rtol in dtypes: for sh in sizes: x = ( mx.random.uniform(shape=(world.size(),) + sh, key=key) * 10 ).astype(dt) + + # All sum y = mx.distributed.all_sum(x[world.rank()]) - z = sum( - x[i] for i in range(world.size()) - ) # to ensure that we don't sum to int32 - maxrelerror = ((y - z).abs() / z.abs()).max() + z = x.sum(0) + maxrelerror = (y - z).abs() + if rtol > 0: + maxrelerror /= z.abs() + maxrelerror = maxrelerror.max() self.assertLessEqual(maxrelerror, rtol) + # All max + y = mx.distributed.all_max(x[world.rank()]) + z = x.max(0) + self.assertTrue(mx.all(y == z)) + + # All min + y = mx.distributed.all_min(x[world.rank()]) + z = x.min(0) + self.assertTrue(mx.all(y == z)) + + def test_all_gather(self): + world = mx.distributed.init() + dtypes = [ + mx.int8, + mx.uint8, + mx.int16, + mx.uint16, + mx.int32, + mx.uint32, + mx.float32, + mx.complex64, + ] + for dt in dtypes: + x = mx.ones((2, 2, 4), dtype=dt) + y = mx.distributed.all_gather(x) + self.assertEqual(y.shape, (world.size() * 2, 2, 4)) + self.assertTrue(mx.all(y == 1)) + def test_send_recv(self): world = mx.distributed.init() dtypes = [ diff --git a/python/tests/test_array.py b/python/tests/test_array.py index b8917b75c..fa5784ea9 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -109,6 +109,18 @@ class TestDtypes(mlx_tests.MLXTestCase): self.assertEqual(mx.finfo(mx.float16).max, np.finfo(np.float16).max) self.assertEqual(mx.finfo(mx.float16).dtype, mx.float16) + def test_iinfo(self): + with self.assertRaises(ValueError): + mx.iinfo(mx.float32) + + self.assertEqual(mx.iinfo(mx.int32).min, np.iinfo(np.int32).min) + self.assertEqual(mx.iinfo(mx.int32).max, np.iinfo(np.int32).max) + self.assertEqual(mx.iinfo(mx.int32).dtype, mx.int32) + + self.assertEqual(mx.iinfo(mx.uint32).min, np.iinfo(np.uint32).min) + self.assertEqual(mx.iinfo(mx.uint32).max, np.iinfo(np.uint32).max) + self.assertEqual(mx.iinfo(mx.int8).dtype, mx.int8) + class TestEquality(mlx_tests.MLXTestCase): def test_array_eq_array(self): @@ -1496,6 +1508,7 @@ class TestArray(mlx_tests.MLXTestCase): ("prod", 1), ("min", 1), ("max", 1), + ("logcumsumexp", 1), ("logsumexp", 1), ("mean", 1), ("var", 1), @@ -1803,7 +1816,6 @@ class TestArray(mlx_tests.MLXTestCase): b = pickle.loads(pickle.dumps(a)) self.assertTrue(mx.array_equal(mx.array(a), mx.array(b))) - @unittest.skipIf(not mx.metal.is_available(), "Metal is not available") def test_multi_output_leak(self): def fun(): a = mx.zeros((2**20)) @@ -1813,10 +1825,10 @@ class TestArray(mlx_tests.MLXTestCase): fun() mx.synchronize() - peak_1 = mx.metal.get_peak_memory() + peak_1 = mx.get_peak_memory() fun() mx.synchronize() - peak_2 = mx.metal.get_peak_memory() + peak_2 = mx.get_peak_memory() self.assertEqual(peak_1, peak_2) def fun(): @@ -1826,10 +1838,10 @@ class TestArray(mlx_tests.MLXTestCase): fun() mx.synchronize() - peak_1 = mx.metal.get_peak_memory() + peak_1 = mx.get_peak_memory() fun() mx.synchronize() - peak_2 = mx.metal.get_peak_memory() + peak_2 = mx.get_peak_memory() self.assertEqual(peak_1, peak_2) def test_add_numpy(self): @@ -2000,6 +2012,14 @@ class TestArray(mlx_tests.MLXTestCase): used = get_mem() self.assertEqual(expected, used) + def test_scalar_integer_conversion_overflow(self): + y = mx.array(2000000000, dtype=mx.int32) + x = 3000000000 + with self.assertRaises(ValueError): + y + x + with self.assertRaises(ValueError): + mx.add(y, x) + if __name__ == "__main__": unittest.main() diff --git a/python/tests/test_autograd.py b/python/tests/test_autograd.py index 350b09837..ec9d957ea 100644 --- a/python/tests/test_autograd.py +++ b/python/tests/test_autograd.py @@ -745,11 +745,8 @@ class TestAutograd(mlx_tests.MLXTestCase): mx.custom_function, mx.checkpoint, ]: - if mx.metal.is_available(): - mx.synchronize(mx.default_stream(mx.default_device())) - mem_pre = mx.metal.get_active_memory() - else: - mem_pre = 0 + mx.synchronize() + mem_pre = mx.get_active_memory() def outer(): d = {} @@ -763,12 +760,7 @@ class TestAutograd(mlx_tests.MLXTestCase): for _ in range(5): outer() gc.collect() - - if mx.metal.is_available(): - mem_post = mx.metal.get_active_memory() - else: - mem_post = 0 - + mem_post = mx.get_active_memory() self.assertEqual(mem_pre, mem_post) def test_grad_with_copies(self): diff --git a/python/tests/test_blas.py b/python/tests/test_blas.py index 985ca5ffb..6fca4885b 100644 --- a/python/tests/test_blas.py +++ b/python/tests/test_blas.py @@ -12,7 +12,7 @@ import numpy as np class TestBlas(mlx_tests.MLXTestCase): @property def dtypes(self): - return ["float32", "float16"] if mx.metal.is_available() else ["float32"] + return ["float32", "float16"] def __gemm_test( self, @@ -1108,7 +1108,7 @@ class TestBlas(mlx_tests.MLXTestCase): lhs_indices_ = mx.broadcast_to(lhs_indices, (3, 2)) rhs_indices_ = mx.broadcast_to(rhs_indices, (3, 2)) M = a.shape[-2] - N = b.shape[-2] + N = b.shape[-1] K = a.shape[-1] a = a.reshape((-1, M, K)) @@ -1158,6 +1158,55 @@ class TestBlas(mlx_tests.MLXTestCase): out_gemm = (b @ c)[0] self.assertTrue(mx.allclose(out_gemv, out_gemm)) + def test_complex_gemv(self): + M = 16 + N = 50 + + def rand(shape): + return mx.random.uniform(shape=shape) + 1j * mx.random.uniform(shape=shape) + + a = rand((M, N)) + b = rand((N, 1)) + c = mx.matmul(a, b) + c_np = np.matmul(a, b) + self.assertTrue(np.allclose(c, c_np)) + + # Transposed + a = rand((N, M)) + b = rand((N, 1)) + c = mx.matmul(a.T, b) + c_np = np.matmul(np.array(a).T, b) + self.assertTrue(np.allclose(c, c_np)) + + def test_complex_gemm(self): + M = 16 + K = 50 + N = 32 + + def rand(shape): + return mx.random.uniform(shape=shape) + 1j * mx.random.uniform(shape=shape) + + a = rand((M, K)) + b = rand((K, N)) + c = mx.matmul(a, b) + c_np = np.matmul(a, b) + self.assertTrue(np.allclose(c, c_np)) + + # Test addmm + M = 16 + K = 50 + N = 32 + + def rand(shape): + return mx.random.uniform(shape=shape) + 1j * mx.random.uniform(shape=shape) + + a = rand((M, K)) + b = rand((K, N)) + c = rand((M, N)) + out = mx.addmm(c, a, b, 2.0, 2.0) + out_np = 2.0 * np.matmul(a, b) + 2.0 * c + self.assertTrue(np.allclose(out, out_np)) + if __name__ == "__main__": unittest.main() diff --git a/python/tests/test_compile.py b/python/tests/test_compile.py index 8cf3b4e08..f5ce496cd 100644 --- a/python/tests/test_compile.py +++ b/python/tests/test_compile.py @@ -955,7 +955,7 @@ class TestCompile(mlx_tests.MLXTestCase): def test_leaks(self): gc.collect() if mx.metal.is_available(): - mem_pre = mx.metal.get_active_memory() + mem_pre = mx.get_active_memory() else: mem_pre = 0 @@ -973,7 +973,7 @@ class TestCompile(mlx_tests.MLXTestCase): gc.collect() if mx.metal.is_available(): - mem_post = mx.metal.get_active_memory() + mem_post = mx.get_active_memory() else: mem_post = 0 diff --git a/python/tests/test_conv.py b/python/tests/test_conv.py index 9dd8fd140..671c86a32 100644 --- a/python/tests/test_conv.py +++ b/python/tests/test_conv.py @@ -707,9 +707,11 @@ class TestConv(mlx_tests.MLXTestCase): flip=flip, np_dtype=np_dtype, ): + np.random.seed(0) scale = 1.0 / math.sqrt(np.prod(wt_shape[1:])) - in_np = np.random.normal(0.0, scale, in_shape).astype(np_dtype) - wt_np = np.random.normal(0.0, scale, wt_shape).astype(np_dtype) + scale = min(0.3, scale) + in_np = np.random.normal(0, scale, in_shape).astype(np_dtype) + wt_np = np.random.normal(0, scale, wt_shape).astype(np_dtype) in_mx, wt_mx = map(mx.array, (in_np, wt_np)) @@ -1050,6 +1052,42 @@ class TestConv(mlx_tests.MLXTestCase): y2 = mx.conv2d(x, w, (1, 1), (1, 1), (1, 1), 1) self.assertTrue(mx.allclose(y1, y2)) + @unittest.skipIf(not has_torch, "requires Torch") + def test_torch_conv_depthwise(self): + + # fmt: off + shapes = ( + # N, H, W, C kH, kW, O, strides, padding, groups + ( 2, 16, 16, 32, 1, 1, 32, (2, 2), (1, 1), 32), + ( 1, 16, 16, 32, 3, 3, 32, (2, 2), (1, 1), 32), + ( 1, 32, 32, 32, 7, 7, 32, (1, 1), (3, 3), 32), + ( 3, 32, 32, 32, 5, 5, 32, (1, 2), (0, 0), 32), + ( 1, 32, 32, 32, 7, 7, 32, (2, 1), (1, 3), 32), + ) + # fmt: on + + dtypes = [np.float32] + if mx.default_device() == mx.gpu: + dtypes += [np.float16] + + for N, H, W, C, kH, kW, O, strides, padding, groups in shapes: + for dtype in dtypes: + for flip in [False, True]: + Cw = C // groups + + self.__conv_general_test( + (N, H, W, C), + (O, kH, kW, Cw), + strides, + padding, + kernel_dilation=1, + input_dilation=1, + groups=groups, + flip=flip, + np_dtype=dtype, + atol=2e-5 if dtype == np.float32 else 5e-4, + ) + if __name__ == "__main__": unittest.main() diff --git a/python/tests/test_eval.py b/python/tests/test_eval.py index ebcf64c7a..fcd424343 100644 --- a/python/tests/test_eval.py +++ b/python/tests/test_eval.py @@ -117,10 +117,9 @@ class TestEval(mlx_tests.MLXTestCase): out = mx.vjp(fn, (x,), (y,)) out = mx.vjp(fn, (x,), (y,)) - if mx.metal.is_available(): - peak_mem = mx.metal.get_peak_memory() - out = mx.vjp(fn, (x,), (y,)) - self.assertEqual(peak_mem, mx.metal.get_peak_memory()) + peak_mem = mx.get_peak_memory() + out = mx.vjp(fn, (x,), (y,)) + self.assertEqual(peak_mem, mx.get_peak_memory()) def test_async_eval_with_multiple_streams(self): x = mx.array([1.0]) @@ -137,7 +136,6 @@ class TestEval(mlx_tests.MLXTestCase): mx.async_eval(x) mx.eval(a + b) - @unittest.skipIf(not mx.metal.is_available(), "Metal is not available") def test_donation_for_noops(self): def fun(x): s = x.shape @@ -151,11 +149,11 @@ class TestEval(mlx_tests.MLXTestCase): x = mx.zeros((4096, 4096)) mx.eval(x) - pre = mx.metal.get_peak_memory() + pre = mx.get_peak_memory() out = fun(x) del x mx.eval(out) - post = mx.metal.get_peak_memory() + post = mx.get_peak_memory() self.assertEqual(pre, post) def fun(x): @@ -167,11 +165,11 @@ class TestEval(mlx_tests.MLXTestCase): x = mx.zeros((4096 * 4096,)) mx.eval(x) - pre = mx.metal.get_peak_memory() + pre = mx.get_peak_memory() out = fun(x) del x mx.eval(out) - post = mx.metal.get_peak_memory() + post = mx.get_peak_memory() self.assertEqual(pre, post) @unittest.skipIf(not mx.metal.is_available(), "Metal is not available") @@ -187,7 +185,7 @@ class TestEval(mlx_tests.MLXTestCase): s1 = mx.default_stream(mx.gpu) s2 = mx.new_stream(mx.gpu) - old_limit = mx.metal.set_memory_limit(1000) + old_limit = mx.set_memory_limit(1000) x = mx.ones((512, 512), stream=s2) for _ in range(80): @@ -195,7 +193,7 @@ class TestEval(mlx_tests.MLXTestCase): y = mx.abs(x, stream=s2) z = mx.abs(y, stream=s2) mx.eval(z) - mx.metal.set_memory_limit(old_limit) + mx.set_memory_limit(old_limit) if __name__ == "__main__": diff --git a/python/tests/test_export_import.py b/python/tests/test_export_import.py index fd62a58f6..2b4b425ca 100644 --- a/python/tests/test_export_import.py +++ b/python/tests/test_export_import.py @@ -243,7 +243,7 @@ class TestExportImport(mlx_tests.MLXTestCase): def test_leaks(self): path = os.path.join(self.test_dir, "fn.mlxfn") if mx.metal.is_available(): - mem_pre = mx.metal.get_active_memory() + mem_pre = mx.get_active_memory() else: mem_pre = 0 @@ -261,7 +261,7 @@ class TestExportImport(mlx_tests.MLXTestCase): gc.collect() if mx.metal.is_available(): - mem_post = mx.metal.get_active_memory() + mem_post = mx.get_active_memory() else: mem_post = 0 diff --git a/python/tests/test_fast_sdpa.py b/python/tests/test_fast_sdpa.py index 78e03159f..d35a2b1da 100644 --- a/python/tests/test_fast_sdpa.py +++ b/python/tests/test_fast_sdpa.py @@ -95,7 +95,13 @@ def prepare_inputs(B, qL, kL, D, qH, kH, mask, transpose, dtype): def mlx_primitives_sdpa(q, k, v, scale, mask=None): p = (q * scale) @ k.transpose(0, 1, 3, 2) if mask is not None: - if mask.dtype == mx.bool_: + if mask == "causal": + q_offset = max(0, k.shape[2] - q.shape[2]) + q_indices = mx.arange(q_offset, q_offset + q.shape[2]) + k_indices = mx.arange(k.shape[2]) + mask = q_indices[:, None] >= k_indices[None] + p = mx.where(mask, p, mx.finfo(mx.float32).min) + elif mask.dtype == mx.bool_: p = mx.where(mask, p, mx.finfo(mx.float32).min) else: p += mask @@ -176,7 +182,10 @@ class TestFastSelfAttentionSDPA(mlx_tests.MLXTestCase): reference = mlx_primitives_sdpa_with_gqa(q_mlx, k_mlx, v_mlx, scale) o_mlx = mx.fast.scaled_dot_product_attention( - q_mlx, k_mlx, v_mlx, scale=scale, memory_efficient_threshold=2 + q_mlx, + k_mlx, + v_mlx, + scale=scale, ) self.assertListEqual(list(reference.shape), list(o_mlx.shape)) @@ -338,10 +347,16 @@ class TestFastSDPA(mlx_tests.MLXTestCase): ) masks = [ + None, mx.array(True), mx.array([True] * (L - 10) + [False] * 10), mx.random.uniform(shape=(Nq, 1, L)) > 0.2, mx.random.uniform(shape=(L, 1, Nq)).T > 0.2, + mx.random.uniform(shape=(Nq, 1, L)), + mx.random.uniform(shape=(L, 1, Nq)).T, + mx.log(mx.random.uniform(shape=(Nq, 1, L)) > 0.2), + mx.log(mx.random.uniform(shape=(L, 1, Nq)).T > 0.2), + "causal", ] for m in masks: ref = mlx_primitives_sdpa(q, k, v, scale, mask=m) @@ -366,6 +381,11 @@ class TestFastSDPA(mlx_tests.MLXTestCase): mx.array([True] * (L - 10) + [False] * 10), mx.random.uniform(shape=(Nq, 1, L)) > 0.2, mx.random.uniform(shape=(L, 1, Nq)).T > 0.2, + mx.random.uniform(shape=(Nq, 1, L)), + mx.random.uniform(shape=(L, 1, Nq)).T, + mx.log(mx.random.uniform(shape=(Nq, 1, L)) > 0.2), + mx.log(mx.random.uniform(shape=(L, 1, Nq)).T > 0.2), + "causal", ] for m in masks: ref = mlx_primitives_sdpa(q, k, v, scale, mask=m) @@ -381,7 +401,7 @@ class TestFastSDPA(mlx_tests.MLXTestCase): def test_fast_sdpa_few_query(self): D = 64 L = 43 - Lq = 4 + Lq = 8 Nq = 8 Nkv = 1 scale = 1.0 @@ -392,10 +412,12 @@ class TestFastSDPA(mlx_tests.MLXTestCase): v = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D)) masks = [ + None, mx.array(True), mx.array([True] * (L - 10) + [False] * 10), mx.random.uniform(shape=(Nq, 1, L)) > 0.2, mx.random.uniform(shape=(L, 1, Nq)).T > 0.2, + "causal", ] for m in masks: ref = mlx_primitives_sdpa(q, k, v, scale, mask=m) @@ -416,10 +438,12 @@ class TestFastSDPA(mlx_tests.MLXTestCase): v = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D)) masks = [ + None, mx.array(True), mx.array([True] * (L - 10) + [False] * 10), mx.random.uniform(shape=(Nq, 1, L)) > 0.2, mx.random.uniform(shape=(L, 1, Nq)).T > 0.2, + "causal", ] for m in masks: ref = mlx_primitives_sdpa(q, k, v, scale, mask=m) @@ -543,6 +567,50 @@ class TestSDPA(mlx_tests.MLXTestCase): out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=mask) self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4)) + def test_sdpa_prommote_mask(self): + mask = mx.array(2.0, mx.bfloat16) + D = 64 + Nq = 4 + Nkv = 1 + scale = 1.0 + L = 256 + + mx.random.seed(0) + q = 5e-1 * mx.random.normal(shape=(1, Nq, L, D)) + k = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D)) + v = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D)) + ref = mlx_primitives_sdpa(q, k, v, scale, mask=mask) + out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=mask) + self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4)) + + def test_sdpa_nan_bug(self): + N = 128 + q_shape = (1, 1, N, 128) + kv_shape = (1, 1, N, 128) + q = mx.random.uniform(shape=q_shape) + k = mx.random.uniform(shape=kv_shape) + v = mx.random.uniform(shape=kv_shape) + + # Make boolean window causal mask + linds = rinds = mx.arange(N) + linds = linds[:, None] + rinds = rinds[None] + mask = linds >= rinds + mask = mask & (linds <= rinds + 111) + + out = mx.fast.scaled_dot_product_attention(q, k, v, mask=mask, scale=1.0) + expected = mlx_ref_attn(q, k, v, mask=mask, scale=1.0) + self.assertFalse(mx.isnan(out).any().item()) + self.assertLessEqual(mx.abs(out - expected).max().item(), 1e-4) + + # And an additive one + mask = mx.log(mask) + + out = mx.fast.scaled_dot_product_attention(q, k, v, mask=mask, scale=1.0) + expected = mlx_ref_attn(q, k, v, mask=mask, scale=1.0) + self.assertFalse(mx.isnan(out).any().item()) + self.assertLessEqual(mx.abs(out - expected).max().item(), 1e-4) + if __name__ == "__main__": unittest.main(failfast=True) diff --git a/python/tests/test_fft.py b/python/tests/test_fft.py index 95f9f7a54..ec9a48f00 100644 --- a/python/tests/test_fft.py +++ b/python/tests/test_fft.py @@ -182,6 +182,18 @@ class TestFFT(mlx_tests.MLXTestCase): out_np = np.abs(np.fft.fft(np.tile(np.reshape(np.array(b_np), (1, 4)), (4, 1)))) np.testing.assert_allclose(out_mx, out_np, atol=1e-5, rtol=1e-5) + def test_fft_into_ifft(self): + n_fft = 8193 + mx.random.seed(0) + + segment = mx.random.normal(shape=[1, n_fft]) + 1j * mx.random.normal( + shape=(1, n_fft) + ) + segment = mx.fft.fft(segment, n=n_fft) + r = mx.fft.ifft(segment, n=n_fft) + r_np = np.fft.ifft(segment, n=n_fft) + self.assertTrue(np.allclose(r, r_np, atol=1e-5, rtol=1e-5)) + if __name__ == "__main__": unittest.main() diff --git a/python/tests/test_load.py b/python/tests/test_load.py index fbc67f3c2..341564dae 100644 --- a/python/tests/test_load.py +++ b/python/tests/test_load.py @@ -385,16 +385,16 @@ class TestLoad(mlx_tests.MLXTestCase): mx.eval(x) save_file = os.path.join(self.test_dir, "donation.npy") mx.save(save_file, x) - mx.synchronize(mx.default_stream(mx.default_device())) + mx.synchronize() - mx.metal.reset_peak_memory() + mx.reset_peak_memory() scale = mx.array(2.0) y = mx.load(save_file) mx.eval(y) - load_only = mx.metal.get_peak_memory() + load_only = mx.get_peak_memory() y = mx.load(save_file) * scale mx.eval(y) - load_with_binary = mx.metal.get_peak_memory() + load_with_binary = mx.get_peak_memory() self.assertEqual(load_only, load_with_binary) diff --git a/python/tests/test_memory.py b/python/tests/test_memory.py new file mode 100644 index 000000000..7343bdc91 --- /dev/null +++ b/python/tests/test_memory.py @@ -0,0 +1,63 @@ +# Copyright © 2023-2024 Apple Inc. + +import unittest + +import mlx.core as mx +import mlx_tests + + +class TestMemory(mlx_tests.MLXTestCase): + def test_memory_info(self): + old_limit = mx.set_cache_limit(0) + + a = mx.zeros((4096,)) + mx.eval(a) + del a + self.assertEqual(mx.get_cache_memory(), 0) + self.assertEqual(mx.set_cache_limit(old_limit), 0) + self.assertEqual(mx.set_cache_limit(old_limit), old_limit) + + old_limit = mx.set_memory_limit(10) + self.assertTrue(mx.set_memory_limit(old_limit), 10) + self.assertTrue(mx.set_memory_limit(old_limit), old_limit) + + # Query active and peak memory + a = mx.zeros((4096,)) + mx.eval(a) + mx.synchronize() + active_mem = mx.get_active_memory() + self.assertTrue(active_mem >= 4096 * 4) + + b = mx.zeros((4096,)) + mx.eval(b) + del b + mx.synchronize() + + new_active_mem = mx.get_active_memory() + self.assertEqual(new_active_mem, active_mem) + peak_mem = mx.get_peak_memory() + self.assertTrue(peak_mem >= 4096 * 8) + + if mx.metal.is_available(): + cache_mem = mx.get_cache_memory() + self.assertTrue(cache_mem >= 4096 * 4) + + mx.clear_cache() + self.assertEqual(mx.get_cache_memory(), 0) + + mx.reset_peak_memory() + self.assertEqual(mx.get_peak_memory(), 0) + + @unittest.skipIf(not mx.metal.is_available(), "Metal is not available") + def test_wired_memory(self): + old_limit = mx.set_wired_limit(1000) + old_limit = mx.set_wired_limit(0) + self.assertEqual(old_limit, 1000) + + max_size = mx.metal.device_info()["max_recommended_working_set_size"] + with self.assertRaises(ValueError): + mx.set_wired_limit(max_size + 10) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/tests/test_metal.py b/python/tests/test_metal.py deleted file mode 100644 index 81cefabce..000000000 --- a/python/tests/test_metal.py +++ /dev/null @@ -1,60 +0,0 @@ -# Copyright © 2023-2024 Apple Inc. - -import unittest - -import mlx.core as mx -import mlx_tests - - -class TestMetal(mlx_tests.MLXTestCase): - @unittest.skipIf(not mx.metal.is_available(), "Metal is not available") - def test_memory_info(self): - old_limit = mx.metal.set_cache_limit(0) - - a = mx.zeros((4096,)) - mx.eval(a) - del a - self.assertEqual(mx.metal.get_cache_memory(), 0) - self.assertEqual(mx.metal.set_cache_limit(old_limit), 0) - self.assertEqual(mx.metal.set_cache_limit(old_limit), old_limit) - - old_limit = mx.metal.set_memory_limit(10) - self.assertTrue(mx.metal.set_memory_limit(old_limit), 10) - self.assertTrue(mx.metal.set_memory_limit(old_limit), old_limit) - - # Query active and peak memory - a = mx.zeros((4096,)) - mx.eval(a) - mx.synchronize() - active_mem = mx.metal.get_active_memory() - self.assertTrue(active_mem >= 4096 * 4) - - b = mx.zeros((4096,)) - mx.eval(b) - del b - mx.synchronize() - - new_active_mem = mx.metal.get_active_memory() - self.assertEqual(new_active_mem, active_mem) - peak_mem = mx.metal.get_peak_memory() - self.assertTrue(peak_mem >= 4096 * 8) - cache_mem = mx.metal.get_cache_memory() - self.assertTrue(cache_mem >= 4096 * 4) - - mx.metal.clear_cache() - self.assertEqual(mx.metal.get_cache_memory(), 0) - - mx.metal.reset_peak_memory() - self.assertEqual(mx.metal.get_peak_memory(), 0) - - old_limit = mx.metal.set_wired_limit(1000) - old_limit = mx.metal.set_wired_limit(0) - self.assertEqual(old_limit, 1000) - - max_size = mx.metal.device_info()["max_recommended_working_set_size"] - with self.assertRaises(ValueError): - mx.metal.set_wired_limit(max_size + 10) - - -if __name__ == "__main__": - unittest.main() diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 8e1cd8efd..4fcb31f18 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -690,15 +690,34 @@ class TestOps(mlx_tests.MLXTestCase): self.assertTrue(np.array_equal(b_npy, b_mlx)) def test_logsumexp(self): + def logsumexp(x, axes=None): + maxs = mx.max(x, axis=axes, keepdims=True) + return mx.log(mx.sum(mx.exp(x - maxs), axis=axes, keepdims=True)) + maxs + x = mx.array( [ [1.0, 2.0], [3.0, 4.0], ] ) - xnp = np.array(x.tolist(), dtype=np.float32) - expected = np.log(np.sum(np.exp(xnp))) - self.assertTrue(math.isclose(mx.logsumexp(x).item(), expected.item())) + self.assertTrue(math.isclose(mx.logsumexp(x).item(), logsumexp(x).item())) + + x = mx.random.uniform(shape=(1025,)) + self.assertTrue(mx.allclose(mx.logsumexp(x), logsumexp(x))) + + # Transposed + x = mx.random.uniform(shape=(2, 2, 8)) + x = x.swapaxes(0, 1) + self.assertTrue(mx.allclose(mx.logsumexp(x), logsumexp(x))) + + # Broadcast + x = mx.broadcast_to(mx.random.uniform(shape=(2, 1, 8)), (2, 2, 8)) + self.assertTrue(mx.allclose(mx.logsumexp(x), logsumexp(x))) + + # Large + x = mx.random.uniform(shape=(1025,)) + x = mx.broadcast_to(mx.random.uniform(shape=(2, 1, 8)), (2, 2, 8)) + self.assertTrue(mx.allclose(mx.logsumexp(x), logsumexp(x))) def test_mean(self): x = mx.array( @@ -845,6 +864,11 @@ class TestOps(mlx_tests.MLXTestCase): self.assertTrue(np.allclose(result, expected)) + a = mx.array(1.0) + 1j * mx.array(2.0) + result = mx.log(a) + expected = np.log(np.array(a)) + self.assertTrue(np.allclose(result, expected)) + def test_log2(self): a = mx.array([0.5, 1, 2, 10, 16]) result = mx.log2(a) @@ -852,6 +876,11 @@ class TestOps(mlx_tests.MLXTestCase): self.assertTrue(np.allclose(result, expected)) + a = mx.array(1.0) + 1j * mx.array(2.0) + result = mx.log2(a) + expected = np.log2(np.array(a)) + self.assertTrue(np.allclose(result, expected)) + def test_log10(self): a = mx.array([0.1, 1, 10, 20, 100]) result = mx.log10(a) @@ -859,6 +888,11 @@ class TestOps(mlx_tests.MLXTestCase): self.assertTrue(np.allclose(result, expected)) + a = mx.array(1.0) + 1j * mx.array(2.0) + result = mx.log10(a) + expected = np.log10(np.array(a)) + self.assertTrue(np.allclose(result, expected)) + def test_exp(self): a = mx.array([0, 0.5, -0.5, 5]) result = mx.exp(a) @@ -1628,6 +1662,15 @@ class TestOps(mlx_tests.MLXTestCase): x = mx.full((n,), vals=-float("inf")) self.assertTrue(mx.all(mx.isnan(mx.softmax(x)))) + # Transposed inputs + a = mx.random.uniform(shape=(32, 32, 32)) + b = mx.softmax(a, axis=-1) + c = mx.softmax(a.swapaxes(0, 1), axis=-1).swapaxes(0, 1) + self.assertEqual((b - c).abs().max().item(), 0.0) + + with self.assertRaises(ValueError): + mx.softmax(mx.array(1.0), axis=-1) + def test_concatenate(self): a_npy = np.random.randn(32, 32, 32) b_npy = np.random.randn(32, 32, 32) @@ -1814,6 +1857,30 @@ class TestOps(mlx_tests.MLXTestCase): y = mx.as_strided(x, (x.size,), (-1,), x.size - 1) self.assertTrue(mx.array_equal(y, x[::-1])) + def test_logcumsumexp(self): + npop = np.logaddexp.accumulate + mxop = mx.logcumsumexp + + a_npy = np.random.randn(32, 32, 32).astype(np.float32) + a_mlx = mx.array(a_npy) + + for axis in (0, 1, 2): + c_npy = npop(a_npy, axis=axis) + c_mlx = mxop(a_mlx, axis=axis) + self.assertTrue(np.allclose(c_npy, c_mlx, rtol=1e-3, atol=1e-3)) + + edge_cases_npy = [ + np.float32([-float("inf")] * 8), + np.float32([-float("inf"), 0, -float("inf")]), + np.float32([-float("inf"), float("inf"), -float("inf")]), + ] + edge_cases_mlx = [mx.array(a) for a in edge_cases_npy] + + for a_npy, a_mlx in zip(edge_cases_npy, edge_cases_mlx): + c_npy = npop(a_npy, axis=0) + c_mlx = mxop(a_mlx, axis=0) + self.assertTrue(np.allclose(c_npy, c_mlx, rtol=1e-3, atol=1e-3)) + def test_scans(self): a_npy = np.random.randn(32, 32, 32).astype(np.float32) a_mlx = mx.array(a_npy) @@ -1901,13 +1968,13 @@ class TestOps(mlx_tests.MLXTestCase): x = mx.cumsum(x) return x - mx.synchronize(mx.default_stream(mx.default_device())) + mx.synchronize() mx.eval(fn(2)) - mx.synchronize(mx.default_stream(mx.default_device())) - mem2 = mx.metal.get_peak_memory() + mx.synchronize() + mem2 = mx.get_peak_memory() mx.eval(fn(4)) - mx.synchronize(mx.default_stream(mx.default_device())) - mem4 = mx.metal.get_peak_memory() + mx.synchronize() + mem4 = mx.get_peak_memory() self.assertEqual(mem2, mem4) def test_squeeze_expand(self): diff --git a/python/tests/test_quantized.py b/python/tests/test_quantized.py index 160eb6400..eeefcd94f 100644 --- a/python/tests/test_quantized.py +++ b/python/tests/test_quantized.py @@ -174,12 +174,14 @@ class TestQuantized(mlx_tests.MLXTestCase): tests = product( [128, 64, 32], # group_size [2, 3, 4, 6, 8], # bits - [128, 256], # M + [32, 128, 256], # M [128, 256, 67], # N [0, 1, 3, 8], # B ) for group_size, bits, M, N, B in tests: with self.subTest(shape=(B, M, N), group_size=group_size, bits=bits): + if M < group_size: + continue x_shape = (1, N) if B == 0 else (B, 1, N) w_shape = (N, M) if B == 0 else (B, N, M) x = mx.random.normal(shape=x_shape, key=k1) @@ -448,6 +450,7 @@ class TestQuantized(mlx_tests.MLXTestCase): ) for kwargs in inputs: + test_shape(1, 32, 128, **kwargs) test_shape(32, 32, 256, **kwargs) test_shape(1, 32, 256, **kwargs) test_shape(32, 256, 32, transpose=False, **kwargs) @@ -486,6 +489,66 @@ class TestQuantized(mlx_tests.MLXTestCase): g2 = mx.grad(f_test)(x, qw, s, b, lhs_indices, rhs_indices) self.assertTrue(mx.allclose(g1, g2, atol=1e-4)) + def test_gather_qmm_sorted(self): + def quantize(w, transpose=True, group_size=64, bits=4): + qw, s, b = mx.quantize(w, group_size=group_size, bits=bits) + w_hat = mx.dequantize(qw, s, b, group_size=group_size, bits=bits) + if transpose: + w_hat = w_hat.swapaxes(-1, -2) + return w_hat, qw, s, b + + def gather_sort(x, indices): + N, M = indices.shape + indices = indices.flatten() + order = mx.argsort(indices) + inv_order = mx.argsort(order) + return x.flatten(0, -3)[order // M], indices[order], inv_order + + def scatter_unsort(x, inv_order, shape=None): + x = x[inv_order] + if shape is not None: + x = mx.unflatten(x, 0, shape) + return x + + parameters = [ + # L, K, D, E, I, transpose + (128, 1024, 1024, 32, 4, True), + (128, 1024, 544, 32, 4, True), + (433, 1024, 1024, 32, 4, True), + (433, 1024, 555, 32, 4, True), + (433, 2048, 1024, 32, 4, True), + (128, 1024, 1024, 32, 4, False), + (128, 1024, 544, 32, 4, False), + (433, 1024, 1024, 32, 4, False), + (433, 1024, 544, 32, 4, False), + (433, 1024, 555, 32, 4, False), + (433, 2048, 1024, 32, 4, False), + ] + for L, K, D, E, I, transpose in parameters: + K, D = (K, D) if transpose else (D, K) + ishape = (L, I) + xshape = (L, 1, 1, K) + wshape = (E, D, K) if transpose else (E, K, D) + + indices = (mx.random.uniform(shape=ishape) * E).astype(mx.uint32) + x = mx.random.normal(xshape) / K**0.5 + w = mx.random.normal(wshape) / K**0.5 + w, *wq = quantize(w, transpose=transpose) + + y1 = mx.gather_mm(x, w, rhs_indices=indices) + y2 = mx.gather_qmm(x, *wq, transpose=transpose, rhs_indices=indices) + xs, idx, inv_order = gather_sort(x, indices) + y3 = mx.gather_mm(xs, w, rhs_indices=idx, sorted_indices=True) + y4 = mx.gather_qmm( + xs, *wq, rhs_indices=idx, transpose=transpose, sorted_indices=True + ) + y3 = scatter_unsort(y3, inv_order, indices.shape) + y4 = scatter_unsort(y4, inv_order, indices.shape) + + self.assertTrue(mx.allclose(y1, y2, atol=1e-5)) + self.assertTrue(mx.allclose(y1, y3, atol=1e-5)) + self.assertTrue(mx.allclose(y1, y4, atol=1e-5)) + if __name__ == "__main__": unittest.main() diff --git a/python/tests/test_vmap.py b/python/tests/test_vmap.py index d1d4f0bd4..1a1ba23b3 100644 --- a/python/tests/test_vmap.py +++ b/python/tests/test_vmap.py @@ -635,7 +635,7 @@ class TestVmap(mlx_tests.MLXTestCase): def test_leaks(self): if mx.metal.is_available(): - mem_pre = mx.metal.get_active_memory() + mem_pre = mx.get_active_memory() else: mem_pre = 0 @@ -653,7 +653,7 @@ class TestVmap(mlx_tests.MLXTestCase): gc.collect() if mx.metal.is_available(): - mem_post = mx.metal.get_active_memory() + mem_post = mx.get_active_memory() else: mem_post = 0 diff --git a/setup.py b/setup.py index f1769b21f..d742e6595 100644 --- a/setup.py +++ b/setup.py @@ -12,7 +12,16 @@ from setuptools import Command, Extension, find_namespace_packages, setup from setuptools.command.build_ext import build_ext -def get_version(version): +def get_version(): + with open("mlx/version.h", "r") as fid: + for l in fid: + if "#define MLX_VERSION_MAJOR" in l: + major = l.split()[-1] + if "#define MLX_VERSION_MINOR" in l: + minor = l.split()[-1] + if "#define MLX_VERSION_PATCH" in l: + patch = l.split()[-1] + version = f"{major}.{minor}.{patch}" if "PYPI_RELEASE" not in os.environ: today = datetime.date.today() version = f"{version}.dev{today.year}{today.month:02d}{today.day:02d}" @@ -50,10 +59,6 @@ class CMakeBuild(build_ext): debug = int(os.environ.get("DEBUG", 0)) if self.debug is None else self.debug cfg = "Debug" if debug else "Release" - # CMake lets you override the generator - we need to check this. - # Can be set with Conda-Build, for example. - cmake_generator = os.environ.get("CMAKE_GENERATOR", "") - # Set Python_EXECUTABLE instead if you use PYBIND11_FINDPYTHON # EXAMPLE_VERSION_INFO shows you how to pass a value into the C++ code # from Python. @@ -172,7 +177,7 @@ if __name__ == "__main__": setup( name="mlx", - version=get_version("0.24.0"), + version=get_version(), author="MLX Contributors", author_email="mlx@group.apple.com", description="A framework for machine learning on Apple silicon.", diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index b8a65d03e..be4479e70 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -1,3 +1,6 @@ +# Doctest works fine with cmake 3.5 +set(CMAKE_POLICY_VERSION_MINIMUM 3.5) + FetchContent_Declare( doctest GIT_REPOSITORY "https://github.com/onqtam/doctest" diff --git a/tests/creations_tests.cpp b/tests/creations_tests.cpp index 8f94fa3b8..ea43bd0e2 100644 --- a/tests/creations_tests.cpp +++ b/tests/creations_tests.cpp @@ -139,14 +139,6 @@ TEST_CASE("test astype") { y = astype(x, int32); CHECK_EQ(y.dtype(), int32); CHECK_EQ(y.item(), -3); - - y = astype(x, uint32); - CHECK_EQ(y.dtype(), uint32); - - // Use std::copy since the result is platform dependent - uint32_t v; - std::copy(x.data(), x.data() + 1, &v); - CHECK_EQ(y.item(), v); } } diff --git a/tests/export_import_tests.cpp b/tests/export_import_tests.cpp index 83ee1e590..7ad2c640d 100644 --- a/tests/export_import_tests.cpp +++ b/tests/export_import_tests.cpp @@ -97,8 +97,7 @@ TEST_CASE("test export primitives with state") { TEST_CASE("test export functions with kwargs") { std::string file_path = get_temp_file("model.mlxfn"); - auto fun = - [](const std::map& kwargs) -> std::vector { + auto fun = [](const Kwargs& kwargs) -> std::vector { return {kwargs.at("x") + kwargs.at("y")}; }; diff --git a/tests/metal_tests.cpp b/tests/metal_tests.cpp index 1185ea04f..7aabdf36d 100644 --- a/tests/metal_tests.cpp +++ b/tests/metal_tests.cpp @@ -473,24 +473,24 @@ TEST_CASE("test metal validation") { eval(scatter_max(array(1), {}, array(2), std::vector{})); } -TEST_CASE("test metal memory info") { +TEST_CASE("test memory info") { // Test cache limits { - auto old_limit = metal::set_cache_limit(0); + auto old_limit = set_cache_limit(0); { auto a = zeros({4096}); eval(a); } - CHECK_EQ(metal::get_cache_memory(), 0); - CHECK_EQ(metal::set_cache_limit(old_limit), 0); - CHECK_EQ(metal::set_cache_limit(old_limit), old_limit); + CHECK_EQ(get_cache_memory(), 0); + CHECK_EQ(set_cache_limit(old_limit), 0); + CHECK_EQ(set_cache_limit(old_limit), old_limit); } // Test memory limits { - auto old_limit = metal::set_memory_limit(10); - CHECK_EQ(metal::set_memory_limit(old_limit), 10); - CHECK_EQ(metal::set_memory_limit(old_limit), old_limit); + auto old_limit = set_memory_limit(10); + CHECK_EQ(set_memory_limit(old_limit), 10); + CHECK_EQ(set_memory_limit(old_limit), old_limit); } // Query active and peak memory @@ -498,22 +498,22 @@ TEST_CASE("test metal memory info") { auto a = zeros({4096}); eval(a); synchronize(); - auto active_mem = metal::get_active_memory(); + auto active_mem = get_active_memory(); CHECK(active_mem >= 4096 * 4); { auto b = zeros({4096}); eval(b); } synchronize(); - auto new_active_mem = metal::get_active_memory(); + auto new_active_mem = get_active_memory(); CHECK_EQ(new_active_mem, active_mem); - auto peak_mem = metal::get_peak_memory(); + auto peak_mem = get_peak_memory(); CHECK(peak_mem >= 4096 * 8); - auto cache_mem = metal::get_cache_memory(); + auto cache_mem = get_cache_memory(); CHECK(cache_mem >= 4096 * 4); } - metal::clear_cache(); - CHECK_EQ(metal::get_cache_memory(), 0); + clear_cache(); + CHECK_EQ(get_cache_memory(), 0); } diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index b8bc8b45e..de0f3352c 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -1241,6 +1241,47 @@ TEST_CASE("test arithmetic unary ops") { // bool x = array({false, true}); CHECK(array_equal(sign(x), x).item()); + + // uint64 + array x_uint64( + {uint64_t(0xa11cc311cb6acd70), + uint64_t(0x7a375ac3ebb533f3), + uint64_t(0x734969adf9d7190c), + uint64_t(0xb400515a4f673424)}); + array expected( + {uint64_t(0x0000000000000001), + uint64_t(0x0000000000000001), + uint64_t(0x0000000000000001), + uint64_t(0x0000000000000001)}); + CHECK(array_equal(sign(x_uint64), expected).item()); + + x_uint64 = array( + {uint64_t(0xa11cc311cb6acd70), + uint64_t(0x7a375ac3ebb533f3), + uint64_t(0x734969adf9d7190c)}); + expected = array( + {uint64_t(0x0000000000000001), + uint64_t(0x0000000000000001), + uint64_t(0x0000000000000001)}); + CHECK(array_equal(sign(x_uint64), expected).item()); + + x_uint64 = + array({uint64_t(0xa11cc311cb6acd70), uint64_t(0x7a375ac3ebb533f3)}); + expected = + array({uint64_t(0x0000000000000001), uint64_t(0x0000000000000001)}); + CHECK(array_equal(sign(x_uint64), expected).item()); + + x_uint64 = array({uint64_t(0xa11cc311cb6acd70)}); + expected = array({uint64_t(0x0000000000000001)}); + CHECK(array_equal(sign(x_uint64), expected).item()); + + x_uint64 = array({uint64_t(0xffffffffffffffff)}); + expected = array({uint64_t(0x0000000000000001)}); + CHECK(array_equal(sign(x_uint64), expected).item()); + + x_uint64 = array({uint64_t(0x0000000000000001)}); + expected = array({uint64_t(0x0000000000000001)}); + CHECK(array_equal(sign(x_uint64), expected).item()); } constexpr float neginf = -std::numeric_limits::infinity(); @@ -3833,3 +3874,41 @@ TEST_CASE("test contiguous") { CHECK(x.flags().col_contiguous); CHECK_EQ(x.strides(), decltype(x.strides()){1, 2}); } + +TEST_CASE("test bitwise shift operations") { + std::vector dtypes = { + int8, int16, int32, int64, uint8, uint16, uint32, uint64}; + + for (const auto& dtype : dtypes) { + array x = full({4}, 1, dtype); + array y = full({4}, 2, dtype); + + auto left_shift_result = left_shift(x, y); + CHECK_EQ(left_shift_result.dtype(), dtype); + CHECK(array_equal(left_shift_result, array({4, 4, 4, 4}, dtype)) + .item()); + + auto right_shift_result = right_shift(full({4}, 4, dtype), y); + CHECK_EQ(right_shift_result.dtype(), dtype); + CHECK(array_equal(right_shift_result, full({4}, 1, dtype)).item()); + } + + array x = array({127, -128}, int8); + array y = array({1, 1}, int8); + auto left_shift_result = left_shift(x, y); + auto right_shift_result = right_shift(x, y); + + CHECK(array_equal(left_shift_result, array({-2, 0}, int8)).item()); + CHECK(array_equal(right_shift_result, array({63, -64}, int8)).item()); + + array x_bool = full({4}, true, bool_); + array y_bool = full({4}, true, bool_); + auto left_shift_bool_result = left_shift(x_bool, y_bool); + auto right_shift_bool_result = right_shift(x_bool, y_bool); + + CHECK_EQ(left_shift_bool_result.dtype(), uint8); + CHECK(array_equal(left_shift_bool_result, full({4}, 2, uint8)).item()); + + CHECK_EQ(right_shift_bool_result.dtype(), uint8); + CHECK(array_equal(right_shift_bool_result, full({4}, 0, uint8)).item()); +} \ No newline at end of file diff --git a/tests/utils_tests.cpp b/tests/utils_tests.cpp index a17f12e33..88c3e7b37 100644 --- a/tests/utils_tests.cpp +++ b/tests/utils_tests.cpp @@ -55,3 +55,13 @@ TEST_CASE("test finfo") { CHECK_EQ(finfo(float16).min, -65504); CHECK_EQ(finfo(float16).max, 65504); } + +TEST_CASE("test iinfo") { + CHECK_EQ(iinfo(int8).dtype, int8); + CHECK_EQ(iinfo(int64).dtype, int64); + CHECK_EQ(iinfo(int64).max, std::numeric_limits::max()); + CHECK_EQ(iinfo(uint64).max, std::numeric_limits::max()); + CHECK_EQ(iinfo(uint64).max, std::numeric_limits::max()); + CHECK_EQ(iinfo(uint64).min, 0); + CHECK_EQ(iinfo(int64).min, std::numeric_limits::min()); +}