diff --git a/.circleci/config.yml b/.circleci/config.yml index 6dc7ec4df..3d24cb432 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -7,15 +7,6 @@ parameters: nightly_build: type: boolean default: false - weekly_build: - type: boolean - default: false - test_release: - type: boolean - default: false - linux_release: - type: boolean - default: false jobs: build_documentation: @@ -38,7 +29,7 @@ jobs: pip install --upgrade pip pip install --upgrade cmake pip install -r docs/requirements.txt - CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` pip install . -v + pip install . -v - when: condition: not: << parameters.upload-docs >> @@ -70,9 +61,9 @@ jobs: git push -f origin gh-pages linux_build_and_test: - docker: - - image: cimg/python:3.9 - + machine: + image: ubuntu-2204:current + resource_class: large steps: - checkout - run: @@ -84,37 +75,33 @@ jobs: - run: name: Install dependencies command: | - pip install --upgrade cmake - pip install nanobind==2.4.0 - pip install numpy + export DEBIAN_FRONTEND=noninteractive + export NEEDRESTART_MODE=a sudo apt-get update - sudo apt-get install libblas-dev liblapack-dev liblapacke-dev + sudo apt-get upgrade -y + pip install --upgrade cmake + sudo apt-get install -y libblas-dev liblapack-dev liblapacke-dev sudo apt-get install openmpi-bin openmpi-common libopenmpi-dev - run: name: Install Python package command: | - CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" \ - CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \ - python3 setup.py build_ext --inplace - CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" \ - CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \ - python3 setup.py develop + pip install -e ".[dev]" - run: name: Generate package stubs command: | echo "stubs" pip install typing_extensions - python setup.py generate_stubs + python setup.py generate_stubs - run: name: Run Python tests command: | - python3 -m unittest discover python/tests -v + python -m unittest discover python/tests -v mpirun --bind-to none -host localhost:8 -np 8 python python/tests/mpi_test_distributed.py mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py - run: name: Build CPP only command: | - mkdir -p build && cd build + mkdir -p build && cd build cmake .. -DMLX_BUILD_METAL=OFF -DCMAKE_BUILD_TYPE=DEBUG make -j `nproc` - run: @@ -154,15 +141,14 @@ jobs: name: Install Python package command: | source env/bin/activate - DEBUG=1 CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \ - CMAKE_ARGS="-DCMAKE_COMPILE_WARNING_AS_ERROR=ON" \ + DEBUG=1 CMAKE_ARGS="-DCMAKE_COMPILE_WARNING_AS_ERROR=ON" \ pip install -e . -v - run: name: Generate package stubs command: | source env/bin/activate pip install typing_extensions - python setup.py generate_stubs + python setup.py generate_stubs - run: name: Run Python tests command: | @@ -205,13 +191,34 @@ jobs: name: Run Python tests with JIT command: | source env/bin/activate - CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \ - CMAKE_ARGS="-DMLX_METAL_JIT=ON" \ + CMAKE_ARGS="-DMLX_METAL_JIT=ON" \ pip install -e . -v LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 \ METAL_DEBUG_ERROR_MODE=0 \ python -m xmlrunner discover -v python/tests -o test-results/gpu_jit + cuda_build_and_test: + machine: + image: linux-cuda-12:default + resource_class: gpu.nvidia.small.gen2 + steps: + - checkout + - run: + name: Install Python package + command: | + sudo apt-get update + sudo apt-get install libblas-dev liblapack-dev liblapacke-dev + python -m venv env + source env/bin/activate + CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \ + pip install -e ".[dev]" + - run: + name: Run Python tests + command: | + source env/bin/activate + LOW_MEMORY=1 DEVICE=cpu python -m unittest discover python/tests -v + LOW_MEMORY=1 DEVICE=gpu python -m tests discover python/tests -v + build_release: parameters: python_version: @@ -252,21 +259,28 @@ jobs: command: | source env/bin/activate env -u MACOSX_DEPLOYMENT_TARGET DEV_RELEASE=1 \ - CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \ pip install . -v - run: name: Generate package stubs command: | source env/bin/activate pip install typing_extensions - python setup.py generate_stubs + python setup.py generate_stubs - run: name: Build Python package command: | source env/bin/activate - << parameters.build_env >> \ - CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \ - python -m build -w + << parameters.build_env >> MLX_BUILD_STAGE=1 python -m build -w + - when: + condition: + equal: ["3.9", << parameters.python_version >>] + steps: + - run: + name: Build common package + command: | + source env/bin/activate + python setup.py clean --all + << parameters.build_env >> MLX_BUILD_STAGE=2 python -m build -w - when: condition: << parameters.build_env >> steps: @@ -283,52 +297,99 @@ jobs: python_version: type: string default: "3.9" - extra_env: + build_env: type: string - default: "DEV_RELEASE=1" - docker: - - image: ubuntu:20.04 + default: "" + machine: + image: ubuntu-2204:current + resource_class: large steps: - checkout - run: name: Build wheel command: | PYTHON=python<< parameters.python_version >> - apt-get update - apt-get upgrade -y - DEBIAN_FRONTEND=noninteractive TZ=Etc/UTC apt-get -y install tzdata - apt-get install -y apt-utils - apt-get install -y software-properties-common - add-apt-repository -y ppa:deadsnakes/ppa - apt-get install -y $PYTHON $PYTHON-dev $PYTHON-full - apt-get install -y libblas-dev liblapack-dev liblapacke-dev - apt-get install -y build-essential git + export DEBIAN_FRONTEND=noninteractive + export NEEDRESTART_MODE=a + sudo apt-get update + sudo apt-get upgrade -y + TZ=Etc/UTC sudo apt-get -y install tzdata + sudo apt-get install -y apt-utils + sudo apt-get install -y software-properties-common + sudo add-apt-repository -y ppa:deadsnakes/ppa + sudo apt-get install -y $PYTHON $PYTHON-dev $PYTHON-full + sudo apt-get install -y libblas-dev liblapack-dev liblapacke-dev + sudo apt-get install -y build-essential git $PYTHON -m venv env source env/bin/activate pip install --upgrade pip pip install --upgrade cmake - pip install nanobind==2.4.0 - pip install --upgrade setuptools - pip install numpy pip install auditwheel pip install patchelf pip install build pip install twine - << parameters.extra_env >> \ - CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \ - pip install . -v + << parameters.build_env >> pip install ".[dev]" -v pip install typing_extensions - python setup.py generate_stubs - << parameters.extra_env >> \ - CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \ - python -m build --wheel - auditwheel show dist/* - auditwheel repair dist/* --plat manylinux_2_31_x86_64 + python setup.py generate_stubs + MLX_BUILD_STAGE=1 << parameters.build_env >> python -m build -w + bash python/scripts/repair_linux.sh + - when: + condition: + equal: ["3.9", << parameters.python_version >>] + steps: + - run: + name: Build common package + command: | + source env/bin/activate + python setup.py clean --all + << parameters.build_env >> MLX_BUILD_STAGE=2 \ + python -m build -w + auditwheel repair dist/mlx_cpu*.whl --plat manylinux_2_35_x86_64 + - when: + condition: << parameters.build_env >> + steps: + - run: + name: Upload packages + command: | + source env/bin/activate + twine upload wheelhouse/*.whl + - store_artifacts: + path: wheelhouse/ + + build_cuda_release: + parameters: + build_env: + type: string + default: "" + machine: + image: linux-cuda-12:default + resource_class: gpu.nvidia.small.gen2 + steps: + - checkout - run: - name: Upload package + name: Build wheel command: | + sudo apt-get update + sudo apt-get install libblas-dev liblapack-dev liblapacke-dev + sudo apt-get install zip + python -m venv env source env/bin/activate - twine upload wheelhouse/* + pip install auditwheel + pip install patchelf + pip install build + pip install twine + << parameters.build_env >> MLX_BUILD_STAGE=2 \ + CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \ + python -m build -w + bash python/scripts/repair_cuda.sh + - when: + condition: << parameters.build_env >> + steps: + - run: + name: Upload package + command: | + source env/bin/activate + twine upload wheelhouse/*.whl - store_artifacts: path: wheelhouse/ @@ -340,22 +401,19 @@ workflows: pattern: "^(?!pull/)[-\\w]+$" value: << pipeline.git.branch >> - not: << pipeline.parameters.nightly_build >> - - not: << pipeline.parameters.weekly_build >> - - not: << pipeline.parameters.test_release >> jobs: - mac_build_and_test: matrix: parameters: macosx_deployment_target: ["13.5", "14.0"] - linux_build_and_test + - cuda_build_and_test - build_documentation build_pypi_release: when: and: - not: << pipeline.parameters.nightly_build >> - - not: << pipeline.parameters.weekly_build >> - - not: << pipeline.parameters.test_release >> jobs: - build_release: filters: @@ -437,6 +495,25 @@ workflows: branches: ignore: /.*/ upload-docs: true + - build_linux_release: + filters: + tags: + only: /^v.*/ + branches: + ignore: /.*/ + matrix: + parameters: + python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"] + build_env: ["PYPI_RELEASE=1"] + - build_cuda_release: + filters: + tags: + only: /^v.*/ + branches: + ignore: /.*/ + matrix: + parameters: + build_env: ["PYPI_RELEASE=1"] prb: when: @@ -455,6 +532,8 @@ workflows: macosx_deployment_target: ["13.5", "14.0"] - linux_build_and_test: requires: [ hold ] + - cuda_build_and_test: + requires: [ hold ] nightly_build: when: and: @@ -513,88 +592,8 @@ workflows: - macosx_deployment_target: "15.0" xcode_version: "15.0.0" python_version: "3.13" - weekly_build: - when: - and: - - equal: [ main, << pipeline.git.branch >> ] - - << pipeline.parameters.weekly_build >> - jobs: - - build_release: - matrix: - parameters: - python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"] - macosx_deployment_target: ["13.5", "14.0", "15.0"] - build_env: ["DEV_RELEASE=1"] - xcode_version: ["16.2.0", "15.0.0"] - exclude: - - macosx_deployment_target: "13.5" - xcode_version: "16.2.0" - python_version: "3.9" - build_env: "DEV_RELEASE=1" - - macosx_deployment_target: "13.5" - xcode_version: "16.2.0" - python_version: "3.10" - build_env: "DEV_RELEASE=1" - - macosx_deployment_target: "13.5" - xcode_version: "16.2.0" - python_version: "3.11" - build_env: "DEV_RELEASE=1" - - macosx_deployment_target: "13.5" - xcode_version: "16.2.0" - python_version: "3.12" - build_env: "DEV_RELEASE=1" - - macosx_deployment_target: "13.5" - xcode_version: "16.2.0" - python_version: "3.13" - build_env: "DEV_RELEASE=1" - - macosx_deployment_target: "14.0" - xcode_version: "15.0.0" - python_version: "3.9" - build_env: "DEV_RELEASE=1" - - macosx_deployment_target: "14.0" - xcode_version: "15.0.0" - python_version: "3.10" - build_env: "DEV_RELEASE=1" - - macosx_deployment_target: "14.0" - xcode_version: "15.0.0" - python_version: "3.11" - build_env: "DEV_RELEASE=1" - - macosx_deployment_target: "14.0" - xcode_version: "15.0.0" - python_version: "3.12" - build_env: "DEV_RELEASE=1" - - macosx_deployment_target: "14.0" - xcode_version: "15.0.0" - python_version: "3.13" - build_env: "DEV_RELEASE=1" - - macosx_deployment_target: "15.0" - xcode_version: "15.0.0" - python_version: "3.9" - build_env: "DEV_RELEASE=1" - - macosx_deployment_target: "15.0" - xcode_version: "15.0.0" - python_version: "3.10" - build_env: "DEV_RELEASE=1" - - macosx_deployment_target: "15.0" - xcode_version: "15.0.0" - python_version: "3.11" - build_env: "DEV_RELEASE=1" - - macosx_deployment_target: "15.0" - xcode_version: "15.0.0" - python_version: "3.12" - build_env: "DEV_RELEASE=1" - - macosx_deployment_target: "15.0" - xcode_version: "15.0.0" - python_version: "3.13" - build_env: "DEV_RELEASE=1" - linux_test_release: - when: - and: - - equal: [ main, << pipeline.git.branch >> ] - - << pipeline.parameters.linux_release >> - jobs: - build_linux_release: matrix: parameters: python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"] - extra_env: ["PYPI_RELEASE=1"] + - build_cuda_release diff --git a/.gitignore b/.gitignore index e748ee2bf..43629548d 100644 --- a/.gitignore +++ b/.gitignore @@ -36,6 +36,7 @@ share/python-wheels/ .installed.cfg *.egg MANIFEST +uv.lock # vim *.swp diff --git a/CMakeLists.txt b/CMakeLists.txt index e2002fc94..9e67e4bf2 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -34,6 +34,7 @@ option(MLX_BUILD_BENCHMARKS "Build benchmarks for mlx" OFF) option(MLX_BUILD_PYTHON_BINDINGS "Build python bindings for mlx" OFF) option(MLX_BUILD_METAL "Build metal backend" ON) option(MLX_BUILD_CPU "Build cpu backend" ON) +option(MLX_BUILD_CUDA "Build cuda backend" OFF) option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF) option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF) option(MLX_BUILD_GGUF "Include support for GGUF format" ON) @@ -63,10 +64,8 @@ if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin") message(WARNING "Building for x86_64 arch is not officially supported.") endif() endif() - else() set(MLX_BUILD_METAL OFF) - message(WARNING "MLX is prioritised for Apple silicon systems using macOS.") endif() # ----------------------------- Lib ----------------------------- @@ -83,6 +82,10 @@ if(MLX_BUILD_METAL) set(QUARTZ_LIB "-framework QuartzCore") endif() +if(MLX_BUILD_CUDA) + enable_language(CUDA) +endif() + if(MLX_BUILD_METAL AND NOT METAL_LIB) message(STATUS "Metal not found. Unable to build GPU") set(MLX_BUILD_METAL OFF) @@ -226,6 +229,9 @@ target_include_directories( mlx PUBLIC $ $) +# Do not add mlx_EXPORTS define for shared library. +set_target_properties(mlx PROPERTIES DEFINE_SYMBOL "") + FetchContent_Declare( fmt GIT_REPOSITORY https://github.com/fmtlib/fmt.git diff --git a/MANIFEST.in b/MANIFEST.in index 9faafee45..d0daeb7ae 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,4 +1,6 @@ include CMakeLists.txt +include mlx.pc.in recursive-include mlx/ * +include cmake/* include python/src/* include python/mlx/py.typed # support type hinting as in PEP-561 diff --git a/benchmarks/cpp/irregular_strides.cpp b/benchmarks/cpp/irregular_strides.cpp index cda76fed6..552461335 100644 --- a/benchmarks/cpp/irregular_strides.cpp +++ b/benchmarks/cpp/irregular_strides.cpp @@ -1,5 +1,6 @@ // Copyright © 2023 Apple Inc. +#include #include #include diff --git a/benchmarks/cpp/single_ops.cpp b/benchmarks/cpp/single_ops.cpp index 5b327be58..1f93a78d7 100644 --- a/benchmarks/cpp/single_ops.cpp +++ b/benchmarks/cpp/single_ops.cpp @@ -192,6 +192,22 @@ void time_reductions() { auto argmin_along_1 = [&a]() { return mx::argmin(a, 1, false); }; TIME(argmin_along_1); + + auto indices = mx::array({1}); + auto updates = mx::reshape(mx::array({NAN}), {1, 1, 1}); + std::vector axes{0}; + auto b = scatter(a, {indices}, updates, axes); + mx::eval(b); + + auto max_along_0 = [&b]() { return mx::max(b, 0, false); }; + TIME(max_along_0); + auto max_along_1 = [&b]() { return mx::max(b, 1, false); }; + TIME(max_along_1); + + auto min_along_0 = [&b]() { return mx::min(b, 0, false); }; + TIME(min_along_0); + auto min_along_1 = [&b]() { return mx::min(b, 1, false); }; + TIME(min_along_1); } void time_gather_scatter() { diff --git a/benchmarks/python/comparative/bench_torch.py b/benchmarks/python/comparative/bench_torch.py index a2157707b..dd3436d9a 100644 --- a/benchmarks/python/comparative/bench_torch.py +++ b/benchmarks/python/comparative/bench_torch.py @@ -5,6 +5,7 @@ import os import time import torch +import torch.cuda import torch.mps @@ -44,8 +45,10 @@ def bench(f, *args): def sync_if_needed(x): - if x.device != torch.device("cpu"): + if x.device == torch.device("mps"): torch.mps.synchronize() + elif x.device == torch.device("cuda"): + torch.cuda.synchronize() @torch.no_grad() @@ -99,6 +102,14 @@ def reduction(op, axis, x): sync_if_needed(x) +@torch.no_grad() +def sum_and_add(axis, x, y): + z = x.sum(axis=axis, keepdims=True) + for i in range(50): + z = (z + y).sum(axis=axis, keepdims=True) + sync_if_needed(x) + + @torch.no_grad() def softmax(axis, x): ys = [] @@ -340,7 +351,11 @@ if __name__ == "__main__": args.axis.pop(0) torch.set_num_threads(1) - device = "cpu" if args.cpu else "mps" + device = "mps" + if torch.cuda.is_available(): + device = "cuda" + if args.cpu: + device = "cpu" types = args.dtype if not types: @@ -460,5 +475,8 @@ if __name__ == "__main__": elif args.benchmark == "selu": print(bench(selu, x)) + elif args.benchmark == "sum_and_add": + print(bench(sum_and_add, axis, *xs)) + else: raise ValueError(f"Unknown benchmark `{args.benchmark}`.") diff --git a/benchmarks/python/conv_unaligned_bench.py b/benchmarks/python/conv_unaligned_bench.py new file mode 100644 index 000000000..981d7b48b --- /dev/null +++ b/benchmarks/python/conv_unaligned_bench.py @@ -0,0 +1,107 @@ +import math +import time + +import mlx.core as mx +import numpy as np +import torch + +N_warmup = 10 +N_iter_bench = 100 +N_iter_func = 5 + + +def bench(f, a, b): + for i in range(N_warmup): + f(a, b) + torch.mps.synchronize() + + s = time.perf_counter_ns() + for i in range(N_iter_bench): + f(a, b) + e = time.perf_counter_ns() + return (e - s) * 1e-9 + + +def make_mx_conv_2D(strides=(1, 1), padding=(0, 0), groups=1): + def mx_conv_2D(a, b): + ys = [] + for i in range(N_iter_func): + y = mx.conv2d(a, b, stride=strides, padding=padding, groups=groups) + ys.append(y) + mx.eval(ys) + return ys + + return mx_conv_2D + + +def make_pt_conv_2D(strides=(1, 1), padding=(0, 0), groups=1): + @torch.no_grad() + def pt_conv_2D(a, b): + ys = [] + for i in range(N_iter_func): + y = torch.conv2d(a, b, stride=strides, padding=padding, groups=groups) + ys.append(y) + torch.mps.synchronize() + return ys + + return pt_conv_2D + + +def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype): + scale = 1.0 / math.sqrt(kH * kH * C) + a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype) + b_np = np.random.uniform(-scale, scale, (O, kH, kW, int(C / groups))).astype( + np_dtype + ) + + a_mx = mx.array(a_np) + b_mx = mx.array(b_np) + + a_pt = torch.from_numpy(a_np.transpose((0, 3, 1, 2))).to("mps") + b_pt = torch.from_numpy(b_np.transpose((0, 3, 1, 2))).to("mps") + + torch.mps.synchronize() + + f_mx = make_mx_conv_2D(strides, padding, groups) + f_pt = make_pt_conv_2D(strides, padding, groups) + + time_torch = bench(f_pt, a_pt, b_pt) + time_mlx = bench(f_mx, a_mx, b_mx) + + out_mx = mx.conv2d(a_mx, b_mx, stride=strides, padding=padding, groups=groups) + out_pt = torch.conv2d( + a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups + ) + out_pt = torch.permute(out_pt, (0, 2, 3, 1)) + out_pt = out_pt.numpy(force=True) + + atol = 2e-5 if np_dtype == np.float32 else 1e-4 + + if not np.allclose(out_pt, out_mx, atol=atol): + print( + f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}" + ) + + return time_mlx, time_torch + + +if __name__ == "__main__": + dtype = "float32" + shapes = ( + (4, 32, 32, 21, 3, 3, 128), + (4, 32, 32, 21, 3, 3, 37), + (4, 32, 32, 370, 3, 3, 370), + (4, 32, 32, 370, 7, 7, 128), + (2, 320, 640, 21, 7, 7, 21), + ) + for N, H, W, C, kh, kw, O in shapes: + time_mlx, time_torch = bench_shape( + N, H, W, C, kh, kw, O, (1, 1), (0, 0), 1, dtype + ) + diff = time_torch / time_mlx - 1.0 + + print( + f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kh:2d}, {kw:2d}, {C:3d}), {dtype}, {100. * diff:+5.2f}%" + ) + if time_mlx >= 2.0 * time_torch: + print("ATTENTION ^^^^^^^") diff --git a/benchmarks/python/layer_norm_bench.py b/benchmarks/python/layer_norm_bench.py index 69263835a..29925de0b 100644 --- a/benchmarks/python/layer_norm_bench.py +++ b/benchmarks/python/layer_norm_bench.py @@ -1,5 +1,7 @@ # Copyright © 2023-2024 Apple Inc. +from functools import partial + import mlx.core as mx import mlx.nn as nn from time_utils import time_fn @@ -18,51 +20,63 @@ def layer_norm(x, w, b, eps): return y -def time_layer_norm(): +def time_layer_norm(N, dt): + L = 1024 f1 = lambda x, w, b, y: (layer_norm(x, w, b, 1e-5) * y).sum() f2 = lambda x, w, b, y: (mx.fast.layer_norm(x, w, b, 1e-5) * y).sum() g1 = mx.grad(f1, argnums=(0, 1, 2)) g2 = mx.grad(f2, argnums=(0, 1, 2)) - x = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16) - w = mx.random.uniform(shape=(4096,)).astype(mx.float16) - b = mx.random.uniform(shape=(4096,)).astype(mx.float16) - y = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16) + x = mx.random.uniform(shape=(8, L, N)).astype(dt) + w = mx.random.uniform(shape=(N,)).astype(dt) + b = mx.random.uniform(shape=(N,)).astype(dt) + y = mx.random.uniform(shape=(8, L, N)).astype(dt) mx.eval(x, w, b, y) - def layer_norm_loop(g, x, w, b): + def layer_norm_loop(f, x, w, b): + for _ in range(32): + x = f(x, w, b) + return x + + time_fn(layer_norm_loop, partial(layer_norm, eps=1e-5), x, w, b) + time_fn(layer_norm_loop, partial(mx.fast.layer_norm, eps=1e-5), x, w, b) + + def layer_norm_grad_loop(g, x, w, b): gx, gw, gb = x, w, b for _ in range(32): gx, gw, gb = g(gx, gw, gb, y) return gx, gw, gb - time_fn(layer_norm_loop, g1, x, w, b) - time_fn(layer_norm_loop, g2, x, w, b) - time_fn(layer_norm_loop, mx.compile(g1), x, w, b) - time_fn(layer_norm_loop, mx.compile(g2), x, w, b) + time_fn(layer_norm_grad_loop, g1, x, w, b) + time_fn(layer_norm_grad_loop, g2, x, w, b) + time_fn(layer_norm_grad_loop, mx.compile(g1), x, w, b) + time_fn(layer_norm_grad_loop, mx.compile(g2), x, w, b) f1 = lambda x, y: (layer_norm(x, None, None, 1e-5) * y).sum() f2 = lambda x, y: (mx.fast.layer_norm(x, None, None, 1e-5) * y).sum() g1 = mx.grad(f1, argnums=(0,)) g2 = mx.grad(f2, argnums=(0,)) - x = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16) - w = mx.random.uniform(shape=(4096,)).astype(mx.float16) - b = mx.random.uniform(shape=(4096,)).astype(mx.float16) - y = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16) + x = mx.random.uniform(shape=(8, L, N)).astype(dt) + w = mx.random.uniform(shape=(N,)).astype(dt) + b = mx.random.uniform(shape=(N,)).astype(dt) + y = mx.random.uniform(shape=(8, L, N)).astype(dt) mx.eval(x, w, b, y) - def layer_norm_loop(g, x): + def layer_norm_grad_x_loop(g, x): gx = x for _ in range(32): gx = g(gx, y) return gx - time_fn(layer_norm_loop, g1, x) - time_fn(layer_norm_loop, g2, x) - time_fn(layer_norm_loop, mx.compile(g1), x) - time_fn(layer_norm_loop, mx.compile(g2), x) + time_fn(layer_norm_grad_x_loop, g1, x) + time_fn(layer_norm_grad_x_loop, g2, x) + time_fn(layer_norm_grad_x_loop, mx.compile(g1), x) + time_fn(layer_norm_grad_x_loop, mx.compile(g2), x) if __name__ == "__main__": - time_layer_norm() + for dt in [mx.float32, mx.float16, mx.bfloat16]: + for n in [1024, 2048, 4096, 8192, 8192 + 1024]: + print(dt, n) + time_layer_norm(n, dt) diff --git a/benchmarks/python/single_ops.py b/benchmarks/python/single_ops.py index 3160a1833..939faf305 100644 --- a/benchmarks/python/single_ops.py +++ b/benchmarks/python/single_ops.py @@ -51,6 +51,20 @@ def time_maximum(): time_fn(mx.maximum, a, b) +def time_max(): + a = mx.random.uniform(shape=(32, 1024, 1024)) + a[1, 1] = mx.nan + mx.eval(a) + time_fn(mx.max, a, 0) + + +def time_min(): + a = mx.random.uniform(shape=(32, 1024, 1024)) + a[1, 1] = mx.nan + mx.eval(a) + time_fn(mx.min, a, 0) + + def time_negative(): a = mx.random.uniform(shape=(10000, 1000)) mx.eval(a) @@ -108,6 +122,8 @@ if __name__ == "__main__": time_add() time_matmul() + time_min() + time_max() time_maximum() time_exp() time_negative() diff --git a/cmake/extension.cmake b/cmake/extension.cmake index 3270b0056..13db804a1 100644 --- a/cmake/extension.cmake +++ b/cmake/extension.cmake @@ -11,13 +11,14 @@ include(CMakeParseArguments) # Args: TARGET: Custom target to be added for the metal library TITLE: Name of # the .metallib OUTPUT_DIRECTORY: Where to place ${TITLE}.metallib SOURCES: List # of source files INCLUDE_DIRS: List of include dirs DEPS: List of dependency -# files (like headers) +# files (like headers) DEBUG: Boolean, if true, enables debug compile options +# for this specific library. If not provided, uses global MLX_METAL_DEBUG. # # clang format on macro(mlx_build_metallib) # Parse args - set(oneValueArgs TARGET TITLE OUTPUT_DIRECTORY) + set(oneValueArgs TARGET TITLE OUTPUT_DIRECTORY DEBUG) set(multiValueArgs SOURCES INCLUDE_DIRS DEPS) cmake_parse_arguments(MTLLIB "" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) @@ -26,6 +27,10 @@ macro(mlx_build_metallib) # Collect compile options set(MTLLIB_COMPILE_OPTIONS -Wall -Wextra -fno-fast-math -Wno-c++17-extensions) + if(MLX_METAL_DEBUG OR MTLLIB_DEBUG) + set(MTLLIB_COMPILE_OPTIONS ${MTLLIB_COMPILE_OPTIONS} -gline-tables-only + -frecord-sources) + endif() # Prepare metallib build command add_custom_command( diff --git a/docs/src/conf.py b/docs/src/conf.py index abc68c3a2..d9dd32ad1 100644 --- a/docs/src/conf.py +++ b/docs/src/conf.py @@ -10,7 +10,7 @@ import mlx.core as mx # -- Project information ----------------------------------------------------- project = "MLX" -copyright = "2023, MLX Contributors" +copyright = "2023, Apple" author = "MLX Contributors" version = ".".join(mx.__version__.split(".")[:3]) release = version diff --git a/docs/src/dev/custom_metal_kernels.rst b/docs/src/dev/custom_metal_kernels.rst index 3e92f2814..873b1e544 100644 --- a/docs/src/dev/custom_metal_kernels.rst +++ b/docs/src/dev/custom_metal_kernels.rst @@ -8,23 +8,26 @@ MLX supports writing custom Metal kernels through the Python and C++ APIs. Simple Example -------------- +.. currentmodule:: mlx.core + Let's write a custom kernel that computes ``exp`` elementwise: .. code-block:: python - def exp_elementwise(a: mx.array): - source = """ - uint elem = thread_position_in_grid.x; - T tmp = inp[elem]; - out[elem] = metal::exp(tmp); - """ + source = """ + uint elem = thread_position_in_grid.x; + T tmp = inp[elem]; + out[elem] = metal::exp(tmp); + """ - kernel = mx.fast.metal_kernel( - name="myexp", - input_names=["inp"], - output_names=["out"], - source=source, - ) + kernel = mx.fast.metal_kernel( + name="myexp", + input_names=["inp"], + output_names=["out"], + source=source, + ) + + def exp_elementwise(a: mx.array): outputs = kernel( inputs=[a], template=[("T", mx.float32)], @@ -39,8 +42,13 @@ Let's write a custom kernel that computes ``exp`` elementwise: b = exp_elementwise(a) assert mx.allclose(b, mx.exp(a)) +Every time you make a kernel, a new Metal library is created and possibly +JIT compiled. To reduce the overhead from that, build the kernel once with +:func:`fast.metal_kernel` and then use it many times. + .. note:: - We are only required to pass the body of the Metal kernel in ``source``. + Only pass the body of the Metal kernel in ``source``. The function + signature is generated automatically. The full function signature will be generated using: @@ -78,44 +86,51 @@ Putting this all together, the generated function signature for ``myexp`` is as template [[host_name("custom_kernel_myexp_float")]] [[kernel]] decltype(custom_kernel_myexp_float) custom_kernel_myexp_float; -Note: ``grid`` and ``threadgroup`` are parameters to the Metal `dispatchThreads `_ function. -This means we will launch ``mx.prod(grid)`` threads, subdivided into ``threadgroup`` size threadgroups. -For optimal performance, each thread group dimension should be less than or equal to the corresponding grid dimension. +Note: ``grid`` and ``threadgroup`` are parameters to the Metal `dispatchThreads +`_ +function. This means we will launch ``mx.prod(grid)`` threads, subdivided into +``threadgroup`` size threadgroups. For optimal performance, each thread group +dimension should be less than or equal to the corresponding grid dimension. -Passing ``verbose=True`` to ``mx.fast.metal_kernel.__call__`` will print the generated code for debugging purposes. +Passing ``verbose=True`` to :func:`ast.metal_kernel.__call__` will print the +generated code for debugging purposes. Using Shape/Strides ------------------- -``mx.fast.metal_kernel`` supports an argument ``ensure_row_contiguous`` which is ``True`` by default. -This will copy the ``mx.array`` inputs if needed before the kernel is launched to ensure that the memory layout is row contiguous. -Generally this makes writing the kernel easier, since we don't have to worry about gaps or the ordering of the dims -when indexing. +:func:`fast.metal_kernel` supports an argument ``ensure_row_contiguous`` which +is ``True`` by default. This will copy the array inputs if needed +before the kernel is launched to ensure that the memory layout is row +contiguous. Generally this makes writing the kernel easier, since we don't +have to worry about gaps or the ordering of the dims when indexing. -If we want to avoid this copy, ``metal_kernel`` automatically passes ``a_shape``, ``a_strides`` and ``a_ndim`` for each -input array ``a`` if any are present in ``source``. -We can then use MLX's built in indexing utils to fetch the right elements for each thread. +If we want to avoid this copy, :func:`fast.metal_kernel` automatically passes +``a_shape``, ``a_strides`` and ``a_ndim`` for each input array ``a`` if any are +present in ``source``. We can then use MLX's built in indexing utils to fetch +the right elements for each thread. -Let's convert ``myexp`` above to support arbitrarily strided arrays without relying on a copy from ``ensure_row_contiguous``: +Let's convert ``myexp`` above to support arbitrarily strided arrays without +relying on a copy from ``ensure_row_contiguous``: .. code-block:: python + + source = """ + uint elem = thread_position_in_grid.x; + // Utils from `mlx/backend/metal/kernels/utils.h` are automatically included + uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim); + T tmp = inp[loc]; + // Output arrays are always row contiguous + out[elem] = metal::exp(tmp); + """ + + kernel = mx.fast.metal_kernel( + name="myexp_strided", + input_names=["inp"], + output_names=["out"], + source=source + ) def exp_elementwise(a: mx.array): - source = """ - uint elem = thread_position_in_grid.x; - // Utils from `mlx/backend/metal/kernels/utils.h` are automatically included - uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim); - T tmp = inp[loc]; - // Output arrays are always row contiguous - out[elem] = metal::exp(tmp); - """ - - kernel = mx.fast.metal_kernel( - name="myexp_strided", - input_names=["inp"], - output_names=["out"], - source=source - ) outputs = kernel( inputs=[a], template=[("T", mx.float32)], @@ -142,137 +157,139 @@ We'll start with the following MLX implementation using standard ops: .. code-block:: python - def grid_sample_ref(x, grid): - N, H_in, W_in, _ = x.shape - ix = ((grid[..., 0] + 1) * W_in - 1) / 2 - iy = ((grid[..., 1] + 1) * H_in - 1) / 2 + def grid_sample_ref(x, grid): + N, H_in, W_in, _ = x.shape + ix = ((grid[..., 0] + 1) * W_in - 1) / 2 + iy = ((grid[..., 1] + 1) * H_in - 1) / 2 - ix_nw = mx.floor(ix).astype(mx.int32) - iy_nw = mx.floor(iy).astype(mx.int32) + ix_nw = mx.floor(ix).astype(mx.int32) + iy_nw = mx.floor(iy).astype(mx.int32) - ix_ne = ix_nw + 1 - iy_ne = iy_nw + ix_ne = ix_nw + 1 + iy_ne = iy_nw - ix_sw = ix_nw - iy_sw = iy_nw + 1 + ix_sw = ix_nw + iy_sw = iy_nw + 1 - ix_se = ix_nw + 1 - iy_se = iy_nw + 1 + ix_se = ix_nw + 1 + iy_se = iy_nw + 1 - nw = (ix_se - ix) * (iy_se - iy) - ne = (ix - ix_sw) * (iy_sw - iy) - sw = (ix_ne - ix) * (iy - iy_ne) - se = (ix - ix_nw) * (iy - iy_nw) + nw = (ix_se - ix) * (iy_se - iy) + ne = (ix - ix_sw) * (iy_sw - iy) + sw = (ix_ne - ix) * (iy - iy_ne) + se = (ix - ix_nw) * (iy - iy_nw) - I_nw = x[mx.arange(N)[:, None, None], iy_nw, ix_nw, :] - I_ne = x[mx.arange(N)[:, None, None], iy_ne, ix_ne, :] - I_sw = x[mx.arange(N)[:, None, None], iy_sw, ix_sw, :] - I_se = x[mx.arange(N)[:, None, None], iy_se, ix_se, :] + I_nw = x[mx.arange(N)[:, None, None], iy_nw, ix_nw, :] + I_ne = x[mx.arange(N)[:, None, None], iy_ne, ix_ne, :] + I_sw = x[mx.arange(N)[:, None, None], iy_sw, ix_sw, :] + I_se = x[mx.arange(N)[:, None, None], iy_se, ix_se, :] - mask_nw = (iy_nw >= 0) & (iy_nw <= H_in - 1) & (ix_nw >= 0) & (ix_nw <= W_in - 1) - mask_ne = (iy_ne >= 0) & (iy_ne <= H_in - 1) & (ix_ne >= 0) & (ix_ne <= W_in - 1) - mask_sw = (iy_sw >= 0) & (iy_sw <= H_in - 1) & (ix_sw >= 0) & (ix_sw <= W_in - 1) - mask_se = (iy_se >= 0) & (iy_se <= H_in - 1) & (ix_se >= 0) & (ix_se <= W_in - 1) + mask_nw = (iy_nw >= 0) & (iy_nw <= H_in - 1) & (ix_nw >= 0) & (ix_nw <= W_in - 1) + mask_ne = (iy_ne >= 0) & (iy_ne <= H_in - 1) & (ix_ne >= 0) & (ix_ne <= W_in - 1) + mask_sw = (iy_sw >= 0) & (iy_sw <= H_in - 1) & (ix_sw >= 0) & (ix_sw <= W_in - 1) + mask_se = (iy_se >= 0) & (iy_se <= H_in - 1) & (ix_se >= 0) & (ix_se <= W_in - 1) - I_nw *= mask_nw[..., None] - I_ne *= mask_ne[..., None] - I_sw *= mask_sw[..., None] - I_se *= mask_se[..., None] + I_nw *= mask_nw[..., None] + I_ne *= mask_ne[..., None] + I_sw *= mask_sw[..., None] + I_se *= mask_se[..., None] - output = nw[..., None] * I_nw + ne[..., None] * I_ne + sw[..., None] * I_sw + se[..., None] * I_se + output = nw[..., None] * I_nw + ne[..., None] * I_ne + sw[..., None] * I_sw + se[..., None] * I_se - return output + return output -Now let's use ``mx.custom_function`` together with ``mx.fast.metal_kernel`` +Now let's use :func:`custom_function` together with :func:`fast.metal_kernel` to write a fast GPU kernel for both the forward and backward passes. First we'll implement the forward pass as a fused kernel: .. code-block:: python - @mx.custom_function - def grid_sample(x, grid): + source = """ + uint elem = thread_position_in_grid.x; + int H = x_shape[1]; + int W = x_shape[2]; + int C = x_shape[3]; + int gH = grid_shape[1]; + int gW = grid_shape[2]; - assert x.ndim == 4, "`x` must be 4D." - assert grid.ndim == 4, "`grid` must be 4D." + int w_stride = C; + int h_stride = W * w_stride; + int b_stride = H * h_stride; - B, _, _, C = x.shape - _, gN, gM, D = grid.shape - out_shape = (B, gN, gM, C) + uint grid_idx = elem / C * 2; + float ix = ((grid[grid_idx] + 1) * W - 1) / 2; + float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2; - assert D == 2, "Last dim of `grid` must be size 2." + int ix_nw = floor(ix); + int iy_nw = floor(iy); - source = """ - uint elem = thread_position_in_grid.x; - int H = x_shape[1]; - int W = x_shape[2]; - int C = x_shape[3]; - int gH = grid_shape[1]; - int gW = grid_shape[2]; + int ix_ne = ix_nw + 1; + int iy_ne = iy_nw; - int w_stride = C; - int h_stride = W * w_stride; - int b_stride = H * h_stride; + int ix_sw = ix_nw; + int iy_sw = iy_nw + 1; - uint grid_idx = elem / C * 2; - float ix = ((grid[grid_idx] + 1) * W - 1) / 2; - float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2; + int ix_se = ix_nw + 1; + int iy_se = iy_nw + 1; - int ix_nw = floor(ix); - int iy_nw = floor(iy); + T nw = (ix_se - ix) * (iy_se - iy); + T ne = (ix - ix_sw) * (iy_sw - iy); + T sw = (ix_ne - ix) * (iy - iy_ne); + T se = (ix - ix_nw) * (iy - iy_nw); - int ix_ne = ix_nw + 1; - int iy_ne = iy_nw; + int batch_idx = elem / C / gH / gW * b_stride; + int channel_idx = elem % C; + int base_idx = batch_idx + channel_idx; - int ix_sw = ix_nw; - int iy_sw = iy_nw + 1; + T I_nw = x[base_idx + iy_nw * h_stride + ix_nw * w_stride]; + T I_ne = x[base_idx + iy_ne * h_stride + ix_ne * w_stride]; + T I_sw = x[base_idx + iy_sw * h_stride + ix_sw * w_stride]; + T I_se = x[base_idx + iy_se * h_stride + ix_se * w_stride]; - int ix_se = ix_nw + 1; - int iy_se = iy_nw + 1; + I_nw = iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1 ? I_nw : 0; + I_ne = iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1 ? I_ne : 0; + I_sw = iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1 ? I_sw : 0; + I_se = iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1 ? I_se : 0; - T nw = (ix_se - ix) * (iy_se - iy); - T ne = (ix - ix_sw) * (iy_sw - iy); - T sw = (ix_ne - ix) * (iy - iy_ne); - T se = (ix - ix_nw) * (iy - iy_nw); + out[elem] = nw * I_nw + ne * I_ne + sw * I_sw + se * I_se; + """ - int batch_idx = elem / C / gH / gW * b_stride; - int channel_idx = elem % C; - int base_idx = batch_idx + channel_idx; + kernel = mx.fast.metal_kernel( + name="grid_sample", + input_names=["x", "grid"], + output_names=["out"], + source=source, + ) - T I_nw = x[base_idx + iy_nw * h_stride + ix_nw * w_stride]; - T I_ne = x[base_idx + iy_ne * h_stride + ix_ne * w_stride]; - T I_sw = x[base_idx + iy_sw * h_stride + ix_sw * w_stride]; - T I_se = x[base_idx + iy_se * h_stride + ix_se * w_stride]; + @mx.custom_function + def grid_sample(x, grid): - I_nw = iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1 ? I_nw : 0; - I_ne = iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1 ? I_ne : 0; - I_sw = iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1 ? I_sw : 0; - I_se = iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1 ? I_se : 0; + assert x.ndim == 4, "`x` must be 4D." + assert grid.ndim == 4, "`grid` must be 4D." - out[elem] = nw * I_nw + ne * I_ne + sw * I_sw + se * I_se; - """ - kernel = mx.fast.metal_kernel( - name="grid_sample", - input_names=["x", "grid"], - output_names=["out"], - source=source, - ) - outputs = kernel( - inputs=[x, grid], - template=[("T", x.dtype)], - output_shapes=[out_shape], - output_dtypes=[x.dtype], - grid=(np.prod(out_shape), 1, 1), - threadgroup=(256, 1, 1), - ) - return outputs[0] + B, _, _, C = x.shape + _, gN, gM, D = grid.shape + out_shape = (B, gN, gM, C) + + assert D == 2, "Last dim of `grid` must be size 2." + + outputs = kernel( + inputs=[x, grid], + template=[("T", x.dtype)], + output_shapes=[out_shape], + output_dtypes=[x.dtype], + grid=(np.prod(out_shape), 1, 1), + threadgroup=(256, 1, 1), + ) + return outputs[0] For a reasonably sized input such as: .. code-block:: python - x.shape = (8, 1024, 1024, 64) - grid.shape = (8, 256, 256, 2) + x.shape = (8, 1024, 1024, 64) + grid.shape = (8, 256, 256, 2) On an M1 Max, we see a big performance improvement: @@ -281,11 +298,11 @@ On an M1 Max, we see a big performance improvement: Grid Sample VJP --------------- -Since we decorated ``grid_sample`` with ``mx.custom_function``, we can now define -its custom vjp transform so MLX can differentiate it. +Since we decorated ``grid_sample`` with :func:`custom_function`, we can now +define its custom vjp transform so MLX can differentiate it. The backwards pass requires atomically updating ``x_grad``/``grid_grad`` and so -requires a few extra ``mx.fast.metal_kernel`` features: +requires a few extra :func:`fast.metal_kernel` features: * ``init_value=0`` Initialize all of the kernel's outputs to this value before it runs. This allows us to update only part of the output arrays with the kernel. @@ -299,128 +316,129 @@ We can then implement the backwards pass as follows: .. code-block:: python - @grid_sample.vjp - def grid_sample_vjp(primals, cotangent, _): - x, grid = primals - B, _, _, C = x.shape - _, gN, gM, D = grid.shape + source = """ + uint elem = thread_position_in_grid.x; + int H = x_shape[1]; + int W = x_shape[2]; + int C = x_shape[3]; + // Pad C to the nearest larger simdgroup size multiple + int C_padded = ceildiv(C, threads_per_simdgroup) * threads_per_simdgroup; - assert D == 2, "Last dim of `grid` must be size 2." + int gH = grid_shape[1]; + int gW = grid_shape[2]; - source = """ - uint elem = thread_position_in_grid.x; - int H = x_shape[1]; - int W = x_shape[2]; - int C = x_shape[3]; - // Pad C to the nearest larger simdgroup size multiple - int C_padded = ceildiv(C, threads_per_simdgroup) * threads_per_simdgroup; + int w_stride = C; + int h_stride = W * w_stride; + int b_stride = H * h_stride; - int gH = grid_shape[1]; - int gW = grid_shape[2]; + uint grid_idx = elem / C_padded * 2; + float ix = ((grid[grid_idx] + 1) * W - 1) / 2; + float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2; - int w_stride = C; - int h_stride = W * w_stride; - int b_stride = H * h_stride; + int ix_nw = floor(ix); + int iy_nw = floor(iy); - uint grid_idx = elem / C_padded * 2; - float ix = ((grid[grid_idx] + 1) * W - 1) / 2; - float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2; + int ix_ne = ix_nw + 1; + int iy_ne = iy_nw; - int ix_nw = floor(ix); - int iy_nw = floor(iy); + int ix_sw = ix_nw; + int iy_sw = iy_nw + 1; - int ix_ne = ix_nw + 1; - int iy_ne = iy_nw; + int ix_se = ix_nw + 1; + int iy_se = iy_nw + 1; - int ix_sw = ix_nw; - int iy_sw = iy_nw + 1; + T nw = (ix_se - ix) * (iy_se - iy); + T ne = (ix - ix_sw) * (iy_sw - iy); + T sw = (ix_ne - ix) * (iy - iy_ne); + T se = (ix - ix_nw) * (iy - iy_nw); - int ix_se = ix_nw + 1; - int iy_se = iy_nw + 1; + int batch_idx = elem / C_padded / gH / gW * b_stride; + int channel_idx = elem % C_padded; + int base_idx = batch_idx + channel_idx; - T nw = (ix_se - ix) * (iy_se - iy); - T ne = (ix - ix_sw) * (iy_sw - iy); - T sw = (ix_ne - ix) * (iy - iy_ne); - T se = (ix - ix_nw) * (iy - iy_nw); + T gix = T(0); + T giy = T(0); + if (channel_idx < C) { + int cot_index = elem / C_padded * C + channel_idx; + T cot = cotangent[cot_index]; + if (iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1) { + int offset = base_idx + iy_nw * h_stride + ix_nw * w_stride; + atomic_fetch_add_explicit(&x_grad[offset], nw * cot, memory_order_relaxed); - int batch_idx = elem / C_padded / gH / gW * b_stride; - int channel_idx = elem % C_padded; - int base_idx = batch_idx + channel_idx; + T I_nw = x[offset]; + gix -= I_nw * (iy_se - iy) * cot; + giy -= I_nw * (ix_se - ix) * cot; + } + if (iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1) { + int offset = base_idx + iy_ne * h_stride + ix_ne * w_stride; + atomic_fetch_add_explicit(&x_grad[offset], ne * cot, memory_order_relaxed); - T gix = T(0); - T giy = T(0); - if (channel_idx < C) { - int cot_index = elem / C_padded * C + channel_idx; - T cot = cotangent[cot_index]; - if (iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1) { - int offset = base_idx + iy_nw * h_stride + ix_nw * w_stride; - atomic_fetch_add_explicit(&x_grad[offset], nw * cot, memory_order_relaxed); + T I_ne = x[offset]; + gix += I_ne * (iy_sw - iy) * cot; + giy -= I_ne * (ix - ix_sw) * cot; + } + if (iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1) { + int offset = base_idx + iy_sw * h_stride + ix_sw * w_stride; + atomic_fetch_add_explicit(&x_grad[offset], sw * cot, memory_order_relaxed); - T I_nw = x[offset]; - gix -= I_nw * (iy_se - iy) * cot; - giy -= I_nw * (ix_se - ix) * cot; - } - if (iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1) { - int offset = base_idx + iy_ne * h_stride + ix_ne * w_stride; - atomic_fetch_add_explicit(&x_grad[offset], ne * cot, memory_order_relaxed); + T I_sw = x[offset]; + gix -= I_sw * (iy - iy_ne) * cot; + giy += I_sw * (ix_ne - ix) * cot; + } + if (iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1) { + int offset = base_idx + iy_se * h_stride + ix_se * w_stride; + atomic_fetch_add_explicit(&x_grad[offset], se * cot, memory_order_relaxed); - T I_ne = x[offset]; - gix += I_ne * (iy_sw - iy) * cot; - giy -= I_ne * (ix - ix_sw) * cot; - } - if (iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1) { - int offset = base_idx + iy_sw * h_stride + ix_sw * w_stride; - atomic_fetch_add_explicit(&x_grad[offset], sw * cot, memory_order_relaxed); + T I_se = x[offset]; + gix += I_se * (iy - iy_nw) * cot; + giy += I_se * (ix - ix_nw) * cot; + } + } - T I_sw = x[offset]; - gix -= I_sw * (iy - iy_ne) * cot; - giy += I_sw * (ix_ne - ix) * cot; - } - if (iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1) { - int offset = base_idx + iy_se * h_stride + ix_se * w_stride; - atomic_fetch_add_explicit(&x_grad[offset], se * cot, memory_order_relaxed); + T gix_mult = W / 2; + T giy_mult = H / 2; - T I_se = x[offset]; - gix += I_se * (iy - iy_nw) * cot; - giy += I_se * (ix - ix_nw) * cot; - } - } + // Reduce across each simdgroup first. + // This is much faster than relying purely on atomics. + gix = simd_sum(gix); + giy = simd_sum(giy); - T gix_mult = W / 2; - T giy_mult = H / 2; + if (thread_index_in_simdgroup == 0) { + atomic_fetch_add_explicit(&grid_grad[grid_idx], gix * gix_mult, memory_order_relaxed); + atomic_fetch_add_explicit(&grid_grad[grid_idx + 1], giy * giy_mult, memory_order_relaxed); + } + """ + kernel = mx.fast.metal_kernel( + name="grid_sample_grad", + input_names=["x", "grid", "cotangent"], + output_names=["x_grad", "grid_grad"], + source=source, + atomic_outputs=True, + ) - // Reduce across each simdgroup first. - // This is much faster than relying purely on atomics. - gix = simd_sum(gix); - giy = simd_sum(giy); + @grid_sample.vjp + def grid_sample_vjp(primals, cotangent, _): + x, grid = primals + B, _, _, C = x.shape + _, gN, gM, D = grid.shape - if (thread_index_in_simdgroup == 0) { - atomic_fetch_add_explicit(&grid_grad[grid_idx], gix * gix_mult, memory_order_relaxed); - atomic_fetch_add_explicit(&grid_grad[grid_idx + 1], giy * giy_mult, memory_order_relaxed); - } - """ - kernel = mx.fast.metal_kernel( - name="grid_sample_grad", - input_names=["x", "grid", "cotangent"], - output_names=["x_grad", "grid_grad"], - source=source, - atomic_outputs=True, - ) - # pad the output channels to simd group size - # so that our `simd_sum`s don't overlap. - simdgroup_size = 32 - C_padded = (C + simdgroup_size - 1) // simdgroup_size * simdgroup_size - grid_size = B * gN * gM * C_padded - outputs = kernel( - inputs=[x, grid, cotangent], - template=[("T", x.dtype)], - output_shapes=[x.shape, grid.shape], - output_dtypes=[x.dtype, x.dtype], - grid=(grid_size, 1, 1), - threadgroup=(256, 1, 1), - init_value=0, - ) - return outputs[0], outputs[1] + assert D == 2, "Last dim of `grid` must be size 2." + + # pad the output channels to simd group size + # so that our `simd_sum`s don't overlap. + simdgroup_size = 32 + C_padded = (C + simdgroup_size - 1) // simdgroup_size * simdgroup_size + grid_size = B * gN * gM * C_padded + outputs = kernel( + inputs=[x, grid, cotangent], + template=[("T", x.dtype)], + output_shapes=[x.shape, grid.shape], + output_dtypes=[x.dtype, x.dtype], + grid=(grid_size, 1, 1), + threadgroup=(256, 1, 1), + init_value=0, + ) + return outputs[0], outputs[1] There's an even larger speed up for the vjp: diff --git a/docs/src/dev/extensions.rst b/docs/src/dev/extensions.rst index 2aef28f99..5a4de8123 100644 --- a/docs/src/dev/extensions.rst +++ b/docs/src/dev/extensions.rst @@ -138,13 +138,13 @@ more concrete: * representing the vectorized computation and the axis which * corresponds to the output vectorized dimension. */ - virtual std::pair, std::vector> vmap( + std::pair, std::vector> vmap( const std::vector& inputs, const std::vector& axes) override; - /** Print the primitive. */ - void print(std::ostream& os) override { - os << "Axpby"; + /** The name of primitive. */ + const char* name() const override { + return "Axpby"; } /** Equivalence check **/ @@ -397,11 +397,11 @@ below. std::ostringstream kname; kname << "axpby_" << "general_" << type_to_name(out); - // Make sure the metal library is available - d.register_library("mlx_ext"); + // Load the metal library + auto lib = d.get_library("mlx_ext"); // Make a kernel from this metal library - auto kernel = d.get_kernel(kname.str(), "mlx_ext"); + auto kernel = d.get_kernel(kname.str(), lib); // Prepare to encode kernel auto& compute_encoder = d.get_command_encoder(s.index); diff --git a/docs/src/install.rst b/docs/src/install.rst index 059b2cba4..70491ac64 100644 --- a/docs/src/install.rst +++ b/docs/src/install.rst @@ -23,13 +23,24 @@ To install from PyPI you must meet the following requirements: MLX is only available on devices running macOS >= 13.5 It is highly recommended to use macOS 14 (Sonoma) +CUDA +^^^^ -MLX is also available on conda-forge. To install MLX with conda do: +MLX has a CUDA backend which you can use on any Linux platform with CUDA 12 +and SM 7.0 (Volta) and up. To install MLX with CUDA support, run: .. code-block:: shell - conda install conda-forge::mlx + pip install "mlx[cuda]" +CPU-only (Linux) +^^^^^^^^^^^^^^^^ + +For a CPU-only version of MLX that runs on Linux use: + +.. code-block:: shell + + pip install "mlx[cpu]" Troubleshooting ^^^^^^^^^^^^^^^ @@ -65,6 +76,8 @@ Build Requirements Python API ^^^^^^^^^^ +.. _python install: + To build and install the MLX python library from source, first, clone MLX from `its GitHub repo `_: @@ -76,20 +89,20 @@ Then simply build and install MLX using pip: .. code-block:: shell - CMAKE_BUILD_PARALLEL_LEVEL=8 pip install . + pip install . For developing, install the package with development dependencies, and use an editable install: .. code-block:: shell - CMAKE_BUILD_PARALLEL_LEVEL=8 pip install -e ".[dev]" + pip install -e ".[dev]" Once the development dependencies are installed, you can build faster with: .. code-block:: shell - CMAKE_BUILD_PARALLEL_LEVEL=8 python setup.py build_ext --inplace + python setup.py build_ext --inplace Run the tests with: @@ -107,6 +120,8 @@ IDE: C++ API ^^^^^^^ +.. _cpp install: + Currently, MLX must be built and installed from source. Similarly to the python library, to build and install the MLX C++ library start @@ -185,6 +200,7 @@ should point to the path to the built metal library. xcrun -sdk macosx --show-sdk-version + Binary Size Minimization ~~~~~~~~~~~~~~~~~~~~~~~~ @@ -213,6 +229,50 @@ be anwywhere from a few hundred millisecond to a few seconds depending on the application. Once a kernel is compiled, it will be cached by the system. The Metal kernel cache persists across reboots. +Linux +^^^^^ + +To build from source on Linux (CPU only), install the BLAS and LAPACK headers. +For example on Ubuntu, run the following: + +.. code-block:: shell + + apt-get update -y + apt-get install libblas-dev liblapack-dev liblapacke-dev -y + +From here follow the instructions to install either the :ref:`Python ` or :ref:`C++ ` APIs. + +CUDA +^^^^ + +To build from source on Linux with CUDA, install the BLAS and LAPACK headers +and the CUDA toolkit. For example on Ubuntu, run the following: + +.. code-block:: shell + + wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb + dpkg -i cuda-keyring_1.1-1_all.deb + apt-get update -y + apt-get -y install cuda-toolkit-12-9 + apt-get install libblas-dev liblapack-dev liblapacke-dev -y + + +When building either the Python or C++ APIs make sure to pass the cmake flag +``MLX_BUILD_CUDA=ON``. For example, to build the Python API run: + +.. code-block:: shell + + CMAKE_ARGS="-DMLX_BUILD_CUDA=ON" pip install -e ".[dev]" + +To build the C++ package run: + +.. code-block:: shell + + mkdir -p build && cd build + cmake .. -DMLX_BUILD_CUDA=ON && make -j + + Troubleshooting ^^^^^^^^^^^^^^^ diff --git a/docs/src/python/array.rst b/docs/src/python/array.rst index 7e1c3339d..e68524d5a 100644 --- a/docs/src/python/array.rst +++ b/docs/src/python/array.rst @@ -19,6 +19,8 @@ Array array.ndim array.shape array.size + array.real + array.imag array.abs array.all array.any diff --git a/docs/src/python/fft.rst b/docs/src/python/fft.rst index 9e4be084b..36d9d7838 100644 --- a/docs/src/python/fft.rst +++ b/docs/src/python/fft.rst @@ -20,3 +20,5 @@ FFT irfft2 rfftn irfftn + fftshift + ifftshift diff --git a/docs/src/python/linalg.rst b/docs/src/python/linalg.rst index b01f74117..495380c46 100644 --- a/docs/src/python/linalg.rst +++ b/docs/src/python/linalg.rst @@ -16,6 +16,8 @@ Linear Algebra cross qr svd + eigvals + eig eigvalsh eigh lu diff --git a/docs/src/usage/indexing.rst b/docs/src/usage/indexing.rst index c74e357fa..dcbc84c1b 100644 --- a/docs/src/usage/indexing.rst +++ b/docs/src/usage/indexing.rst @@ -107,6 +107,16 @@ same array: >>> a array([1, 2, 0], dtype=int32) + +Note, unlike NumPy, updates to the same location are nondeterministic: + +.. code-block:: shell + + >>> a = mx.array([1, 2, 3]) + >>> a[[0, 0]] = mx.array([4, 5]) + +The first element of ``a`` could be ``4`` or ``5``. + Transformations of functions which use in-place updates are allowed and work as expected. For example: diff --git a/examples/extensions/axpby/axpby.cpp b/examples/extensions/axpby/axpby.cpp index 291246617..9ba933483 100644 --- a/examples/extensions/axpby/axpby.cpp +++ b/examples/extensions/axpby/axpby.cpp @@ -172,11 +172,11 @@ void Axpby::eval_gpu( kname << (contiguous_kernel ? "contiguous_" : "general_"); kname << type_to_name(out); - // Make sure the metal library is available - d.register_library("mlx_ext"); + // Load the metal library + auto lib = d.get_library("mlx_ext"); // Make a kernel from this metal library - auto kernel = d.get_kernel(kname.str(), "mlx_ext"); + auto kernel = d.get_kernel(kname.str(), lib); // Prepare to encode kernel auto& compute_encoder = d.get_command_encoder(s.index); diff --git a/examples/extensions/axpby/axpby.h b/examples/extensions/axpby/axpby.h index 26f80961c..e6da491f8 100644 --- a/examples/extensions/axpby/axpby.h +++ b/examples/extensions/axpby/axpby.h @@ -74,9 +74,9 @@ class Axpby : public mx::Primitive { const std::vector& inputs, const std::vector& axes) override; - /** Print the primitive. */ - void print(std::ostream& os) override { - os << "Axpby"; + /** The name of primitive. */ + const char* name() const override { + return "Axpby"; } /** Equivalence check **/ diff --git a/mlx/CMakeLists.txt b/mlx/CMakeLists.txt index abf46a7d5..7aa648533 100644 --- a/mlx/CMakeLists.txt +++ b/mlx/CMakeLists.txt @@ -21,7 +21,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/metal.h) # Define MLX_VERSION only in the version.cpp file. -add_library(mlx_version STATIC ${CMAKE_CURRENT_SOURCE_DIR}/version.cpp) +add_library(mlx_version OBJECT ${CMAKE_CURRENT_SOURCE_DIR}/version.cpp) target_compile_definitions(mlx_version PRIVATE MLX_VERSION="${MLX_VERSION}") target_link_libraries(mlx PRIVATE $) @@ -49,5 +49,19 @@ add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/io) if(MLX_BUILD_METAL) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/metal) else() - add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_metal) + target_sources(mlx + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/no_metal.cpp) +endif() + +if(MLX_BUILD_CUDA) + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/cuda) +else() + target_sources(mlx + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/backend/cuda/no_cuda.cpp) +endif() + +if(MLX_BUILD_METAL OR MLX_BUILD_CUDA) + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/gpu) +else() + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_gpu) endif() diff --git a/mlx/array.h b/mlx/array.h index 66a4702a6..98eef2e33 100644 --- a/mlx/array.h +++ b/mlx/array.h @@ -224,6 +224,10 @@ class array { // Not copyable Data(const Data& d) = delete; Data& operator=(const Data& d) = delete; + Data(Data&& o) : buffer(o.buffer), d(o.d) { + o.buffer = allocator::Buffer(nullptr); + o.d = [](allocator::Buffer) {}; + } ~Data() { d(buffer); } @@ -356,7 +360,7 @@ class array { } enum Status { - // The ouptut of a computation which has not been scheduled. + // The output of a computation which has not been scheduled. // For example, the status of `x` in `auto x = a + b`. unscheduled, diff --git a/mlx/backend/common/buffer_cache.h b/mlx/backend/common/buffer_cache.h new file mode 100644 index 000000000..92b20f222 --- /dev/null +++ b/mlx/backend/common/buffer_cache.h @@ -0,0 +1,157 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include +#include +#include + +namespace mlx::core { + +template +class BufferCache { + public: + BufferCache( + size_t page_size, + std::function get_size, + std::function free) + : page_size_(page_size), + get_size_(std::move(get_size)), + free_(std::move(free)) {} + + ~BufferCache() { + clear(); + } + + BufferCache(const BufferCache&) = delete; + BufferCache& operator=(const BufferCache&) = delete; + + T* reuse_from_cache(size_t size) { + // Find the closest buffer in pool. + auto it = buffer_pool_.lower_bound(size); + if (it == buffer_pool_.end() || + it->first >= std::min(2 * size, size + 2 * page_size_)) { + return nullptr; + } + + // Collect from the cache. + T* buf = it->second->buf; + pool_size_ -= it->first; + + // Remove from record. + remove_from_list(it->second); + buffer_pool_.erase(it); + return buf; + } + + void recycle_to_cache(T* buf) { + assert(buf); + // Add to cache. + BufferHolder* bh = new BufferHolder(buf); + add_at_head(bh); + size_t size = get_size_(buf); + pool_size_ += size; + buffer_pool_.emplace(size, bh); + } + + int release_cached_buffers(size_t min_bytes_to_free) { + if (min_bytes_to_free >= 0.9 * pool_size_) { + return clear(); + } else { + int n_release = 0; + size_t total_bytes_freed = 0; + + while (tail_ && (total_bytes_freed < min_bytes_to_free)) { + // Release buffer. + size_t size = get_size_(tail_->buf); + total_bytes_freed += size; + free_(tail_->buf); + n_release++; + + // Remove from record. + auto its = buffer_pool_.equal_range(size); + auto it = std::find_if(its.first, its.second, [this](const auto& el) { + return el.second == tail_; + }); + assert(it != buffer_pool_.end()); + buffer_pool_.erase(it); + remove_from_list(tail_); + } + + pool_size_ -= total_bytes_freed; + return n_release; + } + } + + int clear() { + int n_release = 0; + for (auto& [size, holder] : buffer_pool_) { + free_(holder->buf); + n_release++; + delete holder; + } + buffer_pool_.clear(); + pool_size_ = 0; + head_ = nullptr; + tail_ = nullptr; + return n_release; + } + + size_t cache_size() const { + return pool_size_; + } + + size_t page_size() const { + return page_size_; + } + + private: + struct BufferHolder { + public: + explicit BufferHolder(T* buf_) : buf(buf_) {} + + BufferHolder* prev{nullptr}; + BufferHolder* next{nullptr}; + T* buf; + }; + + void add_at_head(BufferHolder* to_add) { + if (!head_) { + head_ = to_add; + tail_ = to_add; + } else { + head_->prev = to_add; + to_add->next = head_; + head_ = to_add; + } + } + + void remove_from_list(BufferHolder* to_remove) { + if (to_remove->prev && to_remove->next) { // if middle + to_remove->prev->next = to_remove->next; + to_remove->next->prev = to_remove->prev; + } else if (to_remove->prev && to_remove == tail_) { // if tail + tail_ = to_remove->prev; + tail_->next = nullptr; + } else if (to_remove == head_ && to_remove->next) { // if head + head_ = to_remove->next; + head_->prev = nullptr; + } else if (to_remove == head_ && to_remove == tail_) { // if only element + head_ = nullptr; + tail_ = nullptr; + } + + delete to_remove; + } + + std::multimap buffer_pool_; + BufferHolder* head_{nullptr}; + BufferHolder* tail_{nullptr}; + size_t pool_size_{0}; + + const size_t page_size_; + std::function get_size_; + std::function free_; +}; + +} // namespace mlx::core diff --git a/mlx/backend/common/compiled.cpp b/mlx/backend/common/compiled.cpp index f7b5598ab..44e2a432b 100644 --- a/mlx/backend/common/compiled.cpp +++ b/mlx/backend/common/compiled.cpp @@ -1,8 +1,7 @@ // Copyright © 2023-2024 Apple Inc. #include "mlx/backend/common/compiled.h" -#include "mlx/graph_utils.h" -#include "mlx/primitives.h" +#include "mlx/backend/common/utils.h" #include "mlx/utils.h" namespace mlx::core { @@ -15,6 +14,8 @@ void print_constant(std::ostream& os, const array& x) { return print_float_constant(os, x); case bfloat16: return print_float_constant(os, x); + case float64: + return print_float_constant(os, x); case complex64: return print_complex_constant(os, x); case int8: @@ -51,6 +52,8 @@ std::string get_type_string(Dtype d) { return "float16_t"; case bfloat16: return "bfloat16_t"; + case float64: + return "double"; case complex64: return "complex64_t"; case bool_: @@ -79,55 +82,6 @@ std::string get_type_string(Dtype d) { } } -std::string build_lib_name( - const std::vector& inputs, - const std::vector& outputs, - const std::vector& tape, - const std::unordered_set& constant_ids) { - NodeNamer namer; - std::ostringstream os; - std::ostringstream constant_hasher; - - // Fill the input names. This is not really necessary, I just like having A, - // B, C, ... as the inputs. - for (auto& x : inputs) { - namer.get_name(x); - } - - // The primitives describing the tape. For unary and binary primitives this - // must be enough to describe the full computation. - for (auto& a : tape) { - // name and type of output - os << namer.get_name(a) << kindof(a.dtype()) << a.itemsize(); - // computation performed - a.primitive().print(os); - // name of inputs to the function - for (auto& inp : a.inputs()) { - os << namer.get_name(inp); - } - } - os << "_"; - - for (auto& x : inputs) { - if (constant_ids.find(x.id()) != constant_ids.end()) { - os << "C"; - print_constant(constant_hasher, x); - } else { - os << (is_scalar(x) ? "S" : "V"); - } - } - os << "_"; - for (auto& x : inputs) { - if (constant_ids.find(x.id()) != constant_ids.end()) { - continue; - } - os << kindof(x.dtype()) << x.itemsize(); - } - os << "_" << std::hash{}(constant_hasher.str()); - - return os.str(); -} - bool compiled_check_contiguity( const std::vector& inputs, const Shape& shape) { @@ -159,8 +113,7 @@ bool compiled_check_contiguity( void compiled_allocate_outputs( const std::vector& inputs, std::vector& outputs, - const std::vector& inputs_, - const std::unordered_set& constant_ids_, + const std::function& is_constant, bool contiguous) { if (contiguous) { int o = 0; @@ -175,8 +128,7 @@ void compiled_allocate_outputs( // - Donatable // - Not a constant if (in.itemsize() == outputs[o].itemsize() && !is_scalar(in) && - in.is_donatable() && - constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) { + in.is_donatable() && is_constant(i)) { outputs[o++].copy_shared_buffer(in); } // Get representative input flags to properly set non-donated outputs @@ -204,7 +156,7 @@ void compiled_allocate_outputs( // - Not a constant if (in.flags().row_contiguous && in.size() == outputs[o].size() && in.itemsize() == outputs[o].itemsize() && in.is_donatable() && - constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) { + is_constant(i)) { outputs[o].copy_shared_buffer( in, outputs[o].strides(), in.flags(), in.data_size()); o++; @@ -216,4 +168,74 @@ void compiled_allocate_outputs( } } +std::tuple> compiled_collapse_contiguous_dims( + const std::vector& inputs, + const array& out, + const std::function& is_constant) { + const Shape& shape = out.shape(); + bool contiguous = compiled_check_contiguity(inputs, shape); + if (contiguous) { + return {true, shape, {}}; + } + + std::vector strides_vec{out.strides()}; + for (size_t i = 0; i < inputs.size(); ++i) { + // Skip constants. + if (is_constant(i)) { + continue; + } + + // Skip scalar inputs. + const auto& x = inputs[i]; + if (is_scalar(x)) { + continue; + } + + // Broadcast the inputs to the output shape. + Strides xstrides; + size_t j = 0; + for (; j < shape.size() - x.ndim(); ++j) { + if (shape[j] == 1) { + xstrides.push_back(out.strides()[j]); + } else { + xstrides.push_back(0); + } + } + for (size_t i = 0; i < x.ndim(); ++i, ++j) { + if (x.shape(i) == 1) { + if (shape[j] == 1) { + xstrides.push_back(out.strides()[j]); + } else { + xstrides.push_back(0); + } + } else { + xstrides.push_back(x.strides()[i]); + } + } + strides_vec.push_back(std::move(xstrides)); + } + + auto tup = collapse_contiguous_dims(shape, strides_vec, INT32_MAX); + return {false, std::move(std::get<0>(tup)), std::move(std::get<1>(tup))}; +} + +bool compiled_use_large_index( + const std::vector& inputs, + const std::vector& outputs, + bool contiguous) { + if (contiguous) { + size_t max_size = 0; + for (const auto& in : inputs) { + max_size = std::max(max_size, in.data_size()); + } + return max_size > UINT32_MAX; + } else { + size_t max_size = 0; + for (const auto& o : outputs) { + max_size = std::max(max_size, o.size()); + } + return max_size > UINT32_MAX; + } +} + } // namespace mlx::core diff --git a/mlx/backend/common/compiled.h b/mlx/backend/common/compiled.h index f4d28d6ab..e92a6d0ad 100644 --- a/mlx/backend/common/compiled.h +++ b/mlx/backend/common/compiled.h @@ -1,9 +1,8 @@ // Copyright © 2023-2024 Apple Inc. #pragma once +#include #include -#include -#include #include "mlx/array.h" #include "mlx/primitives.h" @@ -14,19 +13,17 @@ inline bool is_static_cast(const Primitive& p) { return (typeid(p) == typeid(Broadcast) || typeid(p) == typeid(AsType)); } -std::string build_lib_name( - const std::vector& inputs, - const std::vector& outputs, - const std::vector& tape, - const std::unordered_set& constant_ids); - std::string get_type_string(Dtype d); template void print_float_constant(std::ostream& os, const array& x) { auto old_precision = os.precision(); - os << std::setprecision(std::numeric_limits::digits10 + 1) - << x.item() << std::setprecision(old_precision); + if constexpr (std::is_same_v) { + os << std::setprecision(std::numeric_limits::digits10 + 1); + } else { + os << std::setprecision(std::numeric_limits::digits10 + 1); + } + os << x.item() << std::setprecision(old_precision); } template @@ -60,8 +57,19 @@ bool compiled_check_contiguity( void compiled_allocate_outputs( const std::vector& inputs, std::vector& outputs, - const std::vector& inputs_, - const std::unordered_set& constant_ids_, + const std::function& is_constant, + bool contiguous); + +// Collapse contiguous dims ignoring scalars and constants. +std::tuple> compiled_collapse_contiguous_dims( + const std::vector& inputs, + const array& out, + const std::function& is_constant); + +// Return whether the kernel should use large index. +bool compiled_use_large_index( + const std::vector& inputs, + const std::vector& outputs, bool contiguous); } // namespace mlx::core diff --git a/mlx/backend/common/copy.h b/mlx/backend/common/copy.h index 0c9f28c94..c23d2e79a 100644 --- a/mlx/backend/common/copy.h +++ b/mlx/backend/common/copy.h @@ -2,7 +2,7 @@ #pragma once -#include "mlx/array.h" +#include "mlx/backend/common/utils.h" namespace mlx::core { @@ -26,7 +26,7 @@ inline bool set_copy_output_data(const array& in, array& out, CopyType ctype) { if (ctype == CopyType::Vector) { // If the input is donateable, we are doing a vector copy and the types // have the same size, then the input buffer can hold the output. - if (in.is_donatable() && in.itemsize() == out.itemsize()) { + if (is_donatable(in, out)) { out.copy_shared_buffer(in); return true; } else { diff --git a/mlx/backend/common/hadamard.h b/mlx/backend/common/hadamard.h index a8fed76b0..ba5c4e41e 100644 --- a/mlx/backend/common/hadamard.h +++ b/mlx/backend/common/hadamard.h @@ -99,7 +99,11 @@ inline std::pair decompose_hadamard(int n) { "[hadamard] Only supports n = m*2^k where m in (1, 12, 20, 28)."); } } + if (n > (1 << 26)) { + throw std::invalid_argument( + "[hadamard] Only supports n = m*2^k where k <= 26"); + } return {n, m}; } -} // namespace mlx::core \ No newline at end of file +} // namespace mlx::core diff --git a/mlx/backend/common/matmul.h b/mlx/backend/common/matmul.h new file mode 100644 index 000000000..2faf256d1 --- /dev/null +++ b/mlx/backend/common/matmul.h @@ -0,0 +1,67 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/backend/common/utils.h" +#include "mlx/utils.h" + +#include + +namespace mlx::core { + +inline std::tuple collapse_batches( + const array& a, + const array& b) { + if (a.ndim() == 2) { + return {{1}, {0}, {0}}; + } + + Shape A_bshape{a.shape().begin(), a.shape().end() - 2}; + Strides A_bstride{a.strides().begin(), a.strides().end() - 2}; + Strides B_bstride{b.strides().begin(), b.strides().end() - 2}; + + auto [batch_shape, batch_strides] = + collapse_contiguous_dims(A_bshape, std::vector{A_bstride, B_bstride}); + + auto a_batch_strides = batch_strides[0]; + auto b_batch_strides = batch_strides[1]; + + if (batch_shape.empty()) { + batch_shape.push_back(1); + a_batch_strides.push_back(0); + b_batch_strides.push_back(0); + } + + return std::make_tuple(batch_shape, a_batch_strides, b_batch_strides); +} + +inline std::tuple +collapse_batches(const array& a, const array& b, const array& c) { + if (a.ndim() == 2) { + return {{1}, {0}, {0}, {0}}; + } + + Shape A_bshape{a.shape().begin(), a.shape().end() - 2}; + Strides A_bstride{a.strides().begin(), a.strides().end() - 2}; + Strides B_bstride{b.strides().begin(), b.strides().end() - 2}; + Strides C_bstride{c.strides().begin(), c.strides().end() - 2}; + + auto [batch_shape, batch_strides] = collapse_contiguous_dims( + A_bshape, std::vector{A_bstride, B_bstride, C_bstride}); + + auto A_batch_stride = batch_strides[0]; + auto B_batch_stride = batch_strides[1]; + auto C_batch_stride = batch_strides[2]; + + if (batch_shape.empty()) { + batch_shape.push_back(1); + A_batch_stride.push_back(0); + B_batch_stride.push_back(0); + C_batch_stride.push_back(0); + } + + return std::make_tuple( + batch_shape, A_batch_stride, B_batch_stride, C_batch_stride); +} + +} // namespace mlx::core diff --git a/mlx/backend/common/reduce.cpp b/mlx/backend/common/reduce.cpp index 5c7f63b75..ceef46400 100644 --- a/mlx/backend/common/reduce.cpp +++ b/mlx/backend/common/reduce.cpp @@ -5,11 +5,9 @@ namespace mlx::core { std::pair shapes_without_reduction_axes( - const array& x, + Shape shape, + Strides strides, const std::vector& axes) { - auto shape = x.shape(); - auto strides = x.strides(); - for (int i = axes.size() - 1; i >= 0; i--) { int a = axes[i]; shape.erase(shape.begin() + a); @@ -19,6 +17,15 @@ std::pair shapes_without_reduction_axes( return std::make_pair(shape, strides); } +std::pair shapes_without_reduction_axes( + const array& x, + const std::vector& axes) { + auto shape = x.shape(); + auto strides = x.strides(); + return shapes_without_reduction_axes( + std::move(shape), std::move(strides), axes); +} + ReductionPlan get_reduction_plan(const array& x, const std::vector& axes) { // The data is all there and we are reducing over everything if (x.size() == x.data_size() && axes.size() == x.ndim() && diff --git a/mlx/backend/common/reduce.h b/mlx/backend/common/reduce.h index ddb5c3492..8b24f4f53 100644 --- a/mlx/backend/common/reduce.h +++ b/mlx/backend/common/reduce.h @@ -51,5 +51,9 @@ ReductionPlan get_reduction_plan(const array& x, const std::vector& axes); std::pair shapes_without_reduction_axes( const array& x, const std::vector& axes); +std::pair shapes_without_reduction_axes( + Shape shape, + Strides strides, + const std::vector& axes); } // namespace mlx::core diff --git a/mlx/backend/common/unary.h b/mlx/backend/common/unary.h new file mode 100644 index 000000000..a27a1f45c --- /dev/null +++ b/mlx/backend/common/unary.h @@ -0,0 +1,26 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/allocator.h" +#include "mlx/backend/common/utils.h" + +namespace mlx::core { + +inline void set_unary_output_data(const array& in, array& out) { + if (in.flags().contiguous) { + if (is_donatable(in, out)) { + out.copy_shared_buffer(in); + } else { + out.set_data( + allocator::malloc(in.data_size() * out.itemsize()), + in.data_size(), + in.strides(), + in.flags()); + } + } else { + out.set_data(allocator::malloc(out.nbytes())); + } +} + +} // namespace mlx::core diff --git a/mlx/backend/common/utils.cpp b/mlx/backend/common/utils.cpp index 35bba9c63..ae169e35e 100644 --- a/mlx/backend/common/utils.cpp +++ b/mlx/backend/common/utils.cpp @@ -1,9 +1,22 @@ // Copyright © 2023-2024 Apple Inc. +#include + #include "mlx/backend/common/utils.h" namespace mlx::core { +std::filesystem::path current_binary_dir() { + static std::filesystem::path binary_dir = []() { + Dl_info info; + if (!dladdr(reinterpret_cast(¤t_binary_dir), &info)) { + throw std::runtime_error("Unable to get current binary dir."); + } + return std::filesystem::path(info.dli_fname).parent_path(); + }(); + return binary_dir; +} + std::tuple> collapse_contiguous_dims( const Shape& shape, const std::vector& strides, @@ -101,4 +114,118 @@ std::pair collapse_contiguous_dims( return collapse_contiguous_dims(a.shape(), a.strides(), size_cap); } +Dims get_block_dims_common(int dim0, int dim1, int dim2, int pow2 /* = 10 */) { + int pows[3] = {0, 0, 0}; + int sum = 0; + while (true) { + int presum = sum; + // Check all the pows + if (dim0 >= (1 << (pows[0] + 1))) { + pows[0]++; + sum++; + } + if (sum == 10) { + break; + } + if (dim1 >= (1 << (pows[1] + 1))) { + pows[1]++; + sum++; + } + if (sum == 10) { + break; + } + if (dim2 >= (1 << (pows[2] + 1))) { + pows[2]++; + sum++; + } + if (sum == presum || sum == pow2) { + break; + } + } + return std::make_tuple(1ul << pows[0], 1ul << pows[1], 1ul << pows[2]); +} + +Dims get_2d_grid_dims_common(const Shape& shape, const Strides& strides) { + // Dims with strides of 0 are ignored as they + // correspond to broadcasted dimensions + size_t grid_x = 1; + size_t grid_y = 1; + for (int i = 0; i < shape.size(); ++i) { + if (strides[i] == 0) { + continue; + } + if (grid_x * shape[i] < UINT32_MAX) { + grid_x *= shape[i]; + } else { + grid_y *= shape[i]; + } + } + if (grid_y > UINT32_MAX || grid_x > UINT32_MAX) { + throw std::runtime_error("Unable to safely factor shape."); + } + if (grid_y > grid_x) { + std::swap(grid_x, grid_y); + } + return std::make_tuple( + static_cast(grid_x), static_cast(grid_y), 1); +} + +Dims get_2d_grid_dims_common( + const Shape& shape, + const Strides& strides, + size_t divisor) { + // Compute the 2d grid dimensions such that the total size of the grid is + // divided by divisor. + size_t grid_x = 1; + size_t grid_y = 1; + for (int i = 0; i < shape.size(); ++i) { + if (strides[i] == 0) { + continue; + } + + // No need to add this shape we can just remove it from the divisor. + if (divisor % shape[i] == 0) { + divisor /= shape[i]; + continue; + } + + if (grid_x * shape[i] < UINT32_MAX) { + grid_x *= shape[i]; + } else { + grid_y *= shape[i]; + } + + if (divisor > 1) { + if (grid_x % divisor == 0) { + grid_x /= divisor; + divisor = 1; + } else if (grid_y % divisor == 0) { + grid_y /= divisor; + divisor = 1; + } + } + } + if (grid_y > UINT32_MAX || grid_x > UINT32_MAX) { + throw std::runtime_error("Unable to safely factor shape."); + } + if (grid_y > grid_x) { + std::swap(grid_x, grid_y); + } + if (divisor > 1) { + grid_x = ((grid_x + divisor - 1) / divisor) * divisor; + } + return std::make_tuple( + static_cast(grid_x), static_cast(grid_y), 1); +} + +std::pair get_grid_and_block_common(int dim0, int dim1, int dim2) { + auto [bx, by, bz] = get_block_dims_common(dim0, dim1, dim2); + auto gx = (dim0 + bx - 1) / bx; + auto gy = (dim1 + by - 1) / by; + auto gz = (dim2 + bz - 1) / bz; + + return std::make_pair( + std::make_tuple(gx, gy, gz), std::make_tuple(bx, by, bz)); +} + } // namespace mlx::core diff --git a/mlx/backend/common/utils.h b/mlx/backend/common/utils.h index 20a65d7b1..0f9846086 100644 --- a/mlx/backend/common/utils.h +++ b/mlx/backend/common/utils.h @@ -2,12 +2,17 @@ #pragma once +#include +#include #include #include "mlx/array.h" namespace mlx::core { +// Return the directory that contains current shared library. +std::filesystem::path current_binary_dir(); + inline int64_t elem_to_loc(int elem, const Shape& shape, const Strides& strides) { int64_t loc = 0; @@ -70,6 +75,31 @@ std::pair collapse_contiguous_dims( const array& a, int64_t size_cap = std::numeric_limits::max()); +// Compute the thread block dimensions which fit the given +// input dimensions. +// - The thread block dimensions will be powers of two +// - The thread block size will be less than 2^pow2 +using Dims = std::tuple; +Dims get_block_dims_common(int dim0, int dim1, int dim2, int pow2 = 10); + +// Computes a 2D grid where each element is < UINT_MAX +// Assumes: +// - overall size (product of non-broadcasted dimensions) is < UINT_MAX^2 +// - shape and strides correspond to a contiguous (no holes) but +// possibly broadcasted array +Dims get_2d_grid_dims_common(const Shape& shape, const Strides& strides); + +// Same as above but we do an implicit division with divisor. +// Basically, equivalent to factorizing +// Prod(s \forall s in shape if strides[s] > 0) / divisor. +Dims get_2d_grid_dims_common( + const Shape& shape, + const Strides& strides, + size_t divisor); + +// Get both the block and a grid of blocks that covers dim0, dim1 and dim2. +std::pair get_grid_and_block_common(int dim0, int dim1, int dim2); + struct ContiguousIterator { inline void step() { int dims = shape_.size(); @@ -165,4 +195,11 @@ void shared_buffer_reshape( const array& in, const Strides& out_strides, array& out); + +template +inline std::vector remove_index(std::vector vec, size_t index) { + vec.erase(std::next(vec.begin(), index)); + return vec; +} + } // namespace mlx::core diff --git a/mlx/backend/cpu/CMakeLists.txt b/mlx/backend/cpu/CMakeLists.txt index 152f33b17..9d322c4c4 100644 --- a/mlx/backend/cpu/CMakeLists.txt +++ b/mlx/backend/cpu/CMakeLists.txt @@ -40,11 +40,13 @@ add_dependencies(mlx cpu_compiled_preamble) target_sources( mlx - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cpp + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/available.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cpp ${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp ${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp ${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp ${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/eig.cpp ${CMAKE_CURRENT_SOURCE_DIR}/eigh.cpp ${CMAKE_CURRENT_SOURCE_DIR}/encoder.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp diff --git a/mlx/backend/cpu/arg_reduce.cpp b/mlx/backend/cpu/arg_reduce.cpp index a8ba3efe2..66468912d 100644 --- a/mlx/backend/cpu/arg_reduce.cpp +++ b/mlx/backend/cpu/arg_reduce.cpp @@ -14,10 +14,8 @@ template void arg_reduce(const array& in, array& out, const OpT& op, int axis) { auto axis_size = in.shape()[axis]; auto axis_stride = in.strides()[axis]; - Strides strides = in.strides(); - Shape shape = in.shape(); - strides.erase(strides.begin() + axis); - shape.erase(shape.begin() + axis); + Strides strides = remove_index(in.strides(), axis); + Shape shape = remove_index(in.shape(), axis); auto in_ptr = in.data(); auto out_ptr = out.data(); diff --git a/mlx/backend/cpu/available.cpp b/mlx/backend/cpu/available.cpp new file mode 100644 index 000000000..0449d49b9 --- /dev/null +++ b/mlx/backend/cpu/available.cpp @@ -0,0 +1,11 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cpu/available.h" + +namespace mlx::core::cpu { + +bool is_available() { + return true; +} + +} // namespace mlx::core::cpu diff --git a/mlx/backend/cpu/available.h b/mlx/backend/cpu/available.h new file mode 100644 index 000000000..1df95def2 --- /dev/null +++ b/mlx/backend/cpu/available.h @@ -0,0 +1,9 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +namespace mlx::core::cpu { + +bool is_available(); + +} // namespace mlx::core::cpu diff --git a/mlx/backend/cpu/binary.cpp b/mlx/backend/cpu/binary.cpp index dbdab6a06..35aa2a3e0 100644 --- a/mlx/backend/cpu/binary.cpp +++ b/mlx/backend/cpu/binary.cpp @@ -172,9 +172,12 @@ void binary_float( case bfloat16: binary_op(a, b, out, bopt); break; + case complex64: + binary_op(a, b, out, bopt); + break; default: throw std::runtime_error( - "[binary_float] Only supports non-complex floating point types."); + "[binary_float] Only supports floating point types."); } }); } diff --git a/mlx/backend/cpu/compiled.cpp b/mlx/backend/cpu/compiled.cpp index 9da9c14e8..d85114987 100644 --- a/mlx/backend/cpu/compiled.cpp +++ b/mlx/backend/cpu/compiled.cpp @@ -40,7 +40,10 @@ struct CompilerCache { std::shared_mutex mtx; }; -static CompilerCache cache{}; +static CompilerCache& cache() { + static CompilerCache cache_; + return cache_; +}; // GPU compile is always available if the GPU is available and since we are in // this file CPU compile is also available. @@ -56,14 +59,16 @@ void* compile( const std::string& kernel_name, const std::function& source_builder) { { - std::shared_lock lock(cache.mtx); - if (auto it = cache.kernels.find(kernel_name); it != cache.kernels.end()) { + std::shared_lock lock(cache().mtx); + if (auto it = cache().kernels.find(kernel_name); + it != cache().kernels.end()) { return it->second; } } - std::unique_lock lock(cache.mtx); - if (auto it = cache.kernels.find(kernel_name); it != cache.kernels.end()) { + std::unique_lock lock(cache().mtx); + if (auto it = cache().kernels.find(kernel_name); + it != cache().kernels.end()) { return it->second; } std::string source_code = source_builder(); @@ -120,10 +125,10 @@ void* compile( } // load library - cache.libs.emplace_back(shared_lib_path); + cache().libs.emplace_back(shared_lib_path); // Load function - void* fun = dlsym(cache.libs.back().lib, kernel_name.c_str()); + void* fun = dlsym(cache().libs.back().lib, kernel_name.c_str()); if (!fun) { std::ostringstream msg; msg << "[Compile::eval_cpu] Failed to load compiled function " @@ -131,7 +136,7 @@ void* compile( << dlerror(); throw std::runtime_error(msg.str()); } - cache.kernels.insert({kernel_name, fun}); + cache().kernels.insert({kernel_name, fun}); return fun; } @@ -141,18 +146,9 @@ inline void build_kernel( const std::vector& inputs, const std::vector& outputs, const std::vector& tape, - const std::unordered_set& constant_ids, + const std::function& is_constant, bool contiguous, int ndim) { - // All outputs should have the exact same shape and will be row contiguous - auto output_shape = outputs[0].shape(); - auto output_strides = outputs[0].strides(); - - // Constants are scalars that are captured by value and cannot change - auto is_constant = [&constant_ids](const array& x) { - return constant_ids.find(x.id()) != constant_ids.end(); - }; - NodeNamer namer; #ifdef _MSC_VER @@ -165,14 +161,15 @@ inline void build_kernel( // Add the input arguments int cnt = 0; - for (auto& x : inputs) { - auto& xname = namer.get_name(x); - + for (size_t i = 0; i < inputs.size(); ++i) { // Skip constants from the input list - if (is_constant(x)) { + if (is_constant(i)) { continue; } + const auto& x = inputs[i]; + auto& xname = namer.get_name(x); + auto tstr = get_type_string(x.dtype()); os << " " << tstr << "* " << xname << " = (" << tstr << "*)args[" << cnt++ << "];" << std::endl; @@ -206,10 +203,11 @@ inline void build_kernel( } // Read the inputs in tmps - for (auto& x : inputs) { + for (size_t i = 0; i < inputs.size(); ++i) { + const auto& x = inputs[i]; auto& xname = namer.get_name(x); - if (is_constant(x)) { + if (is_constant(i)) { os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "; print_constant(os, x); os << ";" << std::endl; @@ -233,7 +231,7 @@ inline void build_kernel( os << "static_cast<" << get_type_string(x.dtype()) << ">(tmp_" << namer.get_name(x.inputs()[0]) << ");" << std::endl; } else { - x.primitive().print(os); + os << x.primitive().name(); os << "()("; for (int i = 0; i < x.inputs().size() - 1; i++) { os << "tmp_" << namer.get_name(x.inputs()[i]) << ", "; @@ -259,8 +257,9 @@ inline void build_kernel( } else { for (int d = ndim - 1; d >= 0; --d) { // Update pointers - for (auto& x : inputs) { - if (is_constant(x) || is_scalar(x)) { + for (size_t i = 0; i < inputs.size(); ++i) { + const auto& x = inputs[i]; + if (is_constant(i) || is_scalar(x)) { continue; } auto& xname = namer.get_name(x); @@ -282,65 +281,37 @@ inline void build_kernel( void Compiled::eval_cpu( const std::vector& inputs, std::vector& outputs) { - if (kernel_lib_.empty()) { - kernel_lib_ = build_lib_name(inputs_, outputs_, tape_, constant_ids_); - } - - // Figure out which kernel we are using - auto& shape = outputs[0].shape(); - auto contiguous = compiled_check_contiguity(inputs, shape); auto& encoder = cpu::get_command_encoder(stream()); - // Handle all broadcasting and collect function input arguments + // Collapse contiguous dims to route to a faster kernel if possible. Also + // handle all broadcasting. + auto [contiguous, shape, strides] = + compiled_collapse_contiguous_dims(inputs, outputs[0], is_constant_); + + // Collect function input arguments. std::vector args; - std::vector> strides; - for (int i = 0; i < inputs.size(); i++) { - // Skip constants. - if (constant_ids_.find(inputs_[i].id()) != constant_ids_.end()) { + int strides_index = 1; + for (size_t i = 0; i < inputs.size(); ++i) { + if (is_constant_(i)) { continue; } - auto& x = inputs[i]; + const auto& x = inputs[i]; encoder.set_input_array(x); args.push_back((void*)x.data()); - - if (contiguous || is_scalar(x)) { - continue; + if (!contiguous && !is_scalar(x)) { + args.push_back(strides[strides_index++].data()); } - - // Broadcast the input to the output shape. - std::vector xstrides; - int j = 0; - for (; j < shape.size() - x.ndim(); j++) { - if (shape[j] == 1) { - xstrides.push_back(outputs[0].strides()[j]); - } else { - xstrides.push_back(0); - } - } - for (int i = 0; i < x.ndim(); i++, j++) { - if (x.shape(i) == 1) { - if (shape[j] == 1) { - xstrides.push_back(outputs[0].strides()[j]); - } else { - xstrides.push_back(0); - } - } else { - xstrides.push_back(x.strides()[i]); - } - } - strides.push_back(std::move(xstrides)); - args.push_back(strides.back().data()); } // Get the kernel name from the lib int ndim = shape.size(); auto kernel_name = kernel_lib_ + (contiguous ? "_contiguous" : "_strided_"); if (!contiguous) { - kernel_name += std::to_string(shape.size()); + kernel_name += std::to_string(ndim); } // Get the function - auto fn_ptr = compile(kernel_name, [&]() { + auto fn_ptr = compile(kernel_name, [&, contiguous = contiguous]() { std::ostringstream kernel; kernel << get_kernel_preamble() << std::endl; kernel << "extern \"C\" {" << std::endl; @@ -350,7 +321,7 @@ void Compiled::eval_cpu( inputs_, outputs_, tape_, - constant_ids_, + is_constant_, contiguous, ndim); // Close extern "C" @@ -358,26 +329,22 @@ void Compiled::eval_cpu( return kernel.str(); }); - compiled_allocate_outputs( - inputs, outputs, inputs_, constant_ids_, contiguous); + compiled_allocate_outputs(inputs, outputs, is_constant_, contiguous); for (auto& x : outputs) { args.push_back(x.data()); encoder.set_output_array(x); } - Shape out_shape; if (!contiguous) { - out_shape = outputs[0].shape(); - args.push_back((void*)out_shape.data()); + args.push_back((void*)shape.data()); } else { args.push_back((void*)outputs[0].data_size()); } auto fun = (void (*)(void**))fn_ptr; - encoder.dispatch( - [fun, - args = std::move(args), - strides = std::move(strides), - out_shape = std::move(out_shape)]() mutable { fun(args.data()); }); + encoder.dispatch([fun, + args = std::move(args), + strides = std::move(strides), + shape = std::move(shape)]() mutable { fun(args.data()); }); } } // namespace mlx::core diff --git a/mlx/backend/cpu/conv.cpp b/mlx/backend/cpu/conv.cpp index d52f92f8b..e5636b3b8 100644 --- a/mlx/backend/cpu/conv.cpp +++ b/mlx/backend/cpu/conv.cpp @@ -22,7 +22,8 @@ void slow_conv_1D( const array& in, const array& wt, array out, - const std::vector& padding, + const std::vector& padding_lo, + const std::vector& padding_hi, const std::vector& wt_strides, const std::vector& wt_dilation, const std::vector& in_dilation, @@ -60,7 +61,8 @@ void slow_conv_1D( out_stride_O = out.strides()[2], flip, - padding = padding[0], + padding_lo = padding_lo[0], + padding_hi = padding_hi[0], wt_stride = wt_strides[0], wt_dilation = wt_dilation[0], in_dilation = in_dilation[0]]() mutable { @@ -77,7 +79,7 @@ void slow_conv_1D( const T* wt_ptr = filter_wt_ptr + wh * wt_stride_H; int wh_flip = flip ? (wH - wh - 1) : wh; - int ih = oh * wt_stride - padding + wh_flip * wt_dilation; + int ih = oh * wt_stride - padding_lo + wh_flip * wt_dilation; auto ih_div = std::div(ih, in_dilation); @@ -109,7 +111,8 @@ void slow_conv_2D( const array& in, const array& wt, array out, - const std::vector& padding, + const std::vector& padding_lo, + const std::vector& padding_hi, const std::vector& wt_strides, const std::vector& wt_dilation, const std::vector& in_dilation, @@ -120,230 +123,235 @@ void slow_conv_2D( encoder.set_input_array(wt); encoder.set_output_array(out); - encoder.dispatch([st_wt_ptr = wt.data(), - st_in_ptr = in.data(), - st_out_ptr = out.data(), + encoder.dispatch( + [st_wt_ptr = wt.data(), + st_in_ptr = in.data(), + st_out_ptr = out.data(), - N = in.shape( - 0), // Batch size, should be the same as out.shape(0) - iH = 1 + - in_dilation[0] * (in.shape(1) - 1), // Input spatial dim - iW = 1 + - in_dilation[1] * (in.shape(2) - 1), // Input spatial dim - C = in.shape(3), // In channels - oH = out.shape(1), // Output spatial dim - oW = out.shape(2), // Output spatial dim - O = wt.shape(0), // Out channels - wH = wt.shape(1), // Weight spatial dim - wW = wt.shape(2), // Weight spatial dim + N = in.shape(0), // Batch size, should be the same as out.shape(0) + iH = 1 + in_dilation[0] * (in.shape(1) - 1), // Input spatial dim + iW = 1 + in_dilation[1] * (in.shape(2) - 1), // Input spatial dim + C = in.shape(3), // In channels + oH = out.shape(1), // Output spatial dim + oW = out.shape(2), // Output spatial dim + O = wt.shape(0), // Out channels + wH = wt.shape(1), // Weight spatial dim + wW = wt.shape(2), // Weight spatial dim - groups = in.shape(3) / wt.shape(3), - C_per_group = wt.shape(3), + groups = in.shape(3) / wt.shape(3), + C_per_group = wt.shape(3), - in_stride_N = in.strides()[0], - in_stride_H = in.strides()[1], - in_stride_W = in.strides()[2], - in_stride_C = in.strides()[3], + in_stride_N = in.strides()[0], + in_stride_H = in.strides()[1], + in_stride_W = in.strides()[2], + in_stride_C = in.strides()[3], - wt_stride_O = wt.strides()[0], - wt_stride_H = wt.strides()[1], - wt_stride_W = wt.strides()[2], - wt_stride_C = wt.strides()[3], + wt_stride_O = wt.strides()[0], + wt_stride_H = wt.strides()[1], + wt_stride_W = wt.strides()[2], + wt_stride_C = wt.strides()[3], - out_stride_N = out.strides()[0], - out_stride_H = out.strides()[1], - out_stride_W = out.strides()[2], - out_stride_O = out.strides()[3], + out_stride_N = out.strides()[0], + out_stride_H = out.strides()[1], + out_stride_W = out.strides()[2], + out_stride_O = out.strides()[3], - padding, - wt_strides, - wt_dilation, - in_dilation, - flip]() mutable { - bool is_idil_one = in_dilation[0] == 1 && in_dilation[1] == 1; + padding_lo, + padding_hi, + wt_strides, + wt_dilation, + in_dilation, + flip]() mutable { + bool is_idil_one = in_dilation[0] == 1 && in_dilation[1] == 1; - const int O_per_group = O / groups; - auto pt_conv_no_checks = [&](const T* in_ptr, - const T* wt_ptr, - T* out_ptr, - int oh, - int ow) { - out_ptr += oh * out_stride_H + ow * out_stride_W; - int ih_base = oh * wt_strides[0] - padding[0]; - int iw_base = ow * wt_strides[1] - padding[1]; + const int O_per_group = O / groups; + auto pt_conv_no_checks = + [&](const T* in_ptr, const T* wt_ptr, T* out_ptr, int oh, int ow) { + out_ptr += oh * out_stride_H + ow * out_stride_W; + int ih_base = oh * wt_strides[0] - padding_lo[0]; + int iw_base = ow * wt_strides[1] - padding_lo[1]; - for (int g = 0; g < groups; ++g) { - for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) { - float r = 0.; + for (int g = 0; g < groups; ++g) { + for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) { + float r = 0.; - for (int wh = 0; wh < wH; ++wh) { - for (int ww = 0; ww < wW; ++ww) { - int wh_flip = flip ? wH - wh - 1 : wh; - int ww_flip = flip ? wW - ww - 1 : ww; - int ih = ih_base + wh_flip * wt_dilation[0]; - int iw = iw_base + ww_flip * wt_dilation[1]; + for (int wh = 0; wh < wH; ++wh) { + for (int ww = 0; ww < wW; ++ww) { + int wh_flip = flip ? wH - wh - 1 : wh; + int ww_flip = flip ? wW - ww - 1 : ww; + int ih = ih_base + wh_flip * wt_dilation[0]; + int iw = iw_base + ww_flip * wt_dilation[1]; - const T* wt_ptr_pt = wt_ptr + wh * wt_stride_H + ww * wt_stride_W; - const T* in_ptr_pt = in_ptr + ih * in_stride_H + iw * in_stride_W; + const T* wt_ptr_pt = + wt_ptr + wh * wt_stride_H + ww * wt_stride_W; + const T* in_ptr_pt = + in_ptr + ih * in_stride_H + iw * in_stride_W; - for (int c = g * C_per_group; c < (g + 1) * C_per_group; ++c) { - r += static_cast(in_ptr_pt[c * in_stride_C]) * - static_cast( - wt_ptr_pt[(c % C_per_group) * wt_stride_C]); - } // c - } // ww - } // wh + for (int c = g * C_per_group; c < (g + 1) * C_per_group; + ++c) { + r += static_cast(in_ptr_pt[c * in_stride_C]) * + static_cast( + wt_ptr_pt[(c % C_per_group) * wt_stride_C]); + } // c + } // ww + } // wh - out_ptr[0] = static_cast(r); - out_ptr += out_stride_O; - wt_ptr += wt_stride_O; - } // o - } // g - }; + out_ptr[0] = static_cast(r); + out_ptr += out_stride_O; + wt_ptr += wt_stride_O; + } // o + } // g + }; - int jump_h = flip ? -wt_dilation[0] : wt_dilation[0]; - int jump_w = flip ? -wt_dilation[1] : wt_dilation[1]; + int jump_h = flip ? -wt_dilation[0] : wt_dilation[0]; + int jump_w = flip ? -wt_dilation[1] : wt_dilation[1]; - int init_h = (flip ? (wH - 1) * wt_dilation[0] : 0); - int init_w = (flip ? (wW - 1) * wt_dilation[1] : 0); + int init_h = (flip ? (wH - 1) * wt_dilation[0] : 0); + int init_w = (flip ? (wW - 1) * wt_dilation[1] : 0); - int f_wgt_jump_h = - std::lcm(in_dilation[0], wt_dilation[0]) / wt_dilation[0]; - int f_wgt_jump_w = - std::lcm(in_dilation[1], wt_dilation[1]) / wt_dilation[1]; + int f_wgt_jump_h = + std::lcm(in_dilation[0], wt_dilation[0]) / wt_dilation[0]; + int f_wgt_jump_w = + std::lcm(in_dilation[1], wt_dilation[1]) / wt_dilation[1]; - int f_out_jump_h = std::lcm(in_dilation[0], wt_strides[0]) / wt_strides[0]; - int f_out_jump_w = std::lcm(in_dilation[1], wt_strides[1]) / wt_strides[1]; + int f_out_jump_h = + std::lcm(in_dilation[0], wt_strides[0]) / wt_strides[0]; + int f_out_jump_w = + std::lcm(in_dilation[1], wt_strides[1]) / wt_strides[1]; - std::vector base_h(f_out_jump_h); - std::vector base_w(f_out_jump_w); + std::vector base_h(f_out_jump_h); + std::vector base_w(f_out_jump_w); - for (int i = 0; i < f_out_jump_h; ++i) { - int ih_loop = i * wt_strides[0] - padding[0] + init_h; + for (int i = 0; i < f_out_jump_h; ++i) { + int ih_loop = i * wt_strides[0] - padding_lo[0] + init_h; - int wh_base = 0; - while (wh_base < wH && ih_loop % in_dilation[0] != 0) { - wh_base++; - ih_loop += jump_h; - } + int wh_base = 0; + while (wh_base < wH && ih_loop % in_dilation[0] != 0) { + wh_base++; + ih_loop += jump_h; + } - base_h[i] = wh_base; - } + base_h[i] = wh_base; + } - for (int j = 0; j < f_out_jump_w; ++j) { - int iw_loop = j * wt_strides[1] - padding[1] + init_w; + for (int j = 0; j < f_out_jump_w; ++j) { + int iw_loop = j * wt_strides[1] - padding_lo[1] + init_w; - int ww_base = 0; - while (ww_base < wW && iw_loop % in_dilation[1] != 0) { - ww_base++; - iw_loop += jump_w; - } + int ww_base = 0; + while (ww_base < wW && iw_loop % in_dilation[1] != 0) { + ww_base++; + iw_loop += jump_w; + } - base_w[j] = ww_base; - } + base_w[j] = ww_base; + } - auto pt_conv_all_checks = - [&](const T* in_ptr, const T* wt_ptr, T* out_ptr, int oh, int ow) { - out_ptr += oh * out_stride_H + ow * out_stride_W; + auto pt_conv_all_checks = + [&](const T* in_ptr, const T* wt_ptr, T* out_ptr, int oh, int ow) { + out_ptr += oh * out_stride_H + ow * out_stride_W; - int ih_base = oh * wt_strides[0] - padding[0]; - int iw_base = ow * wt_strides[1] - padding[1]; + int ih_base = oh * wt_strides[0] - padding_lo[0]; + int iw_base = ow * wt_strides[1] - padding_lo[1]; - int wh_base = base_h[oh % f_out_jump_h]; - int ww_base = base_w[ow % f_out_jump_w]; + int wh_base = base_h[oh % f_out_jump_h]; + int ww_base = base_w[ow % f_out_jump_w]; - for (int g = 0; g < groups; ++g) { - for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) { - float r = 0.; + for (int g = 0; g < groups; ++g) { + for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) { + float r = 0.; - for (int wh = wh_base; wh < wH; wh += f_wgt_jump_h) { - for (int ww = ww_base; ww < wW; ww += f_wgt_jump_w) { - int wh_flip = flip ? wH - wh - 1 : wh; - int ww_flip = flip ? wW - ww - 1 : ww; - int ih = ih_base + wh_flip * wt_dilation[0]; - int iw = iw_base + ww_flip * wt_dilation[1]; + for (int wh = wh_base; wh < wH; wh += f_wgt_jump_h) { + for (int ww = ww_base; ww < wW; ww += f_wgt_jump_w) { + int wh_flip = flip ? wH - wh - 1 : wh; + int ww_flip = flip ? wW - ww - 1 : ww; + int ih = ih_base + wh_flip * wt_dilation[0]; + int iw = iw_base + ww_flip * wt_dilation[1]; - if (ih >= 0 && ih < iH && iw >= 0 && iw < iW) { - const T* wt_ptr_pt = - wt_ptr + wh * wt_stride_H + ww * wt_stride_W; + if (ih >= 0 && ih < iH && iw >= 0 && iw < iW) { + const T* wt_ptr_pt = + wt_ptr + wh * wt_stride_H + ww * wt_stride_W; - int ih_dil = !is_idil_one ? (ih / in_dilation[0]) : ih; - int iw_dil = !is_idil_one ? (iw / in_dilation[1]) : iw; + int ih_dil = !is_idil_one ? (ih / in_dilation[0]) : ih; + int iw_dil = !is_idil_one ? (iw / in_dilation[1]) : iw; - const T* in_ptr_pt = - in_ptr + ih_dil * in_stride_H + iw_dil * in_stride_W; + const T* in_ptr_pt = in_ptr + ih_dil * in_stride_H + + iw_dil * in_stride_W; - for (int c = g * C_per_group; c < (g + 1) * C_per_group; - ++c) { - r += static_cast(in_ptr_pt[c * in_stride_C]) * - static_cast( - wt_ptr_pt[(c % C_per_group) * wt_stride_C]); - } // c + for (int c = g * C_per_group; c < (g + 1) * C_per_group; + ++c) { + r += static_cast(in_ptr_pt[c * in_stride_C]) * + static_cast( + wt_ptr_pt[(c % C_per_group) * wt_stride_C]); + } // c - } // ih, iw check - } // ww - } // wh + } // ih, iw check + } // ww + } // wh - out_ptr[0] = static_cast(r); - out_ptr += out_stride_O; - wt_ptr += wt_stride_O; - } // o - } // g - }; + out_ptr[0] = static_cast(r); + out_ptr += out_stride_O; + wt_ptr += wt_stride_O; + } // o + } // g + }; - int oH_border_0 = 0; - int oH_border_1 = - is_idil_one ? ((padding[0] + wt_strides[0] - 1) / wt_strides[0]) : oH; - int oH_border_2 = std::max( - oH_border_1, (iH + padding[0] - wH * wt_dilation[0]) / wt_strides[0]); - int oH_border_3 = oH; + int oH_border_0 = 0; + int oH_border_1 = is_idil_one + ? ((padding_lo[0] + wt_strides[0] - 1) / wt_strides[0]) + : oH; + int oH_border_2 = std::max( + oH_border_1, + (iH + padding_lo[0] - wH * wt_dilation[0]) / wt_strides[0]); + int oH_border_3 = oH; - int oW_border_0 = 0; - int oW_border_1 = - is_idil_one ? ((padding[1] + wt_strides[1] - 1) / wt_strides[1]) : oW; - int oW_border_2 = std::max( - oW_border_1, (iW + padding[1] - wW * wt_dilation[1]) / wt_strides[1]); - int oW_border_3 = oW; + int oW_border_0 = 0; + int oW_border_1 = is_idil_one + ? ((padding_lo[1] + wt_strides[1] - 1) / wt_strides[1]) + : oW; + int oW_border_2 = std::max( + oW_border_1, + (iW + padding_lo[1] - wW * wt_dilation[1]) / wt_strides[1]); + int oW_border_3 = oW; - for (int n = 0; n < N; ++n) { - // Case 1: oh might put us out of bounds - for (int oh = oH_border_0; oh < oH_border_1; ++oh) { - for (int ow = 0; ow < oW; ++ow) { - pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow); - } // ow - } // oh + for (int n = 0; n < N; ++n) { + // Case 1: oh might put us out of bounds + for (int oh = oH_border_0; oh < oH_border_1; ++oh) { + for (int ow = 0; ow < oW; ++ow) { + pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow); + } // ow + } // oh - // Case 2: oh in bounds - for (int oh = oH_border_1; oh < oH_border_2; ++oh) { - // Case a: ow might put us out of bounds - for (int ow = oW_border_0; ow < oW_border_1; ++ow) { - pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow); - } // ow + // Case 2: oh in bounds + for (int oh = oH_border_1; oh < oH_border_2; ++oh) { + // Case a: ow might put us out of bounds + for (int ow = oW_border_0; ow < oW_border_1; ++ow) { + pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow); + } // ow - // Case b: ow in bounds - for (int ow = oW_border_1; ow < oW_border_2; ++ow) { - pt_conv_no_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow); - } // ow + // Case b: ow in bounds + for (int ow = oW_border_1; ow < oW_border_2; ++ow) { + pt_conv_no_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow); + } // ow - // Case c: ow might put us out of bounds - for (int ow = oW_border_2; ow < oW_border_3; ++ow) { - pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow); - } // ow + // Case c: ow might put us out of bounds + for (int ow = oW_border_2; ow < oW_border_3; ++ow) { + pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow); + } // ow - } // oh + } // oh - // Case 3: oh might put us out of bounds - for (int oh = oH_border_2; oh < oH_border_3; ++oh) { - for (int ow = 0; ow < oW; ++ow) { - pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow); - } // ow - } // oh + // Case 3: oh might put us out of bounds + for (int oh = oH_border_2; oh < oH_border_3; ++oh) { + for (int ow = 0; ow < oW; ++ow) { + pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow); + } // ow + } // oh - st_in_ptr += in_stride_N; - st_out_ptr += out_stride_N; + st_in_ptr += in_stride_N; + st_out_ptr += out_stride_N; - } // n - }); + } // n + }); } template @@ -351,7 +359,8 @@ void slow_conv_3D( const array& in, const array& wt, array out, - const std::vector& padding, + const std::vector& padding_lo, + const std::vector& padding_hi, const std::vector& wt_strides, const std::vector& wt_dilation, const std::vector& in_dilation, @@ -400,7 +409,8 @@ void slow_conv_3D( out_stride_H = out.strides()[2], out_stride_W = out.strides()[3], out_stride_O = out.strides()[4], - padding, + padding_lo, + padding_hi, wt_strides, wt_dilation, in_dilation, @@ -415,9 +425,9 @@ void slow_conv_3D( int oh, int ow) { out_ptr += od * out_stride_D + oh * out_stride_H + ow * out_stride_W; - int id_base = od * wt_strides[0] - padding[0]; - int ih_base = oh * wt_strides[1] - padding[1]; - int iw_base = ow * wt_strides[2] - padding[2]; + int id_base = od * wt_strides[0] - padding_lo[0]; + int ih_base = oh * wt_strides[1] - padding_lo[1]; + int iw_base = ow * wt_strides[2] - padding_lo[2]; for (int o = 0; o < O; ++o) { float r = 0.; @@ -478,7 +488,7 @@ void slow_conv_3D( std::vector base_w(f_out_jump_w); for (int i = 0; i < f_out_jump_d; ++i) { - int id_loop = i * wt_strides[0] - padding[0] + init_d; + int id_loop = i * wt_strides[0] - padding_lo[0] + init_d; int wd_base = 0; while (wd_base < wD && id_loop % in_dilation[0] != 0) { @@ -490,7 +500,7 @@ void slow_conv_3D( } for (int i = 0; i < f_out_jump_h; ++i) { - int ih_loop = i * wt_strides[1] - padding[1] + init_h; + int ih_loop = i * wt_strides[1] - padding_lo[1] + init_h; int wh_base = 0; while (wh_base < wH && ih_loop % in_dilation[1] != 0) { @@ -502,7 +512,7 @@ void slow_conv_3D( } for (int j = 0; j < f_out_jump_w; ++j) { - int iw_loop = j * wt_strides[2] - padding[2] + init_w; + int iw_loop = j * wt_strides[2] - padding_lo[2] + init_w; int ww_base = 0; while (ww_base < wW && iw_loop % in_dilation[2] != 0) { @@ -521,9 +531,9 @@ void slow_conv_3D( int ow) { out_ptr += od * out_stride_D + oh * out_stride_H + ow * out_stride_W; - int id_base = od * wt_strides[0] - padding[0]; - int ih_base = oh * wt_strides[1] - padding[1]; - int iw_base = ow * wt_strides[2] - padding[2]; + int id_base = od * wt_strides[0] - padding_lo[0]; + int ih_base = oh * wt_strides[1] - padding_lo[1]; + int iw_base = ow * wt_strides[2] - padding_lo[2]; int wd_base = base_d[od % f_out_jump_d]; int wh_base = base_h[oh % f_out_jump_h]; @@ -573,24 +583,30 @@ void slow_conv_3D( }; int oD_border_0 = 0; - int oD_border_1 = - is_idil_one ? ((padding[0] + wt_strides[0] - 1) / wt_strides[0]) : oD; + int oD_border_1 = is_idil_one + ? ((padding_lo[0] + wt_strides[0] - 1) / wt_strides[0]) + : oD; int oD_border_2 = std::max( - oD_border_1, (iD + padding[0] - wD * wt_dilation[0]) / wt_strides[0]); + oD_border_1, + (iD + padding_lo[0] - wD * wt_dilation[0]) / wt_strides[0]); int oD_border_3 = oD; int oH_border_0 = 0; - int oH_border_1 = - is_idil_one ? ((padding[1] + wt_strides[1] - 1) / wt_strides[1]) : oH; + int oH_border_1 = is_idil_one + ? ((padding_lo[1] + wt_strides[1] - 1) / wt_strides[1]) + : oH; int oH_border_2 = std::max( - oH_border_1, (iH + padding[1] - wH * wt_dilation[1]) / wt_strides[1]); + oH_border_1, + (iH + padding_lo[1] - wH * wt_dilation[1]) / wt_strides[1]); int oH_border_3 = oH; int oW_border_0 = 0; - int oW_border_1 = - is_idil_one ? ((padding[2] + wt_strides[2] - 1) / wt_strides[2]) : oW; + int oW_border_1 = is_idil_one + ? ((padding_lo[2] + wt_strides[2] - 1) / wt_strides[2]) + : oW; int oW_border_2 = std::max( - oW_border_1, (iW + padding[2] - wW * wt_dilation[2]) / wt_strides[2]); + oW_border_1, + (iW + padding_lo[2] - wW * wt_dilation[2]) / wt_strides[2]); int oW_border_3 = oW; for (int n = 0; n < N; ++n) { @@ -658,7 +674,8 @@ void dispatch_slow_conv_1D( const array& in, const array& wt, array out, - const std::vector& padding, + const std::vector& padding_lo, + const std::vector& padding_hi, const std::vector& wt_strides, const std::vector& wt_dilation, const std::vector& in_dilation, @@ -669,7 +686,8 @@ void dispatch_slow_conv_1D( in, wt, out, - padding, + padding_lo, + padding_hi, wt_strides, wt_dilation, in_dilation, @@ -680,7 +698,8 @@ void dispatch_slow_conv_1D( in, wt, out, - padding, + padding_lo, + padding_hi, wt_strides, wt_dilation, in_dilation, @@ -691,7 +710,8 @@ void dispatch_slow_conv_1D( in, wt, out, - padding, + padding_lo, + padding_hi, wt_strides, wt_dilation, in_dilation, @@ -707,7 +727,8 @@ void dispatch_slow_conv_2D( const array& in, const array& wt, array out, - const std::vector& padding, + const std::vector& padding_lo, + const std::vector& padding_hi, const std::vector& wt_strides, const std::vector& wt_dilation, const std::vector& in_dilation, @@ -718,7 +739,8 @@ void dispatch_slow_conv_2D( in, wt, out, - padding, + padding_lo, + padding_hi, wt_strides, wt_dilation, in_dilation, @@ -729,7 +751,8 @@ void dispatch_slow_conv_2D( in, wt, out, - padding, + padding_lo, + padding_hi, wt_strides, wt_dilation, in_dilation, @@ -740,7 +763,8 @@ void dispatch_slow_conv_2D( in, wt, out, - padding, + padding_lo, + padding_hi, wt_strides, wt_dilation, in_dilation, @@ -756,7 +780,8 @@ void dispatch_slow_conv_3D( const array& in, const array& wt, array out, - const std::vector& padding, + const std::vector& padding_lo, + const std::vector& padding_hi, const std::vector& wt_strides, const std::vector& wt_dilation, const std::vector& in_dilation, @@ -767,7 +792,8 @@ void dispatch_slow_conv_3D( in, wt, out, - padding, + padding_lo, + padding_hi, wt_strides, wt_dilation, in_dilation, @@ -778,7 +804,8 @@ void dispatch_slow_conv_3D( in, wt, out, - padding, + padding_lo, + padding_hi, wt_strides, wt_dilation, in_dilation, @@ -789,7 +816,8 @@ void dispatch_slow_conv_3D( in, wt, out, - padding, + padding_lo, + padding_hi, wt_strides, wt_dilation, in_dilation, @@ -829,7 +857,8 @@ void explicit_gemm_conv_1D_cpu( const array& in, const array& wt, array out, - const std::vector& padding, + const std::vector& padding_lo, + const std::vector& padding_hi, const std::vector& wt_strides, const std::vector& wt_dilation, Stream stream) { @@ -848,7 +877,7 @@ void explicit_gemm_conv_1D_cpu( auto& encoder = cpu::get_command_encoder(stream); // Pad input - Shape padded_shape = {N, iH + 2 * padding[0], C}; + Shape padded_shape = {N, iH + padding_lo[0] + padding_hi[0], C}; array in_padded(padded_shape, conv_dtype, nullptr, {}); // Fill with zeros @@ -857,7 +886,7 @@ void explicit_gemm_conv_1D_cpu( copy(temps.back(), in_padded, CopyType::Scalar, stream); // Pick input slice from padded - size_t data_offset = padding[0] * in_padded.strides()[1]; + size_t data_offset = padding_lo[0] * in_padded.strides()[1]; array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {}); in_padded_slice.copy_shared_buffer( in_padded, @@ -971,7 +1000,8 @@ void explicit_gemm_conv_2D_cpu( const array& in, const array& wt, array out, - const std::vector& padding, + const std::vector& padding_lo, + const std::vector& padding_hi, const std::vector& wt_strides, const std::vector& wt_dilation, Stream stream) { @@ -989,7 +1019,11 @@ void explicit_gemm_conv_2D_cpu( auto& encoder = cpu::get_command_encoder(stream); // Pad input - Shape padded_shape = {N, iH + 2 * padding[0], iW + 2 * padding[1], C}; + Shape padded_shape = { + N, + iH + padding_lo[0] + padding_hi[0], + iW + padding_lo[1] + padding_hi[1], + C}; array in_padded(padded_shape, conv_dtype, nullptr, {}); // Fill with zeros @@ -998,8 +1032,8 @@ void explicit_gemm_conv_2D_cpu( copy(temps.back(), in_padded, CopyType::Scalar, stream); // Pick input slice from padded - size_t data_offset = - padding[0] * in_padded.strides()[1] + padding[1] * in_padded.strides()[2]; + size_t data_offset = padding_lo[0] * in_padded.strides()[1] + + padding_lo[1] * in_padded.strides()[2]; array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {}); in_padded_slice.copy_shared_buffer( in_padded, @@ -1091,7 +1125,8 @@ void explicit_gemm_conv_ND_cpu( const array& in, const array& wt, array out, - const std::vector& padding, + const std::vector& padding_lo, + const std::vector& padding_hi, const std::vector& wt_strides, const std::vector& wt_dilation, const bool flip, @@ -1114,7 +1149,7 @@ void explicit_gemm_conv_ND_cpu( Shape padded_shape(in.shape().size()); padded_shape.front() = N; for (size_t i = 0; i < iDim.size(); i++) { - padded_shape[i + 1] = iDim[i] + 2 * padding[i]; + padded_shape[i + 1] = iDim[i] + padding_lo[i] + padding_hi[i]; } padded_shape.back() = C; array in_padded(padded_shape, conv_dtype, nullptr, {}); @@ -1125,9 +1160,10 @@ void explicit_gemm_conv_ND_cpu( // Pick input slice from padded size_t data_offset = 0; - for (size_t i = 0; i < padding.size(); i++) { - data_offset += padding[i] * in_padded.strides()[i + 1]; + for (size_t i = 0; i < padding_lo.size(); i++) { + data_offset += padding_lo[i] * in_padded.strides()[i + 1]; } + array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {}); in_padded_slice.copy_shared_buffer( in_padded, @@ -1261,7 +1297,8 @@ void conv_1D_cpu( const array& in, const array& wt, array out, - const std::vector& padding, + const std::vector& padding_lo, + const std::vector& padding_hi, const std::vector& wt_strides, const std::vector& wt_dilation, const std::vector& in_dilation, @@ -1270,22 +1307,40 @@ void conv_1D_cpu( const int groups = in.shape().back() / wt.shape().back(); if (wt_dilation[0] == 1 && in_dilation[0] == 1 && !flip) { return explicit_gemm_conv_1D_cpu( - in, wt, out, padding, wt_strides, wt_dilation, stream); + in, wt, out, padding_lo, padding_hi, wt_strides, wt_dilation, stream); } if (wt_dilation[0] == 1 && in_dilation[0] == 1 && groups == 1) { return explicit_gemm_conv_ND_cpu( - in, wt, out, padding, wt_strides, wt_dilation, flip, stream); + in, + wt, + out, + padding_lo, + padding_hi, + wt_strides, + wt_dilation, + flip, + stream); } return dispatch_slow_conv_1D( - in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip, stream); + in, + wt, + out, + padding_lo, + padding_hi, + wt_strides, + wt_dilation, + in_dilation, + flip, + stream); } void conv_2D_cpu( const array& in, const array& wt, array out, - const std::vector& padding, + const std::vector& padding_lo, + const std::vector& padding_hi, const std::vector& wt_strides, const std::vector& wt_dilation, const std::vector& in_dilation, @@ -1295,18 +1350,35 @@ void conv_2D_cpu( if (wt_dilation[0] == 1 && wt_dilation[1] == 1 && in_dilation[0] == 1 && in_dilation[1] == 1 && groups == 1) { return explicit_gemm_conv_ND_cpu( - in, wt, out, padding, wt_strides, wt_dilation, flip, stream); + in, + wt, + out, + padding_lo, + padding_hi, + wt_strides, + wt_dilation, + flip, + stream); } - return dispatch_slow_conv_2D( - in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip, stream); + in, + wt, + out, + padding_lo, + padding_hi, + wt_strides, + wt_dilation, + in_dilation, + flip, + stream); } void conv_3D_cpu( const array& in, const array& wt, array out, - const std::vector& padding, + const std::vector& padding_lo, + const std::vector& padding_hi, const std::vector& wt_strides, const std::vector& wt_dilation, const std::vector& in_dilation, @@ -1317,11 +1389,28 @@ void conv_3D_cpu( in_dilation[0] == 1 && in_dilation[1] == 1 && in_dilation[2] == 1 && groups == 1) { return explicit_gemm_conv_ND_cpu( - in, wt, out, padding, wt_strides, wt_dilation, flip, stream); + in, + wt, + out, + padding_lo, + padding_hi, + wt_strides, + wt_dilation, + flip, + stream); } return dispatch_slow_conv_3D( - in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip, stream); + in, + wt, + out, + padding_lo, + padding_hi, + wt_strides, + wt_dilation, + in_dilation, + flip, + stream); } } // namespace @@ -1338,7 +1427,8 @@ void Convolution::eval_cpu(const std::vector& inputs, array& out) { in, wt, out, - padding_, + padding_lo_, + padding_hi_, kernel_strides_, kernel_dilation_, input_dilation_, @@ -1351,7 +1441,8 @@ void Convolution::eval_cpu(const std::vector& inputs, array& out) { in, wt, out, - padding_, + padding_lo_, + padding_hi_, kernel_strides_, kernel_dilation_, input_dilation_, @@ -1364,7 +1455,8 @@ void Convolution::eval_cpu(const std::vector& inputs, array& out) { in, wt, out, - padding_, + padding_lo_, + padding_hi_, kernel_strides_, kernel_dilation_, input_dilation_, diff --git a/mlx/backend/cpu/eig.cpp b/mlx/backend/cpu/eig.cpp new file mode 100644 index 000000000..c89003fc0 --- /dev/null +++ b/mlx/backend/cpu/eig.cpp @@ -0,0 +1,174 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/allocator.h" +#include "mlx/array.h" +#include "mlx/backend/cpu/copy.h" +#include "mlx/backend/cpu/encoder.h" +#include "mlx/backend/cpu/lapack.h" +#include "mlx/linalg.h" +#include "mlx/primitives.h" + +namespace mlx::core { + +namespace { + +template +void eig_impl( + array& a, + array& vectors, + array& values, + bool compute_eigenvectors, + Stream stream) { + using OT = std::complex; + auto a_ptr = a.data(); + auto eig_ptr = values.data(); + + auto& encoder = cpu::get_command_encoder(stream); + encoder.set_input_array(a); + encoder.set_output_array(values); + OT* vec_ptr = nullptr; + if (compute_eigenvectors) { + encoder.set_output_array(vectors); + vec_ptr = vectors.data(); + } + encoder.dispatch([a_ptr, + vec_ptr, + eig_ptr, + compute_eigenvectors, + N = vectors.shape(-1), + size = vectors.size()]() mutable { + // Work query + char jobr = 'N'; + char jobl = compute_eigenvectors ? 'V' : 'N'; + int n_vecs_r = 1; + int n_vecs_l = compute_eigenvectors ? N : 1; + int lwork = -1; + int info; + { + T work; + int iwork; + geev( + &jobl, + &jobr, + &N, + nullptr, + &N, + nullptr, + nullptr, + nullptr, + &n_vecs_l, + nullptr, + &n_vecs_r, + &work, + &lwork, + &info); + lwork = static_cast(work); + } + + auto eig_tmp_data = array::Data{allocator::malloc(sizeof(T) * N * 2)}; + auto vec_tmp_data = + array::Data{allocator::malloc(vec_ptr ? sizeof(T) * N * N * 2 : 0)}; + auto eig_tmp = static_cast(eig_tmp_data.buffer.raw_ptr()); + auto vec_tmp = static_cast(vec_tmp_data.buffer.raw_ptr()); + auto work_buf = array::Data{allocator::malloc(sizeof(T) * lwork)}; + for (size_t i = 0; i < size / (N * N); ++i) { + geev( + &jobl, + &jobr, + &N, + a_ptr, + &N, + eig_tmp, + eig_tmp + N, + vec_tmp, + &n_vecs_l, + nullptr, + &n_vecs_r, + static_cast(work_buf.buffer.raw_ptr()), + &lwork, + &info); + for (int i = 0; i < N; ++i) { + eig_ptr[i] = {eig_tmp[i], eig_tmp[N + i]}; + } + if (vec_ptr) { + for (int i = 0; i < N; ++i) { + if (eig_ptr[i].imag() != 0) { + // This vector and the next are a pair + for (int j = 0; j < N; ++j) { + vec_ptr[i * N + j] = { + vec_tmp[i * N + j], -vec_tmp[(i + 1) * N + j]}; + vec_ptr[(i + 1) * N + j] = { + vec_tmp[i * N + j], vec_tmp[(i + 1) * N + j]}; + } + i += 1; + } else { + for (int j = 0; j < N; ++j) { + vec_ptr[i * N + j] = {vec_tmp[i * N + j], 0}; + } + } + } + vec_ptr += N * N; + } + a_ptr += N * N; + eig_ptr += N; + if (info != 0) { + std::stringstream msg; + msg << "[Eig::eval_cpu] Eigenvalue decomposition failed with error code " + << info; + throw std::runtime_error(msg.str()); + } + } + }); + encoder.add_temporary(a); +} + +} // namespace + +void Eig::eval_cpu( + const std::vector& inputs, + std::vector& outputs) { + const auto& a = inputs[0]; + auto& values = outputs[0]; + + auto vectors = compute_eigenvectors_ + ? outputs[1] + : array(a.shape(), complex64, nullptr, {}); + + auto a_copy = array(a.shape(), a.dtype(), nullptr, {}); + copy( + a, + a_copy, + a.flags().row_contiguous ? CopyType::Vector : CopyType::General, + stream()); + + values.set_data(allocator::malloc(values.nbytes())); + + if (compute_eigenvectors_) { + // Set the strides and flags so the eigenvectors + // are in the columns of the output + auto flags = vectors.flags(); + auto strides = vectors.strides(); + auto ndim = a.ndim(); + std::swap(strides[ndim - 1], strides[ndim - 2]); + + if (a.size() > 1) { + flags.row_contiguous = false; + if (ndim > 2) { + flags.col_contiguous = false; + } else { + flags.col_contiguous = true; + } + } + vectors.set_data( + allocator::malloc(vectors.nbytes()), vectors.size(), strides, flags); + } + switch (a.dtype()) { + case float32: + eig_impl(a_copy, vectors, values, compute_eigenvectors_, stream()); + break; + default: + throw std::runtime_error("[Eig::eval_cpu] only supports float32."); + } +} + +} // namespace mlx::core diff --git a/mlx/backend/cpu/eigh.cpp b/mlx/backend/cpu/eigh.cpp index b50f2c722..58d3634e8 100644 --- a/mlx/backend/cpu/eigh.cpp +++ b/mlx/backend/cpu/eigh.cpp @@ -12,6 +12,133 @@ namespace mlx::core { namespace { +template +struct EighWork {}; + +template +struct EighWork< + T, + typename std::enable_if::value>::type> { + using R = T; + + char jobz; + char uplo; + int N; + int lwork; + int liwork; + int info; + std::vector buffers; + + EighWork(char jobz_, char uplo_, int N_) + : jobz(jobz_), uplo(uplo_), N(N_), lwork(-1), liwork(-1) { + T work; + int iwork; + syevd( + &jobz, + &uplo, + &N, + nullptr, + &N, + nullptr, + &work, + &lwork, + &iwork, + &liwork, + &info); + lwork = static_cast(work); + liwork = iwork; + buffers.emplace_back(allocator::malloc(sizeof(T) * lwork)); + buffers.emplace_back(allocator::malloc(sizeof(int) * liwork)); + } + + void run(T* vectors, T* values) { + syevd( + &jobz, + &uplo, + &N, + vectors, + &N, + values, + static_cast(buffers[0].buffer.raw_ptr()), + &lwork, + static_cast(buffers[1].buffer.raw_ptr()), + &liwork, + &info); + } +}; + +template <> +struct EighWork> { + using T = std::complex; + using R = float; + + char jobz; + char uplo; + int N; + int lwork; + int lrwork; + int liwork; + int info; + std::vector buffers; + + EighWork(char jobz_, char uplo_, int N_) + : jobz(jobz_), uplo(uplo_), N(N_), lwork(-1), lrwork(-1), liwork(-1) { + T work; + R rwork; + int iwork; + heevd( + &jobz, + &uplo, + &N, + nullptr, + &N, + nullptr, + &work, + &lwork, + &rwork, + &lrwork, + &iwork, + &liwork, + &info); + lwork = static_cast(work.real()); + lrwork = static_cast(rwork); + liwork = iwork; + buffers.emplace_back(allocator::malloc(sizeof(T) * lwork)); + buffers.emplace_back(allocator::malloc(sizeof(R) * lrwork)); + buffers.emplace_back(allocator::malloc(sizeof(int) * liwork)); + } + + void run(T* vectors, R* values) { + heevd( + &jobz, + &uplo, + &N, + vectors, + &N, + values, + static_cast(buffers[0].buffer.raw_ptr()), + &lwork, + static_cast(buffers[1].buffer.raw_ptr()), + &lrwork, + static_cast(buffers[2].buffer.raw_ptr()), + &liwork, + &info); + if (jobz == 'V') { + // We have pre-transposed the vectors but we also must conjugate them + // when they are complex. + // + // We could vectorize this but it is so fast in comparison to heevd that + // it doesn't really matter. + for (int i = 0; i < N; i++) { + for (int j = 0; j < N; j++) { + *vectors = std::conj(*vectors); + vectors++; + } + } + } + } +}; + template void eigh_impl( array& vectors, @@ -19,8 +146,10 @@ void eigh_impl( const std::string& uplo, bool compute_eigenvectors, Stream stream) { + using R = typename EighWork::R; + auto vec_ptr = vectors.data(); - auto eig_ptr = values.data(); + auto eig_ptr = values.data(); char jobz = compute_eigenvectors ? 'V' : 'N'; auto& encoder = cpu::get_command_encoder(stream); @@ -33,49 +162,17 @@ void eigh_impl( N = vectors.shape(-1), size = vectors.size()]() mutable { // Work query - int lwork = -1; - int liwork = -1; - int info; - { - T work; - int iwork; - syevd( - &jobz, - &uplo, - &N, - nullptr, - &N, - nullptr, - &work, - &lwork, - &iwork, - &liwork, - &info); - lwork = static_cast(work); - liwork = iwork; - } + EighWork work(jobz, uplo, N); - auto work_buf = array::Data{allocator::malloc(sizeof(T) * lwork)}; - auto iwork_buf = array::Data{allocator::malloc(sizeof(int) * liwork)}; + // Work loop for (size_t i = 0; i < size / (N * N); ++i) { - syevd( - &jobz, - &uplo, - &N, - vec_ptr, - &N, - eig_ptr, - static_cast(work_buf.buffer.raw_ptr()), - &lwork, - static_cast(iwork_buf.buffer.raw_ptr()), - &liwork, - &info); + work.run(vec_ptr, eig_ptr); vec_ptr += N * N; eig_ptr += N; - if (info != 0) { + if (work.info != 0) { std::stringstream msg; msg << "[Eigh::eval_cpu] Eigenvalue decomposition failed with error code " - << info; + << work.info; throw std::runtime_error(msg.str()); } } @@ -131,6 +228,10 @@ void Eigh::eval_cpu( eigh_impl( vectors, values, uplo_, compute_eigenvectors_, stream()); break; + case complex64: + eigh_impl>( + vectors, values, uplo_, compute_eigenvectors_, stream()); + break; default: throw std::runtime_error( "[Eigh::eval_cpu] only supports float32 or float64."); diff --git a/mlx/backend/cpu/indexing.cpp b/mlx/backend/cpu/indexing.cpp index 70d6b3eb7..5f99093e5 100644 --- a/mlx/backend/cpu/indexing.cpp +++ b/mlx/backend/cpu/indexing.cpp @@ -257,15 +257,11 @@ void gather_axis( const array& ind, array& out, const int axis) { - auto strides = ind.strides(); - strides.erase(strides.begin() + axis); - auto shape = ind.shape(); - shape.erase(shape.begin() + axis); - ContiguousIterator ind_it(shape, strides, src.ndim() - 1); - - strides = src.strides(); - strides.erase(strides.begin() + axis); - ContiguousIterator src_it(shape, strides, src.ndim() - 1); + auto shape = remove_index(ind.shape(), axis); + ContiguousIterator ind_it( + shape, remove_index(ind.strides(), axis), src.ndim() - 1); + ContiguousIterator src_it( + shape, remove_index(src.strides(), axis), src.ndim() - 1); auto ind_ptr = ind.data(); auto src_ptr = src.data(); @@ -585,15 +581,11 @@ void Scatter::eval_cpu(const std::vector& inputs, array& out) { template void scatter_axis(array& out, const array idx, const array& upd, int axis) { - auto strides = idx.strides(); - strides.erase(strides.begin() + axis); - auto shape = idx.shape(); - shape.erase(shape.begin() + axis); - ContiguousIterator idx_it(shape, strides, upd.ndim() - 1); - - strides = upd.strides(); - strides.erase(strides.begin() + axis); - ContiguousIterator upd_it(shape, strides, upd.ndim() - 1); + auto shape = remove_index(idx.shape(), axis); + ContiguousIterator idx_it( + shape, remove_index(idx.strides(), axis), upd.ndim() - 1); + ContiguousIterator upd_it( + shape, remove_index(upd.strides(), axis), upd.ndim() - 1); auto idx_ptr = idx.data(); auto upd_ptr = upd.data(); diff --git a/mlx/backend/cpu/lapack.h b/mlx/backend/cpu/lapack.h index 2911c63f8..b242093ff 100644 --- a/mlx/backend/cpu/lapack.h +++ b/mlx/backend/cpu/lapack.h @@ -2,14 +2,14 @@ #pragma once -// Required for Visual Studio. -// https://github.com/OpenMathLib/OpenBLAS/blob/develop/docs/install.md -#ifdef _MSC_VER #include #define LAPACK_COMPLEX_CUSTOM #define lapack_complex_float std::complex #define lapack_complex_double std::complex -#endif +#define lapack_complex_float_real(z) ((z).real()) +#define lapack_complex_float_imag(z) ((z).imag()) +#define lapack_complex_double_real(z) ((z).real()) +#define lapack_complex_double_imag(z) ((z).imag()) #ifdef MLX_USE_ACCELERATE #include @@ -32,7 +32,7 @@ #endif -#define INSTANTIATE_LAPACK_TYPES(FUNC) \ +#define INSTANTIATE_LAPACK_REAL(FUNC) \ template \ void FUNC(Args... args) { \ if constexpr (std::is_same_v) { \ @@ -42,11 +42,24 @@ } \ } -INSTANTIATE_LAPACK_TYPES(geqrf) -INSTANTIATE_LAPACK_TYPES(orgqr) -INSTANTIATE_LAPACK_TYPES(syevd) -INSTANTIATE_LAPACK_TYPES(potrf) -INSTANTIATE_LAPACK_TYPES(gesvdx) -INSTANTIATE_LAPACK_TYPES(getrf) -INSTANTIATE_LAPACK_TYPES(getri) -INSTANTIATE_LAPACK_TYPES(trtri) +INSTANTIATE_LAPACK_REAL(geqrf) +INSTANTIATE_LAPACK_REAL(orgqr) +INSTANTIATE_LAPACK_REAL(syevd) +INSTANTIATE_LAPACK_REAL(geev) +INSTANTIATE_LAPACK_REAL(potrf) +INSTANTIATE_LAPACK_REAL(gesvdx) +INSTANTIATE_LAPACK_REAL(getrf) +INSTANTIATE_LAPACK_REAL(getri) +INSTANTIATE_LAPACK_REAL(trtri) + +#define INSTANTIATE_LAPACK_COMPLEX(FUNC) \ + template \ + void FUNC(Args... args) { \ + if constexpr (std::is_same_v>) { \ + MLX_LAPACK_FUNC(c##FUNC)(std::forward(args)...); \ + } else if constexpr (std::is_same_v>) { \ + MLX_LAPACK_FUNC(z##FUNC)(std::forward(args)...); \ + } \ + } + +INSTANTIATE_LAPACK_COMPLEX(heevd) diff --git a/mlx/backend/cpu/masked_mm.cpp b/mlx/backend/cpu/masked_mm.cpp index 0be7c79ce..fbee6118f 100644 --- a/mlx/backend/cpu/masked_mm.cpp +++ b/mlx/backend/cpu/masked_mm.cpp @@ -6,6 +6,7 @@ #include "mlx/backend/common/utils.h" #include "mlx/backend/cpu/copy.h" #include "mlx/backend/cpu/encoder.h" +#include "mlx/backend/cpu/gemm.h" #include "mlx/backend/cpu/lapack.h" #include "mlx/primitives.h" @@ -52,6 +53,58 @@ inline void mask_matrix( } } +template +inline void segmented_mm( + const T* a, + const T* b, + const uint32_t* segments, + T* out, + bool a_transposed, + bool b_transposed, + size_t lda, + size_t ldb, + const Shape& a_shape, + const Strides& a_strides, + const Shape& b_shape, + const Strides& b_strides, + size_t num_segments, + const Shape& segments_shape, + const Strides& segments_strides) { + int ndim = a_shape.size(); + Shape a_copy = a_shape; + Shape b_copy = b_shape; + int32_t M = a_copy[ndim - 2]; + int32_t N = b_copy[ndim - 1]; + for (int i = 0; i < num_segments; i++) { + uint32_t k_start = + segments[elem_to_loc(2 * i, segments_shape, segments_strides)]; + uint32_t k_end = + segments[elem_to_loc(2 * i + 1, segments_shape, segments_strides)]; + if (k_end <= k_start) { + std::fill_n(out + i * M * N, M * N, T(0)); + continue; + } + a_copy[ndim - 1] = k_end - k_start; + b_copy[ndim - 2] = k_end - k_start; + matmul( + a + k_start * a_strides[ndim - 1], + b + k_start * b_strides[ndim - 2], + out + i * M * N, + a_transposed, + b_transposed, + lda, + ldb, + N, + 1.0, + 0.0, + 1, + a_copy, + a_strides, + b_copy, + b_strides); + } +} + } // namespace void BlockMaskedMM::eval_cpu(const std::vector& inputs, array& out) { @@ -437,4 +490,121 @@ void GatherMM::eval_cpu(const std::vector& inputs, array& out) { encoder.add_temporaries(std::move(temps)); } +void SegmentedMM::eval_cpu(const std::vector& inputs, array& out) { + out.set_data(allocator::malloc(out.nbytes())); + + auto& s = stream(); + auto& encoder = cpu::get_command_encoder(stream()); + auto check_transpose = [&s, &encoder](const array& x) { + auto stx = x.strides()[x.ndim() - 2]; + auto sty = x.strides()[x.ndim() - 1]; + if (stx == x.shape(-1) && sty == 1) { + return std::make_tuple(false, stx, x); + } else if (stx == 1 && sty == x.shape(-2)) { + return std::make_tuple(true, sty, x); + } else { + array xc(x.shape(), x.dtype(), nullptr, {}); + copy(x, xc, CopyType::General, s); + encoder.add_temporary(xc); + int64_t stx = x.shape(-1); + return std::make_tuple(false, stx, xc); + } + }; + + auto [a_transposed, lda, a] = check_transpose(inputs[0]); + auto [b_transposed, ldb, b] = check_transpose(inputs[1]); + auto& segments = inputs[2]; + + encoder.set_input_array(a); + encoder.set_input_array(b); + encoder.set_input_array(segments); + encoder.set_output_array(out); + encoder.dispatch([a = array::unsafe_weak_copy(a), + b = array::unsafe_weak_copy(b), + segments = array::unsafe_weak_copy(segments), + out_ptr = out.data(), + a_transposed = a_transposed, + b_transposed = b_transposed, + lda = lda, + ldb = ldb]() { + switch (a.dtype()) { + case float64: + segmented_mm( + a.data(), + b.data(), + segments.data(), + static_cast(out_ptr), + a_transposed, + b_transposed, + lda, + ldb, + a.shape(), + a.strides(), + b.shape(), + b.strides(), + segments.size() / 2, + segments.shape(), + segments.strides()); + break; + case float32: + segmented_mm( + a.data(), + b.data(), + segments.data(), + static_cast(out_ptr), + a_transposed, + b_transposed, + lda, + ldb, + a.shape(), + a.strides(), + b.shape(), + b.strides(), + segments.size() / 2, + segments.shape(), + segments.strides()); + break; + case float16: + segmented_mm( + a.data(), + b.data(), + segments.data(), + static_cast(out_ptr), + a_transposed, + b_transposed, + lda, + ldb, + a.shape(), + a.strides(), + b.shape(), + b.strides(), + segments.size() / 2, + segments.shape(), + segments.strides()); + break; + case bfloat16: + segmented_mm( + a.data(), + b.data(), + segments.data(), + static_cast(out_ptr), + a_transposed, + b_transposed, + lda, + ldb, + a.shape(), + a.strides(), + b.shape(), + b.strides(), + segments.size() / 2, + segments.shape(), + segments.strides()); + break; + default: + throw std::invalid_argument( + "Segmented mm supports only real float types."); + } + }); +} + } // namespace mlx::core diff --git a/mlx/backend/cpu/matmul.cpp b/mlx/backend/cpu/matmul.cpp index 8ae99ab2d..b944aacc0 100644 --- a/mlx/backend/cpu/matmul.cpp +++ b/mlx/backend/cpu/matmul.cpp @@ -132,6 +132,10 @@ void AddMM::eval_cpu(const std::vector& inputs, array& out) { throw std::runtime_error( "[AddMM::eval_cpu] Currently only supports float32."); } + if (out.size() == 0) { + out.set_data(allocator::malloc(out.nbytes())); + return; + } // Fill output with C auto& c = inputs[2]; @@ -139,7 +143,9 @@ void AddMM::eval_cpu(const std::vector& inputs, array& out) { ? CopyType::Scalar : (c.flags().row_contiguous ? CopyType::Vector : CopyType::General); copy(c, out, ctype, stream()); - + if (inputs[0].shape(-1) == 0) { + return; + } matmul_general(inputs[0], inputs[1], out, stream(), alpha_, beta_); } diff --git a/mlx/backend/cpu/quantized.cpp b/mlx/backend/cpu/quantized.cpp index f0ac9d57f..ee8e56cc0 100644 --- a/mlx/backend/cpu/quantized.cpp +++ b/mlx/backend/cpu/quantized.cpp @@ -13,9 +13,18 @@ namespace mlx::core { namespace { +inline constexpr short get_pack_factor(int bits, int wsize = 8) { + return (bits == 3 || bits == 5) ? 8 : (bits == 6 ? 4 : wsize / bits); +} + +inline constexpr short get_bytes_per_pack(int bits, int wsize = 8) { + auto power_of_2_bits = (bits & (bits - 1)) == 0; + return power_of_2_bits ? (wsize / 8) : (bits == 5 ? 5 : 3); +} + template void extract_bits(const uint8_t* w_in, T* w_out) { - assert(bits == 3 || bits == 6); + static_assert(bits == 3 || bits == 5 || bits == 6); if (bits == 3) { w_out[0] = static_cast(w_in[0] & 0x7); w_out[1] = static_cast((w_in[0] & 0x38) >> 3); @@ -25,6 +34,16 @@ void extract_bits(const uint8_t* w_in, T* w_out) { w_out[5] = static_cast(((w_in[1] & 0x80) >> 7) + ((w_in[2] & 0x3) << 1)); w_out[6] = static_cast((w_in[2] & 0x1c) >> 2); w_out[7] = static_cast((w_in[2] & 0xe0) >> 5); + } else if (bits == 5) { + w_out[0] = static_cast(w_in[0] & 0x1f); + w_out[1] = static_cast(((w_in[0] & 0xe0) >> 5) + ((w_in[1] & 0x3) << 3)); + w_out[2] = static_cast((w_in[1] & 0x7c) >> 2); + w_out[3] = static_cast(((w_in[1] & 0x80) >> 7) + ((w_in[2] & 0xf) << 1)); + w_out[4] = static_cast(((w_in[2] & 0xf0) >> 4) + ((w_in[3] & 0x1) << 4)); + w_out[5] = static_cast((w_in[3] & 0x3e) >> 1); + w_out[6] = static_cast(((w_in[3] & 0xc0) >> 6) + ((w_in[4] & 0x7) << 2)); + w_out[7] = static_cast((w_in[4] & 0xf8) >> 3); + } else if (bits == 6) { w_out[0] = static_cast(w_in[0] & 0x3f); w_out[1] = @@ -46,8 +65,8 @@ void _qmm( int N, int K) { constexpr int bitmask = (1 << bits) - 1; - constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits; - constexpr int bytes_per_pack = (bits == 3 || bits == 6) ? 3 : 1; + constexpr int pack_factor = get_pack_factor(bits, 8); + constexpr int bytes_per_pack = get_bytes_per_pack(bits); constexpr int packs_in_group = group_size / pack_factor; for (int m = 0; m < M; m++) { @@ -65,7 +84,7 @@ void _qmm( T scale = *scales_local++; T bias = *biases_local++; for (int ng = 0; ng < packs_in_group; ng++) { - if (bits == 3 || bits == 6) { + if constexpr (bits == 3 || bits == 5 || bits == 6) { T wl[pack_factor]; extract_bits(w_local, wl); #pragma clang loop unroll(full) @@ -104,8 +123,9 @@ void _qmm_t( int N, int K) { constexpr int bitmask = (1 << bits) - 1; - constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits; - constexpr int bytes_per_pack = (bits == 3 || bits == 6) ? 3 : 1; + + constexpr int pack_factor = get_pack_factor(bits, 8); + constexpr int bytes_per_pack = get_bytes_per_pack(bits); constexpr int packs_in_group = group_size / pack_factor; for (int m = 0; m < M; m++) { @@ -121,7 +141,7 @@ void _qmm_t( T bias = *biases_local++; for (int kw = 0; kw < packs_in_group; kw++) { - if (bits == 3 || bits == 6) { + if constexpr (bits == 3 || bits == 5 || bits == 6) { T wl[pack_factor]; extract_bits(w_local, wl); #pragma clang loop unroll(full) @@ -304,6 +324,10 @@ void _qmm_dispatch_typed( _qmm_dispatch_group( result, x, w, scales, biases, M, N, K, group_size, transposed_w); break; + case 5: + _qmm_dispatch_group( + result, x, w, scales, biases, M, N, K, group_size, transposed_w); + break; case 6: _qmm_dispatch_group( result, x, w, scales, biases, M, N, K, group_size, transposed_w); @@ -613,9 +637,8 @@ void quantize( float eps = 1e-7; bool power_of_2_bits = is_power_of_2(bits); - int el_per_int = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits; - // For 3/6 bits we read 3 uint8s at a time instead of 1 uint32 - int bytes_per_pack = power_of_2_bits ? 1 : 3; + int el_per_int = get_pack_factor(bits, 32); + int bytes_per_pack = get_bytes_per_pack(bits); int int_per_group = group_size * bytes_per_pack / el_per_int; size_t n_groups = w_size / group_size; @@ -640,15 +663,21 @@ void quantize( } size_t out_idx = i * int_per_group; for (int j = 0; j < int_per_group / bytes_per_pack; ++j) { - uint32_t out_el = 0; + uint64_t out_el = 0; for (int k = 0; k < el_per_int; ++k) { float w_el = w[w_idx + j * el_per_int + k]; w_el = std::rint((w_el - bias) / scale); w_el = std::min(std::max(w_el, 0.0f), n_bins); - out_el |= static_cast(w_el) << (k * bits); + out_el |= static_cast(w_el) << (k * bits); } if (power_of_2_bits) { out[out_idx + j] = out_el; + } else if (bits == 5) { + out[out_idx + bytes_per_pack * j] = out_el & 0xff; + out[out_idx + bytes_per_pack * j + 1] = (out_el & 0xff00) >> 8; + out[out_idx + bytes_per_pack * j + 2] = (out_el & 0xff0000) >> 16; + out[out_idx + bytes_per_pack * j + 3] = (out_el & 0xff000000) >> 24; + out[out_idx + bytes_per_pack * j + 4] = (out_el & 0xff00000000) >> 32; } else { out[out_idx + bytes_per_pack * j] = out_el & 0xff; out[out_idx + bytes_per_pack * j + 1] = (out_el & 0xff00) >> 8; diff --git a/mlx/backend/cpu/reduce.cpp b/mlx/backend/cpu/reduce.cpp index ce25feb11..8febbd050 100644 --- a/mlx/backend/cpu/reduce.cpp +++ b/mlx/backend/cpu/reduce.cpp @@ -325,7 +325,15 @@ struct MaxReduce { }; template - T operator()(simd::Simd x) { + std::enable_if_t, T> operator()(simd::Simd x) { + return simd::max(x); + }; + + template + std::enable_if_t, T> operator()(simd::Simd x) { + if (simd::any(x != x)) { + return static_cast(NAN); + } return simd::max(x); }; }; @@ -342,7 +350,15 @@ struct MinReduce { }; template - T operator()(simd::Simd x) { + std::enable_if_t, T> operator()(simd::Simd x) { + return simd::min(x); + }; + + template + std::enable_if_t, T> operator()(simd::Simd x) { + if (simd::any(x != x)) { + return static_cast(NAN); + } return simd::min(x); }; }; @@ -527,10 +543,10 @@ void Reduce::eval_cpu(const std::vector& inputs, array& out) { reduce_dispatch_min_max(in, out, reduce_type_, axes_); break; case int8: - reduce_dispatch_min_max(in, out, reduce_type_, axes_); + reduce_dispatch_min_max(in, out, reduce_type_, axes_); break; case int16: - reduce_dispatch_min_max(in, out, reduce_type_, axes_); + reduce_dispatch_min_max(in, out, reduce_type_, axes_); break; case int32: reduce_dispatch_min_max(in, out, reduce_type_, axes_); diff --git a/mlx/backend/cpu/scan.cpp b/mlx/backend/cpu/scan.cpp index 199dbab35..33addd161 100644 --- a/mlx/backend/cpu/scan.cpp +++ b/mlx/backend/cpu/scan.cpp @@ -330,7 +330,8 @@ void Scan::eval_cpu(const std::vector& inputs, array& out) { reduce_type_, in, out, axis_, reverse_, inclusive_); break; case complex64: - throw std::runtime_error("Scan ops do not support complex types yet"); + scan_dispatch( + reduce_type_, in, out, axis_, reverse_, inclusive_); break; } }); diff --git a/mlx/backend/cpu/simd/base_simd.h b/mlx/backend/cpu/simd/base_simd.h index 7e82a4d56..17cd35b9a 100644 --- a/mlx/backend/cpu/simd/base_simd.h +++ b/mlx/backend/cpu/simd/base_simd.h @@ -88,12 +88,33 @@ DEFAULT_UNARY(expm1, std::expm1) DEFAULT_UNARY(floor, std::floor) DEFAULT_UNARY(log, std::log) DEFAULT_UNARY(log10, std::log10) -DEFAULT_UNARY(log1p, std::log1p) DEFAULT_UNARY(sinh, std::sinh) DEFAULT_UNARY(sqrt, std::sqrt) DEFAULT_UNARY(tan, std::tan) DEFAULT_UNARY(tanh, std::tanh) +template +Simd log1p(Simd in) { + if constexpr (is_complex) { + auto x = in.value.real(); + auto y = in.value.imag(); + auto zabs = std::abs(in.value); + auto theta = std::atan2(y, x + 1); + if (zabs < 0.5) { + auto r = x * (2 + x) + y * y; + if (r == 0) { // handle underflow + return Simd{T{x, theta}}; + } + return Simd{T{((typeof(x))(0.5)) * std::log1p(r), theta}}; + } else { + auto z0 = std::hypot(x + 1, y); + return Simd{T{std::log(z0), theta}}; + } + } else { + return Simd{std::log1p(in.value)}; + } +} + template Simd log2(Simd in) { if constexpr (is_complex) { diff --git a/mlx/backend/cpu/unary.h b/mlx/backend/cpu/unary.h index fa539541c..14c1dd479 100644 --- a/mlx/backend/cpu/unary.h +++ b/mlx/backend/cpu/unary.h @@ -2,32 +2,13 @@ #pragma once -#include "mlx/allocator.h" -#include "mlx/array.h" -#include "mlx/backend/common/utils.h" +#include "mlx/backend/common/unary.h" #include "mlx/backend/cpu/encoder.h" #include "mlx/backend/cpu/simd/simd.h" #include "mlx/utils.h" namespace mlx::core { -void set_unary_output_data(const array& in, array& out) { - if (in.flags().contiguous) { - if (is_donatable(in, out)) { - out.copy_shared_buffer(in); - } else { - auto size = in.data_size(); - out.set_data( - allocator::malloc(size * out.itemsize()), - size, - in.strides(), - in.flags()); - } - } else { - out.set_data(allocator::malloc(out.nbytes())); - } -} - template void unary_op(const T* a, U* out, size_t shape, size_t stride) { for (size_t i = 0; i < shape; i += 1) { diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt new file mode 100644 index 000000000..9f236b4ea --- /dev/null +++ b/mlx/backend/cuda/CMakeLists.txt @@ -0,0 +1,132 @@ +# Filename rules in cuda backend: +# +# * Use .cu/.cuh if code contains device code, and .cpp/.h if not. +# * Device-only code should be put in device/ subdir. +# * Files in device/ subdir should not include files outside. +target_sources( + mlx + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cu + ${CMAKE_CURRENT_SOURCE_DIR}/binary.cu + ${CMAKE_CURRENT_SOURCE_DIR}/binary_two.cu + ${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/copy.cu + ${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_contiguous.cu + ${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general.cu + ${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_dynamic.cu + ${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_input.cu + ${CMAKE_CURRENT_SOURCE_DIR}/cuda.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/event.cu + ${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/jit_module.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.cu + ${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/layer_norm.cu + ${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.cu + ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cu + ${CMAKE_CURRENT_SOURCE_DIR}/random.cu + ${CMAKE_CURRENT_SOURCE_DIR}/reduce.cu + ${CMAKE_CURRENT_SOURCE_DIR}/reduce/all_reduce.cu + ${CMAKE_CURRENT_SOURCE_DIR}/reduce/col_reduce.cu + ${CMAKE_CURRENT_SOURCE_DIR}/reduce/init_reduce.cu + ${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu + ${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cu + ${CMAKE_CURRENT_SOURCE_DIR}/rope.cu + ${CMAKE_CURRENT_SOURCE_DIR}/scan.cu + ${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/softmax.cu + ${CMAKE_CURRENT_SOURCE_DIR}/sort.cu + ${CMAKE_CURRENT_SOURCE_DIR}/ternary.cu + ${CMAKE_CURRENT_SOURCE_DIR}/unary.cu + ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/quantized.cu + ${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp) + +target_compile_definitions(mlx PRIVATE MLX_USE_CUDA) + +# Embed kernel sources in binary for JIT compilation. +file( + GLOB MLX_JIT_SOURCES + RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} + "${CMAKE_CURRENT_SOURCE_DIR}/device/*.h" + "${CMAKE_CURRENT_SOURCE_DIR}/device/*.cuh") +string(JOIN ":" MLX_JIT_SOURCES_ARG ${MLX_JIT_SOURCES}) +add_custom_command( + OUTPUT gen/cuda_jit_sources.h + COMMAND + ${CMAKE_COMMAND} -DMLX_SOURCE_ROOT=${CMAKE_CURRENT_SOURCE_DIR} + -DMLX_JIT_SOURCES=${MLX_JIT_SOURCES_ARG} -P + "${CMAKE_CURRENT_SOURCE_DIR}/bin2h.cmake" + DEPENDS bin2h.cmake ${MLX_JIT_SOURCES}) +add_custom_target(cuda_jit_sources DEPENDS gen/cuda_jit_sources.h) +add_dependencies(mlx cuda_jit_sources) +target_include_directories(mlx PRIVATE "${CMAKE_CURRENT_BINARY_DIR}/gen") + +# Enable defining device lambda functions. +target_compile_options(mlx + PRIVATE "$<$:--extended-lambda>") + +# Enable calling host constexpr functions from device. This is needed because +# the constexpr version of isnan is host only. +target_compile_options( + mlx PRIVATE "$<$:--expt-relaxed-constexpr>") + +# CUDA 12.8 emits warning #20280-D for copy kernels which is a false positive. +# Explicitly pass this flag to suppress the warning, it is safe to set it to +# true but the warning wouldn't be suppressed. +if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.8.0) + target_compile_options( + mlx + PRIVATE "$<$:--static-global-template-stub=false>") +endif() + +# Suppress warning when building for compute capability 7 used by V100. +target_compile_options( + mlx PRIVATE "$<$:--Wno-deprecated-gpu-targets>") + +# Compute capability 7 is required for synchronization between CPU/GPU with +# managed memory. TODO: Add more architectures for potential performance gain. +set(MLX_CUDA_ARCHITECTURES + "70;80" + CACHE STRING "CUDA architectures") +message(STATUS "CUDA architectures: ${MLX_CUDA_ARCHITECTURES}") +set_target_properties(mlx PROPERTIES CUDA_ARCHITECTURES + "${MLX_CUDA_ARCHITECTURES}") + +# Use fixed version of CCCL. +FetchContent_Declare( + cccl + URL "https://github.com/NVIDIA/cccl/releases/download/v2.8.1/cccl-v2.8.1.zip") +FetchContent_MakeAvailable(cccl) +target_include_directories(mlx BEFORE PRIVATE "${cccl_SOURCE_DIR}/include") + +# Use fixed version of NVTX. +FetchContent_Declare( + nvtx3 + GIT_REPOSITORY https://github.com/NVIDIA/NVTX.git + GIT_TAG v3.1.1 + GIT_SHALLOW TRUE + SOURCE_SUBDIR c EXCLUDE_FROM_ALL) +FetchContent_MakeAvailable(nvtx3) +target_link_libraries(mlx PUBLIC $) + +# Make cuda runtime APIs available in non-cuda files. +find_package(CUDAToolkit REQUIRED) +target_include_directories(mlx PRIVATE ${CUDAToolkit_INCLUDE_DIRS}) + +# Use cublasLt. +target_link_libraries(mlx PRIVATE CUDA::cublasLt) + +# Use NVRTC and driver APIs. +target_link_libraries(mlx PRIVATE CUDA::nvrtc CUDA::cuda_driver) + +# Suppress nvcc warnings on MLX headers. +target_compile_options(mlx PRIVATE $<$:-Xcudafe + --diag_suppress=997>) + +# Install CCCL headers for JIT. +install(DIRECTORY ${cccl_SOURCE_DIR}/include/cuda + DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/cccl) diff --git a/mlx/backend/cuda/allocator.cpp b/mlx/backend/cuda/allocator.cpp new file mode 100644 index 000000000..6cc7145b5 --- /dev/null +++ b/mlx/backend/cuda/allocator.cpp @@ -0,0 +1,215 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/allocator.h" +#include "mlx/backend/cuda/utils.h" +#include "mlx/backend/cuda/worker.h" +#include "mlx/utils.h" + +#include +#include +#include + +#include + +namespace mlx::core { + +namespace cu { + +constexpr int page_size = 16384; + +CudaAllocator::CudaAllocator() + : buffer_cache_( + page_size, + [](CudaBuffer* buf) { return buf->size; }, + [this](CudaBuffer* buf) { + cuda_free(buf->data); + delete buf; + }) { + // TODO: Set memory limit for multi-device. + size_t free, total; + CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total)); + memory_limit_ = total * 0.8; + max_pool_size_ = memory_limit_; +} + +Buffer CudaAllocator::malloc(size_t size) { + // Find available buffer from cache. + auto orig_size = size; + std::unique_lock lock(mutex_); + if (size < page_size) { + size = next_power_of_2(size); + } else { + size = page_size * ((size + page_size - 1) / page_size); + } + + CudaBuffer* buf = buffer_cache_.reuse_from_cache(size); + if (!buf) { + // If we have a lot of memory pressure or are over the maximum cache size, + // try to reclaim memory from the cache. + size_t mem_required = get_active_memory() + get_cache_memory() + size; + if (mem_required >= memory_limit_) { + buffer_cache_.release_cached_buffers(mem_required - memory_limit_); + } + + lock.unlock(); + buf = new CudaBuffer{nullptr, size}; + cudaError_t err = cudaMallocManaged(&buf->data, size); + if (err != cudaSuccess && err != cudaErrorMemoryAllocation) { + throw std::runtime_error(fmt::format( + "cudaMallocManaged failed: {}.", cudaGetErrorString(err))); + } + lock.lock(); + } + active_memory_ += size; + peak_memory_ = std::max(active_memory_, peak_memory_); + + // Maintain the cache below the requested limit. + if (get_cache_memory() > max_pool_size_) { + buffer_cache_.release_cached_buffers(get_cache_memory() - max_pool_size_); + } + + return Buffer{buf}; +} + +void CudaAllocator::free(Buffer buffer) { + auto* buf = static_cast(buffer.ptr()); + if (!buf) { + return; + } + + std::unique_lock lock(mutex_); + active_memory_ -= buf->size; + if (get_cache_memory() < max_pool_size_) { + buffer_cache_.recycle_to_cache(buf); + } else { + lock.unlock(); + cuda_free(buf->data); + delete buf; + } +} + +size_t CudaAllocator::size(Buffer buffer) const { + auto* buf = static_cast(buffer.ptr()); + if (!buf) { + return 0; + } + return buf->size; +} + +void CudaAllocator::register_this_thread() { + std::lock_guard lock(worker_mutex_); + allowed_threads_.insert(std::this_thread::get_id()); +} + +void CudaAllocator::cuda_free(void* buf) { + // If cuda_free() is called from a unregistered thread, reschedule the call to + // worker. + { + std::lock_guard lock(worker_mutex_); + if (allowed_threads_.count(std::this_thread::get_id()) == 0) { + if (!worker_) { + worker_.reset(new Worker); + } + worker_->add_task([this, buf]() { this->cuda_free(buf); }); + worker_->end_batch(); + worker_->commit(); + return; + } + } + cudaFree(buf); +} + +size_t CudaAllocator::get_active_memory() const { + return active_memory_; +} + +size_t CudaAllocator::get_peak_memory() const { + return peak_memory_; +} + +void CudaAllocator::reset_peak_memory() { + std::lock_guard lock(mutex_); + peak_memory_ = 0; +} + +size_t CudaAllocator::get_memory_limit() { + return memory_limit_; +} + +size_t CudaAllocator::set_memory_limit(size_t limit) { + std::lock_guard lock(mutex_); + std::swap(limit, memory_limit_); + return limit; +} + +size_t CudaAllocator::get_cache_memory() const { + return buffer_cache_.cache_size(); +} + +size_t CudaAllocator::set_cache_limit(size_t limit) { + std::lock_guard lk(mutex_); + std::swap(limit, max_pool_size_); + return limit; +} + +void CudaAllocator::clear_cache() { + std::lock_guard lk(mutex_); + buffer_cache_.clear(); +} + +CudaAllocator& allocator() { + // By creating the |allocator_| on heap, the destructor of CudaAllocator + // will not be called on exit and buffers in the cache will be leaked. This + // can save some time at program exit. + static CudaAllocator* allocator_ = new CudaAllocator; + return *allocator_; +} + +} // namespace cu + +namespace allocator { + +Allocator& allocator() { + return cu::allocator(); +} + +void* Buffer::raw_ptr() { + if (!ptr_) { + return nullptr; + } + return static_cast(ptr_)->data; +} + +} // namespace allocator + +size_t get_active_memory() { + return cu::allocator().get_active_memory(); +} +size_t get_peak_memory() { + return cu::allocator().get_peak_memory(); +} +void reset_peak_memory() { + return cu::allocator().reset_peak_memory(); +} +size_t set_memory_limit(size_t limit) { + return cu::allocator().set_memory_limit(limit); +} +size_t get_memory_limit() { + return cu::allocator().get_memory_limit(); +} +size_t get_cache_memory() { + return cu::allocator().get_cache_memory(); +} +size_t set_cache_limit(size_t limit) { + return cu::allocator().set_cache_limit(limit); +} +void clear_cache() { + cu::allocator().clear_cache(); +} + +// Not supported in CUDA. +size_t set_wired_limit(size_t) { + return 0; +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/allocator.h b/mlx/backend/cuda/allocator.h new file mode 100644 index 000000000..e268c6334 --- /dev/null +++ b/mlx/backend/cuda/allocator.h @@ -0,0 +1,67 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/allocator.h" +#include "mlx/backend/common/buffer_cache.h" + +#include +#include +#include +#include + +namespace mlx::core::cu { + +class Worker; + +using allocator::Buffer; + +// Stores cuda-managed unified memory. +struct CudaBuffer { + void* data; + size_t size; +}; + +class CudaAllocator : public allocator::Allocator { + public: + Buffer malloc(size_t size) override; + void free(Buffer buffer) override; + size_t size(Buffer buffer) const override; + + // Register current thread as safe to free buffers. + // In cuda freeing a buffer implicitly synchronizes stream, and for threads + // that may be waited by gpu stream (for example cpu stream threads), freeing + // buffers there would result in dead lock. + void register_this_thread(); + + // Call cudaFree in the safe thread. + void cuda_free(void* buf); + + size_t get_active_memory() const; + size_t get_peak_memory() const; + void reset_peak_memory(); + size_t get_memory_limit(); + size_t set_memory_limit(size_t limit); + size_t get_cache_memory() const; + size_t set_cache_limit(size_t limit); + void clear_cache(); + + private: + CudaAllocator(); + friend CudaAllocator& allocator(); + + std::mutex worker_mutex_; + std::unique_ptr worker_; + std::set allowed_threads_; + + std::mutex mutex_; + size_t memory_limit_; + size_t max_pool_size_; + BufferCache buffer_cache_; + size_t active_memory_{0}; + size_t peak_memory_{0}; +}; + +CudaAllocator& allocator(); + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/arg_reduce.cu b/mlx/backend/cuda/arg_reduce.cu new file mode 100644 index 000000000..67ef5d968 --- /dev/null +++ b/mlx/backend/cuda/arg_reduce.cu @@ -0,0 +1,182 @@ +// Copyright © 2025 Apple Inc. +#include "mlx/backend/common/utils.h" +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/device/fp16_math.cuh" +#include "mlx/backend/cuda/iterators/strided_iterator.cuh" +#include "mlx/backend/cuda/kernel_utils.cuh" +#include "mlx/dtype_utils.h" +#include "mlx/primitives.h" + +#include +#include +#include +#include + +#include + +namespace mlx::core { + +namespace cu { + +namespace cg = cooperative_groups; + +template +struct IndexValPair { + uint32_t index; + T val; +}; + +template +struct ArgMin { + constexpr __device__ T init() { + return Limits::max(); + } + + __device__ IndexValPair operator()( + const IndexValPair& best, + const IndexValPair& current) { + if (best.val > current.val || + (best.val == current.val && best.index > current.index)) { + return current; + } else { + return best; + } + } + + template + __device__ IndexValPair + reduce_many(IndexValPair best, T (&vals)[N], uint32_t offset) { + for (int i = 0; i < N; i++) { + if (vals[i] < best.val) { + best.val = vals[i]; + best.index = offset + i; + } + } + return best; + } +}; + +template +struct ArgMax { + constexpr __device__ T init() { + return Limits::min(); + } + + __device__ IndexValPair operator()( + const IndexValPair& best, + const IndexValPair& current) { + if (best.val < current.val || + (best.val == current.val && best.index > current.index)) { + return current; + } else { + return best; + } + } + + template + __device__ IndexValPair + reduce_many(IndexValPair best, T (&vals)[N], uint32_t offset) { + for (int i = 0; i < N; i++) { + if (vals[i] > best.val) { + best.val = vals[i]; + best.index = offset + i; + } + } + return best; + } +}; + +template +__global__ void arg_reduce_general( + const T* in, + uint32_t* out, + size_t size, + const __grid_constant__ Shape shape, + const __grid_constant__ Strides in_strides, + const __grid_constant__ Strides out_strides, + int32_t ndim, + int64_t axis_stride, + int32_t axis_size) { + auto block = cg::this_thread_block(); + + int64_t index = cg::this_grid().block_rank(); + if (index >= size) { + return; + } + + int64_t in_idx = elem_to_loc(index, shape.data(), in_strides.data(), ndim); + int64_t out_idx = elem_to_loc(index, shape.data(), out_strides.data(), ndim); + + Op op; + T init = op.init(); + IndexValPair best{0, init}; + + for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) { + T vals[N_READS]; + auto tid = r * BLOCK_DIM + block.thread_index().x; + cub::LoadDirectBlocked( + tid, strided_iterator(in + in_idx, axis_stride), vals, axis_size, init); + best = op.reduce_many(best, vals, tid * N_READS); + } + + typedef cub::BlockReduce, BLOCK_DIM> BlockReduceT; + __shared__ typename BlockReduceT::TempStorage temp; + + best = BlockReduceT(temp).Reduce(best, op); + + if (block.thread_rank() == 0) { + out[out_idx] = best.index; + } +} + +} // namespace cu + +void ArgReduce::eval_gpu(const std::vector& inputs, array& out) { + nvtx3::scoped_range r("ArgReduce::eval_gpu"); + assert(inputs.size() == 1); + auto& in = inputs[0]; + out.set_data(allocator::malloc(out.nbytes())); + auto& s = stream(); + + // Prepare the shapes, strides and axis arguments. + Shape shape = remove_index(in.shape(), axis_); + Strides in_strides = remove_index(in.strides(), axis_); + Strides out_strides = out.ndim() == in.ndim() + ? remove_index(out.strides(), axis_) + : out.strides(); + int64_t axis_stride = in.strides()[axis_]; + int32_t axis_size = in.shape()[axis_]; + int32_t ndim = shape.size(); + + // ArgReduce. + auto& encoder = cu::get_command_encoder(s); + encoder.set_input_array(in); + encoder.set_output_array(out); + dispatch_real_types(in.dtype(), "ArgReduce", [&](auto type_tag) { + using T = cuda_type_t; + constexpr uint32_t N_READS = 4; + dispatch_block_dim(cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) { + dim3 num_blocks = get_2d_grid_dims(out.shape(), out.strides()); + auto kernel = + cu::arg_reduce_general, block_dim(), N_READS>; + if (reduce_type_ == ArgReduce::ArgMin) { + kernel = cu::arg_reduce_general, block_dim(), N_READS>; + } + encoder.add_kernel_node( + kernel, + num_blocks, + block_dim(), + in.data(), + out.data(), + out.size(), + const_param(shape), + const_param(in_strides), + const_param(out_strides), + ndim, + axis_stride, + axis_size); + }); + }); +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/bin2h.cmake b/mlx/backend/cuda/bin2h.cmake new file mode 100644 index 000000000..b791d3d1a --- /dev/null +++ b/mlx/backend/cuda/bin2h.cmake @@ -0,0 +1,150 @@ +# Based on: https://github.com/sivachandran/cmake-bin2h +# +# Copyright 2020 Sivachandran Paramasivam +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +include(CMakeParseArguments) + +# Function to wrap a given string into multiple lines at the given column +# position. +# +# Parameters: +# +# * VARIABLE - The name of the CMake variable holding the string. +# * AT_COLUMN - The column position at which string will be wrapped. +function(WRAP_STRING) + set(oneValueArgs VARIABLE AT_COLUMN) + cmake_parse_arguments(WRAP_STRING "${options}" "${oneValueArgs}" "" ${ARGN}) + + string(LENGTH ${${WRAP_STRING_VARIABLE}} stringLength) + math(EXPR offset "0") + + while(stringLength GREATER 0) + if(stringLength GREATER ${WRAP_STRING_AT_COLUMN}) + math(EXPR length "${WRAP_STRING_AT_COLUMN}") + else() + math(EXPR length "${stringLength}") + endif() + + string(SUBSTRING ${${WRAP_STRING_VARIABLE}} ${offset} ${length} line) + set(lines "${lines}\n ${line}") + + math(EXPR stringLength "${stringLength} - ${length}") + math(EXPR offset "${offset} + ${length}") + endwhile() + + set(${WRAP_STRING_VARIABLE} + "${lines}" + PARENT_SCOPE) +endfunction() + +# Function to embed contents of a file as byte array in C/C++ header file(.h). +# The header file will contain a byte array and integer variable holding the +# size of the array. +# +# Parameters: +# +# * SOURCE_FILES - The paths of source files whose contents will be embedded in +# the header file. +# * VARIABLE_NAME - The name of the variable for the byte array. The string +# "_SIZE" will be append to this name and will be used a variable name for +# size variable. +# * HEADER_FILE - The path of header file. +# * APPEND - If specified appends to the header file instead of overwriting it +# * HEADER_NAMESPACE - The namespace, where the array should be located in. +# * NULL_TERMINATE - If specified a null byte(zero) will be append to the byte +# array. +# +# Usage: +# +# bin2h(SOURCE_FILE "Logo.png" HEADER_FILE "Logo.h" VARIABLE_NAME "LOGO_PNG") +function(BIN2H) + set(options APPEND NULL_TERMINATE) + set(oneValueArgs VARIABLE_NAME HEADER_FILE HEADER_NAMESPACE) + set(multiValueArgs SOURCE_FILES) + cmake_parse_arguments(BIN2H "${options}" "${oneValueArgs}" + "${multiValueArgs}" ${ARGN}) + + set(arrayDefinition "") + foreach(SOURCE_FILE IN LISTS BIN2H_SOURCE_FILES) + # get filename without extension + get_filename_component(FILE_NAME_WE ${SOURCE_FILE} NAME_WE) + # convert the filename to a valid C identifier + string(MAKE_C_IDENTIFIER "${FILE_NAME_WE}" VALID_FILE_NAME) + + # reads source file contents as hex string + file(READ ${SOURCE_FILE} hexString HEX) + + # append null + if(BIN2H_NULL_TERMINATE) + string(APPEND hexString "00") + endif() + + # wraps the hex string into multiple lines + wrap_string(VARIABLE hexString AT_COLUMN 24) + + # strip the © in source code + string(REGEX REPLACE "c2a9" "2020" arrayValues ${hexString}) + + string(REGEX REPLACE "([0-9a-f][0-9a-f])" " 0x\\1," arrayValues + ${arrayValues}) + + # make a full variable name for the array + set(FULL_VARIABLE_NAME "${BIN2H_VARIABLE_NAME}_${VALID_FILE_NAME}") + + # declares byte array and the length variables + string(APPEND arrayDefinition + "constexpr char ${FULL_VARIABLE_NAME}[] = {${arrayValues}\n};\n\n") + endforeach() + + # add namespace wrapper if defined + if(DEFINED BIN2H_HEADER_NAMESPACE) + set(namespaceStart "namespace ${BIN2H_HEADER_NAMESPACE} {") + set(namespaceEnd "} // namespace ${BIN2H_HEADER_NAMESPACE}") + set(declarations "${namespaceStart}\n\n${arrayDefinition}${namespaceEnd}\n") + endif() + + set(arrayIncludes "#pragma once") + string(PREPEND declarations "${arrayIncludes}\n\n") + + if(BIN2H_APPEND) + file(APPEND ${BIN2H_HEADER_FILE} "${declarations}") + else() + file(WRITE ${BIN2H_HEADER_FILE} "${declarations}") + endif() +endfunction() + +# ----------------------------- CLI args ----------------------------- + +string(REPLACE ":" ";" MLX_JIT_SOURCES_LIST ${MLX_JIT_SOURCES}) +foreach(source ${MLX_JIT_SOURCES_LIST}) + list(APPEND MLX_JIT_SOURCES_ABS "${MLX_SOURCE_ROOT}/${source}") +endforeach() + +bin2h( + SOURCE_FILES + ${MLX_JIT_SOURCES_ABS} + NULL_TERMINATE + VARIABLE_NAME + "jit_source" + HEADER_NAMESPACE + "mlx::core" + HEADER_FILE + "${CMAKE_CURRENT_BINARY_DIR}/gen/cuda_jit_sources.h") diff --git a/mlx/backend/cuda/binary.cu b/mlx/backend/cuda/binary.cu new file mode 100644 index 000000000..3eade024d --- /dev/null +++ b/mlx/backend/cuda/binary.cu @@ -0,0 +1,359 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/common/binary.h" +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/device/binary_ops.cuh" +#include "mlx/backend/cuda/kernel_utils.cuh" +#include "mlx/dtype_utils.h" +#include "mlx/primitives.h" + +#include +#include + +namespace mlx::core { + +namespace cu { + +namespace cg = cooperative_groups; + +template +__global__ void binary_ss(const In* a, const In* b, Out* out, IdxT size) { + IdxT index = cg::this_grid().thread_rank(); + + if ((index + 1) * N_READS > size) { + for (int i = index * N_READS; i < size; ++i) { + out[i] = Op{}(a[0], b[0]); + } + } else { + AlignedVector out_vec; +#pragma unroll + for (int i = 0; i < N_READS; ++i) { + out_vec.val[i] = Op{}(a[0], b[0]); + } + + store_vector(out, index, out_vec); + } +} + +template +__global__ void binary_sv(const In* a, const In* b, Out* out, IdxT size) { + IdxT index = cg::this_grid().thread_rank(); + + if ((index + 1) * N_READS > size) { + for (IdxT i = index * N_READS; i < size; ++i) { + out[i] = Op{}(a[0], b[i]); + } + } else { + auto b_vec = load_vector(b, index); + + AlignedVector out_vec; +#pragma unroll + for (int i = 0; i < N_READS; ++i) { + out_vec.val[i] = Op{}(a[0], b_vec.val[i]); + } + + store_vector(out, index, out_vec); + } +} + +template +__global__ void binary_vs(const In* a, const In* b, Out* out, IdxT size) { + IdxT index = cg::this_grid().thread_rank(); + + if ((index + 1) * N_READS > size) { + for (IdxT i = index * N_READS; i < size; ++i) { + out[i] = Op{}(a[i], b[0]); + } + } else { + auto a_vec = load_vector(a, index); + + AlignedVector out_vec; +#pragma unroll + for (int i = 0; i < N_READS; ++i) { + out_vec.val[i] = Op{}(a_vec.val[i], b[0]); + } + + store_vector(out, index, out_vec); + } +} + +template +__global__ void binary_vv(const In* a, const In* b, Out* out, IdxT size) { + IdxT index = cg::this_grid().thread_rank(); + + if ((index + 1) * N_READS > size) { + for (IdxT i = index * N_READS; i < size; ++i) { + out[i] = Op{}(a[i], b[i]); + } + } else { + auto a_vec = load_vector(a, index); + auto b_vec = load_vector(b, index); + + AlignedVector out_vec; +#pragma unroll + for (int i = 0; i < N_READS; ++i) { + out_vec.val[i] = Op{}(a_vec.val[i], b_vec.val[i]); + } + + store_vector(out, index, out_vec); + } +} + +template +__global__ void binary_g_nd( + const In* a, + const In* b, + Out* out, + IdxT size, + const __grid_constant__ cuda::std::array shape, + const __grid_constant__ cuda::std::array a_strides, + const __grid_constant__ cuda::std::array b_strides) { + IdxT index = cg::this_grid().thread_rank(); + if (index < size) { + auto [a_idx, b_idx] = elem_to_loc_nd( + index, shape.data(), a_strides.data(), b_strides.data()); + out[index] = Op{}(a[a_idx], b[b_idx]); + } +} + +template +__global__ void binary_g( + const In* a, + const In* b, + Out* out, + IdxT size, + const __grid_constant__ Shape shape, + const __grid_constant__ Strides a_strides, + const __grid_constant__ Strides b_strides, + int ndim) { + IdxT index = cg::this_grid().thread_rank(); + if (index < size) { + auto [a_idx, b_idx] = elem_to_loc_4d( + index, shape.data(), a_strides.data(), b_strides.data(), ndim); + out[index] = Op{}(a[a_idx], b[b_idx]); + } +} + +template +constexpr bool supports_binary_op() { + if (std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v) { + return std::is_same_v; + } + if (std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v) { + return std::is_same_v; + } + if (std::is_same_v || std::is_same_v) { + return std::is_same_v && std::is_same_v; + } + if (std::is_same_v) { + return std::is_same_v && is_inexact_v; + } + if (std::is_same_v) { + return std::is_same_v && is_inexact_v; + } + if (std::is_same_v) { + return std::is_same_v && is_floating_v; + } + if (std::is_same_v || std::is_same_v || + std::is_same_v) { + return std::is_same_v && std::is_integral_v; + } + if (std::is_same_v || std::is_same_v) { + return std::is_same_v && std::is_integral_v && + !std::is_same_v; + } + return false; +} + +} // namespace cu + +template +void binary_op_gpu_inplace( + const std::vector& inputs, + array& out, + const char* op, + const Stream& s) { + assert(inputs.size() > 1); + const auto& a = inputs[0]; + const auto& b = inputs[1]; + if (out.size() == 0) { + return; + } + + auto& encoder = cu::get_command_encoder(s); + encoder.set_input_array(a); + encoder.set_input_array(b); + encoder.set_output_array(out); + dispatch_all_types(a.dtype(), [&](auto in_type_tag) { + dispatch_all_types(out.dtype(), [&](auto out_type_tag) { + using CTYPE_IN = MLX_GET_TYPE(in_type_tag); + using CTYPE_OUT = MLX_GET_TYPE(out_type_tag); + if constexpr (cu::supports_binary_op()) { + using InType = cuda_type_t; + using OutType = cuda_type_t; + auto bopt = get_binary_op_type(a, b); + if (bopt == BinaryOpType::General) { + dispatch_bool( + a.data_size() > INT32_MAX || b.data_size() > INT32_MAX || + out.data_size() > INT32_MAX, + [&](auto large) { + using IdxT = std::conditional_t; + Shape shape; + std::vector strides; + std::tie(shape, strides) = collapse_contiguous_dims(a, b, out); + auto& a_strides = strides[0]; + auto& b_strides = strides[1]; + int ndim = shape.size(); + if (ndim <= 3) { + dispatch_1_2_3(ndim, [&](auto dims_constant) { + auto kernel = cu:: + binary_g_nd; + auto [num_blocks, block_dims] = + get_launch_args(kernel, out, large()); + encoder.add_kernel_node( + kernel, + num_blocks, + block_dims, + a.data(), + b.data(), + out.data(), + out.size(), + const_param(shape), + const_param(a_strides), + const_param(b_strides)); + }); + } else { + auto kernel = cu::binary_g; + auto [num_blocks, block_dims] = + get_launch_args(kernel, out, large()); + encoder.add_kernel_node( + kernel, + num_blocks, + block_dims, + a.data(), + b.data(), + out.data(), + out.size(), + const_param(shape), + const_param(a_strides), + const_param(b_strides), + ndim); + } + }); + } else { + dispatch_bool(out.data_size() > UINT32_MAX, [&](auto large) { + using IdxT = std::conditional_t; + // TODO: Choose optimized value based on type size. + constexpr int N_READS = 4; + auto kernel = cu::binary_ss; + if (bopt == BinaryOpType::ScalarVector) { + kernel = cu::binary_sv; + } else if (bopt == BinaryOpType::VectorScalar) { + kernel = cu::binary_vs; + } else if (bopt == BinaryOpType::VectorVector) { + kernel = cu::binary_vv; + } + auto [num_blocks, block_dims] = get_launch_args( + kernel, + out.data_size(), + out.shape(), + out.strides(), + large(), + N_READS); + encoder.add_kernel_node( + kernel, + num_blocks, + block_dims, + a.data(), + b.data(), + out.data(), + out.data_size()); + }); + } + } else { + throw std::runtime_error(fmt::format( + "Can not do binary op {} on inputs of {} with result of {}.", + op, + dtype_to_string(a.dtype()), + dtype_to_string(out.dtype()))); + } + }); + }); +} + +template +void binary_op_gpu( + const std::vector& inputs, + array& out, + const char* op, + const Stream& s) { + auto& a = inputs[0]; + auto& b = inputs[1]; + auto bopt = get_binary_op_type(a, b); + set_binary_op_output_data(a, b, out, bopt); + binary_op_gpu_inplace(inputs, out, op, s); +} + +#define BINARY_GPU(func) \ + void func::eval_gpu(const std::vector& inputs, array& out) { \ + nvtx3::scoped_range r(#func "::eval_gpu"); \ + auto& s = out.primitive().stream(); \ + binary_op_gpu(inputs, out, name(), s); \ + } + +BINARY_GPU(Add) +BINARY_GPU(ArcTan2) +BINARY_GPU(Divide) +BINARY_GPU(Remainder) +BINARY_GPU(Greater) +BINARY_GPU(GreaterEqual) +BINARY_GPU(Less) +BINARY_GPU(LessEqual) +BINARY_GPU(LogicalAnd) +BINARY_GPU(LogicalOr) +BINARY_GPU(LogAddExp) +BINARY_GPU(Maximum) +BINARY_GPU(Minimum) +BINARY_GPU(Multiply) +BINARY_GPU(NotEqual) +BINARY_GPU(Power) +BINARY_GPU(Subtract) + +void Equal::eval_gpu(const std::vector& inputs, array& out) { + nvtx3::scoped_range r("Equal::eval_gpu"); + auto& s = out.primitive().stream(); + if (equal_nan_) { + binary_op_gpu(inputs, out, name(), s); + } else { + binary_op_gpu(inputs, out, name(), s); + } +} + +void BitwiseBinary::eval_gpu(const std::vector& inputs, array& out) { + nvtx3::scoped_range r("BitwiseBinary::eval_gpu"); + auto& s = out.primitive().stream(); + switch (op_) { + case BitwiseBinary::And: + binary_op_gpu(inputs, out, name(), s); + break; + case BitwiseBinary::Or: + binary_op_gpu(inputs, out, name(), s); + break; + case BitwiseBinary::Xor: + binary_op_gpu(inputs, out, name(), s); + break; + case BitwiseBinary::LeftShift: + binary_op_gpu(inputs, out, name(), s); + break; + case BitwiseBinary::RightShift: + binary_op_gpu(inputs, out, name(), s); + break; + } +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/binary_two.cu b/mlx/backend/cuda/binary_two.cu new file mode 100644 index 000000000..3ac8a9516 --- /dev/null +++ b/mlx/backend/cuda/binary_two.cu @@ -0,0 +1,334 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/common/binary.h" +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/device/binary_ops.cuh" +#include "mlx/backend/cuda/kernel_utils.cuh" +#include "mlx/dtype_utils.h" +#include "mlx/primitives.h" + +#include +#include + +namespace mlx::core { + +namespace cu { + +namespace cg = cooperative_groups; + +template +__global__ void +binary_two_ss(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) { + IdxT index = cg::this_grid().thread_rank(); + + if ((index + 1) * N_READS > size) { + for (IdxT i = index * N_READS; i < size; ++i) { + auto out = Op{}(a[0], b[0]); + out_a[i] = out[0]; + out_b[i] = out[1]; + } + } else { + AlignedVector out_a_vec; + AlignedVector out_b_vec; +#pragma unroll + for (int i = 0; i < N_READS; ++i) { + auto out = Op{}(a[0], b[0]); + out_a_vec.val[i] = out[0]; + out_b_vec.val[i] = out[1]; + } + + store_vector(out_a, index, out_a_vec); + store_vector(out_b, index, out_b_vec); + } +} + +template +__global__ void +binary_two_sv(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) { + IdxT index = cg::this_grid().thread_rank(); + + if ((index + 1) * N_READS > size) { + for (IdxT i = index * N_READS; i < size; ++i) { + auto out = Op{}(a[0], b[i]); + out_a[i] = out[0]; + out_b[i] = out[1]; + } + } else { + auto b_vec = load_vector(b, index); + + AlignedVector out_a_vec; + AlignedVector out_b_vec; +#pragma unroll + for (int i = 0; i < N_READS; ++i) { + auto out = Op{}(a[0], b_vec.val[i]); + out_a_vec.val[i] = out[0]; + out_b_vec.val[i] = out[1]; + } + + store_vector(out_a, index, out_a_vec); + store_vector(out_b, index, out_b_vec); + } +} + +template +__global__ void +binary_two_vs(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) { + IdxT index = cg::this_grid().thread_rank(); + + if ((index + 1) * N_READS > size) { + for (IdxT i = index * N_READS; i < size; ++i) { + auto out = Op{}(a[i], b[0]); + out_a[i] = out[0]; + out_b[i] = out[1]; + } + } else { + auto a_vec = load_vector(a, index); + + AlignedVector out_a_vec; + AlignedVector out_b_vec; +#pragma unroll + for (int i = 0; i < N_READS; ++i) { + auto out = Op{}(a_vec.val[i], b[0]); + out_a_vec.val[i] = out[0]; + out_b_vec.val[i] = out[1]; + } + + store_vector(out_a, index, out_a_vec); + store_vector(out_b, index, out_b_vec); + } +} + +template +__global__ void +binary_two_vv(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) { + IdxT index = cg::this_grid().thread_rank(); + + if ((index + 1) * N_READS > size) { + for (IdxT i = index * N_READS; i < size; ++i) { + auto out = Op{}(a[i], b[i]); + out_a[i] = out[0]; + out_b[i] = out[1]; + } + } else { + auto a_vec = load_vector(a, index); + auto b_vec = load_vector(b, index); + + AlignedVector out_a_vec; + AlignedVector out_b_vec; +#pragma unroll + for (int i = 0; i < N_READS; ++i) { + auto out = Op{}(a_vec.val[i], b_vec.val[i]); + out_a_vec.val[i] = out[0]; + out_b_vec.val[i] = out[1]; + } + + store_vector(out_a, index, out_a_vec); + store_vector(out_b, index, out_b_vec); + } +} + +template +__global__ void binary_two_g_nd( + const In* a, + const In* b, + Out* out_a, + Out* out_b, + IdxT size, + const __grid_constant__ cuda::std::array shape, + const __grid_constant__ cuda::std::array a_strides, + const __grid_constant__ cuda::std::array b_strides) { + IdxT index = cg::this_grid().thread_rank(); + if (index < size) { + auto [a_idx, b_idx] = elem_to_loc_nd( + index, shape.data(), a_strides.data(), b_strides.data()); + auto out = Op{}(a[a_idx], b[b_idx]); + out_a[index] = out[0]; + out_b[index] = out[1]; + } +} + +template +__global__ void binary_two_g( + const In* a, + const In* b, + Out* out_a, + Out* out_b, + IdxT size, + const __grid_constant__ Shape shape, + const __grid_constant__ Strides a_strides, + const __grid_constant__ Strides b_strides, + int ndim) { + IdxT index = cg::this_grid().thread_rank(); + if (index < size) { + auto [a_idx, b_idx] = elem_to_loc_4d( + index, shape.data(), a_strides.data(), b_strides.data(), ndim); + auto out = Op{}(a[a_idx], b[b_idx]); + out_a[index] = out[0]; + out_b[index] = out[1]; + } +} + +template +constexpr bool supports_binary_two_op() { + if (std::is_same_v) { + return std::is_same_v && + (std::is_integral_v || is_floating_v); + } + return false; +} + +} // namespace cu + +template +void binary_two_op_gpu_inplace( + const std::vector& inputs, + std::vector& outputs, + const char* op, + const Stream& s) { + assert(inputs.size() > 1); + const auto& a = inputs[0]; + const auto& b = inputs[1]; + auto& out_a = outputs[0]; + auto& out_b = outputs[1]; + auto bopt = get_binary_op_type(a, b); + set_binary_op_output_data(a, b, out_a, bopt); + set_binary_op_output_data(a, b, out_b, bopt); + + if (out_a.size() == 0) { + return; + } + + auto& encoder = cu::get_command_encoder(s); + encoder.set_input_array(a); + encoder.set_input_array(b); + encoder.set_output_array(out_a); + encoder.set_output_array(out_b); + dispatch_all_types(a.dtype(), [&](auto in_type_tag) { + dispatch_all_types(out_a.dtype(), [&](auto out_type_tag) { + using CTYPE_IN = MLX_GET_TYPE(in_type_tag); + using CTYPE_OUT = MLX_GET_TYPE(out_type_tag); + if constexpr (cu::supports_binary_two_op()) { + using InType = cuda_type_t; + using OutType = cuda_type_t; + + auto bopt = get_binary_op_type(a, b); + if (bopt == BinaryOpType::General) { + dispatch_bool( + a.data_size() > INT32_MAX || b.data_size() > INT32_MAX || + out_a.data_size() > INT32_MAX, + [&](auto large) { + using IdxT = std::conditional_t; + Shape shape; + std::vector strides; + std::tie(shape, strides) = + collapse_contiguous_dims(a, b, out_a); + auto& a_strides = strides[0]; + auto& b_strides = strides[1]; + int ndim = shape.size(); + if (ndim <= 3) { + dispatch_1_2_3(ndim, [&](auto dims_constant) { + auto kernel = cu::binary_two_g_nd< + Op, + InType, + OutType, + IdxT, + dims_constant()>; + auto [num_blocks, block_dims] = + get_launch_args(kernel, out_a, large()); + encoder.add_kernel_node( + kernel, + num_blocks, + block_dims, + a.data(), + b.data(), + out_a.data(), + out_b.data(), + out_a.size(), + const_param(shape), + const_param(a_strides), + const_param(b_strides)); + }); + } else { + auto kernel = cu::binary_two_g; + auto [num_blocks, block_dims] = + get_launch_args(kernel, out_a, large()); + encoder.add_kernel_node( + kernel, + num_blocks, + block_dims, + a.data(), + b.data(), + out_a.data(), + out_b.data(), + out_a.size(), + const_param(shape), + const_param(a_strides), + const_param(b_strides), + ndim); + } + }); + } else { + dispatch_bool(out_a.data_size() > UINT32_MAX, [&](auto large) { + using IdxT = std::conditional_t; + // TODO: Choose optimized value based on type size. + constexpr int N_READS = 4; + auto kernel = cu::binary_two_ss; + if (bopt == BinaryOpType::ScalarVector) { + kernel = cu::binary_two_sv; + } else if (bopt == BinaryOpType::VectorScalar) { + kernel = cu::binary_two_vs; + } else if (bopt == BinaryOpType::VectorVector) { + kernel = cu::binary_two_vv; + } + auto [num_blocks, block_dims] = get_launch_args( + kernel, + out_a.data_size(), + out_a.shape(), + out_a.strides(), + large(), + N_READS); + encoder.add_kernel_node( + kernel, + num_blocks, + block_dims, + a.data(), + b.data(), + out_a.data(), + out_b.data(), + out_a.data_size()); + }); + } + } else { + throw std::runtime_error(fmt::format( + "Can not do binary op {} on inputs of {} with result of {}.", + op, + dtype_to_string(a.dtype()), + dtype_to_string(out_a.dtype()))); + } + }); + }); +} + +template +void binary_two_op_gpu( + const std::vector& inputs, + std::vector& outputs, + const char* op, + const Stream& s) { + auto& a = inputs[0]; + auto& b = inputs[1]; + auto bopt = get_binary_op_type(a, b); + set_binary_op_output_data(a, b, outputs[0], bopt); + set_binary_op_output_data(a, b, outputs[1], bopt); + binary_two_op_gpu_inplace(inputs, outputs, op, s); +} + +void DivMod::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + nvtx3::scoped_range r("DivMod::eval_gpu"); + auto& s = outputs[0].primitive().stream(); + binary_two_op_gpu(inputs, outputs, name(), s); +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/compiled.cpp b/mlx/backend/cuda/compiled.cpp new file mode 100644 index 000000000..2f3990b90 --- /dev/null +++ b/mlx/backend/cuda/compiled.cpp @@ -0,0 +1,231 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/common/compiled.h" +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/jit_module.h" +#include "mlx/backend/cuda/kernel_utils.cuh" +#include "mlx/graph_utils.h" +#include "mlx/primitives.h" + +#include +#include + +namespace mlx::core { + +namespace cu { + +struct FusedKernelBuilder { + std::string os; + const std::string& kernel_name; + const std::vector& inputs; + const std::vector& outputs; + const std::vector& tape; + const std::function& is_constant; + + void build(const char* name, bool contiguous) { + NodeNamer namer; + + // Function parameters. + std::vector params; + for (size_t i = 0; i < inputs.size(); ++i) { + if (is_constant(i)) { + continue; + } + const auto& x = inputs[i]; + const std::string& xname = namer.get_name(x); + params.push_back( + fmt::format("const {}* {}", dtype_to_cuda_type(x.dtype()), xname)); + if (!is_scalar(x) && !contiguous) { + params.push_back(fmt::format( + "const __grid_constant__ cuda::std::array {}_strides", + xname)); + } + } + for (const auto& x : outputs) { + params.push_back(fmt::format( + "{}* {}", dtype_to_cuda_type(x.dtype()), namer.get_name(x))); + } + if (!contiguous) { + params.push_back( + "const __grid_constant__ cuda::std::array shape"); + } + params.push_back("IdxT size"); + + // Build function signature. + if (contiguous) { + os += "template \n"; + } else { + os += "template \n"; + } + os += fmt::format("__global__ void {}(\n", kernel_name + name); + for (size_t i = 0; i < params.size(); ++i) { + os += " "; + os += params[i]; + if (i != params.size() - 1) { + os += ",\n"; + } + } + os += ") {\n"; + + // Index. + os += + " IdxT index = cg::this_grid().thread_rank();\n" + " if (index >= size) {\n" + " return;\n" + " }\n"; + + // Read inputs. + for (size_t i = 0; i < inputs.size(); ++i) { + const auto& x = inputs[i]; + const std::string& xname = namer.get_name(x); + std::string type = dtype_to_cuda_type(x.dtype()); + std::string value; + if (is_constant(i)) { + std::ostringstream ss; + print_constant(ss, x); + value = fmt::format("static_cast<{}>({})", type, ss.str()); + } else if (is_scalar(x)) { + value = fmt::format("{}[0]", xname); + } else if (contiguous) { + value = fmt::format("{}[index]", xname); + } else { + std::string index = fmt::format( + "elem_to_loc_nd(index, shape.data(), {}_strides.data())", + xname); + value = fmt::format("{}[{}]", xname, index); + } + os += fmt::format(" {} tmp_{} = {};\n", type, xname, value); + } + + // Write tape. + for (const auto& x : tape) { + const std::string& xname = namer.get_name(x); + std::string type = dtype_to_cuda_type(x.dtype()); + std::string value; + if (is_static_cast(x.primitive())) { + value = fmt::format( + "static_cast<{}>(tmp_{})", type, namer.get_name(x.inputs()[0])); + } else { + value = x.primitive().name(); + value += "{}("; + for (size_t i = 0; i < x.inputs().size() - 1; ++i) { + value += fmt::format("tmp_{}, ", namer.get_name(x.inputs()[i])); + } + value += fmt::format("tmp_{})", namer.get_name(x.inputs().back())); + } + os += fmt::format(" {} tmp_{} = {};\n", type, xname, value); + } + + // Write output. + for (const auto& x : outputs) { + os += fmt::format(" {0}[index] = tmp_{0};\n", namer.get_name(x)); + } + + os += "}\n"; + } +}; + +} // namespace cu + +constexpr const char* g_jit_includes = R"( +#include "mlx/backend/cuda/device/binary_ops.cuh" +#include "mlx/backend/cuda/device/ternary_ops.cuh" +#include "mlx/backend/cuda/device/unary_ops.cuh" +#include "mlx/backend/cuda/device/utils.cuh" + +#include + +#define inf cuda::std::numeric_limits::infinity() +)"; + +void Compiled::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + nvtx3::scoped_range r("Compiled::eval_gpu"); + auto& s = stream(); + + cu::JitModule& mod = cu::get_jit_module(s.device, lib_name(), [&]() { + // Build source code. + cu::FusedKernelBuilder builder{ + g_jit_includes, lib_name(), inputs_, outputs_, tape_, is_constant_}; + builder.os += + "namespace mlx::core::cu {\n\n" + "namespace cg = cooperative_groups;\n\n"; + builder.build("_contiguous", true); + builder.os += "\n"; + builder.build("_strided", false); + builder.os += "\n} // namespace mlx::core::cu\n"; + // Build kernel names. + std::vector kernel_names = { + fmt::format("mlx::core::cu::{}_contiguous", lib_name()), + fmt::format("mlx::core::cu::{}_contiguous", lib_name()), + }; + for (int i = 1; i <= MAX_NDIM; ++i) { + kernel_names.push_back(fmt::format( + "mlx::core::cu::{}_strided<{}, uint32_t>", lib_name(), i)); + kernel_names.push_back( + fmt::format("mlx::core::cu::{}_strided<{}, int64_t>", lib_name(), i)); + } + return std::make_pair(std::move(builder.os), std::move(kernel_names)); + }); + + // Collapse contiguous dims to route to a faster kernel if possible. Also + // handle all broadcasting. + auto [contiguous, shape, strides_vec] = + compiled_collapse_contiguous_dims(inputs, outputs[0], is_constant_); + + // Whether to use large index. + bool large = compiled_use_large_index(inputs, outputs, contiguous); + + cu::KernelArgs args; + // Put inputs. + int strides_index = 1; + for (size_t i = 0; i < inputs.size(); ++i) { + if (is_constant_(i)) { + continue; + } + const auto& x = inputs[i]; + args.append(x); + if (!contiguous && !is_scalar(x)) { + args.append_ptr(strides_vec[strides_index++].data()); + } + } + + // Put outputs. + compiled_allocate_outputs(inputs, outputs, is_constant_, contiguous); + for (auto& x : outputs) { + args.append(x); + } + + // Put shape and size. + if (!contiguous) { + args.append_ptr(shape.data()); + } + if (large) { + args.append(outputs[0].data_size()); + } else { + args.append(outputs[0].data_size()); + } + + // Launch kernel. + const char* index_type = large ? "int64_t" : "uint32_t"; + std::string kernel_name = fmt::format("mlx::core::cu::{}", lib_name()); + if (contiguous) { + kernel_name += fmt::format("_contiguous<{}>", index_type); + } else { + kernel_name += fmt::format("_strided<{}, {}>", shape.size(), index_type); + } + auto& encoder = cu::get_command_encoder(s); + for (const auto& in : inputs) { + encoder.set_input_array(in); + } + for (const auto& out : outputs) { + encoder.set_output_array(out); + } + + auto kernel = mod.get_kernel(kernel_name); + auto [num_blocks, block_dims] = get_launch_args(kernel, outputs[0], large); + encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args()); +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/copy.cu b/mlx/backend/cuda/copy.cu new file mode 100644 index 000000000..321806720 --- /dev/null +++ b/mlx/backend/cuda/copy.cu @@ -0,0 +1,87 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/common/utils.h" +#include "mlx/backend/cuda/copy/copy.cuh" + +namespace mlx::core { + +void copy_gpu_inplace( + const array& in, + array& out, + const Shape& shape, + const Strides& strides_in, + const Strides& strides_out, + int64_t offset_in, + int64_t offset_out, + CopyType ctype, + const Stream& s, + const std::optional& dynamic_offset_in, + const std::optional& dynamic_offset_out) { + if (out.size() == 0) { + return; + } + + auto& encoder = cu::get_command_encoder(s); + encoder.set_input_array(in); + encoder.set_output_array(out); + if (ctype == CopyType::Scalar || ctype == CopyType::Vector) { + copy_contiguous(encoder, ctype, in, out, offset_in, offset_out); + return; + } + + if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) { + auto [shape_collapsed, strides_vec] = collapse_contiguous_dims( + shape, std::vector{strides_in, strides_out}, INT32_MAX); + if (ctype == CopyType::General) { + copy_general_input( + encoder, + ctype, + in, + out, + offset_in, + offset_out, + shape_collapsed, + strides_vec[0]); + } else { + if (dynamic_offset_in || dynamic_offset_out) { + copy_general_dynamic( + encoder, + ctype, + in, + out, + offset_in, + offset_out, + shape_collapsed, + strides_vec[0], + strides_vec[1], + dynamic_offset_in ? *dynamic_offset_in : array(0, int64), + dynamic_offset_out ? *dynamic_offset_out : array(0, int64)); + } else { + copy_general( + encoder, + ctype, + in, + out, + offset_in, + offset_out, + shape_collapsed, + strides_vec[0], + strides_vec[1]); + } + } + return; + } +} + +void fill_gpu(const array& in, array& out, const Stream& s) { + if (out.size() == 0) { + return; + } + out.set_data(allocator::malloc(out.nbytes())); + auto& encoder = cu::get_command_encoder(s); + encoder.set_input_array(in); + encoder.set_output_array(out); + copy_contiguous(encoder, CopyType::Scalar, in, out, 0, 0); +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/copy/copy.cuh b/mlx/backend/cuda/copy/copy.cuh new file mode 100644 index 000000000..e80fdec8c --- /dev/null +++ b/mlx/backend/cuda/copy/copy.cuh @@ -0,0 +1,55 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/device/cast_op.cuh" +#include "mlx/backend/cuda/kernel_utils.cuh" +#include "mlx/backend/gpu/copy.h" +#include "mlx/dtype_utils.h" + +namespace mlx::core { + +void copy_contiguous( + cu::CommandEncoder& encoder, + CopyType ctype, + const array& in, + array& out, + int64_t offset_in, + int64_t offset_out); + +void copy_general( + cu::CommandEncoder& encoder, + CopyType ctype, + const array& in, + array& out, + int64_t offset_in, + int64_t offset_out, + const Shape& shape, + const Strides& strides_in, + const Strides& strides_out); + +void copy_general_dynamic( + cu::CommandEncoder& encoder, + CopyType ctype, + const array& in, + array& out, + int64_t offset_in, + int64_t offset_out, + const Shape& shape, + const Strides& strides_in, + const Strides& strides_out, + const array& dynamic_offset_in, + const array& dynamic_offset_out); + +void copy_general_input( + cu::CommandEncoder& encoder, + CopyType ctype, + const array& in, + array& out, + int64_t offset_in, + int64_t offset_out, + const Shape& shape, + const Strides& strides_in); + +} // namespace mlx::core diff --git a/mlx/backend/cuda/copy/copy_contiguous.cu b/mlx/backend/cuda/copy/copy_contiguous.cu new file mode 100644 index 000000000..4e9eaccb7 --- /dev/null +++ b/mlx/backend/cuda/copy/copy_contiguous.cu @@ -0,0 +1,93 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/copy/copy.cuh" + +#include + +namespace mlx::core { + +namespace cu { + +namespace cg = cooperative_groups; + +template +__global__ void copy_s(const In* in, Out* out, IdxT size) { + IdxT index = cg::this_grid().thread_rank(); + + if ((index + 1) * N_READS > size) { + for (IdxT i = index * N_READS; i < size; ++i) { + out[i] = cast_to(in[0]); + } + } else { + AlignedVector out_vec; +#pragma unroll + for (int i = 0; i < N_READS; ++i) { + out_vec.val[i] = cast_to(in[0]); + } + + store_vector(out, index, out_vec); + } +} + +template +__global__ void copy_v(const In* in, Out* out, IdxT size) { + IdxT index = cg::this_grid().thread_rank(); + + if ((index + 1) * N_READS > size) { + for (IdxT i = index * N_READS; i < size; ++i) { + out[i] = cast_to(in[i]); + } + } else { + auto in_vec = load_vector(in, index); + + AlignedVector out_vec; +#pragma unroll + for (int i = 0; i < N_READS; ++i) { + out_vec.val[i] = cast_to(in_vec.val[i]); + } + + store_vector(out, index, out_vec); + } +} + +} // namespace cu + +void copy_contiguous( + cu::CommandEncoder& encoder, + CopyType ctype, + const array& in, + array& out, + int64_t in_offset, + int64_t out_offset) { + dispatch_all_types(in.dtype(), [&](auto in_type_tag) { + dispatch_all_types(out.dtype(), [&](auto out_type_tag) { + dispatch_bool(out.data_size() > UINT32_MAX, [&](auto large) { + using InType = cuda_type_t; + using OutType = cuda_type_t; + using IdxT = std::conditional_t; + // TODO: Choose optimized value based on type size. + constexpr int N_READS = 4; + auto kernel = cu::copy_s; + if (ctype == CopyType::Vector) { + kernel = cu::copy_v; + } + auto [num_blocks, block_dims] = get_launch_args( + kernel, + out.data_size(), + out.shape(), + out.strides(), + large(), + N_READS); + encoder.add_kernel_node( + kernel, + num_blocks, + block_dims, + in.data() + in_offset, + out.data() + out_offset, + out.data_size()); + }); + }); + }); +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/copy/copy_general.cu b/mlx/backend/cuda/copy/copy_general.cu new file mode 100644 index 000000000..5c7f9f954 --- /dev/null +++ b/mlx/backend/cuda/copy/copy_general.cu @@ -0,0 +1,110 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/copy/copy.cuh" + +#include + +namespace mlx::core { + +namespace cu { + +namespace cg = cooperative_groups; + +template +__global__ void copy_gg_nd( + const In* in, + Out* out, + IdxT size, + const __grid_constant__ cuda::std::array shape, + const __grid_constant__ cuda::std::array strides_in, + const __grid_constant__ cuda::std::array strides_out) { + IdxT index = cg::this_grid().thread_rank(); + if (index < size) { + auto [idx_in, idx_out] = elem_to_loc_nd( + index, shape.data(), strides_in.data(), strides_out.data()); + out[idx_out] = CastOp{}(in[idx_in]); + } +} + +template +__global__ void copy_gg( + const In* in, + Out* out, + IdxT size, + const __grid_constant__ Shape shape, + const __grid_constant__ Strides strides_in, + const __grid_constant__ Strides strides_out, + int ndim) { + IdxT index = cg::this_grid().thread_rank(); + if (index < size) { + auto [idx_in, idx_out] = elem_to_loc_4d( + index, shape.data(), strides_in.data(), strides_out.data(), ndim); + out[idx_out] = CastOp{}(in[idx_in]); + } +} + +} // namespace cu + +void copy_general( + cu::CommandEncoder& encoder, + CopyType ctype, + const array& in, + array& out, + int64_t offset_in, + int64_t offset_out, + const Shape& shape, + const Strides& strides_in, + const Strides& strides_out) { + dispatch_all_types(in.dtype(), [&](auto in_type_tag) { + dispatch_all_types(out.dtype(), [&](auto out_type_tag) { + dispatch_bool( + in.data_size() > INT32_MAX || out.data_size() > INT32_MAX, + [&](auto large) { + using InType = cuda_type_t; + using OutType = cuda_type_t; + using IdxT = std::conditional_t; + const InType* in_ptr = in.data() + offset_in; + OutType* out_ptr = out.data() + offset_out; + int ndim = shape.size(); + size_t data_size = 1; + for (auto& s : shape) + data_size *= s; + if (ndim <= 3) { + dispatch_1_2_3(ndim, [&](auto ndim_constant) { + auto kernel = + cu::copy_gg_nd; + auto [num_blocks, block_dims] = get_launch_args( + kernel, data_size, shape, out.strides(), large()); + encoder.add_kernel_node( + kernel, + num_blocks, + block_dims, + in_ptr, + out_ptr, + data_size, + const_param(shape), + const_param(strides_in), + const_param(strides_out)); + }); + } else { // ndim >= 4 + auto kernel = cu::copy_gg; + auto [num_blocks, block_dims] = get_launch_args( + kernel, data_size, shape, out.strides(), large()); + encoder.add_kernel_node( + kernel, + num_blocks, + block_dims, + in_ptr, + out_ptr, + data_size, + const_param(shape), + const_param(strides_in), + const_param(strides_out), + ndim); + } + }); + }); + }); +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/copy/copy_general_dynamic.cu b/mlx/backend/cuda/copy/copy_general_dynamic.cu new file mode 100644 index 000000000..1b643111f --- /dev/null +++ b/mlx/backend/cuda/copy/copy_general_dynamic.cu @@ -0,0 +1,117 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/copy/copy.cuh" + +#include + +namespace mlx::core { + +namespace cu { + +namespace cg = cooperative_groups; + +template +__global__ void copy_gg_dynamic_nd( + const In* in, + Out* out, + IdxT size, + const __grid_constant__ cuda::std::array shape, + const __grid_constant__ cuda::std::array strides_in, + const __grid_constant__ cuda::std::array strides_out, + const int64_t* offset_in, + const int64_t* offset_out) { + IdxT index = cg::this_grid().thread_rank(); + if (index < size) { + auto [idx_in, idx_out] = elem_to_loc_nd( + index, shape.data(), strides_in.data(), strides_out.data()); + out[idx_out + *offset_out] = CastOp{}(in[idx_in + *offset_in]); + } +} + +template +__global__ void copy_gg_dynamic( + const In* in, + Out* out, + IdxT size, + const __grid_constant__ Shape shape, + const __grid_constant__ Strides strides_in, + const __grid_constant__ Strides strides_out, + int ndim, + const int64_t* offset_in, + const int64_t* offset_out) { + IdxT index = cg::this_grid().thread_rank(); + if (index < size) { + auto [idx_in, idx_out] = elem_to_loc_4d( + index, shape.data(), strides_in.data(), strides_out.data(), ndim); + out[idx_out + *offset_out] = CastOp{}(in[idx_in + *offset_in]); + } +} + +} // namespace cu + +void copy_general_dynamic( + cu::CommandEncoder& encoder, + CopyType ctype, + const array& in, + array& out, + int64_t offset_in, + int64_t offset_out, + const Shape& shape, + const Strides& strides_in, + const Strides& strides_out, + const array& dynamic_offset_in, + const array& dynamic_offset_out) { + dispatch_all_types(in.dtype(), [&](auto in_type_tag) { + dispatch_all_types(out.dtype(), [&](auto out_type_tag) { + dispatch_bool( + in.data_size() > INT32_MAX || out.data_size() > INT32_MAX, + [&](auto large) { + using InType = cuda_type_t; + using OutType = cuda_type_t; + using IdxT = std::conditional_t; + const InType* in_ptr = in.data() + offset_in; + OutType* out_ptr = out.data() + offset_out; + int ndim = shape.size(); + if (ndim <= 3) { + dispatch_1_2_3(ndim, [&](auto dims_constant) { + auto kernel = cu:: + copy_gg_dynamic_nd; + auto [num_blocks, block_dims] = + get_launch_args(kernel, out, large()); + encoder.add_kernel_node( + kernel, + num_blocks, + block_dims, + in_ptr, + out_ptr, + out.size(), + const_param(shape), + const_param(strides_in), + const_param(strides_out), + dynamic_offset_in.data(), + dynamic_offset_out.data()); + }); + } else { // ndim >= 4 + auto kernel = cu::copy_gg_dynamic; + auto [num_blocks, block_dims] = + get_launch_args(kernel, out, large()); + encoder.add_kernel_node( + kernel, + num_blocks, + block_dims, + in_ptr, + out_ptr, + out.size(), + const_param(shape), + const_param(strides_in), + const_param(strides_out), + ndim, + dynamic_offset_in.data(), + dynamic_offset_out.data()); + } + }); + }); + }); +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/copy/copy_general_input.cu b/mlx/backend/cuda/copy/copy_general_input.cu new file mode 100644 index 000000000..1ac7712e6 --- /dev/null +++ b/mlx/backend/cuda/copy/copy_general_input.cu @@ -0,0 +1,100 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/copy/copy.cuh" + +#include + +namespace mlx::core { + +namespace cu { + +namespace cg = cooperative_groups; + +template +__global__ void copy_g_nd( + const In* in, + Out* out, + IdxT size, + const __grid_constant__ cuda::std::array shape, + const __grid_constant__ cuda::std::array strides_in) { + IdxT index = cg::this_grid().thread_rank(); + if (index < size) { + IdxT idx_in = elem_to_loc_nd(index, shape.data(), strides_in.data()); + out[index] = CastOp{}(in[idx_in]); + } +} + +template +__global__ void copy_g( + const In* in, + Out* out, + IdxT size, + const __grid_constant__ Shape shape, + const __grid_constant__ Strides strides_in, + int ndim) { + IdxT index = cg::this_grid().thread_rank(); + if (index < size) { + IdxT idx_in = elem_to_loc_4d(index, shape.data(), strides_in.data(), ndim); + out[index] = CastOp{}(in[idx_in]); + } +} + +} // namespace cu + +void copy_general_input( + cu::CommandEncoder& encoder, + CopyType ctype, + const array& in, + array& out, + int64_t offset_in, + int64_t offset_out, + const Shape& shape, + const Strides& strides_in) { + dispatch_all_types(in.dtype(), [&](auto in_type_tag) { + dispatch_all_types(out.dtype(), [&](auto out_type_tag) { + dispatch_bool( + in.data_size() > INT32_MAX || out.data_size() > INT32_MAX, + [&](auto large) { + using InType = cuda_type_t; + using OutType = cuda_type_t; + using IdxT = std::conditional_t; + const InType* in_ptr = in.data() + offset_in; + OutType* out_ptr = out.data() + offset_out; + int ndim = shape.size(); + if (ndim <= 3) { + dispatch_1_2_3(ndim, [&](auto dims_constant) { + auto kernel = + cu::copy_g_nd; + auto [num_blocks, block_dims] = + get_launch_args(kernel, out, large()); + encoder.add_kernel_node( + kernel, + num_blocks, + block_dims, + in_ptr, + out_ptr, + out.size(), + const_param(shape), + const_param(strides_in)); + }); + } else { // ndim >= 4 + auto kernel = cu::copy_g; + auto [num_blocks, block_dims] = + get_launch_args(kernel, out, large()); + encoder.add_kernel_node( + kernel, + num_blocks, + block_dims, + in_ptr, + out_ptr, + out.size(), + const_param(shape), + const_param(strides_in), + ndim); + } + }); + }); + }); +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/cuda.cpp b/mlx/backend/cuda/cuda.cpp new file mode 100644 index 000000000..ceb4d7dfe --- /dev/null +++ b/mlx/backend/cuda/cuda.cpp @@ -0,0 +1,11 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/cuda.h" + +namespace mlx::core::cu { + +bool is_available() { + return true; +} + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/cuda.h b/mlx/backend/cuda/cuda.h new file mode 100644 index 000000000..2c6a5c724 --- /dev/null +++ b/mlx/backend/cuda/cuda.h @@ -0,0 +1,10 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +namespace mlx::core::cu { + +/* Check if the CUDA backend is available. */ +bool is_available(); + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/device.cpp b/mlx/backend/cuda/device.cpp new file mode 100644 index 000000000..f7c8ecdc0 --- /dev/null +++ b/mlx/backend/cuda/device.cpp @@ -0,0 +1,339 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/worker.h" +#include "mlx/utils.h" + +#include +#include +#include +#include + +namespace mlx::core { + +// Can be tuned with MLX_MAX_OPS_PER_BUFFER +// This should be less than 255 +constexpr int default_max_nodes_per_graph = 20; + +int cuda_graph_cache_size() { + static int cache_size = []() { + return env::get_var("MLX_CUDA_GRAPH_CACHE_SIZE", 100); + }(); + return cache_size; +} + +namespace cu { + +Device::Device(int device) : device_(device) { + CHECK_CUDA_ERROR(cudaDeviceGetAttribute( + &compute_capability_major_, cudaDevAttrComputeCapabilityMajor, device_)); + CHECK_CUDA_ERROR(cudaDeviceGetAttribute( + &compute_capability_minor_, cudaDevAttrComputeCapabilityMinor, device_)); + // Validate the requirements of device. + int attr = 0; + CHECK_CUDA_ERROR(cudaDeviceGetAttribute( + &attr, cudaDevAttrConcurrentManagedAccess, device_)); + if (attr != 1) { + throw std::runtime_error(fmt::format( + "Device {} does not support synchronization in managed memory.", + device_)); + } + // The cublasLt handle is used by matmul. + make_current(); + cublasLtCreate(<_); +} + +Device::~Device() { + cublasLtDestroy(lt_); +} + +void Device::make_current() { + // We need to set/get current CUDA device very frequently, cache it to reduce + // actual calls of CUDA APIs. This function assumes single-thread in host. + static int current = 0; + if (current != device_) { + CHECK_CUDA_ERROR(cudaSetDevice(device_)); + current = device_; + } +} + +CommandEncoder& Device::get_command_encoder(Stream s) { + auto it = encoders_.find(s.index); + if (it == encoders_.end()) { + it = encoders_.try_emplace(s.index, *this).first; + } + return it->second; +} + +CommandEncoder::CaptureContext::CaptureContext(CommandEncoder& enc) : enc(enc) { + CHECK_CUDA_ERROR(cudaGraphCreate(&graph, 0)); + CHECK_CUDA_ERROR( + cudaStreamBeginCapture(enc.stream(), cudaStreamCaptureModeGlobal)); +} + +CommandEncoder::CaptureContext::~CaptureContext() { + CHECK_CUDA_ERROR(cudaStreamEndCapture(enc.stream(), &graph)); + size_t num_nodes; + CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, NULL, &num_nodes)); + if (num_nodes == 1) { + cudaGraphNode_t captured_node; + CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, &captured_node, &num_nodes)); + CUDA_KERNEL_NODE_PARAMS params; + CHECK_CUDA_ERROR(cuGraphKernelNodeGetParams(captured_node, ¶ms)); + cudaGraphNode_t node; + CHECK_CUDA_ERROR(cuGraphAddKernelNode(&node, enc.graph_, NULL, 0, ¶ms)); + enc.insert_graph_dependencies(GraphNode{node, 'K'}); + } else { + cudaGraphNode_t node; + CHECK_CUDA_ERROR( + cudaGraphAddChildGraphNode(&node, enc.graph_, NULL, 0, graph)); + enc.insert_graph_dependencies(GraphNode{node, 'G'}); + } + CHECK_CUDA_ERROR(cudaGraphDestroy(graph)); +} + +CommandEncoder::ConcurrentContext::ConcurrentContext(CommandEncoder& enc) + : enc(enc) { + enc.in_concurrent_ = true; +} + +CommandEncoder::ConcurrentContext::~ConcurrentContext() { + enc.in_concurrent_ = false; + + // Use an empty graph node for synchronization + CommandEncoder::GraphNode empty{NULL, 'E', std::to_string(enc.node_count_++)}; + enc.empty_node_count_++; + CHECK_CUDA_ERROR(cudaGraphAddEmptyNode(&empty.node, enc.graph_, NULL, 0)); + + // Insert the concurrent -> empty node dependencies + for (auto& from : enc.concurrent_nodes_) { + enc.from_nodes_.push_back(from.node); + enc.to_nodes_.push_back(empty.node); + enc.graph_key_ += from.id; + enc.graph_key_ += from.node_type; + enc.graph_key_ += empty.id; + enc.graph_key_ += empty.node_type; + } + + // Insert the input -> concurrent node dependencies without updating output + // nodes + auto outputs = std::move(enc.active_outputs_); + enc.insert_graph_dependencies(std::move(enc.concurrent_nodes_)); + + // Update output node to be the empty node + for (auto o : outputs) { + enc.node_map_.emplace(o, empty).first->second = empty; + } +} + +void CommandEncoder::insert_graph_dependencies(GraphNode node) { + if (node.node_type == 'G') { + graph_node_count_++; + } + node.id = std::to_string(node_count_++); + if (in_concurrent_) { + concurrent_nodes_.push_back(std::move(node)); + } else { + std::vector nodes; + nodes.push_back(std::move(node)); + insert_graph_dependencies(std::move(nodes)); + } +} + +void CommandEncoder::insert_graph_dependencies(std::vector nodes) { + std::vector deps; + { + // Dependencies must be added in the same order to produce a consistent + // topology + std::unordered_set set_deps; + for (auto d : active_deps_) { + if (auto it = node_map_.find(d); it != node_map_.end()) { + auto [_, inserted] = set_deps.insert(it->second.node); + if (inserted) { + deps.push_back(it->second); + } + } + } + } + active_deps_.clear(); + + for (auto o : active_outputs_) { + for (auto& node : nodes) { + node_map_.emplace(o, node).first->second = node; + } + } + active_outputs_.clear(); + + for (auto& from : deps) { + for (auto& to : nodes) { + from_nodes_.push_back(from.node); + to_nodes_.push_back(to.node); + graph_key_ += from.id; + graph_key_ += from.node_type; + graph_key_ += to.id; + graph_key_ += to.node_type; + } + } +} + +CommandEncoder::CommandEncoder(Device& d) : device_(d), stream_(d) { + CHECK_CUDA_ERROR(cudaGraphCreate(&graph_, 0)); +} + +void clear_graphs(std::unordered_map& graphs) { + for (auto& [_, graph_exec] : graphs) { + CHECK_CUDA_ERROR(cudaGraphExecDestroy(graph_exec)); + } + graphs.clear(); +} + +CommandEncoder::~CommandEncoder() { + clear_graphs(graph_cache_); +} + +void CommandEncoder::add_completed_handler(std::function task) { + worker_.add_task(std::move(task)); +} + +void CommandEncoder::set_input_array(const array& arr) { + auto id = reinterpret_cast(arr.buffer().ptr()); + active_deps_.push_back(id); +} + +void CommandEncoder::set_output_array(const array& arr) { + auto id = reinterpret_cast(arr.buffer().ptr()); + active_deps_.push_back(id); + active_outputs_.push_back(id); +} + +void CommandEncoder::maybe_commit() { + if (node_count_ >= env::max_ops_per_buffer(default_max_nodes_per_graph)) { + commit(); + } +} + +void CommandEncoder::add_kernel_node( + void* func, + dim3 grid_dim, + dim3 block_dim, + void** params) { + cudaKernelNodeParams kernel_params = {0}; + kernel_params.func = func; + kernel_params.gridDim = grid_dim; + kernel_params.blockDim = block_dim; + kernel_params.kernelParams = params; + cudaGraphNode_t node; + CHECK_CUDA_ERROR( + cudaGraphAddKernelNode(&node, graph_, NULL, 0, &kernel_params)); + insert_graph_dependencies(GraphNode{node, 'K'}); +} + +void CommandEncoder::add_kernel_node( + CUfunction func, + dim3 grid_dim, + dim3 block_dim, + void** params) { + CUDA_KERNEL_NODE_PARAMS kernel_params = {0}; + kernel_params.func = func; + kernel_params.gridDimX = grid_dim.x; + kernel_params.gridDimY = grid_dim.y; + kernel_params.gridDimZ = grid_dim.z; + kernel_params.blockDimX = block_dim.x; + kernel_params.blockDimY = block_dim.y; + kernel_params.blockDimZ = block_dim.z; + kernel_params.kernelParams = params; + CUgraphNode node; + CHECK_CUDA_ERROR( + cuGraphAddKernelNode(&node, graph_, NULL, 0, &kernel_params)); + insert_graph_dependencies(GraphNode{node, 'K'}); +} + +void CommandEncoder::commit() { + if (!temporaries_.empty()) { + add_completed_handler([temporaries = std::move(temporaries_)]() {}); + } + if (node_count_ > 0) { + if (!from_nodes_.empty()) { + CHECK_CUDA_ERROR(cudaGraphAddDependencies( + graph_, from_nodes_.data(), to_nodes_.data(), from_nodes_.size())); + } + + graph_key_ += "."; + graph_key_ += std::to_string(node_count_); + graph_key_ += "."; + graph_key_ += std::to_string(graph_node_count_); + graph_key_ += "."; + graph_key_ += std::to_string(empty_node_count_); + + cudaGraphExec_t& graph_exec = graph_cache_[graph_key_]; + + if (graph_exec != nullptr) { + cudaGraphExecUpdateResult update_result; +#if CUDART_VERSION >= 12000 + cudaGraphExecUpdateResultInfo info; + cudaGraphExecUpdate(graph_exec, graph_, &info); + update_result = info.result; +#else + cudaGraphNode_t error_node; + cudaGraphExecUpdate(graph_exec, graph_, &error_node, &update_result); +#endif // CUDART_VERSION >= 12000 + if (update_result != cudaGraphExecUpdateSuccess) { + cudaGetLastError(); // reset error + CHECK_CUDA_ERROR(cudaGraphExecDestroy(graph_exec)); + graph_exec = nullptr; + } + } + if (graph_exec == nullptr) { + CHECK_CUDA_ERROR( + cudaGraphInstantiate(&graph_exec, graph_, NULL, NULL, 0)); + } + device_.make_current(); + CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream_)); + + // TODO smarter cache policy + if (graph_cache_.size() > cuda_graph_cache_size()) { + clear_graphs(graph_cache_); + } + + // Reset state + node_count_ = 0; + graph_node_count_ = 0; + from_nodes_.clear(); + to_nodes_.clear(); + graph_key_.clear(); + node_map_.clear(); + CHECK_CUDA_ERROR(cudaGraphDestroy(graph_)); + CHECK_CUDA_ERROR(cudaGraphCreate(&graph_, 0)); + } + + // Put completion handlers in a batch. + worker_.end_batch(); + worker_.commit(stream_); +} + +void CommandEncoder::synchronize() { + cudaStreamSynchronize(stream_); + auto p = std::make_shared>(); + std::future f = p->get_future(); + add_completed_handler([p = std::move(p)]() { p->set_value(); }); + worker_.end_batch(); + commit(); + f.wait(); +} + +Device& device(mlx::core::Device device) { + static std::unordered_map devices; + auto it = devices.find(device.index); + if (it == devices.end()) { + it = devices.try_emplace(device.index, device.index).first; + } + return it->second; +} + +CommandEncoder& get_command_encoder(Stream s) { + return device(s.device).get_command_encoder(s); +} + +} // namespace cu + +} // namespace mlx::core diff --git a/mlx/backend/cuda/device.h b/mlx/backend/cuda/device.h new file mode 100644 index 000000000..8ac840cbb --- /dev/null +++ b/mlx/backend/cuda/device.h @@ -0,0 +1,159 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/array.h" +#include "mlx/backend/cuda/worker.h" +#include "mlx/stream.h" + +#include +#include +#include + +#include + +namespace mlx::core::cu { + +class CommandEncoder { + public: + struct CaptureContext { + CaptureContext(CommandEncoder& enc); + ~CaptureContext(); + cudaGraph_t graph; + CommandEncoder& enc; + }; + struct ConcurrentContext { + ConcurrentContext(CommandEncoder& enc); + ~ConcurrentContext(); + CommandEncoder& enc; + }; + + explicit CommandEncoder(Device& d); + ~CommandEncoder(); + + CommandEncoder(const CommandEncoder&) = delete; + CommandEncoder& operator=(const CommandEncoder&) = delete; + + CaptureContext capture_context() { + return CaptureContext{*this}; + } + ConcurrentContext concurrent_context() { + return ConcurrentContext{*this}; + } + + void set_input_array(const array& arr); + void set_output_array(const array& arr); + + template + void + add_kernel_node(F* func, dim3 grid_dim, dim3 block_dim, Params&&... params) { + constexpr size_t num = sizeof...(Params); + void* ptrs[num]; + size_t i = 0; + ([&](auto&& p) { ptrs[i++] = static_cast(&p); }( + std::forward(params)), + ...); + add_kernel_node((void*)func, grid_dim, block_dim, ptrs); + } + + void add_kernel_node( + CUfunction func, + dim3 grid_dim, + dim3 block_dim, + void** params); + + void + add_kernel_node(void* func, dim3 grid_dim, dim3 block_dim, void** params); + + void add_temporary(const array& arr) { + temporaries_.push_back(arr.data_shared_ptr()); + } + + void add_completed_handler(std::function task); + void maybe_commit(); + void commit(); + + CudaStream& stream() { + return stream_; + } + + // Wait until kernels and completion handlers are finished + void synchronize(); + + private: + struct GraphNode { + cudaGraphNode_t node; + // K = kernel + // E = empty + // G = subgraph + char node_type; + std::string id; + }; + + void insert_graph_dependencies(GraphNode node); + void insert_graph_dependencies(std::vector nodes); + + Device& device_; + CudaStream stream_; + cudaGraph_t graph_; + Worker worker_; + char node_count_{0}; + char graph_node_count_{0}; + char empty_node_count_{0}; + bool in_concurrent_{false}; + std::vector from_nodes_; + std::vector to_nodes_; + std::string graph_key_; + std::vector concurrent_nodes_; + std::vector> temporaries_; + std::unordered_map graph_cache_; + std::vector active_deps_; + std::vector active_outputs_; + std::unordered_map node_map_; +}; + +class Device { + public: + explicit Device(int device); + ~Device(); + + Device(const Device&) = delete; + Device& operator=(const Device&) = delete; + + // Make this device the current cuda device, required by some cuda calls. + void make_current(); + + CommandEncoder& get_command_encoder(Stream s); + + int cuda_device() const { + return device_; + } + int compute_capability_major() const { + return compute_capability_major_; + } + int compute_capability_minor() const { + return compute_capability_minor_; + } + cublasLtHandle_t lt_handle() const { + return lt_; + } + + private: + int device_; + int compute_capability_major_; + int compute_capability_minor_; + cublasLtHandle_t lt_; + std::unordered_map encoders_; +}; + +Device& device(mlx::core::Device device); +CommandEncoder& get_command_encoder(Stream s); + +// Return an execution policy that does not sync for result. +// Note that not all thrust APIs support async policy, confirm before using. +inline auto thrust_policy(cudaStream_t stream) { + // TODO: Connect thrust's custom allocator with mlx's allocator. + return thrust::cuda::par_nosync.on(stream); +} + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/device/arange.cuh b/mlx/backend/cuda/device/arange.cuh new file mode 100644 index 000000000..53c261e34 --- /dev/null +++ b/mlx/backend/cuda/device/arange.cuh @@ -0,0 +1,15 @@ +// Copyright © 2025 Apple Inc. + +namespace mlx::core::cu { + +template +struct Arange { + const T start; + const T step; + + __device__ T operator()(uint32_t i) const { + return start + i * step; + } +}; + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/device/atomic_ops.cuh b/mlx/backend/cuda/device/atomic_ops.cuh new file mode 100644 index 000000000..5df246c0e --- /dev/null +++ b/mlx/backend/cuda/device/atomic_ops.cuh @@ -0,0 +1,67 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/backend/cuda/device/complex.cuh" +#include "mlx/backend/cuda/device/fp16_math.cuh" + +#include + +namespace mlx::core::cu { + +template +inline __device__ void atomic_add(T* out, T val) { + cuda::atomic_ref ref(*out); + ref += val; +} + +template +inline __device__ void atomic_prod(T* out, T val) { + cuda::atomic_ref ref(*out); + T old = ref.load(); + while (!ref.compare_exchange_strong(old, old * val)) { + } +} + +template +inline __device__ void atomic_max(T* out, T val) { + cuda::atomic_ref ref(*out); + ref.fetch_max(val); +} + +template +inline __device__ void atomic_min(T* out, T val) { + cuda::atomic_ref ref(*out); + ref.fetch_min(val); +} + +// Somehow cuda::atomic_ref does not provide atomic add for following types. +template +inline __device__ void atomic_add_general(T* out, T val) { + cuda::atomic_ref ref(*out); + T old = ref.load(); + while (!ref.compare_exchange_strong(old, old + val)) { + } +} + +inline __device__ void atomic_add(__half* out, __half val) { + atomicAdd(out, val); +} + +inline __device__ void atomic_add(complex64_t* out, complex64_t val) { +#if __CUDA_ARCH__ < 900 + atomic_add_general(out, val); +#else + atomicAdd(out, val); +#endif +} + +inline __device__ void atomic_add(__nv_bfloat16* out, __nv_bfloat16 val) { +#if __CUDA_ARCH__ < 800 + atomic_add_general(out, val); +#else + atomicAdd(out, val); +#endif +} + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/device/binary_ops.cuh b/mlx/backend/cuda/device/binary_ops.cuh new file mode 100644 index 000000000..575aced14 --- /dev/null +++ b/mlx/backend/cuda/device/binary_ops.cuh @@ -0,0 +1,293 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/device/unary_ops.cuh" + +#include + +namespace mlx::core::cu { + +struct Add { + template + __device__ T operator()(T x, T y) { + return x + y; + } +}; + +struct FloorDivide { + template + __device__ T operator()(T x, T y) { + if constexpr (cuda::std::is_integral_v) { + return x / y; + } else { + return truncf(x / y); + } + } +}; + +struct Divide { + template + __device__ T operator()(T x, T y) { + return x / y; + } +}; + +struct Remainder { + template + __device__ T operator()(T x, T y) { + if constexpr (cuda::std::is_integral_v) { + if constexpr (cuda::std::is_signed_v) { + auto r = x % y; + if (r != 0 && (r < 0 != y < 0)) { + r += y; + } + return r; + } else { + return x % y; + } + } else if constexpr (is_complex_v) { + return x % y; + } else { + T r = fmod(x, y); + if (r != 0 && (r < 0 != y < 0)) { + r = r + y; + } + return r; + } + } +}; + +struct Equal { + template + __device__ bool operator()(T x, T y) { + return x == y; + } +}; + +struct NaNEqual { + template + __device__ bool operator()(T x, T y) { + if constexpr (is_complex_v) { + return x == y || + (isnan(x.real()) && isnan(y.real()) && isnan(x.imag()) && + isnan(y.imag())) || + (x.real() == y.real() && isnan(x.imag()) && isnan(y.imag())) || + (isnan(x.real()) && isnan(y.real()) && x.imag() == y.imag()); + } else { + return x == y || (isnan(x) && isnan(y)); + } + } +}; + +struct Greater { + template + __device__ bool operator()(T x, T y) { + return x > y; + } +}; + +struct GreaterEqual { + template + __device__ bool operator()(T x, T y) { + return x >= y; + } +}; + +struct Less { + template + __device__ bool operator()(T x, T y) { + return x < y; + } +}; + +struct LessEqual { + template + __device__ bool operator()(T x, T y) { + return x <= y; + } +}; + +struct LogAddExp { + template + __device__ T operator()(T x, T y) { + if constexpr (is_complex_v) { + if (isnan(x.real()) || isnan(x.imag()) || isnan(y.real()) || + isnan(y.imag())) { + return { + cuda::std::numeric_limits::quiet_NaN(), + cuda::std::numeric_limits::quiet_NaN()}; + } + auto max = x.real() > y.real() ? x : y; + auto min = x.real() < y.real() ? x : y; + auto min_real = min.real(); + auto max_real = max.real(); + if (!isfinite(min_real) && (min_real == max_real)) { + if (min_real < 0) { + return min; + } else { + return Log{}(Exp{}(min) + Exp{}(max)); + } + } else { + return Log1p{}(Exp{}(min - max)) + max; + } + } else { + if (isnan(x) || isnan(y)) { + return cuda::std::numeric_limits::quiet_NaN(); + } + T maxval = max(x, y); + T minval = min(x, y); + return (minval == -cuda::std::numeric_limits::infinity() || + maxval == cuda::std::numeric_limits::infinity()) + ? maxval + : T(float(maxval) + log1p(expf(minval - maxval))); + } + }; +}; + +struct Maximum { + template + __device__ T operator()(T x, T y) { + if constexpr (cuda::std::is_integral_v) { + return max(x, y); + } else if constexpr (is_complex_v) { + if (isnan(x.real()) || isnan(x.imag())) { + return x; + } + return x > y ? x : y; + } else { + if (isnan(x)) { + return x; + } + return x > y ? x : y; + } + } +}; + +struct Minimum { + template + __device__ T operator()(T x, T y) { + if constexpr (cuda::std::is_integral_v) { + return min(x, y); + } else if constexpr (is_complex_v) { + if (isnan(x.real()) || isnan(x.imag())) { + return x; + } + return x < y ? x : y; + } else { + if (isnan(x)) { + return x; + } + return x < y ? x : y; + } + } +}; + +struct Multiply { + template + __device__ T operator()(T x, T y) { + return x * y; + } +}; + +struct NotEqual { + template + __device__ bool operator()(T x, T y) { + if constexpr (is_complex_v) { + return x.real() != y.real() || x.imag() != y.imag(); + } else { + return x != y; + } + } +}; + +struct Power { + template + __device__ T operator()(T base, T exp) { + if constexpr (cuda::std::is_integral_v) { + T res = 1; + while (exp) { + if (exp & 1) { + res *= base; + } + exp >>= 1; + base *= base; + } + return res; + } else if constexpr (is_complex_v) { + return pow(base, exp); + } else { + return powf(base, exp); + } + } +}; + +struct Subtract { + template + __device__ T operator()(T x, T y) { + return x - y; + } +}; + +struct LogicalAnd { + template + __device__ T operator()(T x, T y) { + return x && y; + }; +}; + +struct LogicalOr { + template + __device__ T operator()(T x, T y) { + return x || y; + }; +}; + +struct BitwiseAnd { + template + __device__ T operator()(T x, T y) { + return x & y; + }; +}; + +struct BitwiseOr { + template + __device__ T operator()(T x, T y) { + return x | y; + }; +}; + +struct BitwiseXor { + template + __device__ T operator()(T x, T y) { + return x ^ y; + }; +}; + +struct LeftShift { + template + __device__ T operator()(T x, T y) { + return x << y; + }; +}; + +struct RightShift { + template + __device__ T operator()(T x, T y) { + return x >> y; + }; +}; + +struct ArcTan2 { + template + __device__ T operator()(T y, T x) { + return atan2f(y, x); + } +}; + +struct DivMod { + template + __device__ cuda::std::array operator()(T x, T y) { + return {FloorDivide{}(x, y), Remainder{}(x, y)}; + }; +}; + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/device/cast_op.cuh b/mlx/backend/cuda/device/cast_op.cuh new file mode 100644 index 000000000..e10fde6dc --- /dev/null +++ b/mlx/backend/cuda/device/cast_op.cuh @@ -0,0 +1,130 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/backend/cuda/device/complex.cuh" + +#include +#include +#include + +namespace mlx::core::cu { + +// An op that does static_cast, with custom conversions for some types. +template +struct CastOp { + static constexpr bool is_castable = cuda::std::is_convertible_v; + + __device__ DstT operator()(SrcT x) { + return static_cast(x); + } +}; + +// Castings between complex and boolean. +template +struct CastOp, bool> { + static constexpr bool is_castable = true; + + __device__ bool operator()(complex_t x) { + return x.real() != 0 && x.imag() != 0; + } +}; + +template +struct CastOp> { + static constexpr bool is_castable = true; + + __device__ complex_t operator()(bool x) { + return x ? complex_t{1, 1} : complex_t{0, 0}; + } +}; + +// Converting a complex number to real number discards the imaginary part. +template +struct CastOp, DstT, cuda::std::enable_if_t>> { + static constexpr bool is_castable = cuda::std::is_convertible_v; + + __device__ DstT operator()(complex_t x) { + static_assert(!is_complex_v); + return static_cast(x.real()); + } +}; + +// Allow converting a real number to complex number. +template +struct CastOp, cuda::std::enable_if_t>> { + static constexpr bool is_castable = cuda::std::is_convertible_v; + + __device__ complex_t operator()(SrcT x) { + static_assert(!is_complex_v); + return complex_t{static_cast(x), 0}; + } +}; + +// Do nothing when no casting is needed. +template +struct CastOp< + SrcT, + DstT, + cuda::std::enable_if_t>> { + static constexpr bool is_castable = true; + + __device__ SrcT operator()(SrcT x) { + return x; + } +}; + +// In CUDA 11 the half types do not define conversions between some types, +// provide fallbacks here. +#if CUDART_VERSION < 12000 +template +struct CastOp< + SrcT, + DstT, + cuda::std::enable_if_t< + !cuda::std::is_convertible_v && !is_complex_v && + (cuda::std::is_same_v || + cuda::std::is_same_v)>> { + static constexpr bool is_castable = true; + + __device__ DstT operator()(SrcT x) { + return DstT(static_cast(x)); + } +}; + +template +struct CastOp< + SrcT, + DstT, + cuda::std::enable_if_t< + !cuda::std::is_convertible_v && !is_complex_v && + !cuda::std::is_same_v && + !cuda::std::is_same_v && + (cuda::std::is_same_v || + cuda::std::is_same_v)>> { + static constexpr bool is_castable = true; + + __device__ DstT operator()(SrcT x) { + return DstT(static_cast(x)); + } +}; +#endif // CUDART_VERSION < 12000 + +// Helper to deduce the SrcT. +template +inline __host__ __device__ auto cast_to(SrcT x) { + return CastOp{}(x); +} + +// Return an iterator that cast the value to DstT using CastOp. +template +inline __host__ __device__ auto make_cast_iterator(Iterator it) { + using SrcT = typename cuda::std::iterator_traits::value_type; + if constexpr (std::is_same_v) { + return it; + } else { + return thrust::make_transform_iterator(it, CastOp{}); + } +} + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/device/complex.cuh b/mlx/backend/cuda/device/complex.cuh new file mode 100644 index 000000000..03a7bff83 --- /dev/null +++ b/mlx/backend/cuda/device/complex.cuh @@ -0,0 +1,60 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +// Make multiplication and division faster. +#define LIBCUDACXX_ENABLE_SIMPLIFIED_COMPLEX_OPERATIONS + +#include +#include + +namespace mlx::core::cu { + +// TODO: Consider using a faster implementation as cuda::std::complex has to +// conform to C++ standard. +template +using complex_t = cuda::std::complex; + +using complex64_t = complex_t; +using complex128_t = complex_t; + +template +struct is_complex : cuda::std::false_type {}; + +template +struct is_complex> : cuda::std::true_type {}; + +template +inline constexpr bool is_complex_v = is_complex::value; + +// cuda::std::complex is missing some operators. +template +inline __host__ __device__ complex_t operator%( + complex_t a, + complex_t b) { + T r = a.real() - floor(a.real() / b.real()) * b.real(); + T i = a.imag() - floor(a.imag() / b.imag()) * b.imag(); + return complex_t{r, i}; +} + +template +inline __host__ __device__ bool operator>(complex_t a, complex_t b) { + return (a.real() > b.real()) || (a.real() == b.real() && a.imag() > b.imag()); +} + +template +inline __host__ __device__ bool operator<(complex_t a, complex_t b) { + return operator>(b, a); +} + +template +inline __host__ __device__ bool operator<=(complex_t a, complex_t b) { + return !(a > b); +} + +template +inline __host__ __device__ bool operator>=(complex_t a, complex_t b) { + return !(a < b); +} + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/device/config.h b/mlx/backend/cuda/device/config.h new file mode 100644 index 000000000..5a3402905 --- /dev/null +++ b/mlx/backend/cuda/device/config.h @@ -0,0 +1,12 @@ +// Copyright © 2025 Apple Inc. + +// This file is used by both CUDA kernel code and host-only C++ code. + +#pragma once + +// The maximum dimensions of shape/strides passed as kernel parameters. +#define MAX_NDIM 10 + +// All existing NVIDIA hardware has a fixed 32 warp size. Though a built-in +// warpSize variable exists, using it would prevent compile-time optimizations. +#define WARP_SIZE 32 diff --git a/mlx/backend/cuda/device/fp16_math.cuh b/mlx/backend/cuda/device/fp16_math.cuh new file mode 100644 index 000000000..f6fa17bb9 --- /dev/null +++ b/mlx/backend/cuda/device/fp16_math.cuh @@ -0,0 +1,194 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include +#include +#include +#include + +namespace mlx::core::cu { + +/////////////////////////////////////////////////////////////////////////////// +// Unary ops for half types. +/////////////////////////////////////////////////////////////////////////////// + +#if CUDART_VERSION < 12000 && __CUDA_ARCH__ < 800 +#define MLX_DEFINE_UNARY_OP(NAME, HALF_OP) \ + template \ + __forceinline__ __device__ auto NAME(T x) { \ + if constexpr (cuda::std::is_same_v) { \ + return HALF_OP(x); \ + } else { \ + return ::NAME(x); \ + } \ + } +#else +#define MLX_DEFINE_UNARY_OP(NAME, HALF_OP) \ + template \ + __forceinline__ __device__ auto NAME(T x) { \ + if constexpr (cuda::std::is_same_v) { \ + return HALF_OP(x); \ + } else if constexpr (cuda::std::is_same_v) { \ + return HALF_OP(x); \ + } else { \ + return ::NAME(x); \ + } \ + } +#endif + +#define MLX_DEFINE_UNARY_OP_FALLBCK(NAME) \ + template \ + __forceinline__ __device__ auto NAME(T x) { \ + if constexpr (cuda::std::is_same_v) { \ + return ::NAME(__half2float(x)); \ + } else if constexpr (cuda::std::is_same_v) { \ + return ::NAME(__bfloat162float(x)); \ + } else { \ + return ::NAME(x); \ + } \ + } + +MLX_DEFINE_UNARY_OP(abs, __habs) +MLX_DEFINE_UNARY_OP(ceil, hceil) +MLX_DEFINE_UNARY_OP(cos, hcos) +MLX_DEFINE_UNARY_OP(exp, hexp) +MLX_DEFINE_UNARY_OP(floor, hfloor) +MLX_DEFINE_UNARY_OP(isnan, __hisnan) +MLX_DEFINE_UNARY_OP(log, hlog) +MLX_DEFINE_UNARY_OP(log2, hlog2) +MLX_DEFINE_UNARY_OP(log10, hlog10) +MLX_DEFINE_UNARY_OP(rint, hrint) +MLX_DEFINE_UNARY_OP(rsqrt, hrsqrt) +MLX_DEFINE_UNARY_OP(sin, hsin) +MLX_DEFINE_UNARY_OP(sqrt, hsqrt) +MLX_DEFINE_UNARY_OP_FALLBCK(acos) +MLX_DEFINE_UNARY_OP_FALLBCK(acosh) +MLX_DEFINE_UNARY_OP_FALLBCK(asin) +MLX_DEFINE_UNARY_OP_FALLBCK(asinh) +MLX_DEFINE_UNARY_OP_FALLBCK(atan) +MLX_DEFINE_UNARY_OP_FALLBCK(atanh) +MLX_DEFINE_UNARY_OP_FALLBCK(cosh) +MLX_DEFINE_UNARY_OP_FALLBCK(log1p) +MLX_DEFINE_UNARY_OP_FALLBCK(sinh) +MLX_DEFINE_UNARY_OP_FALLBCK(tan) +#if __CUDA_ARCH__ >= 1280 +MLX_DEFINE_UNARY_OP(tanh, htanh) +#else +MLX_DEFINE_UNARY_OP_FALLBCK(tanh) +#endif + +#undef MLX_DEFINE_UNARY_OP +#undef MLX_DEFINE_UNARY_OP_FALLBCK + +/////////////////////////////////////////////////////////////////////////////// +// Binary ops for half types. +/////////////////////////////////////////////////////////////////////////////// + +#if CUDART_VERSION < 12000 && __CUDA_ARCH__ < 800 +#define MLX_DEFINE_BINARY_OP(NAME, HALF_OP) \ + template \ + __forceinline__ __device__ auto NAME(T x, T y) { \ + if constexpr (cuda::std::is_same_v) { \ + return HALF_OP(x, y); \ + } else { \ + return ::NAME(x, y); \ + } \ + } +#else +#define MLX_DEFINE_BINARY_OP(NAME, HALF_OP) \ + template \ + __forceinline__ __device__ auto NAME(T x, T y) { \ + if constexpr (cuda::std::is_same_v) { \ + return HALF_OP(x, y); \ + } else if constexpr (cuda::std::is_same_v) { \ + return HALF_OP(x, y); \ + } else { \ + return ::NAME(x, y); \ + } \ + } +#endif + +MLX_DEFINE_BINARY_OP(max, __hmax) +MLX_DEFINE_BINARY_OP(min, __hmin) + +#undef MLX_DEFINE_BINARY_OP + +template +__forceinline__ __device__ T fmod(T x, T y) { + if constexpr (cuda::std::is_same_v) { + return __float2half(::fmod(__half2float(x), __half2float(y))); +#if CUDART_VERSION >= 12000 || __CUDA_ARCH__ >= 800 + } else if constexpr (cuda::std::is_same_v) { + return __float2bfloat16(::fmod(__bfloat162float(x), __bfloat162float(y))); +#endif + } else { + return ::fmod(x, y); + } +} + +/////////////////////////////////////////////////////////////////////////////// +// Additional C++ operator overrides between half types and native types. +/////////////////////////////////////////////////////////////////////////////// + +template +constexpr bool is_integral_except = + cuda::std::is_integral_v && !cuda::std::is_same_v; + +template +constexpr bool is_arithmetic_except = + cuda::std::is_arithmetic_v && !cuda::std::is_same_v; + +#define MLX_DEFINE_HALF_OP(HALF, HALF2FLOAT, FLOAT2HALF, OP) \ + template < \ + typename T, \ + typename = cuda::std::enable_if_t>> \ + __forceinline__ __device__ HALF operator OP(HALF x, T y) { \ + return FLOAT2HALF(HALF2FLOAT(x) OP static_cast(y)); \ + } \ + template < \ + typename T, \ + typename = cuda::std::enable_if_t>> \ + __forceinline__ __device__ HALF operator OP(T x, HALF y) { \ + return FLOAT2HALF(static_cast(x) OP HALF2FLOAT(y)); \ + } + +#define MLX_DEFINE_HALF_CMP(HALF, HALF2FLOAT, OP) \ + template < \ + typename T, \ + typename = cuda::std::enable_if_t>> \ + __forceinline__ __device__ bool operator OP(HALF x, T y) { \ + return HALF2FLOAT(x) OP static_cast(y); \ + } \ + template < \ + typename T, \ + typename = cuda::std::enable_if_t>> \ + __forceinline__ __device__ bool operator OP(T x, HALF y) { \ + return static_cast(y) OP HALF2FLOAT(x); \ + } + +MLX_DEFINE_HALF_OP(__half, __half2float, __float2half, +) +MLX_DEFINE_HALF_OP(__half, __half2float, __float2half, -) +MLX_DEFINE_HALF_OP(__half, __half2float, __float2half, *) +MLX_DEFINE_HALF_OP(__half, __half2float, __float2half, /) +MLX_DEFINE_HALF_OP(__nv_bfloat16, __bfloat162float, __float2bfloat16, +) +MLX_DEFINE_HALF_OP(__nv_bfloat16, __bfloat162float, __float2bfloat16, -) +MLX_DEFINE_HALF_OP(__nv_bfloat16, __bfloat162float, __float2bfloat16, *) +MLX_DEFINE_HALF_OP(__nv_bfloat16, __bfloat162float, __float2bfloat16, /) +MLX_DEFINE_HALF_CMP(__half, __half2float, <) +MLX_DEFINE_HALF_CMP(__half, __half2float, >) +MLX_DEFINE_HALF_CMP(__half, __half2float, <=) +MLX_DEFINE_HALF_CMP(__half, __half2float, >=) +MLX_DEFINE_HALF_CMP(__half, __half2float, ==) +MLX_DEFINE_HALF_CMP(__half, __half2float, !=) +MLX_DEFINE_HALF_CMP(__nv_bfloat16, __bfloat162float, <) +MLX_DEFINE_HALF_CMP(__nv_bfloat16, __bfloat162float, >) +MLX_DEFINE_HALF_CMP(__nv_bfloat16, __bfloat162float, <=) +MLX_DEFINE_HALF_CMP(__nv_bfloat16, __bfloat162float, >=) +MLX_DEFINE_HALF_CMP(__nv_bfloat16, __bfloat162float, ==) +MLX_DEFINE_HALF_CMP(__nv_bfloat16, __bfloat162float, !=) + +#undef MLX_DEFINE_HALF_OP +#undef MLX_DEFINE_HALF_CMP + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/device/gather.cuh b/mlx/backend/cuda/device/gather.cuh new file mode 100644 index 000000000..7dbd84ac3 --- /dev/null +++ b/mlx/backend/cuda/device/gather.cuh @@ -0,0 +1,53 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/device/indexing.cuh" +#include "mlx/backend/cuda/device/utils.cuh" + +#include + +namespace mlx::core::cu { + +namespace cg = cooperative_groups; + +template +__global__ void gather( + const T* src, + T* out, + LocT size, + const __grid_constant__ Shape src_shape, + const __grid_constant__ Strides src_strides, + int32_t src_ndim, + const __grid_constant__ Shape slice_sizes, + uint32_t slice_size, + const __grid_constant__ cuda::std::array axes, + const __grid_constant__ cuda::std::array indices, + const __grid_constant__ cuda::std::array + indices_shape, + const __grid_constant__ cuda::std::array + indices_strides) { + LocT out_idx = cg::this_grid().thread_rank(); + if (out_idx >= size) { + return; + } + + LocT src_elem = out_idx % slice_size; + LocT idx_elem = out_idx / slice_size; + + LocT src_loc = + elem_to_loc(src_elem, slice_sizes.data(), src_strides.data(), src_ndim); + +#pragma unroll + for (int i = 0; i < NIDX; ++i) { + LocT idx_loc = elem_to_loc_nd( + idx_elem, + indices_shape.data() + i * IDX_NDIM, + indices_strides.data() + i * IDX_NDIM); + int32_t axis = axes[i]; + LocT idx_val = absolute_index(indices[i][idx_loc], src_shape[axis]); + src_loc += idx_val * src_strides[axis]; + } + + out[out_idx] = src[src_loc]; +} + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/device/gather_axis.cuh b/mlx/backend/cuda/device/gather_axis.cuh new file mode 100644 index 000000000..f863b2d95 --- /dev/null +++ b/mlx/backend/cuda/device/gather_axis.cuh @@ -0,0 +1,65 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/device/indexing.cuh" +#include "mlx/backend/cuda/device/utils.cuh" + +#include + +namespace mlx::core::cu { + +namespace cg = cooperative_groups; + +template < + typename T, + typename IdxT, + int NDIM, + bool SrcC, + bool IdxC, + typename LocT> +__global__ void gather_axis( + const T* src, + const IdxT* indices, + T* out, + LocT idx_size_pre, + LocT idx_size_axis, + LocT idx_size_post, + const __grid_constant__ cuda::std::array shape, + const __grid_constant__ cuda::std::array src_strides, + const __grid_constant__ cuda::std::array idx_strides, + int32_t axis, + int32_t axis_size, + int64_t src_stride_axis, + int64_t idx_stride_axis) { + LocT index = cg::this_grid().thread_rank(); + if (index >= idx_size_pre * idx_size_axis * idx_size_post) { + return; + } + + auto [x, y, z] = index_to_dims(index, idx_size_axis, idx_size_pre); + + LocT elem_idx = z * idx_size_post; + + LocT idx_loc = y * idx_stride_axis; + if constexpr (IdxC) { + idx_loc += elem_idx * idx_size_axis + x; + } else { + idx_loc += + elem_to_loc_nd(elem_idx + x, shape.data(), idx_strides.data()); + } + + auto idx_val = absolute_index(indices[idx_loc], axis_size); + + LocT src_loc = idx_val * src_stride_axis; + if constexpr (SrcC) { + src_loc += elem_idx * axis_size + x; + } else { + src_loc += + elem_to_loc_nd(elem_idx + x, shape.data(), src_strides.data()); + } + + LocT out_idx = y * idx_size_post + elem_idx * idx_size_axis + x; + + out[out_idx] = src[src_loc]; +} + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/device/indexing.cuh b/mlx/backend/cuda/device/indexing.cuh new file mode 100644 index 000000000..31cba1a90 --- /dev/null +++ b/mlx/backend/cuda/device/indexing.cuh @@ -0,0 +1,30 @@ +// Copyright © 2025 Apple Inc. + +#include +#include + +namespace mlx::core::cu { + +// Convert an absolute index to positions in a 3d grid, assuming the index is +// calculated with: +// index = x * dim1 * dim2 + y * dim2 + z +template +inline __host__ __device__ cuda::std::tuple +index_to_dims(T index, T dim1, T dim2) { + T x = index / (dim1 * dim2); + T y = (index % (dim1 * dim2)) / dim2; + T z = index % dim2; + return cuda::std::make_tuple(x, y, z); +} + +// Get absolute index from possible negative index. +template +inline __host__ __device__ auto absolute_index(IdxT idx, int32_t size) { + if constexpr (cuda::std::is_unsigned_v) { + return idx; + } else { + return static_cast(idx < 0 ? idx + size : idx); + } +} + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/device/scatter.cuh b/mlx/backend/cuda/device/scatter.cuh new file mode 100644 index 000000000..b2f640350 --- /dev/null +++ b/mlx/backend/cuda/device/scatter.cuh @@ -0,0 +1,68 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/device/indexing.cuh" +#include "mlx/backend/cuda/device/scatter_ops.cuh" +#include "mlx/backend/cuda/device/utils.cuh" + +#include + +namespace mlx::core::cu { + +namespace cg = cooperative_groups; + +template < + typename T, + typename IdxT, + typename Op, + int NIDX, + int IDX_NDIM, + typename LocT> +__global__ void scatter( + const T* upd, + T* out, + LocT size, + const __grid_constant__ Shape upd_shape, + const __grid_constant__ Strides upd_strides, + int32_t upd_ndim, + LocT upd_post_idx_size, + const __grid_constant__ Shape out_shape, + const __grid_constant__ Strides out_strides, + int32_t out_ndim, + const __grid_constant__ cuda::std::array axes, + const __grid_constant__ cuda::std::array indices, + const __grid_constant__ cuda::std::array + indices_shape, + const __grid_constant__ cuda::std::array + indices_strides) { + LocT upd_idx = cg::this_grid().thread_rank(); + if (upd_idx >= size) { + return; + } + + LocT out_elem = upd_idx % upd_post_idx_size; + LocT idx_elem = upd_idx / upd_post_idx_size; + + LocT out_idx = elem_to_loc( + out_elem, upd_shape.data() + IDX_NDIM, out_strides.data(), out_ndim); + +#pragma unroll + for (int i = 0; i < NIDX; ++i) { + LocT idx_loc = elem_to_loc_nd( + idx_elem, + indices_shape.data() + i * IDX_NDIM, + indices_strides.data() + i * IDX_NDIM); + int32_t axis = axes[i]; + LocT idx_val = absolute_index(indices[i][idx_loc], out_shape[axis]); + out_idx += idx_val * out_strides[axis]; + } + + LocT upd_loc = elem_to_loc( + out_elem + idx_elem * upd_post_idx_size, + upd_shape.data(), + upd_strides.data(), + upd_ndim); + + Op{}(out + out_idx, upd[upd_loc]); +} + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/device/scatter_axis.cuh b/mlx/backend/cuda/device/scatter_axis.cuh new file mode 100644 index 000000000..1f30f2ebd --- /dev/null +++ b/mlx/backend/cuda/device/scatter_axis.cuh @@ -0,0 +1,67 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/device/indexing.cuh" +#include "mlx/backend/cuda/device/scatter_ops.cuh" +#include "mlx/backend/cuda/device/utils.cuh" + +#include + +namespace mlx::core::cu { + +namespace cg = cooperative_groups; + +template < + typename T, + typename IdxT, + typename Op, + int NDIM, + bool UpdC, + bool IdxC, + typename LocT> +__global__ void scatter_axis( + const T* upd, + const IdxT* indices, + T* out, + LocT idx_size_pre, + LocT idx_size_axis, + LocT idx_size_post, + const __grid_constant__ cuda::std::array shape, + const __grid_constant__ cuda::std::array upd_strides, + const __grid_constant__ cuda::std::array idx_strides, + int32_t axis, + int32_t axis_size, + int64_t upd_stride_axis, + int64_t idx_stride_axis) { + LocT index = cg::this_grid().thread_rank(); + if (index >= idx_size_pre * idx_size_axis * idx_size_post) { + return; + } + + auto [x, y, z] = index_to_dims(index, idx_size_axis, idx_size_pre); + + LocT elem_idx = z * idx_size_post; + + LocT idx_loc = y * idx_stride_axis; + if constexpr (IdxC) { + idx_loc += elem_idx * idx_size_axis + x; + } else { + idx_loc += + elem_to_loc_nd(elem_idx + x, shape.data(), idx_strides.data()); + } + + auto idx_val = absolute_index(indices[idx_loc], axis_size); + + LocT upd_loc = y * upd_stride_axis; + if constexpr (UpdC) { + upd_loc += elem_idx * idx_size_axis + x; + } else { + upd_loc += + elem_to_loc_nd(elem_idx + x, shape.data(), upd_strides.data()); + } + + LocT out_idx = idx_val * idx_size_post + elem_idx * axis_size + x; + + Op{}(out + out_idx, upd[upd_loc]); +} + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/device/scatter_ops.cuh b/mlx/backend/cuda/device/scatter_ops.cuh new file mode 100644 index 000000000..d88f896ad --- /dev/null +++ b/mlx/backend/cuda/device/scatter_ops.cuh @@ -0,0 +1,44 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/backend/cuda/device/atomic_ops.cuh" + +namespace mlx::core::cu { + +struct ScatterAssign { + template + __device__ void operator()(T* out, T val) const { + *out = val; + } +}; + +struct ScatterSum { + template + __device__ void operator()(T* out, T val) const { + atomic_add(out, val); + } +}; + +struct ScatterProd { + template + __device__ void operator()(T* out, T val) const { + atomic_prod(out, val); + } +}; + +struct ScatterMax { + template + __device__ void operator()(T* out, T val) const { + atomic_max(out, val); + } +}; + +struct ScatterMin { + template + __device__ void operator()(T* out, T val) const { + atomic_min(out, val); + } +}; + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/device/ternary_ops.cuh b/mlx/backend/cuda/device/ternary_ops.cuh new file mode 100644 index 000000000..441845471 --- /dev/null +++ b/mlx/backend/cuda/device/ternary_ops.cuh @@ -0,0 +1,13 @@ +// Copyright © 2025 Apple Inc. +#pragma once + +namespace mlx::core::cu { + +struct Select { + template + __device__ T operator()(bool condition, T x, T y) { + return condition ? x : y; + } +}; + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/device/unary_ops.cuh b/mlx/backend/cuda/device/unary_ops.cuh new file mode 100644 index 000000000..aebed1e4d --- /dev/null +++ b/mlx/backend/cuda/device/unary_ops.cuh @@ -0,0 +1,337 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/backend/cuda/device/fp16_math.cuh" +#include "mlx/backend/cuda/device/utils.cuh" + +#include + +namespace mlx::core::cu { + +struct Abs { + template + __device__ T operator()(T x) { + if constexpr (cuda::std::is_unsigned_v) { + return x; + } else { + return abs(x); + } + } +}; + +struct ArcCos { + template + __device__ T operator()(T x) { + return acos(x); + } +}; + +struct ArcCosh { + template + __device__ T operator()(T x) { + return acosh(x); + } +}; + +struct ArcSin { + template + __device__ T operator()(T x) { + return asin(x); + } +}; + +struct ArcSinh { + template + __device__ T operator()(T x) { + return asinh(x); + } +}; + +struct ArcTan { + template + __device__ T operator()(T x) { + return atan(x); + } +}; + +struct ArcTanh { + template + __device__ T operator()(T x) { + return atanh(x); + } +}; + +struct BitwiseInvert { + template + __device__ T operator()(T x) { + return ~x; + } +}; + +struct Ceil { + template + __device__ T operator()(T x) { + if constexpr (cuda::std::is_integral_v) { + return x; + } else if constexpr (is_complex_v) { + return T{ceil(x.real()), ceil(x.imag())}; + } else { + return ceil(x); + } + } +}; + +struct Conjugate { + template + __device__ complex_t operator()(complex_t x) { + return conj(x); + } +}; + +struct Cos { + template + __device__ T operator()(T x) { + return cos(x); + } +}; + +struct Cosh { + template + __device__ T operator()(T x) { + return cosh(x); + } +}; + +struct Erf { + template + __device__ T operator()(T x) { + if constexpr (cuda::std::is_same_v) { + return erf(__half2float(x)); + } else if constexpr (cuda::std::is_same_v) { + return erf(__bfloat162float(x)); + } else { + return erf(x); + } + } +}; + +struct ErfInv { + template + __device__ T operator()(T x) { + if constexpr (cuda::std::is_same_v) { + return erfinv(__half2float(x)); + } else if constexpr (cuda::std::is_same_v) { + return erfinv(__bfloat162float(x)); + } else { + return erfinv(x); + } + } +}; + +struct Exp { + template + __device__ T operator()(T x) { + return exp(x); + } +}; + +struct Expm1 { + template + __device__ T operator()(T x) { + if constexpr (cuda::std::is_same_v) { + return expm1(__half2float(x)); + } else if constexpr (cuda::std::is_same_v) { + return expm1(__bfloat162float(x)); + } else { + return expm1(x); + } + } +}; + +struct Floor { + template + __device__ T operator()(T x) { + if constexpr (cuda::std::is_integral_v) { + return x; + } else if constexpr (is_complex_v) { + return T{floor(x.real()), floor(x.imag())}; + } else { + return floor(x); + } + } +}; + +struct Imag { + template + __device__ auto operator()(complex_t x) { + return x.imag(); + } +}; + +struct Log { + template + __device__ T operator()(T x) { + return log(x); + } +}; + +struct Log2 { + template + __device__ T operator()(T x) { + if constexpr (is_complex_v) { + auto y = Log{}(x); + return {y.real() / CUDART_LN2_F, y.imag() / CUDART_LN2_F}; + } else { + return log2(x); + } + } +}; + +struct Log10 { + template + __device__ T operator()(T x) { + return log10(x); + } +}; + +struct Log1p { + template + __device__ T operator()(T z) { + if constexpr (is_complex_v) { + float x = z.real(); + float y = z.imag(); + float zabs = Abs{}(z).real(); + float theta = atan2f(y, x + 1); + if (zabs < 0.5f) { + float r = x * (2 + x) + y * y; + if (r == 0) { // handle underflow + return {x, theta}; + } + return {0.5f * log1pf(r), theta}; + } else { + float z0 = hypotf(x + 1, y); + return {logf(z0), theta}; + } + } else { + return log1p(z); + } + } +}; + +struct LogicalNot { + __device__ bool operator()(bool x) { + return !x; + } +}; + +struct Negative { + template + __device__ T operator()(T x) { + if constexpr (is_complex_v) { + return T{0, 0} - x; + } else { + return -x; + } + } +}; + +struct Real { + template + __device__ auto operator()(complex_t x) { + return x.real(); + } +}; + +struct Round { + template + __device__ T operator()(T x) { + if constexpr (is_complex_v) { + return {rint(x.real()), rint(x.imag())}; + } else { + return rint(x); + } + } +}; + +struct Sigmoid { + template + __device__ T operator()(T x) { + T y = 1 / (1 + exp(-abs(x))); + return (x < 0) ? 1 - y : y; + } +}; + +struct Sign { + template + __device__ T operator()(T x) { + if constexpr (cuda::std::is_unsigned_v) { + return x != 0; + } else if constexpr (is_complex_v) { + if (x.real() == 0 && x.imag() == 0) { + return x; + } else { + return x / Abs()(x); + } + } else if constexpr (cuda::std::is_same_v) { + return static_cast((x > T(0.f)) - (x < T(0.f))); + } else { + return (x > T(0)) - (x < T(0)); + } + } +}; + +struct Sin { + template + __device__ T operator()(T x) { + return sin(x); + } +}; + +struct Sinh { + template + __device__ T operator()(T x) { + return sinh(x); + } +}; + +struct Square { + template + __device__ T operator()(T x) { + return x * x; + } +}; + +struct Sqrt { + template + __device__ T operator()(T x) { + return sqrt(x); + } +}; + +struct Rsqrt { + template + __device__ T operator()(T x) { + if constexpr (is_complex_v) { + return 1.0f / Sqrt{}(x); + } else { + return rsqrt(x); + } + } +}; + +struct Tan { + template + __device__ T operator()(T x) { + return tan(x); + } +}; + +struct Tanh { + template + __device__ T operator()(T x) { + return tanh(x); + } +}; + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/device/utils.cuh b/mlx/backend/cuda/device/utils.cuh new file mode 100644 index 000000000..73bc7ff63 --- /dev/null +++ b/mlx/backend/cuda/device/utils.cuh @@ -0,0 +1,362 @@ +// Copyright © 2025 Apple Inc. + +// This file must not include any host-only code, utilies that work under both +// host and device can be put here. +// +// See more about the requirements at: +// https://docs.nvidia.com/cuda/nvrtc/#language + +#pragma once + +#include "mlx/backend/cuda/device/complex.cuh" +#include "mlx/backend/cuda/device/config.h" + +#include +#include +#include +#include +#include + +namespace mlx::core::cu { + +/////////////////////////////////////////////////////////////////////////////// +// CUDA kernel utils +/////////////////////////////////////////////////////////////////////////////// + +// To pass shape/strides to kernels via constant memory, their size must be +// known at compile time. +using Shape = cuda::std::array; +using Strides = cuda::std::array; + +// Vectorized load/store. +template +struct alignas(sizeof(T) * N) AlignedVector { + T val[N]; +}; + +template +inline __device__ AlignedVector load_vector( + const T* ptr, + uint32_t offset) { + auto* from = reinterpret_cast*>(ptr); + return from[offset]; +} + +template +inline __device__ void +store_vector(T* ptr, uint32_t offset, const AlignedVector& vec) { + auto* to = reinterpret_cast*>(ptr); + to[offset] = vec; +} + +/////////////////////////////////////////////////////////////////////////////// +// Type limits utils +/////////////////////////////////////////////////////////////////////////////// + +template +struct Limits { + static constexpr __host__ __device__ T max() { + return cuda::std::numeric_limits::max(); + } + static constexpr __host__ __device__ T min() { + return cuda::std::numeric_limits::min(); + } + static constexpr __host__ __device__ T finite_max() { + return cuda::std::numeric_limits::max(); + } + static constexpr __host__ __device__ T finite_min() { + return cuda::std::numeric_limits::min(); + } +}; + +template +struct Limits< + T, + cuda::std::enable_if_t< + cuda::std::is_same_v || cuda::std::is_same_v>> { + static constexpr __host__ __device__ T max() { + return cuda::std::numeric_limits::infinity(); + } + static constexpr __host__ __device__ T min() { + return -cuda::std::numeric_limits::infinity(); + } + static constexpr __host__ __device__ T finite_max() { + return cuda::std::numeric_limits::max(); + } + static constexpr __host__ __device__ T finite_min() { + return cuda::std::numeric_limits::lowest(); + } +}; + +// CUDA 11 does not have host side arithmatic operators for half types. +template +struct Limits< + T, + cuda::std::enable_if_t< + cuda::std::is_same_v || + cuda::std::is_same_v>> { + static constexpr __host__ __device__ T max() { + return cuda::std::numeric_limits::infinity(); + } + static constexpr __host__ __device__ T min() { +#if CUDART_VERSION < 12000 && __CUDA_ARCH__ < 800 + return -cuda::std::numeric_limits::infinity(); +#else + return -cuda::std::numeric_limits::infinity(); +#endif + } + static constexpr __host__ __device__ T finite_max() { + return cuda::std::numeric_limits::max(); + } + static constexpr __host__ __device__ T finite_min() { +#if CUDART_VERSION < 12000 && __CUDA_ARCH__ < 800 + return cuda::std::numeric_limits::lowest(); +#else + return cuda::std::numeric_limits::lowest(); +#endif + } +}; + +template <> +struct Limits { + static constexpr __host__ __device__ bool max() { + return true; + } + static constexpr __host__ __device__ bool min() { + return false; + } +}; + +template +struct Limits> { + static constexpr __host__ __device__ complex_t max() { + return {Limits::max(), Limits::max()}; + } + static constexpr __host__ __device__ complex_t min() { + return {Limits::min(), Limits::min()}; + } +}; + +/////////////////////////////////////////////////////////////////////////////// +// Indexing utils +/////////////////////////////////////////////////////////////////////////////// + +template +inline __host__ __device__ IdxT +elem_to_loc(IdxT elem, const int* shape, const int64_t* strides, int ndim) { + IdxT loc = 0; + for (int i = ndim - 1; i >= 0 && elem > 0; --i) { + loc += (elem % shape[i]) * IdxT(strides[i]); + elem /= shape[i]; + } + return loc; +} + +// Optimize when the ndim is known at compile time. +template +inline __host__ __device__ IdxT +elem_to_loc_nd(IdxT elem, const int* shape, const int64_t* strides) { + IdxT loc = 0; +#pragma unroll + for (int i = NDIM - 1; i >= 0; --i) { + loc += (elem % shape[i]) * IdxT(strides[i]); + elem /= shape[i]; + } + return loc; +} + +template +inline __host__ __device__ cuda::std::tuple elem_to_loc_nd( + IdxT elem, + const int* shape, + const int64_t* a_strides, + const int64_t* b_strides) { + IdxT a_loc = 0; + IdxT b_loc = 0; +#pragma unroll + for (int i = NDIM - 1; i >= 0; --i) { + int dim_idx = elem % shape[i]; + a_loc += dim_idx * IdxT(a_strides[i]); + b_loc += dim_idx * IdxT(b_strides[i]); + elem /= shape[i]; + } + return cuda::std::make_tuple(a_loc, b_loc); +} + +template +inline __host__ __device__ cuda::std::tuple elem_to_loc_nd( + IdxT elem, + const int* shape, + const int64_t* a_strides, + const int64_t* b_strides, + const int64_t* c_strides) { + IdxT a_loc = 0; + IdxT b_loc = 0; + IdxT c_loc = 0; +#pragma unroll + for (int i = NDIM - 1; i >= 0; --i) { + int dim_idx = elem % shape[i]; + a_loc += dim_idx * IdxT(a_strides[i]); + b_loc += dim_idx * IdxT(b_strides[i]); + c_loc += dim_idx * IdxT(c_strides[i]); + elem /= shape[i]; + } + return cuda::std::make_tuple(a_loc, b_loc, c_loc); +} + +// Optimized version when ndim is larger than 4. +template +inline __host__ __device__ IdxT +elem_to_loc_4d(IdxT elem, const int* shape, const int64_t* strides, int ndim) { + IdxT loc = 0; + for (int i = ndim - 1; i >= 0; --i) { + loc += (elem % shape[i]) * IdxT(strides[i]); + elem /= shape[i]; + } + return loc; +} + +template +inline __host__ __device__ cuda::std::tuple elem_to_loc_4d( + IdxT elem, + const int* shape, + const int64_t* a_strides, + const int64_t* b_strides, + int ndim) { + IdxT a_loc = 0; + IdxT b_loc = 0; + for (int i = ndim - 1; i >= 0; --i) { + int dim_idx = elem % shape[i]; + a_loc += dim_idx * IdxT(a_strides[i]); + b_loc += dim_idx * IdxT(b_strides[i]); + elem /= shape[i]; + } + return cuda::std::make_tuple(a_loc, b_loc); +} + +template +inline __host__ __device__ cuda::std::tuple elem_to_loc_4d( + IdxT elem, + const int* shape, + const int64_t* a_strides, + const int64_t* b_strides, + const int64_t* c_strides, + int ndim) { + IdxT a_loc = 0; + IdxT b_loc = 0; + IdxT c_loc = 0; + for (int i = ndim - 1; i >= 0; --i) { + int dim_idx = elem % shape[i]; + a_loc += dim_idx * IdxT(a_strides[i]); + b_loc += dim_idx * IdxT(b_strides[i]); + c_loc += dim_idx * IdxT(c_strides[i]); + elem /= shape[i]; + } + return cuda::std::make_tuple(a_loc, b_loc, c_loc); +} + +/////////////////////////////////////////////////////////////////////////////// +// Elem to loc in a loop utils +/////////////////////////////////////////////////////////////////////////////// + +template +struct LoopedElemToLoc { + int dim; + LoopedElemToLoc inner_looper; + OffsetT offset{0}; + int index{0}; + + __device__ LoopedElemToLoc(int dim) : dim(dim), inner_looper(dim - 1) {} + + __device__ void next(const int* shape, const int64_t* strides) { + if (dim == 0) { + return; + } + index++; + offset += OffsetT(strides[dim - 1]); + if (index >= shape[dim - 1]) { + index = 0; + inner_looper.next(shape, strides); + offset = inner_looper.offset; + } + } + + __device__ void next(int n, const int* shape, const int64_t* strides) { + if (dim == 0) { + return; + } + index += n; + offset += n * OffsetT(strides[dim - 1]); + + if (index >= shape[dim - 1]) { + int extra = index - shape[dim - 1]; + if (extra >= shape[dim - 1]) { + inner_looper.next(1 + extra / shape[dim - 1], shape, strides); + extra = extra % shape[dim - 1]; + } else { + inner_looper.next(shape, strides); + } + index = 0; + offset = inner_looper.offset; + if (extra > 0) { + next(extra, shape, strides); + } + } + } + + __device__ OffsetT location() { + return offset; + } +}; + +template +struct LoopedElemToLoc<1, true, OffsetT> { + int dim; + OffsetT offset{0}; + int index{0}; + + __device__ LoopedElemToLoc(int dim) : dim(dim) {} + + __device__ void next(const int* shape, const int64_t* strides) { + index++; + if (dim > 1) { + offset = elem_to_loc(index, shape, strides, dim); + } else { + offset += OffsetT(strides[0]); + } + } + + __device__ void next(int n, const int* shape, const int64_t* strides) { + index += n; + if (dim > 1) { + offset = elem_to_loc(index, shape, strides, dim); + } else { + offset = index * OffsetT(strides[0]); + } + } + + __device__ OffsetT location() { + return offset; + } +}; + +template +struct LoopedElemToLoc<1, false, OffsetT> { + OffsetT offset{0}; + + __device__ LoopedElemToLoc(int) {} + + __device__ void next(const int*, const int64_t* strides) { + offset += OffsetT(strides[0]); + } + + __device__ void next(int n, const int*, const int64_t* strides) { + offset += n * OffsetT(strides[0]); + } + + __device__ OffsetT location() { + return offset; + } +}; + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/eval.cpp b/mlx/backend/cuda/eval.cpp new file mode 100644 index 000000000..40beb12d2 --- /dev/null +++ b/mlx/backend/cuda/eval.cpp @@ -0,0 +1,66 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/gpu/eval.h" +#include "mlx/backend/cuda/allocator.h" +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/gpu/available.h" +#include "mlx/primitives.h" + +#include + +namespace mlx::core::gpu { + +bool is_available() { + return true; +} + +void new_stream(Stream s) { + // Force initalization of cuda, so cuda runtime get destroyed at last. + cudaFree(nullptr); + // Ensure the static stream objects get created. + cu::get_command_encoder(s); + // The main thread is safe to free buffers. + cu::allocator().register_this_thread(); +} + +void eval(array& arr) { + nvtx3::scoped_range r("gpu::eval"); + auto outputs = arr.outputs(); + { + // If the array is a tracer hold a reference + // to its inputs so they don't get donated + std::vector inputs; + if (arr.is_tracer()) { + inputs = arr.inputs(); + } + arr.primitive().eval_gpu(arr.inputs(), outputs); + } + + auto& encoder = cu::get_command_encoder(arr.primitive().stream()); + // Keep used buffers alive until kernel finishes running. + std::unordered_set> buffers; + for (auto& in : arr.inputs()) { + buffers.insert(in.data_shared_ptr()); + } + for (auto& s : arr.siblings()) { + buffers.insert(s.data_shared_ptr()); + } + // Remove the output if it was donated to by an input. + if (auto it = buffers.find(arr.data_shared_ptr()); it != buffers.end()) { + buffers.erase(it); + } + encoder.add_completed_handler([buffers = std::move(buffers)]() {}); + encoder.maybe_commit(); +} + +void finalize(Stream s) { + nvtx3::scoped_range r("gpu::finalize"); + cu::get_command_encoder(s).commit(); +} + +void synchronize(Stream s) { + nvtx3::scoped_range r("gpu::synchronize"); + cu::get_command_encoder(s).synchronize(); +} + +} // namespace mlx::core::gpu diff --git a/mlx/backend/cuda/event.cu b/mlx/backend/cuda/event.cu new file mode 100644 index 000000000..f51d2f2e3 --- /dev/null +++ b/mlx/backend/cuda/event.cu @@ -0,0 +1,265 @@ +// Copyright © 2024 Apple Inc. + +#include "mlx/backend/cuda/allocator.h" +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/event.h" +#include "mlx/backend/cuda/utils.h" +#include "mlx/event.h" +#include "mlx/scheduler.h" + +#include + +namespace mlx::core { + +namespace cu { + +/////////////////////////////////////////////////////////////////////////////// +// CudaEvent implementations +/////////////////////////////////////////////////////////////////////////////// + +// Cuda event managed with RAII. +class CudaEventHandle { + public: + CudaEventHandle() { + CHECK_CUDA_ERROR(cudaEventCreateWithFlags( + &event_, cudaEventDisableTiming | cudaEventBlockingSync)); + } + + ~CudaEventHandle() { + CHECK_CUDA_ERROR(cudaEventDestroy(event_)); + } + + CudaEventHandle(const CudaEventHandle&) = delete; + CudaEventHandle& operator=(const CudaEventHandle&) = delete; + + operator cudaEvent_t() const { + return event_; + } + + private: + cudaEvent_t event_; +}; + +CudaEvent::CudaEvent() : event_(std::make_shared()) {} + +void CudaEvent::wait() { + nvtx3::scoped_range r("cu::CudaEvent::wait"); + if (!recorded_) { + throw std::runtime_error("Should not wait on a CudaEvent before record."); + } + cudaEventSynchronize(*event_); +} + +void CudaEvent::wait(cudaStream_t stream) { + if (!recorded_) { + throw std::runtime_error("Should not wait on a CudaEvent before record."); + } + cudaStreamWaitEvent(stream, *event_); +} + +void CudaEvent::wait(Stream s) { + if (s.device == mlx::core::Device::cpu) { + scheduler::enqueue(s, [*this]() mutable { wait(); }); + } else { + auto& enc = cu::get_command_encoder(s); + enc.commit(); + wait(enc.stream()); + } +} + +void CudaEvent::record(cudaStream_t stream) { + cudaEventRecord(*event_, stream); + recorded_ = true; +} + +void CudaEvent::record(Stream s) { + if (s.device == mlx::core::Device::cpu) { + throw std::runtime_error("CudaEvent can not wait on cpu stream."); + } else { + auto& enc = cu::get_command_encoder(s); + enc.commit(); + record(enc.stream()); + } +} + +bool CudaEvent::completed() const { + return cudaEventQuery(*event_) == cudaSuccess; +} + +/////////////////////////////////////////////////////////////////////////////// +// SharedEvent implementations +/////////////////////////////////////////////////////////////////////////////// + +__host__ __device__ void event_wait(SharedEvent::Atomic* ac, uint64_t value) { + uint64_t current; + while ((current = ac->load()) < value) { + ac->wait(current); + } +} + +__host__ __device__ void event_signal(SharedEvent::Atomic* ac, uint64_t value) { + ac->store(value); + ac->notify_all(); +} + +__global__ void event_wait_kernel(SharedEvent::Atomic* ac, uint64_t value) { + event_wait(ac, value); +} + +__global__ void event_signal_kernel(SharedEvent::Atomic* ac, uint64_t value) { + event_signal(ac, value); +} + +SharedEvent::SharedEvent() { + // Allocate cuda::atomic on managed memory. + Atomic* ac; + CHECK_CUDA_ERROR(cudaMallocManaged(&ac, sizeof(Atomic))); + new (ac) Atomic(0); + ac_ = std::shared_ptr(ac, [](Atomic* ptr) { + ptr->~Atomic(); + allocator().cuda_free(ptr); + }); +} + +void SharedEvent::wait(uint64_t value) { + nvtx3::scoped_range r("cu::SharedEvent::wait"); + event_wait(ac_.get(), value); +} + +void SharedEvent::wait(cudaStream_t stream, uint64_t value) { + event_wait_kernel<<<1, 1, 0, stream>>>(ac_.get(), value); +} + +void SharedEvent::wait(Stream s, uint64_t value) { + nvtx3::scoped_range r("cu::SharedEvent::wait(s)"); + if (s.device == mlx::core::Device::cpu) { + scheduler::enqueue(s, [*this, value]() mutable { wait(value); }); + } else { + auto& encoder = get_command_encoder(s); + encoder.commit(); + wait(encoder.stream(), value); + encoder.add_completed_handler([ac = ac_]() {}); + } +} + +void SharedEvent::signal(uint64_t value) { + nvtx3::scoped_range r("cu::SharedEvent::signal"); + event_signal(ac_.get(), value); +} + +void SharedEvent::signal(cudaStream_t stream, uint64_t value) { + event_signal_kernel<<<1, 1, 0, stream>>>(ac_.get(), value); +} + +void SharedEvent::signal(Stream s, uint64_t value) { + nvtx3::scoped_range r("cu::SharedEvent::signal(s)"); + if (s.device == mlx::core::Device::cpu) { + // Signal through a GPU stream so the atomic is updated in GPU - updating + // the atomic in CPU sometimes does not get GPU notified. + static CudaStream stream(device(mlx::core::Device::gpu)); + scheduler::enqueue(s, [*this, value]() mutable { signal(stream, value); }); + } else { + auto& encoder = get_command_encoder(s); + encoder.commit(); + signal(encoder.stream(), value); + encoder.add_completed_handler([ac = ac_]() {}); + } +} + +bool SharedEvent::is_signaled(uint64_t value) const { + nvtx3::scoped_range r("cu::SharedEvent::is_signaled"); + return ac_->load() >= value; +} + +uint64_t SharedEvent::value() const { + nvtx3::scoped_range r("cu::SharedEvent::value"); + return ac_->load(); +} + +} // namespace cu + +/////////////////////////////////////////////////////////////////////////////// +// Event implementations +/////////////////////////////////////////////////////////////////////////////// + +namespace { + +struct EventImpl { + // CudaEvent is preferred when possible because it is fast, however we have + // to fallback to SharedEvent in following cases: + // 1. the event is used to wait/signal a cpu stream; + // 2. signal value other than 1 has been specified. + std::unique_ptr cuda; + std::unique_ptr shared; + + bool is_created() const { + return cuda || shared; + } + + void ensure_created(Stream s, uint64_t signal_value) { + if (is_created()) { + return; + } + if (s.device == mlx::core::Device::cpu || signal_value > 1) { + nvtx3::mark("Using slow SharedEvent"); + shared = std::make_unique(); + } else { + cuda = std::make_unique(); + } + } +}; + +} // namespace + +Event::Event(Stream s) : stream_(s) { + event_ = std::shared_ptr( + new EventImpl(), [](void* ptr) { delete static_cast(ptr); }); +} + +void Event::wait() { + auto* event = static_cast(event_.get()); + assert(event->is_created()); + if (event->cuda) { + assert(value() == 1); + event->cuda->wait(); + } else { + event->shared->wait(value()); + } +} + +void Event::wait(Stream s) { + auto* event = static_cast(event_.get()); + assert(event->is_created()); + if (event->cuda) { + assert(value() == 1); + event->cuda->wait(s); + } else { + event->shared->wait(s, value()); + } +} + +void Event::signal(Stream s) { + auto* event = static_cast(event_.get()); + event->ensure_created(s, value()); + if (event->cuda) { + assert(value() == 1); + event->cuda->record(s); + } else { + event->shared->signal(s, value()); + } +} + +bool Event::is_signaled() const { + auto* event = static_cast(event_.get()); + if (!event->is_created()) { + return false; + } + if (event->cuda) { + assert(value() == 1); + return event->cuda->recorded() && event->cuda->completed(); + } else { + return event->shared->is_signaled(value()); + } +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/event.h b/mlx/backend/cuda/event.h new file mode 100644 index 000000000..4b56e2e3b --- /dev/null +++ b/mlx/backend/cuda/event.h @@ -0,0 +1,66 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/stream.h" + +#include +#include + +#include + +namespace mlx::core::cu { + +class CudaEventHandle; + +// Wrapper of native cuda event. It can synchronize between GPU streams, or wait +// on GPU stream in CPU stream, but can not wait on CPU stream. +class CudaEvent { + public: + CudaEvent(); + + void wait(); + void wait(cudaStream_t stream); + void wait(Stream s); + void record(cudaStream_t stream); + void record(Stream s); + + // Return whether the recorded kernels have completed. Note that this method + // returns true if record() has not been called. + bool completed() const; + + bool recorded() const { + return recorded_; + } + + private: + bool recorded_{false}; + std::shared_ptr event_; +}; + +// Event that can synchronize between CPU and GPU. It is much slower than +// CudaEvent so the latter should always be preferred when possible. +class SharedEvent { + public: + using Atomic = cuda::atomic; + + SharedEvent(); + + void wait(uint64_t value); + void wait(cudaStream_t stream, uint64_t value); + void wait(Stream s, uint64_t value); + void signal(uint64_t value); + void signal(cudaStream_t stream, uint64_t value); + void signal(Stream s, uint64_t value); + bool is_signaled(uint64_t value) const; + uint64_t value() const; + + const std::shared_ptr& atomic() const { + return ac_; + } + + private: + std::shared_ptr ac_; +}; + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/fence.cpp b/mlx/backend/cuda/fence.cpp new file mode 100644 index 000000000..f399c4ebb --- /dev/null +++ b/mlx/backend/cuda/fence.cpp @@ -0,0 +1,29 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/fence.h" +#include "mlx/backend/cuda/event.h" + +namespace mlx::core { + +struct FenceImpl { + uint32_t count; + cu::SharedEvent event; +}; + +Fence::Fence(Stream s) { + fence_ = std::shared_ptr( + new FenceImpl{0}, [](void* ptr) { delete static_cast(ptr); }); +} + +void Fence::wait(Stream s, const array&) { + auto* fence = static_cast(fence_.get()); + fence->event.wait(fence->count); +} + +void Fence::update(Stream s, const array&) { + auto* fence = static_cast(fence_.get()); + fence->count++; + fence->event.signal(s, fence->count); +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/indexing.cpp b/mlx/backend/cuda/indexing.cpp new file mode 100644 index 000000000..4b03a604e --- /dev/null +++ b/mlx/backend/cuda/indexing.cpp @@ -0,0 +1,428 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/common/compiled.h" +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/jit_module.h" +#include "mlx/backend/cuda/kernel_utils.cuh" +#include "mlx/backend/gpu/copy.h" +#include "mlx/dtype_utils.h" +#include "mlx/primitives.h" + +#include "cuda_jit_sources.h" + +#include +#include +#include +#include + +#include +#include + +namespace mlx::core { + +namespace { + +constexpr const char* g_scatter_ops[] = {"Max", "Min", "Sum", "Prod", "Assign"}; + +void append_indices_arg( + cu::KernelArgs& args, + const std::vector& inputs, + int nidx, + int idx_ndim) { + std::vector indices(nidx); + for (int i = 0; i < nidx; ++i) { + indices[i] = inputs[i + 1].data(); + } + args.append(std::move(indices)); + std::vector indices_shape(nidx * idx_ndim); + for (int i = 0; i < nidx; ++i) { + std::copy_n( + inputs[i + 1].shape().begin(), + idx_ndim, + indices_shape.data() + i * idx_ndim); + } + args.append(std::move(indices_shape)); + std::vector indices_strides(nidx * idx_ndim); + for (int i = 0; i < nidx; ++i) { + std::copy_n( + inputs[i + 1].strides().begin(), + idx_ndim, + indices_strides.data() + i * idx_ndim); + } + args.append(std::move(indices_strides)); +} + +} // namespace + +void Gather::eval_gpu(const std::vector& inputs, array& out) { + nvtx3::scoped_range r("Gather::eval_gpu"); + assert(inputs.size() > 0); + const auto& src = inputs[0]; + + out.set_data(allocator::malloc(out.nbytes())); + if (out.size() == 0) { + return; + } + + int nidx = inputs.size() - 1; + Dtype idx_dtype = nidx > 0 ? inputs[1].dtype() : int32; + int32_t idx_ndim = nidx > 0 ? inputs[1].ndim() : 0; + + bool large = (nidx > 0 && inputs[1].size() > INT32_MAX) || + (src.size() > INT32_MAX) || (out.size() > INT32_MAX); + + uint32_t slice_size = std::accumulate( + slice_sizes_.begin(), slice_sizes_.end(), 1, std::multiplies()); + + std::string module_name = fmt::format( + "gather_{}_{}_{}", + dtype_to_string(out.dtype()), + dtype_to_string(idx_dtype), + nidx); + + auto& s = stream(); + cu::JitModule& mod = cu::get_jit_module(s.device, module_name, [&]() { + std::vector kernel_names; + for (int ndim = 0; ndim <= MAX_NDIM; ++ndim) { + for (int large = 0; large <= 1; ++large) { + kernel_names.push_back(fmt::format( + "mlx::core::cu::gather<{}, {}, {}, {}, {}>", + dtype_to_cuda_type(out.dtype()), + dtype_to_cuda_type(idx_dtype), + nidx, + ndim, + large ? "int64_t" : "int32_t")); + } + } + return std::make_pair(jit_source_gather, std::move(kernel_names)); + }); + + cu::KernelArgs args; + args.append(src); + args.append(out); + if (large) { + args.append(out.size()); + } else { + args.append(out.size()); + } + args.append_ndim(src.shape()); + args.append_ndim(src.strides()); + args.append(src.ndim()); + args.append_ndim(slice_sizes_); + args.append(slice_size); + args.append(axes_); + append_indices_arg(args, inputs, nidx, idx_ndim); + + std::string kernel_name = fmt::format( + "mlx::core::cu::gather<{}, {}, {}, {}, {}>", + dtype_to_cuda_type(out.dtype()), + dtype_to_cuda_type(idx_dtype), + nidx, + idx_ndim, + large ? "int64_t" : "int32_t"); + + auto& encoder = cu::get_command_encoder(s); + for (const auto& in : inputs) { + encoder.set_input_array(in); + } + encoder.set_output_array(out); + + auto kernel = mod.get_kernel(kernel_name); + auto [num_blocks, block_dims] = get_launch_args(kernel, out, large); + encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args()); +} + +void Scatter::eval_gpu(const std::vector& inputs, array& out) { + nvtx3::scoped_range r("Gather::eval_gpu"); + assert(inputs.size() > 1); + auto& upd = inputs.back(); + + // Copy src into out. + CopyType copy_type; + if (inputs[0].data_size() == 1) { + copy_type = CopyType::Scalar; + } else if (inputs[0].flags().row_contiguous) { + copy_type = CopyType::Vector; + } else { + copy_type = CopyType::General; + } + copy_gpu(inputs[0], out, copy_type); + + // Empty update. + if (upd.size() == 0) { + return; + } + + int nidx = axes_.size(); + Dtype idx_dtype = nidx > 0 ? inputs[1].dtype() : int32; + int32_t idx_ndim = nidx > 0 ? inputs[1].ndim() : 0; + + bool large = (nidx > 0 && inputs[1].size() > INT32_MAX) || + (upd.size() > INT32_MAX) || (out.size() > INT32_MAX); + + int32_t upd_post_idx_size = std::accumulate( + upd.shape().begin() + idx_ndim, + upd.shape().end(), + 1, + std::multiplies()); + + const char* op = g_scatter_ops[reduce_type_]; + std::string module_name = fmt::format( + "scatter_{}_{}_{}_{}", + dtype_to_string(out.dtype()), + dtype_to_string(idx_dtype), + op, + nidx); + + auto& s = stream(); + cu::JitModule& mod = cu::get_jit_module(s.device, module_name, [&]() { + std::vector kernel_names; + for (int ndim = 0; ndim <= MAX_NDIM; ++ndim) { + for (int large = 0; large <= 1; ++large) { + kernel_names.push_back(fmt::format( + "mlx::core::cu::scatter<{}, {}, mlx::core::cu::Scatter{}, {}, {}, {}>", + dtype_to_cuda_type(out.dtype()), + dtype_to_cuda_type(idx_dtype), + op, + nidx, + ndim, + large ? "int64_t" : "int32_t")); + } + } + return std::make_pair(jit_source_scatter, std::move(kernel_names)); + }); + + cu::KernelArgs args; + args.append(upd); + args.append(out); + if (large) { + args.append(upd.size()); + } else { + args.append(upd.size()); + } + args.append_ndim(upd.shape()); + args.append_ndim(upd.strides()); + args.append(upd.ndim()); + if (large) { + args.append(upd_post_idx_size); + } else { + args.append(upd_post_idx_size); + } + args.append_ndim(out.shape()); + args.append_ndim(out.strides()); + args.append(out.ndim()); + args.append(axes_); + append_indices_arg(args, inputs, nidx, idx_ndim); + + std::string kernel_name = fmt::format( + "mlx::core::cu::scatter<{}, {}, mlx::core::cu::Scatter{}, {}, {}, {}>", + dtype_to_cuda_type(out.dtype()), + dtype_to_cuda_type(idx_dtype), + op, + nidx, + idx_ndim, + large ? "int64_t" : "int32_t"); + + auto& encoder = cu::get_command_encoder(s); + for (const auto& in : inputs) { + encoder.set_input_array(in); + } + encoder.set_output_array(out); + auto kernel = mod.get_kernel(kernel_name); + auto [num_blocks, block_dims] = get_launch_args(kernel, upd, large); + encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args()); +} + +void GatherAxis::eval_gpu(const std::vector& inputs, array& out) { + nvtx3::scoped_range r("GatherAxis::eval_gpu"); + assert(inputs.size() > 1); + const auto& src = inputs[0]; + const auto& idx = inputs[1]; + + out.set_data(allocator::malloc(out.nbytes())); + if (out.size() == 0) { + return; + } + + bool large = idx.size() > INT32_MAX || src.size() > INT32_MAX; + + std::string module_name = fmt::format( + "gather_axis_{}_{}", + dtype_to_string(out.dtype()), + dtype_to_string(idx.dtype())); + + auto& s = stream(); + cu::JitModule& mod = cu::get_jit_module(s.device, module_name, [&]() { + std::vector kernel_names; + for (int ndim = 0; ndim <= MAX_NDIM; ++ndim) { + for (int contiguous = 0; contiguous < 4; ++contiguous) { + for (int large = 0; large <= 1; ++large) { + kernel_names.push_back(fmt::format( + "mlx::core::cu::gather_axis<{}, {}, {}, {}, {}, {}>", + dtype_to_cuda_type(out.dtype()), + dtype_to_cuda_type(idx.dtype()), + ndim, + contiguous & 1 ? true : false, + contiguous & 2 ? true : false, + large ? "int64_t" : "int32_t")); + } + } + } + return std::make_pair(jit_source_gather_axis, std::move(kernel_names)); + }); + + size_t idx_size_pre = 1; + size_t idx_size_post = 1; + for (int i = 0; i < axis_; ++i) { + idx_size_pre *= idx.shape(i); + } + for (int i = axis_ + 1; i < idx.ndim(); ++i) { + idx_size_post *= idx.shape(i); + } + size_t idx_size_axis = idx.shape(axis_); + + cu::KernelArgs args; + args.append(src); + args.append(idx); + args.append(out); + if (large) { + args.append(idx_size_pre); + args.append(idx_size_axis); + args.append(idx_size_post); + } else { + args.append(idx_size_pre); + args.append(idx_size_axis); + args.append(idx_size_post); + } + args.append(remove_index(idx.shape(), axis_)); + args.append(remove_index(src.strides(), axis_)); + args.append(remove_index(idx.strides(), axis_)); + args.append(axis_); + args.append(src.shape(axis_)); + args.append(src.strides(axis_)); + args.append(idx.strides(axis_)); + + std::string kernel_name = fmt::format( + "mlx::core::cu::gather_axis<{}, {}, {}, {}, {}, {}>", + dtype_to_cuda_type(out.dtype()), + dtype_to_cuda_type(idx.dtype()), + src.ndim() - 1, + src.flags().row_contiguous, + idx.flags().row_contiguous, + large ? "int64_t" : "int32_t"); + + auto& encoder = cu::get_command_encoder(s); + for (const auto& in : inputs) { + encoder.set_input_array(in); + } + encoder.set_output_array(out); + auto kernel = mod.get_kernel(kernel_name); + auto [num_blocks, block_dims] = get_launch_args(kernel, idx, large); + encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args()); +} + +void ScatterAxis::eval_gpu(const std::vector& inputs, array& out) { + nvtx3::scoped_range r("ScatterAxis::eval_gpu"); + assert(inputs.size() > 2); + const auto& src = inputs[0]; + const auto& idx = inputs[1]; + const auto& upd = inputs[2]; + + // Copy src into out. + CopyType copy_type; + if (src.data_size() == 1) { + copy_type = CopyType::Scalar; + } else if (src.flags().row_contiguous) { + copy_type = CopyType::Vector; + } else { + copy_type = CopyType::General; + } + copy_gpu(src, out, copy_type); + + // Empty update. + if (upd.size() == 0) { + return; + } + + bool large = idx.size() > INT32_MAX || src.size() > INT32_MAX; + + const char* op = reduce_type_ == ScatterAxis::Sum ? "Sum" : "Assign"; + std::string module_name = fmt::format( + "scatter_axis_{}_{}_{}", + dtype_to_string(out.dtype()), + dtype_to_string(idx.dtype()), + op); + + auto& s = stream(); + cu::JitModule& mod = cu::get_jit_module(s.device, module_name, [&]() { + std::vector kernel_names; + for (int ndim = 0; ndim <= MAX_NDIM; ++ndim) { + for (int contiguous = 0; contiguous < 4; ++contiguous) { + for (int large = 0; large <= 1; ++large) { + kernel_names.push_back(fmt::format( + "mlx::core::cu::scatter_axis<{}, {}, mlx::core::cu::Scatter{}, {}, {}, {}, {}>", + dtype_to_cuda_type(out.dtype()), + dtype_to_cuda_type(idx.dtype()), + op, + ndim, + contiguous & 1 ? true : false, + contiguous & 2 ? true : false, + large ? "int64_t" : "int32_t")); + } + } + } + return std::make_pair(jit_source_scatter_axis, std::move(kernel_names)); + }); + + size_t idx_size_pre = 1; + size_t idx_size_post = 1; + for (int i = 0; i < axis_; ++i) { + idx_size_pre *= idx.shape(i); + } + for (int i = axis_ + 1; i < idx.ndim(); ++i) { + idx_size_post *= idx.shape(i); + } + size_t idx_size_axis = idx.shape(axis_); + + cu::KernelArgs args; + args.append(upd); + args.append(idx); + args.append(out); + if (large) { + args.append(idx_size_pre); + args.append(idx_size_axis); + args.append(idx_size_post); + } else { + args.append(idx_size_pre); + args.append(idx_size_axis); + args.append(idx_size_post); + } + args.append(remove_index(idx.shape(), axis_)); + args.append(remove_index(upd.strides(), axis_)); + args.append(remove_index(idx.strides(), axis_)); + args.append(axis_); + args.append(out.shape(axis_)); + args.append(upd.strides(axis_)); + args.append(idx.strides(axis_)); + + std::string kernel_name = fmt::format( + "mlx::core::cu::scatter_axis<{}, {}, mlx::core::cu::Scatter{}, {}, {}, {}, {}>", + dtype_to_cuda_type(out.dtype()), + dtype_to_cuda_type(idx.dtype()), + op, + idx.ndim() - 1, + upd.flags().row_contiguous, + idx.flags().row_contiguous, + large ? "int64_t" : "int32_t"); + + auto& encoder = cu::get_command_encoder(s); + for (const auto& in : inputs) { + encoder.set_input_array(in); + } + encoder.set_output_array(out); + auto kernel = mod.get_kernel(kernel_name); + auto [num_blocks, block_dims] = get_launch_args(kernel, idx, large); + encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args()); +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/iterators/general_iterator.cuh b/mlx/backend/cuda/iterators/general_iterator.cuh new file mode 100644 index 000000000..3c8c098c3 --- /dev/null +++ b/mlx/backend/cuda/iterators/general_iterator.cuh @@ -0,0 +1,121 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include +#include + +#include "mlx/backend/cuda/kernel_utils.cuh" + +namespace mlx::core::cu { + +// Iterating non-contiguous array. +template +class general_iterator + : public thrust:: + iterator_adaptor, Iterator> { + public: + using super_t = + thrust::iterator_adaptor, Iterator>; + + using reference = typename super_t::reference; + using difference_type = typename super_t::difference_type; + + __host__ __device__ general_iterator( + Iterator it, + IdxT index, + int ndim, + Shape shape, + Strides strides) + : super_t(it), + index_(index), + ndim_(ndim), + shape_(cuda::std::move(shape)), + strides_(cuda::std::move(strides)) {} + + __host__ __device__ IdxT index() const { + return index_; + } + + __host__ __device__ const Shape& shape() const { + return shape_; + } + + __host__ __device__ const Strides& strides() const { + return strides_; + } + + private: + friend class thrust::iterator_core_access; + + __host__ __device__ bool equal(const general_iterator& other) const { + return this->base() == other.base() && this->index() == other.index(); + } + + __host__ __device__ void advance(difference_type n) { + this->index_ += n; + } + + __host__ __device__ void increment() { + this->index_ += 1; + } + + __host__ __device__ void decrement() { + this->index_ -= 1; + } + + __host__ __device__ difference_type + distance_to(const general_iterator& other) const { + _CCCL_ASSERT( + this->base() == other.base(), + "Underlying iterator must point to same base iterator"); + return other.index() - this->index(); + } + + // The dereference is device-only to avoid accidental running in host. + __device__ typename super_t::reference dereference() const { + IdxT offset = elem_to_loc(index_, shape_.data(), strides_.data(), ndim_); + return *(this->base() + offset); + } + + IdxT index_; + int ndim_; + Shape shape_; + Strides strides_; +}; + +template +__host__ __device__ auto make_general_iterator( + Iterator it, + IdxT index, + int ndim, + Shape shape, + Strides strides) { + return general_iterator( + it, index, ndim, cuda::std::move(shape), cuda::std::move(strides)); +} + +template +auto make_general_iterator( + Iterator it, + const std::vector& shape, + const std::vector& strides) { + return make_general_iterator( + it, 0, shape.size(), const_param(shape), const_param(strides)); +} + +template +auto make_general_iterators( + Iterator it, + IdxT size, + const std::vector& shape, + const std::vector& strides) { + auto ndim = shape.size(); + auto shape_arg = const_param(shape); + auto strides_arg = const_param(strides); + return std::make_pair( + make_general_iterator(it, 0, ndim, shape_arg, strides_arg), + make_general_iterator(it, size, ndim, shape_arg, strides_arg)); +} + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/iterators/strided_iterator.cuh b/mlx/backend/cuda/iterators/strided_iterator.cuh new file mode 100644 index 000000000..3ef8d66bd --- /dev/null +++ b/mlx/backend/cuda/iterators/strided_iterator.cuh @@ -0,0 +1,60 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include +#include + +namespace mlx::core::cu { + +// RandomAccessIterator for strided access to array entries. +template +class strided_iterator + : public thrust:: + iterator_adaptor, Iterator> { + public: + using super_t = + thrust::iterator_adaptor, Iterator>; + + using reference = typename super_t::reference; + using difference_type = typename super_t::difference_type; + + __host__ __device__ strided_iterator(Iterator it, Stride stride) + : super_t(it), stride_(stride) {} + + __host__ __device__ Stride stride() const { + return stride_; + } + + private: + friend class thrust::iterator_core_access; + + __host__ __device__ bool equal(const strided_iterator& other) const { + return this->base() == other.base(); + } + + __host__ __device__ void advance(difference_type n) { + this->base_reference() += n * stride_; + } + + __host__ __device__ void increment() { + this->base_reference() += stride_; + } + + __host__ __device__ void decrement() { + this->base_reference() -= stride_; + } + + __host__ __device__ difference_type + distance_to(const strided_iterator& other) const { + const difference_type dist = other.base() - this->base(); + _CCCL_ASSERT( + dist % stride() == 0, + "Underlying iterator difference must be divisible by the stride"); + return dist / stride(); + } + + Stride stride_; +}; + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/jit_module.cpp b/mlx/backend/cuda/jit_module.cpp new file mode 100644 index 000000000..343db902e --- /dev/null +++ b/mlx/backend/cuda/jit_module.cpp @@ -0,0 +1,323 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/jit_module.h" +#include "mlx/backend/cuda/device.h" +#include "mlx/version.h" + +#include "cuda_jit_sources.h" + +#include +#include +#include +#include + +#include +#include +#include + +namespace mlx::core::cu { + +namespace { + +#define CHECK_NVRTC_ERROR(cmd) check_nvrtc_error(#cmd, (cmd)) + +void check_nvrtc_error(const char* name, nvrtcResult err) { + if (err != NVRTC_SUCCESS) { + throw std::runtime_error( + fmt::format("{} failed: {}", name, nvrtcGetErrorString(err))); + } +} + +// Return the location of the CUDA toolkit. +const std::string& cuda_home() { + static std::string home = []() -> std::string { + const char* home = std::getenv("CUDA_HOME"); + if (home) { + return home; + } + home = std::getenv("CUDA_PATH"); + if (home) { + return home; + } +#if defined(__linux__) + home = "/usr/local/cuda"; + if (std::filesystem::exists(home)) { + return home; + } +#endif + throw std::runtime_error( + "Environment variable CUDA_HOME or CUDA_PATH is not set."); + }(); + return home; +} + +// Return the location of CCCL headers shipped with the distribution. +bool get_cccl_include(std::string* out) { + auto cccl_headers = current_binary_dir().parent_path() / "include" / "cccl"; + if (!std::filesystem::exists(cccl_headers)) { + return false; + } + *out = fmt::format("--include-path={}", cccl_headers.string()); + return true; +} + +// Get the cache directory for storing compiled results. +const std::filesystem::path& ptx_cache_dir() { + static std::filesystem::path cache = []() -> std::filesystem::path { + std::filesystem::path cache; + if (auto c = std::getenv("MLX_PTX_CACHE_DIR"); c) { + cache = c; + } else { + cache = + std::filesystem::temp_directory_path() / "mlx" / version() / "ptx"; + } + if (!std::filesystem::exists(cache)) { + std::error_code error; + if (!std::filesystem::create_directories(cache, error)) { + return std::filesystem::path(); + } + } + return cache; + }(); + return cache; +} + +// Try to read the cached |ptx| and |ptx_kernels| from |cache_dir|. +bool read_cached_ptx( + const std::filesystem::path& cache_dir, + const std::string& module_name, + std::vector* ptx, + std::vector>* ptx_kernels) { + if (cache_dir.empty()) { + return false; + } + + auto ptx_path = cache_dir / (module_name + ".ptx"); + std::error_code error; + auto ptx_size = std::filesystem::file_size(ptx_path, error); + if (error) { + return false; + } + std::ifstream ptx_file(ptx_path, std::ios::binary); + if (!ptx_file.good()) { + return false; + } + ptx->resize(ptx_size); + ptx_file.read(ptx->data(), ptx_size); + + std::ifstream txt_file(cache_dir / (module_name + ".txt"), std::ios::binary); + std::string line; + while (std::getline(txt_file, line)) { + auto tab = line.find('\t'); + if (tab != std::string::npos) { + ptx_kernels->emplace_back(line.substr(0, tab), line.substr(tab + 1)); + } + } + return true; +} + +// Write the |ptx| and |ptx_kernels| to |cache_dir| with |name|. +void write_cached_ptx( + const std::filesystem::path& cache_dir, + const std::string& module_name, + const std::vector& ptx, + const std::vector>& ptx_kernels) { + if (cache_dir.empty()) { + return; + } + + std::ofstream ptx_file(cache_dir / (module_name + ".ptx"), std::ios::binary); + if (!ptx.empty()) { + ptx_file.write(&ptx.front(), ptx.size()); + } + std::ofstream txt_file(cache_dir / (module_name + ".txt"), std::ios::binary); + for (const auto& [name, mangled] : ptx_kernels) { + txt_file << name << "\t" << mangled << std::endl; + } +} + +// Return if |device|'s version is not newer than |major|.|minor| version. +inline bool version_lower_equal(Device& device, int major, int minor) { + if (device.compute_capability_major() < major) { + return true; + } else if (device.compute_capability_major() == major) { + return device.compute_capability_minor() <= minor; + } else { + return false; + } +} + +// Return whether NVRTC supports compiling to |device|'s SASS code. +bool compiler_supports_device_sass(Device& device) { + int nvrtc_major, nvrtc_minor; + CHECK_NVRTC_ERROR(nvrtcVersion(&nvrtc_major, &nvrtc_minor)); + if (nvrtc_major < 9) { + return false; + } else if (nvrtc_major == 9) { + return version_lower_equal(device, 7, 2); + } else if (nvrtc_major == 10) { + return version_lower_equal(device, 7, 5); + } else if (nvrtc_major == 11 && nvrtc_minor == 0) { + return version_lower_equal(device, 8, 0); + } else if (nvrtc_major == 11 && nvrtc_minor < 8) { + return version_lower_equal(device, 8, 6); + } else { + return true; + } +} + +#define INCLUDE_PREFIX "mlx/backend/cuda/device/" + +constexpr const char* g_include_names[] = { + INCLUDE_PREFIX "atomic_ops.cuh", + INCLUDE_PREFIX "binary_ops.cuh", + INCLUDE_PREFIX "cast_op.cuh", + INCLUDE_PREFIX "config.h", + INCLUDE_PREFIX "complex.cuh", + INCLUDE_PREFIX "fp16_math.cuh", + INCLUDE_PREFIX "indexing.cuh", + INCLUDE_PREFIX "scatter_ops.cuh", + INCLUDE_PREFIX "unary_ops.cuh", + INCLUDE_PREFIX "ternary_ops.cuh", + INCLUDE_PREFIX "utils.cuh", +}; + +#undef INCLUDE_PREFIX + +constexpr const char* g_headers[] = { + jit_source_atomic_ops, + jit_source_binary_ops, + jit_source_cast_op, + jit_source_config, + jit_source_complex, + jit_source_fp16_math, + jit_source_indexing, + jit_source_scatter_ops, + jit_source_unary_ops, + jit_source_ternary_ops, + jit_source_utils, +}; + +} // namespace + +JitModule::JitModule( + Device& device, + const std::string& module_name, + const KernelBuilder& builder) { + // Check cache. + std::vector ptx; + std::vector> ptx_kernels; + if (!read_cached_ptx(ptx_cache_dir(), module_name, &ptx, &ptx_kernels)) { + // Create program. + auto [source_code, kernel_names] = builder(); + nvrtcProgram prog; + CHECK_NVRTC_ERROR(nvrtcCreateProgram( + &prog, + source_code.c_str(), + (module_name + ".cu").c_str(), + std::size(g_headers), + g_headers, + g_include_names)); + std::unique_ptr prog_freer( + &prog, + [](nvrtcProgram* p) { CHECK_NVRTC_ERROR(nvrtcDestroyProgram(p)); }); + for (const auto& name : kernel_names) { + CHECK_NVRTC_ERROR(nvrtcAddNameExpression(prog, name.c_str())); + } + + // Compile program. + std::vector args; + bool use_sass = compiler_supports_device_sass(device); + std::string compute = fmt::format( + "--gpu-architecture={}_{}{}", + use_sass ? "sm" : "compute", + device.compute_capability_major(), + device.compute_capability_minor()); + args.push_back(compute.c_str()); + std::string cccl_include; + if (get_cccl_include(&cccl_include)) { + args.push_back(cccl_include.c_str()); + } + std::string cuda_include = + fmt::format("--include-path={}/include", cuda_home()); + args.push_back(cuda_include.c_str()); + nvrtcResult compile_result = + nvrtcCompileProgram(prog, args.size(), args.data()); + if (compile_result != NVRTC_SUCCESS) { + size_t log_size; + CHECK_NVRTC_ERROR(nvrtcGetProgramLogSize(prog, &log_size)); + std::vector log(log_size + 1, 0); + CHECK_NVRTC_ERROR(nvrtcGetProgramLog(prog, log.data())); + throw std::runtime_error( + fmt::format("Failed to compile kernel: {}.", log.data())); + } + + // Get mangled names of kernel names. + for (const auto& name : kernel_names) { + const char* mangled; + CHECK_NVRTC_ERROR(nvrtcGetLoweredName(prog, name.c_str(), &mangled)); + ptx_kernels.emplace_back(name, mangled); + } + + // Get ptx data. + size_t ptx_size; + if (use_sass) { + CHECK_NVRTC_ERROR(nvrtcGetCUBINSize(prog, &ptx_size)); + } else { + CHECK_NVRTC_ERROR(nvrtcGetPTXSize(prog, &ptx_size)); + } + ptx.resize(ptx_size, 0); + if (use_sass) { + CHECK_NVRTC_ERROR(nvrtcGetCUBIN(prog, ptx.data())); + } else { + CHECK_NVRTC_ERROR(nvrtcGetPTX(prog, ptx.data())); + } + write_cached_ptx(ptx_cache_dir(), module_name, ptx, ptx_kernels); + } + + // Load module. + char jit_log[4089] = {}; + CUjit_option options[] = { + CU_JIT_ERROR_LOG_BUFFER, CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES}; + void* values[] = {jit_log, reinterpret_cast(std::size(jit_log) - 1)}; + CUresult jit_result = cuModuleLoadDataEx( + &module_, ptx.data(), std::size(options), options, values); + if (jit_result != CUDA_SUCCESS) { + throw std::runtime_error(fmt::format( + "Failed to load compiled {} kernel: {}.", module_name, jit_log)); + } + + // Load kernels. + for (const auto& [name, mangled] : ptx_kernels) { + CUfunction kernel; + CHECK_CUDA_ERROR(cuModuleGetFunction(&kernel, module_, mangled.c_str())); + kernels_[name] = kernel; + } +} + +JitModule::~JitModule() { + CHECK_CUDA_ERROR(cuModuleUnload(module_)); +} + +CUfunction JitModule::get_kernel(const std::string& kernel_name) { + auto it = kernels_.find(kernel_name); + if (it == kernels_.end()) { + throw std::runtime_error( + fmt::format("There is no kernel named {}.", kernel_name)); + } + return it->second; +} + +JitModule& get_jit_module( + const mlx::core::Device& device, + const std::string& name, + const KernelBuilder& builder) { + static std::unordered_map map; + auto it = map.find(name); + if (it == map.end()) { + it = map.try_emplace(name, cu::device(device), name, builder).first; + } + return it->second; +} + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/jit_module.h b/mlx/backend/cuda/jit_module.h new file mode 100644 index 000000000..57da7c87e --- /dev/null +++ b/mlx/backend/cuda/jit_module.h @@ -0,0 +1,107 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/array.h" +#include "mlx/backend/common/utils.h" +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/device/config.h" + +#include +#include +#include +#include + +#include +#include + +namespace mlx::core::cu { + +class Device; + +using KernelBuilderResult = std::pair< + /* source code */ std::string, + /* kernel names */ std::vector>; +using KernelBuilder = std::function; + +struct KernelArgs { + void** args() { + return args_.data(); + } + + void append(const array& a) { + append(reinterpret_cast(a.data())); + } + + template + void append(T val) { + storage_.emplace_back(val); + append_ptr(&storage_.back()); + } + + template + void append(std::vector vec) { + if (vec.empty()) { + // The nullptr can not be used as arg, pass something not null. + append(std::monostate{}); + } else { + append_ptr(vec.data()); + storage_.emplace_back(std::move(vec)); + } + } + + // Make sure the arg is copied to an array with size of NDIM. + template + void append_ndim(std::vector vec) { + if (vec.size() > NDIM) { + throw std::runtime_error( + fmt::format("ndim can not be larger than {}.", NDIM)); + } + vec.resize(NDIM); + append(std::move(vec)); + } + + void append_ptr(const void* v) { + args_.push_back(const_cast(v)); + } + + private: + std::vector args_; + + // The cuLaunchKernel API requires passing pointers to arguments so store + // temporary values untill kernel is launched. + using Arg = std::variant< + std::monostate, + CUdeviceptr, + int32_t, + uint32_t, + int64_t, + std::vector, + std::vector, + std::vector>; + std::deque storage_; +}; + +class JitModule { + public: + JitModule( + Device& device, + const std::string& module_name, + const KernelBuilder& builder); + ~JitModule(); + + JitModule(const JitModule&) = delete; + JitModule& operator=(const JitModule&) = delete; + CUfunction get_kernel(const std::string& kernel_name); + + private: + CUmodule module_{nullptr}; + std::unordered_map kernels_; +}; + +JitModule& get_jit_module( + const mlx::core::Device& device, + const std::string& name, + const KernelBuilder& builder); + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/kernel_utils.cu b/mlx/backend/cuda/kernel_utils.cu new file mode 100644 index 000000000..7b87aa5b0 --- /dev/null +++ b/mlx/backend/cuda/kernel_utils.cu @@ -0,0 +1,33 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/common/utils.h" +#include "mlx/backend/cuda/kernel_utils.cuh" + +namespace mlx::core { + +dim3 get_block_dims(int dim0, int dim1, int dim2, int pow2) { + Dims dims = get_block_dims_common(dim0, dim1, dim2, pow2); + return dim3(std::get<0>(dims), std::get<1>(dims), std::get<2>(dims)); +} + +dim3 get_2d_grid_dims(const Shape& shape, const Strides& strides) { + Dims dims = get_2d_grid_dims_common(shape, strides); + return dim3(std::get<0>(dims), std::get<1>(dims), std::get<2>(dims)); +} + +dim3 get_2d_grid_dims( + const Shape& shape, + const Strides& strides, + size_t divisor) { + Dims dims = get_2d_grid_dims_common(shape, strides, divisor); + return dim3(std::get<0>(dims), std::get<1>(dims), std::get<2>(dims)); +} + +std::pair get_grid_and_block(int dim0, int dim1, int dim2) { + auto [grid, block] = get_grid_and_block_common(dim0, dim1, dim2); + auto [gx, gy, gz] = grid; + auto [bx, by, bz] = block; + return std::make_pair(dim3(gx, gy, gz), dim3(bx, by, bz)); +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/kernel_utils.cuh b/mlx/backend/cuda/kernel_utils.cuh new file mode 100644 index 000000000..24c81f2fb --- /dev/null +++ b/mlx/backend/cuda/kernel_utils.cuh @@ -0,0 +1,172 @@ +// Copyright © 2025 Apple Inc. + +// This file includes host-only utilies for writing CUDA kernels, the difference +// from backend/cuda/device/utils.cuh is that the latter file only include +// device-only code. + +#pragma once + +#include + +#include "mlx/array.h" +#include "mlx/backend/cuda/device/utils.cuh" + +#include +#include +#include +#include +#include + +namespace mlx::core { + +template +void dispatch_1_2_3(int n, F&& f) { + switch (n) { + case 1: + f(std::integral_constant{}); + break; + case 2: + f(std::integral_constant{}); + break; + case 3: + f(std::integral_constant{}); + break; + } +} + +template +void dispatch_bool(bool v, F&& f) { + if (v) { + f(std::true_type{}); + } else { + f(std::false_type{}); + } +} + +template +void dispatch_block_dim(int threads, F&& f) { + if (threads <= WARP_SIZE) { + f(std::integral_constant{}); + } else if (threads <= WARP_SIZE * 2) { + f(std::integral_constant{}); + } else if (threads <= WARP_SIZE * 4) { + f(std::integral_constant{}); + } else if (threads <= WARP_SIZE * 8) { + f(std::integral_constant{}); + } else if (threads <= WARP_SIZE * 16) { + f(std::integral_constant{}); + } else { + f(std::integral_constant{}); + } +} + +// Maps CPU types to CUDA types. +template +struct CTypeToCudaType { + using type = T; +}; + +template <> +struct CTypeToCudaType { + using type = __half; +}; + +template <> +struct CTypeToCudaType { + using type = __nv_bfloat16; +}; + +template <> +struct CTypeToCudaType { + using type = cu::complex64_t; +}; + +template +using cuda_type_t = typename CTypeToCudaType::type; + +// Type traits for detecting floating numbers. +template +inline constexpr bool is_floating_v = + cuda::std::is_same_v || cuda::std::is_same_v || + cuda::std::is_same_v || cuda::std::is_same_v; + +// Type traits for detecting complex numbers. +template +inline constexpr bool is_complex_v = cuda::std::is_same_v || + cuda::std::is_same_v; + +// Type traits for detecting complex or real floating point numbers. +template +inline constexpr bool is_inexact_v = is_floating_v || is_complex_v; + +// Utility to copy data from vector to array in host. +template +inline cuda::std::array const_param(const std::vector& vec) { + if (vec.size() > NDIM) { + throw std::runtime_error( + fmt::format("ndim can not be larger than {}.", NDIM)); + } + cuda::std::array result; + std::copy_n(vec.begin(), vec.size(), result.begin()); + return result; +} + +// Compute the grid and block dimensions, check backend/common/utils.h for docs. +dim3 get_block_dims(int dim0, int dim1, int dim2, int pow2 = 10); +dim3 get_2d_grid_dims(const Shape& shape, const Strides& strides); +dim3 get_2d_grid_dims( + const Shape& shape, + const Strides& strides, + size_t divisor); +std::pair get_grid_and_block(int dim0, int dim1, int dim2); + +// Return a block size that achieves maximum potential occupancy for kernel. +template +inline uint max_occupancy_block_dim(T kernel) { + int _, block_dim; + if constexpr (std::is_same_v) { + CHECK_CUDA_ERROR( + cuOccupancyMaxPotentialBlockSize(&_, &block_dim, kernel, 0, 0, 0)); + } else { + CHECK_CUDA_ERROR( + cudaOccupancyMaxPotentialBlockSize(&_, &block_dim, kernel)); + } + return block_dim; +} + +// Get the num_blocks and block_dims that maximize occupancy for |kernel|, +// assuming each thread handles |work_per_thread| elements of |arr|. +template +inline std::tuple get_launch_args( + T kernel, + size_t size, + const Shape& shape, + const Strides& strides, + bool large, + int work_per_thread = 1) { + size_t nthreads = cuda::ceil_div(size, work_per_thread); + uint block_dim = max_occupancy_block_dim(kernel); + if (block_dim > nthreads) { + block_dim = nthreads; + } + dim3 num_blocks; + if (large) { + num_blocks = get_2d_grid_dims(shape, strides, work_per_thread); + num_blocks.x = cuda::ceil_div(num_blocks.x, block_dim); + } else { + num_blocks.x = cuda::ceil_div(nthreads, block_dim); + } + return std::make_tuple(num_blocks, block_dim); +} + +template +inline std::tuple get_launch_args( + T kernel, + const array& arr, + bool large, + int work_per_thread = 1) { + return get_launch_args( + kernel, arr.size(), arr.shape(), arr.strides(), large, work_per_thread); +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/layer_norm.cu b/mlx/backend/cuda/layer_norm.cu new file mode 100644 index 000000000..5fbf949d7 --- /dev/null +++ b/mlx/backend/cuda/layer_norm.cu @@ -0,0 +1,405 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/iterators/strided_iterator.cuh" +#include "mlx/backend/cuda/kernel_utils.cuh" +#include "mlx/backend/cuda/reduce/reduce.cuh" +#include "mlx/backend/gpu/copy.h" +#include "mlx/dtype_utils.h" +#include "mlx/fast_primitives.h" + +#include +#include +#include +#include +#include + +namespace mlx::core { + +namespace cu { + +namespace cg = cooperative_groups; + +inline __device__ float3 plus_f3(const float3& a, const float3& b) { + return {a.x + b.x, a.y + b.y, a.z + b.z}; +} + +// Similar to cub::BlockReduce, but result is broadcasted to every thread. +template +struct BlockBroadcastReduce { + static_assert(WARP_SIZE <= BLOCK_DIM && BLOCK_DIM <= WARP_SIZE * WARP_SIZE); + static_assert(BLOCK_DIM % WARP_SIZE == 0); + using TempStorage = T[BLOCK_DIM / WARP_SIZE]; + + cg::thread_block& block; + TempStorage& temp; + + template + __device__ T Reduce(const T& input, const Op& op, const T& init_value) { + auto warp = cg::tiled_partition(block); + T x = cg::reduce(warp, input, op); + if (warp.thread_rank() == 0) { + temp[warp.meta_group_rank()] = x; + } + block.sync(); + x = warp.thread_rank() < warp.meta_group_size() ? temp[warp.thread_rank()] + : init_value; + return cg::reduce(warp, x, op); + } + + __device__ T Sum(const T& input) { + return Reduce(input, cg::plus{}, T{}); + } +}; + +template +__global__ void layer_norm( + const T* x, + const T* w, + const T* b, + T* out, + float eps, + int32_t axis_size, + int64_t w_stride, + int64_t b_stride) { + auto grid = cg::this_grid(); + auto block = cg::this_thread_block(); + + using BlockReduceT = BlockBroadcastReduce; + __shared__ typename BlockReduceT::TempStorage temp; + + x += grid.block_rank() * axis_size; + out += grid.block_rank() * axis_size; + + // Sum. + float sum = 0; + for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) { + auto index = r * BLOCK_DIM + block.thread_rank(); + T xn[N_READS] = {}; + cub::LoadDirectBlocked(index, x, xn, axis_size); + sum += static_cast(cub::ThreadReduce(xn, cuda::std::plus<>{})); + } + sum = BlockReduceT{block, temp}.Sum(sum); + + // Mean. + float mean = sum / axis_size; + + // Normalizer. + float normalizer = 0; + for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) { + auto index = r * BLOCK_DIM + block.thread_rank(); + T xn[N_READS]; + cub::LoadDirectBlocked(index, x, xn, axis_size, mean); + for (int i = 0; i < N_READS; ++i) { + float t = static_cast(xn[i]) - mean; + normalizer += t * t; + } + } + normalizer = BlockReduceT{block, temp}.Sum(normalizer); + normalizer = rsqrt(normalizer / axis_size + eps); + + // Outputs. + for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) { + auto index = r * BLOCK_DIM + block.thread_rank(); + T xn[N_READS]; + T wn[N_READS]; + T bn[N_READS]; + cub::LoadDirectBlocked(index, x, xn, axis_size); + cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size); + cub::LoadDirectBlocked(index, strided_iterator(b, b_stride), bn, axis_size); + for (int i = 0; i < N_READS; ++i) { + float norm = (static_cast(xn[i]) - mean) * normalizer; + xn[i] = wn[i] * static_cast(norm) + bn[i]; + } + cub::StoreDirectBlocked(index, out, xn, axis_size); + } +} + +template +__global__ void layer_norm_vjp( + const T* x, + const T* w, + const T* g, + T* gx, + T* gw, + float eps, + int32_t axis_size, + int64_t w_stride) { + auto grid = cg::this_grid(); + auto block = cg::this_thread_block(); + + using BlockReduceF = BlockBroadcastReduce; + using BlockReduceF3 = BlockBroadcastReduce; + __shared__ union { + typename BlockReduceF::TempStorage f; + typename BlockReduceF3::TempStorage f3; + } temp; + + x += grid.block_rank() * axis_size; + g += grid.block_rank() * axis_size; + gx += grid.block_rank() * axis_size; + gw += grid.block_rank() * axis_size; + + // Sum. + float sum = 0; + for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) { + auto index = r * BLOCK_DIM + block.thread_rank(); + T xn[N_READS] = {}; + cub::LoadDirectBlocked(index, x, xn, axis_size); + sum += static_cast(cub::ThreadReduce(xn, cuda::std::plus<>{})); + } + sum = BlockReduceF{block, temp.f}.Sum(sum); + + // Mean. + float mean = sum / axis_size; + + // Normalizer. + float3 factors = {}; + for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) { + T xn[N_READS]; + T wn[N_READS] = {}; + T gn[N_READS] = {}; + auto index = r * BLOCK_DIM + block.thread_rank(); + cub::LoadDirectBlocked(index, x, xn, axis_size, mean); + cub::LoadDirectBlocked(index, g, gn, axis_size); + cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size); + for (int i = 0; i < N_READS; i++) { + float t = static_cast(xn[i]) - mean; + float wi = wn[i]; + float gi = gn[i]; + float wg = wi * gi; + factors = plus_f3(factors, {wg, wg * t, t * t}); + } + } + factors = BlockReduceF3{block, temp.f3}.Reduce(factors, plus_f3, {}); + float meanwg = factors.x / axis_size; + float meanwgxc = factors.y / axis_size; + float normalizer2 = 1 / (factors.z / axis_size + eps); + float normalizer = sqrt(normalizer2); + + // Outputs. + for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) { + auto index = r * BLOCK_DIM + block.thread_rank(); + T xn[N_READS]; + T wn[N_READS]; + T gn[N_READS]; + cub::LoadDirectBlocked(index, x, xn, axis_size); + cub::LoadDirectBlocked(index, g, gn, axis_size); + cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size); + for (int i = 0; i < N_READS; i++) { + float xi = (static_cast(xn[i]) - mean) * normalizer; + float wi = wn[i]; + float gi = gn[i]; + xn[i] = normalizer * (wi * gi - meanwg) - xi * meanwgxc * normalizer2; + if constexpr (HAS_W) { + wn[i] = gi * xi; + } + } + cub::StoreDirectBlocked(index, gx, xn, axis_size); + if constexpr (HAS_W) { + cub::StoreDirectBlocked(index, gw, wn, axis_size); + } + } +} + +} // namespace cu + +namespace fast { + +bool LayerNorm::use_fallback(Stream s) { + return s.device == Device::cpu; +} + +// TODO: There are duplicate code with backend/metal/normalization.cpp +void LayerNorm::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + nvtx3::scoped_range r("LayerNorm::eval_gpu"); + auto& s = stream(); + auto& out = outputs[0]; + + // Make sure that the last dimension is contiguous. + auto set_output = [&s, &out](const array& x) { + bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1; + if (no_copy && x.ndim() > 1) { + auto s = x.strides()[x.ndim() - 2]; + no_copy &= (s == 0 || s == x.shape().back()); + } + if (no_copy) { + if (x.is_donatable()) { + out.copy_shared_buffer(x); + } else { + out.set_data( + allocator::malloc(x.data_size() * x.itemsize()), + x.data_size(), + x.strides(), + x.flags()); + } + return x; + } else { + auto x_copy = array(x.shape(), x.dtype(), nullptr, {}); + copy_gpu(x, x_copy, CopyType::General, s); + out.copy_shared_buffer(x_copy); + return x_copy; + } + }; + + const array x = set_output(inputs[0]); + const array& w = inputs[1]; + const array& b = inputs[2]; + + int32_t axis_size = x.shape().back(); + int32_t n_rows = x.data_size() / axis_size; + int64_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0; + int64_t b_stride = (b.ndim() == 1) ? b.strides()[0] : 0; + + auto& encoder = cu::get_command_encoder(s); + encoder.set_input_array(x); + encoder.set_input_array(w); + encoder.set_input_array(b); + encoder.set_output_array(out); + dispatch_float_types(out.dtype(), "layernorm", [&](auto type_tag) { + constexpr uint32_t N_READS = 4; + dispatch_block_dim(cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) { + using DataType = cuda_type_t; + auto kernel = cu::layer_norm; + encoder.add_kernel_node( + kernel, + n_rows, + block_dim(), + x.data(), + w.data(), + b.data(), + out.data(), + eps_, + axis_size, + w_stride, + b_stride); + }); + }); +} + +void LayerNormVJP::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + nvtx3::scoped_range r("LayerNormVJP::eval_gpu"); + auto& s = stream(); + auto& encoder = cu::get_command_encoder(s); + + // Ensure row contiguity. We could relax this step by checking that the array + // is contiguous (no broadcasts or holes) and that the input strides are the + // same as the cotangent strides but for now this is simpler. + auto check_input = [&s](const array& x, bool& copied) { + if (x.flags().row_contiguous) { + copied = false; + return x; + } + copied = true; + array x_copy(x.shape(), x.dtype(), nullptr, {}); + copy_gpu(x, x_copy, CopyType::General, s); + return x_copy; + }; + bool donate_x = inputs[0].is_donatable(); + bool donate_g = inputs[3].is_donatable(); + bool copied; + auto x = check_input(inputs[0], copied); + donate_x |= copied; + const array& w = inputs[1]; + const array& b = inputs[2]; + bool g_copied; + auto g = check_input(inputs[3], g_copied); + donate_g |= g_copied; + array& gx = outputs[0]; + array& gw = outputs[1]; + array& gb = outputs[2]; + + // Check whether we had a weight. + bool has_w = w.ndim() != 0; + + // Allocate space for the outputs. + bool g_in_gx = false; + if (donate_x) { + gx.copy_shared_buffer(x); + } else if (donate_g) { + gx.copy_shared_buffer(g); + g_in_gx = true; + } else { + gx.set_data(allocator::malloc(gx.nbytes())); + } + if (g_copied && !g_in_gx) { + encoder.add_temporary(g); + } + + int32_t axis_size = x.shape().back(); + int32_t n_rows = x.data_size() / axis_size; + int64_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0; + + // Allocate a temporary to store the gradients for w and allocate the output + // gradient accumulators. + array gw_temp = + (has_w) ? array({n_rows, x.shape().back()}, gw.dtype(), nullptr, {}) : w; + bool g_in_gw = false; + if (has_w) { + if (!g_in_gx && donate_g) { + g_in_gw = true; + gw_temp.copy_shared_buffer(g); + } else { + gw_temp.set_data(allocator::malloc(gw_temp.nbytes())); + encoder.add_temporary(gw_temp); + } + } + + // The gradient for b in case we had a b. + bool has_gb = (gb.ndim() == 1 && gb.size() == axis_size); + if (has_gb) { + ReductionPlan plan( + ReductionOpType::ContiguousStridedReduce, {n_rows}, {axis_size}); + col_reduce(encoder, g, gb, Reduce::ReduceType::Sum, {0}, plan); + } + + // Insert dependency if `g` was donated + if ((g_in_gx || g_in_gw) && has_gb) { + encoder.set_input_array(gb); + } + encoder.set_input_array(x); + encoder.set_input_array(w); + encoder.set_input_array(g); + encoder.set_output_array(gx); + encoder.set_output_array(gw_temp); + dispatch_float_types(gx.dtype(), "layernorm_vjp", [&](auto type_tag) { + dispatch_bool(has_w, [&](auto has_w_constant) { + constexpr int N_READS = 4; + dispatch_block_dim( + cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) { + using DataType = cuda_type_t; + auto kernel = cu::layer_norm_vjp< + DataType, + has_w_constant.value, + block_dim(), + N_READS>; + encoder.add_kernel_node( + kernel, + n_rows, + block_dim(), + x.data(), + w.data(), + g.data(), + gx.data(), + gw_temp.data(), + eps_, + axis_size, + w_stride); + }); + }); + }); + + if (has_w) { + ReductionPlan plan( + ReductionOpType::ContiguousStridedReduce, {n_rows}, {axis_size}); + col_reduce(encoder, gw_temp, gw, Reduce::ReduceType::Sum, {0}, plan); + } +} + +} // namespace fast + +} // namespace mlx::core diff --git a/mlx/backend/cuda/logsumexp.cu b/mlx/backend/cuda/logsumexp.cu new file mode 100644 index 000000000..afc52826f --- /dev/null +++ b/mlx/backend/cuda/logsumexp.cu @@ -0,0 +1,162 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/device/cast_op.cuh" +#include "mlx/backend/cuda/kernel_utils.cuh" +#include "mlx/backend/gpu/copy.h" +#include "mlx/dtype_utils.h" +#include "mlx/primitives.h" + +#include +#include +#include +#include + +#include + +namespace mlx::core { + +namespace cu { + +namespace cg = cooperative_groups; + +template +inline __device__ T softmax_exp(T x) { + // Softmax doesn't need high precision exponential cause x is gonna be in + // (-oo, 0] anyway and subsequently it will be divided by sum(exp(x_i)). + return __expf(x); +} + +template +__global__ void logsumexp(const T* in, T* out, int axis_size) { + auto grid = cg::this_grid(); + auto block = cg::this_thread_block(); + auto warp = cg::tiled_partition(block); + + in += grid.block_rank() * axis_size; + + cg::greater max_op; + cg::plus plus_op; + + // Thread reduce. + AccT prevmax; + AccT maxval = Limits::finite_min(); + AccT normalizer = 0; + for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); r++) { + AccT vals[N_READS]; + cub::LoadDirectBlocked( + r * BLOCK_DIM + block.thread_rank(), + make_cast_iterator(in), + vals, + axis_size, + Limits::min()); + prevmax = maxval; + maxval = max_op(maxval, cub::ThreadReduce(vals, max_op)); + // Online normalizer calculation for softmax: + // https://github.com/NVIDIA/online-softmax + normalizer = normalizer * softmax_exp(prevmax - maxval); + for (int i = 0; i < N_READS; i++) { + normalizer = normalizer + softmax_exp(vals[i] - maxval); + } + } + + // First warp reduce. + prevmax = maxval; + maxval = cg::reduce(warp, maxval, max_op); + normalizer = normalizer * softmax_exp(prevmax - maxval); + normalizer = cg::reduce(warp, normalizer, plus_op); + + __shared__ AccT local_max[WARP_SIZE]; + __shared__ AccT local_normalizer[WARP_SIZE]; + + // Write to shared memory and do second warp reduce. + prevmax = maxval; + if (warp.thread_rank() == 0) { + local_max[warp.meta_group_rank()] = maxval; + } + block.sync(); + maxval = warp.thread_rank() < warp.meta_group_size() + ? local_max[warp.thread_rank()] + : Limits::finite_min(); + maxval = cg::reduce(warp, maxval, max_op); + normalizer = normalizer * softmax_exp(prevmax - maxval); + if (warp.thread_rank() == 0) { + local_normalizer[warp.meta_group_rank()] = normalizer; + } + block.sync(); + normalizer = warp.thread_rank() < warp.meta_group_size() + ? local_normalizer[warp.thread_rank()] + : AccT{}; + normalizer = cg::reduce(warp, normalizer, plus_op); + + // Write output. + if (block.thread_rank() == 0) { + out[grid.block_rank()] = isinf(maxval) ? maxval : log(normalizer) + maxval; + } +} + +} // namespace cu + +void LogSumExp::eval_gpu(const std::vector& inputs, array& out) { + nvtx3::scoped_range r("LogSumExp::eval_gpu"); + assert(inputs.size() == 1); + auto& s = stream(); + auto& encoder = cu::get_command_encoder(s); + + // Make sure that the last dimension is contiguous. + auto ensure_contiguous = [&s, &encoder](const array& x) { + if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) { + return x; + } else { + auto x_copy = array(x.shape(), x.dtype(), nullptr, {}); + copy_gpu(x, x_copy, CopyType::General, s); + encoder.add_temporary(x_copy); + return x_copy; + } + }; + + auto in = ensure_contiguous(inputs[0]); + if (in.flags().row_contiguous) { + out.set_data(allocator::malloc(out.nbytes())); + } else { + auto n = in.shape(-1); + auto flags = in.flags(); + auto strides = in.strides(); + for (auto& s : strides) { + s /= n; + } + bool col_contig = strides[0] == 1; + for (int i = 1; col_contig && i < strides.size(); ++i) { + col_contig &= + (out.shape(i) == 1 || strides[i - 1] == out.shape(i) * strides[i]); + } + flags.col_contiguous = col_contig; + out.set_data( + allocator::malloc(in.nbytes() / n), + in.data_size() / n, + std::move(strides), + flags); + } + + int axis_size = in.shape().back(); + int n_rows = in.data_size() / axis_size; + + encoder.set_input_array(in); + encoder.set_output_array(out); + dispatch_float_types(out.dtype(), "logsumexp", [&](auto type_tag) { + constexpr int N_READS = 4; + dispatch_block_dim(cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) { + using DataType = cuda_type_t; + auto kernel = cu::logsumexp; + encoder.add_kernel_node( + kernel, + n_rows, + block_dim(), + in.data(), + out.data(), + axis_size); + }); + }); +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/matmul.cpp b/mlx/backend/cuda/matmul.cpp new file mode 100644 index 000000000..e11c68b7d --- /dev/null +++ b/mlx/backend/cuda/matmul.cpp @@ -0,0 +1,489 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/common/matmul.h" +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/gpu/copy.h" +#include "mlx/dtype_utils.h" +#include "mlx/primitives.h" +#include "mlx/utils.h" + +#include +#include +#include + +#include + +namespace mlx::core { + +namespace cu { + +#define CHECK_CUBLAS_ERROR(cmd) check_cublas_error(#cmd, (cmd)) + +void check_cublas_error(const char* name, cublasStatus_t err) { + if (err != CUBLAS_STATUS_SUCCESS) { + // TODO: Use cublasGetStatusString when it is widely available. + throw std::runtime_error( + fmt::format("{} failed with code: {}.", name, static_cast(err))); + } +} + +class MatMul { + public: + MatMul( + Device& device, + Dtype dtype, + bool a_transposed, + uint64_t a_rows, + uint64_t a_cols, + int64_t lda, + bool b_transposed, + uint64_t b_rows, + uint64_t b_cols, + int64_t ldb, + int32_t batch_count, + int64_t a_batch_stride, + int64_t b_batch_stride) + : handle_(device.lt_handle()) { + heuristic_.state = CUBLAS_STATUS_NOT_INITIALIZED; + + auto scale_type = dtype_to_cuda_type(dtype); + if (dtype == bfloat16 || dtype == float16) { + scale_type = CUDA_R_32F; + } + CHECK_CUBLAS_ERROR(cublasLtMatmulDescCreate( + &matmul_desc_, dtype_to_compute_type(dtype), scale_type)); + int32_t pointer_mode = CUBLASLT_POINTER_MODE_HOST; + CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( + matmul_desc_, + CUBLASLT_MATMUL_DESC_POINTER_MODE, + &pointer_mode, + sizeof(int32_t))); + cublasOperation_t op = CUBLAS_OP_N; + CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( + matmul_desc_, + CUBLASLT_MATMUL_DESC_TRANSA, + &op, + sizeof(cublasOperation_t))); + CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( + matmul_desc_, + CUBLASLT_MATMUL_DESC_TRANSB, + &op, + sizeof(cublasOperation_t))); + + auto type = dtype_to_cuda_type(dtype); + a_desc_ = create_matrix_layout( + type, a_rows, a_cols, a_transposed, lda, batch_count, a_batch_stride); + b_desc_ = create_matrix_layout( + type, b_rows, b_cols, b_transposed, ldb, batch_count, b_batch_stride); + out_desc_ = create_matrix_layout( + type, a_rows, b_cols, false, b_cols, batch_count, a_rows * b_cols); + + // The recommended cublas workspace size is 4 MiB for pre-Hopper and 32 MiB + // for Hopper+: + // https://docs.nvidia.com/cuda/cublas/#cublassetworkspace + uint64_t MiB = 1024 * 1024; + uint64_t workspace_size = + device.compute_capability_major() >= 9 ? 32 * MiB : 4 * MiB; + + CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceCreate(&pref_)); + CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceSetAttribute( + pref_, + CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, + &workspace_size, + sizeof(uint64_t))); + } + + MatMul( + Device& device, + Dtype dtype, + bool a_transposed, + uint64_t a_rows, + uint64_t a_cols, + int64_t lda, + bool b_transposed, + uint64_t b_rows, + uint64_t b_cols, + int64_t ldb, + bool c_transposed, + int64_t ldc, + int32_t batch_count, + int64_t a_batch_stride, + int64_t b_batch_stride, + int64_t c_batch_stride) + : MatMul( + device, + dtype, + a_transposed, + a_rows, + a_cols, + lda, + b_transposed, + b_rows, + b_cols, + ldb, + batch_count, + a_batch_stride, + b_batch_stride) { + auto type = dtype_to_cuda_type(dtype); + c_desc_ = create_matrix_layout( + type, a_rows, b_cols, c_transposed, ldc, batch_count, c_batch_stride); + } + + ~MatMul() { + cublasLtMatrixLayoutDestroy(a_desc_); + cublasLtMatrixLayoutDestroy(b_desc_); + cublasLtMatrixLayoutDestroy(c_desc_); + cublasLtMatrixLayoutDestroy(out_desc_); + cublasLtMatmulDescDestroy(matmul_desc_); + } + + void run( + cu::CommandEncoder& encoder, + void* out, + void* a, + void* b, + void* c = nullptr, + float alpha = 1, + float beta = 0) { + if (heuristic_.state != CUBLAS_STATUS_SUCCESS) { + int ret = 0; + CHECK_CUBLAS_ERROR(cublasLtMatmulAlgoGetHeuristic( + handle_, + matmul_desc_, + a_desc_, + b_desc_, + out_desc_, + out_desc_, + pref_, + 1, + &heuristic_, + &ret)); + if (ret == 0) { + throw std::runtime_error("Can not find algorithm for matmul."); + } + } + + void* workspace_ptr = nullptr; + if (heuristic_.workspaceSize > 0) { + array workspace( + allocator::malloc(heuristic_.workspaceSize), + {static_cast(heuristic_.workspaceSize)}, + int8); + encoder.add_temporary(workspace); + workspace_ptr = workspace.data(); + } + + auto capture = encoder.capture_context(); + CHECK_CUBLAS_ERROR(cublasLtMatmul( + handle_, + matmul_desc_, + &alpha, + a, + a_desc_, + b, + b_desc_, + &beta, + c ? c : out, + c ? c_desc_ : out_desc_, + out, + out_desc_, + &heuristic_.algo, + workspace_ptr, + heuristic_.workspaceSize, + encoder.stream())); + } + + private: + cublasComputeType_t dtype_to_compute_type(Dtype dtype) { + switch (dtype) { + case float16: + return CUBLAS_COMPUTE_32F; + case bfloat16: + return CUBLAS_COMPUTE_32F; + case float32: + return mlx::core::env::enable_tf32() ? CUBLAS_COMPUTE_32F_FAST_TF32 + : CUBLAS_COMPUTE_32F; + case float64: + case complex64: + return CUBLAS_COMPUTE_64F; + default: + throw std::runtime_error(fmt::format( + "Unsupported dtype in MatMul: {}.", dtype_to_string(dtype))); + } + } + + cudaDataType_t dtype_to_cuda_type(Dtype dtype) { + switch (dtype) { + case float16: + return CUDA_R_16F; + case bfloat16: + return CUDA_R_16BF; + case float32: + return CUDA_R_32F; + case float64: + return CUDA_R_64F; + case complex64: + return CUDA_C_32F; + default: + throw std::runtime_error(fmt::format( + "Unsupported dtype in MatMul: {}.", dtype_to_string(dtype))); + } + } + + cublasLtMatrixLayout_t create_matrix_layout( + cudaDataType_t type, + uint64_t rows, + uint64_t cols, + bool transposed, + int64_t ld, + int32_t batch_count, + int64_t batch_stride) { + cublasLtMatrixLayout_t desc; + CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(&desc, type, rows, cols, ld)); + cublasLtOrder_t order = + transposed ? CUBLASLT_ORDER_COL : CUBLASLT_ORDER_ROW; + CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute( + desc, CUBLASLT_MATRIX_LAYOUT_ORDER, &order, sizeof(cublasLtOrder_t))); + if (batch_count > 1) { + CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute( + desc, + CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, + &batch_count, + sizeof(int32_t))); + CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute( + desc, + CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, + &batch_stride, + sizeof(int64_t))); + } + return desc; + } + + cublasLtHandle_t handle_{nullptr}; + cublasLtMatmulDesc_t matmul_desc_{nullptr}; + cublasLtMatmulPreference_t pref_{nullptr}; + cublasLtMatrixLayout_t a_desc_{nullptr}; + cublasLtMatrixLayout_t b_desc_{nullptr}; + cublasLtMatrixLayout_t c_desc_{nullptr}; + cublasLtMatrixLayout_t out_desc_{nullptr}; + cublasLtMatmulHeuristicResult_t heuristic_; +}; + +} // namespace cu + +namespace { + +std::tuple +check_transpose(cu::CommandEncoder& enc, const Stream& s, const array& arr) { + auto stx = arr.strides()[arr.ndim() - 2]; + auto sty = arr.strides()[arr.ndim() - 1]; + if (sty == 1 && stx == arr.shape(-1)) { + return std::make_tuple(false, stx, arr); + } else if (stx == 1 && sty == arr.shape(-2)) { + return std::make_tuple(true, sty, arr); + } else { + array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); + copy_gpu(arr, arr_copy, CopyType::General, s); + enc.add_temporary(arr_copy); + return std::make_tuple(false, arr.shape(-1), arr_copy); + } +} + +} // namespace + +void Matmul::eval_gpu(const std::vector& inputs, array& out) { + nvtx3::scoped_range r("Matmul::eval_gpu"); + auto& s = stream(); + auto& encoder = cu::get_command_encoder(s); + + assert(inputs.size() == 2); + auto& a_pre = inputs[0]; + auto& b_pre = inputs[1]; + // Return 0s if either input is empty. + if (a_pre.size() == 0 || b_pre.size() == 0) { + array zero(0, a_pre.dtype()); + encoder.add_temporary(zero); + fill_gpu(zero, out, s); + return; + } + + out.set_data(allocator::malloc(out.nbytes())); + + ///////////////////////////////////////////////////////////////////////////// + // Init checks and prep + + int M = a_pre.shape(-2); + int N = b_pre.shape(-1); + int K = a_pre.shape(-1); + + // Keep a vector with copies to be cleared in the completed buffer to release + // the arrays + auto [a_transposed, lda, a] = check_transpose(encoder, s, a_pre); + auto [b_transposed, ldb, b] = check_transpose(encoder, s, b_pre); + + ///////////////////////////////////////////////////////////////////////////// + // Check and collapse batch dimensions + + auto [batch_shape, a_batch_strides, b_batch_strides] = collapse_batches(a, b); + + auto batch_count = out.size() / (M * N); + + // Collapse batches into M if needed + if (batch_count > 1 && !a_transposed && batch_shape.size() == 1 && + a.strides()[a.ndim() - 2] == K && a_batch_strides.back() == M * K && + b_batch_strides.back() == 0) { + M *= batch_shape.back(); + batch_count = 1; + + a_batch_strides = {0}; + b_batch_strides = {0}; + batch_shape = {1}; + } + + ///////////////////////////////////////////////////////////////////////////// + // Invoke cublasLt + + cu::MatMul matmul( + cu::device(s.device), + a.dtype(), + a_transposed, + M, + K, + lda, + b_transposed, + K, + N, + ldb, + batch_shape.back(), + a_batch_strides.back(), + b_batch_strides.back()); + + encoder.set_input_array(a); + encoder.set_input_array(b); + encoder.set_output_array(out); + auto nbatch = batch_count / batch_shape.back(); + if (nbatch == 1) { + matmul.run(encoder, out.data(), a.data(), b.data()); + return; + } + + ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1); + ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1); + auto concurrent = encoder.concurrent_context(); + for (size_t i = 0; i < nbatch; ++i) { + matmul.run( + encoder, + out.data() + out.itemsize() * i * batch_shape.back() * M * N, + a.data() + a.itemsize() * a_it.loc, + b.data() + b.itemsize() * b_it.loc); + a_it.step(); + b_it.step(); + } +} + +void AddMM::eval_gpu(const std::vector& inputs, array& out) { + nvtx3::scoped_range r("AddMM::eval_gpu"); + auto& s = stream(); + auto& encoder = cu::get_command_encoder(s); + + assert(inputs.size() == 3); + auto& a_pre = inputs[0]; + auto& b_pre = inputs[1]; + auto& c_pre = inputs[2]; + + out.set_data(allocator::malloc(out.nbytes())); + + ///////////////////////////////////////////////////////////////////////////// + // Init checks and prep + + int M = a_pre.shape(-2); + int N = b_pre.shape(-1); + int K = a_pre.shape(-1); + + // Keep a vector with copies to be cleared in the completed buffer to release + // the arrays + auto [a_transposed, lda, a] = check_transpose(encoder, s, a_pre); + auto [b_transposed, ldb, b] = check_transpose(encoder, s, b_pre); + auto [c_transposed, ldc, c] = check_transpose(encoder, s, c_pre); + + ///////////////////////////////////////////////////////////////////////////// + // Check and collapse batch dimensions + + auto [batch_shape, a_batch_strides, b_batch_strides, c_batch_strides] = + collapse_batches(a, b, c); + + auto batch_count = out.size() / (M * N); + + // Collapse batches into M if needed + if (batch_count > 1 && !a_transposed && batch_shape.size() == 1 && + a.strides()[a.ndim() - 2] == K && a_batch_strides.back() == M * K && + c_batch_strides.back() == M * c.strides()[c.ndim() - 2] && + b_batch_strides.back() == 0) { + M *= batch_shape.back(); + batch_count = 1; + + a_batch_strides = {0}; + b_batch_strides = {0}; + c_batch_strides = {0}; + batch_shape = {1}; + } + + ///////////////////////////////////////////////////////////////////////////// + // Invoke cublasLt + + cu::MatMul matmul( + cu::device(s.device), + a.dtype(), + a_transposed, + M, + K, + lda, + b_transposed, + K, + N, + ldb, + c_transposed, + ldc, + batch_shape.back(), + a_batch_strides.back(), + b_batch_strides.back(), + c_batch_strides.back()); + + encoder.set_input_array(a); + encoder.set_input_array(b); + encoder.set_input_array(c); + encoder.set_output_array(out); + + auto nbatch = batch_count / batch_shape.back(); + if (nbatch == 1) { + matmul.run( + encoder, + out.data(), + a.data(), + b.data(), + c.data(), + alpha_, + beta_); + return; + } + + ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1); + ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1); + ContiguousIterator c_it(batch_shape, c_batch_strides, batch_shape.size() - 1); + auto concurrent = encoder.concurrent_context(); + for (size_t i = 0; i < nbatch; ++i) { + matmul.run( + encoder, + out.data() + out.itemsize() * i * batch_shape.back() * M * N, + a.data() + a.itemsize() * a_it.loc, + b.data() + b.itemsize() * b_it.loc, + c.data() + c.itemsize() * c_it.loc, + alpha_, + beta_); + a_it.step(); + b_it.step(); + c_it.step(); + } +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/no_cuda.cpp b/mlx/backend/cuda/no_cuda.cpp new file mode 100644 index 000000000..8a394c9e3 --- /dev/null +++ b/mlx/backend/cuda/no_cuda.cpp @@ -0,0 +1,11 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/cuda.h" + +namespace mlx::core::cu { + +bool is_available() { + return false; +} + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/primitives.cu b/mlx/backend/cuda/primitives.cu new file mode 100644 index 000000000..a7f4e8f66 --- /dev/null +++ b/mlx/backend/cuda/primitives.cu @@ -0,0 +1,104 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/device/arange.cuh" +#include "mlx/backend/cuda/device/fp16_math.cuh" +#include "mlx/backend/cuda/kernel_utils.cuh" +#include "mlx/distributed/primitives.h" +#include "mlx/dtype_utils.h" +#include "mlx/fast_primitives.h" +#include "mlx/primitives.h" + +#include +#include +#include + +#include + +namespace mlx::core { + +void Arange::eval_gpu(const std::vector& inputs, array& out) { + nvtx3::scoped_range r("Arange::eval_gpu"); + assert(inputs.size() == 0); + out.set_data(allocator::malloc(out.nbytes())); + if (out.size() == 0) { + return; + } + auto& encoder = cu::get_command_encoder(stream()); + encoder.set_output_array(out); + auto capture = encoder.capture_context(); + dispatch_int_float_types(out.dtype(), "Arange", [&](auto type_tag) { + using CTYPE = MLX_GET_TYPE(type_tag); + using OutType = cuda_type_t; + CTYPE step = + static_cast(start_ + step_) - static_cast(start_); + thrust::transform( + cu::thrust_policy(encoder.stream()), + thrust::counting_iterator(0), + thrust::counting_iterator(out.data_size()), + thrust::device_pointer_cast(out.data()), + cu::Arange{ + static_cast(start_), static_cast(step)}); + }); +} + +bool fast::ScaledDotProductAttention::use_fallback( + const array& q, + const array& k, + const array& v, + bool has_mask, + bool has_arr_mask, + bool do_causal, + Stream s) { + return true; +} + +#define NO_GPU_MULTI(func) \ + void func::eval_gpu( \ + const std::vector& inputs, std::vector& outputs) { \ + throw std::runtime_error(#func " has no CUDA implementation."); \ + } + +#define NO_GPU_USE_FALLBACK(func) \ + bool func::use_fallback(Stream s) { \ + return true; \ + } \ + NO_GPU_MULTI(func) + +#define NO_GPU(func) \ + void func::eval_gpu(const std::vector& inputs, array& out) { \ + throw std::runtime_error(#func " has no CUDA implementation."); \ + } + +NO_GPU(BlockMaskedMM) +NO_GPU(Convolution) +NO_GPU(DynamicSlice) +NO_GPU(DynamicSliceUpdate) +NO_GPU(FFT) +NO_GPU(GatherMM) +NO_GPU(GatherQMM) +NO_GPU(Hadamard) +NO_GPU(Load) +NO_GPU_MULTI(LUF) +NO_GPU_MULTI(QRF) +NO_GPU(QuantizedMatmul) +NO_GPU(SegmentedMM) +NO_GPU_MULTI(SVD) +NO_GPU(Inverse) +NO_GPU(Cholesky) +NO_GPU_MULTI(Eig) +NO_GPU_MULTI(Eigh) + +namespace fast { +NO_GPU(ScaledDotProductAttention) +NO_GPU_MULTI(CustomKernel) +} // namespace fast + +namespace distributed { +NO_GPU_MULTI(AllReduce) +NO_GPU_MULTI(AllGather) +NO_GPU_MULTI(Send) +NO_GPU_MULTI(Recv) +} // namespace distributed + +} // namespace mlx::core diff --git a/mlx/backend/cuda/quantized.cu b/mlx/backend/cuda/quantized.cu new file mode 100644 index 000000000..12a1f6fe4 --- /dev/null +++ b/mlx/backend/cuda/quantized.cu @@ -0,0 +1,383 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/kernel_utils.cuh" +#include "mlx/backend/gpu/copy.h" +#include "mlx/dtype_utils.h" +#include "mlx/fast_primitives.h" + +#include +#include +#include + +namespace mlx::core { +namespace cu { + +namespace cg = cooperative_groups; + +template +inline constexpr __device__ short get_pack_factor() { + return (bits == 3 || bits == 5) ? 8 : (bits == 6 ? 4 : wsize / bits); +} + +template +inline constexpr __device__ short get_bytes_per_pack() { + constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; + return power_of_2_bits ? (wsize / 8) : (bits == 5 ? 5 : 3); +} + +template +__global__ void +affine_quantize(const T* w, uint8_t* out, T* scales, T* biases, size_t size) { + auto block_size = cg::this_thread_block().dim_threads(); + auto block_idx = cg::this_thread_block().group_index(); + auto idx_in_block = cg::this_thread_block().thread_index(); + + auto tidx = block_idx.x * block_size.x + idx_in_block.x; + auto tidy = block_idx.y * block_size.y + idx_in_block.y; + + auto grid_dim = cg::this_grid().dim_threads(); + constexpr float eps = 1e-7; + constexpr int simd_size = WARP_SIZE; + constexpr float n_bins = (1 << bits) - 1; + constexpr int pack_factor = get_pack_factor(); + constexpr int bytes_per_pack = get_bytes_per_pack(); + constexpr int values_per_reduce = group_size / simd_size; + constexpr int writes_per_reduce = pack_factor / values_per_reduce; + constexpr int writes_per_pack = + writes_per_reduce > 1 ? 1 : values_per_reduce / pack_factor; + constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; + + size_t offset = tidx + grid_dim.x * size_t(tidy); + size_t in_index = offset * values_per_reduce; + if (in_index >= size) { + return; + } + size_t out_index = power_of_2_bits + ? offset * writes_per_pack + : offset * bytes_per_pack / writes_per_reduce; + + float w_thread[values_per_reduce]; + float w_min = Limits::max(); + float w_max = 0; + +#pragma clang loop unroll(full) + for (int i = 0; i < values_per_reduce; i++) { + float val = w[in_index + i]; + w_thread[i] = val; + w_min = min(w_min, val); + w_max = max(w_max, val); + } + + cg::greater max_op; + cg::less min_op; + auto warp = cg::tiled_partition(cg::this_thread_block()); + + w_min = cg::reduce(warp, w_min, min_op); + w_max = cg::reduce(warp, w_max, max_op); + + float scale = max((w_max - w_min) / n_bins, eps); + bool side = abs(w_min) > abs(w_max); + scale = side ? scale : -scale; + float edge = side ? w_min : w_max; + float q0 = round(edge / scale); + bool at_zero = q0 == 0.0f; + scale = at_zero ? scale : edge / q0; + float bias = at_zero ? 0 : edge; + + // Write out the scales and biases + size_t gindex = in_index / group_size; + if (in_index % group_size == 0) { + scales[gindex] = static_cast(scale); + biases[gindex] = static_cast(bias); + } + + using OutType = std::conditional_t; + OutType output = 0; + +#pragma clang loop unroll(full) + for (int i = 0; i < values_per_reduce; i++) { + uint8_t val = min(round((w_thread[i] - bias) / scale), n_bins); + if (bits == 8) { + output = val; + } else { + output |= val << (bits * (i % pack_factor)); + } + + if (pack_factor < values_per_reduce && i % pack_factor == pack_factor - 1) { + out[out_index + i / pack_factor] = output; + output = 0; + } else { +#pragma clang loop unroll(full) + for (int j = 1; j < writes_per_reduce; j++) { + uint8_t sval = warp.shfl_down(val, j); + output |= static_cast(sval) + << (bits * (j * values_per_reduce + i)); + } + } + } + if constexpr (bits == 3 || bits == 6) { + if (in_index % pack_factor == 0 && out_index % bytes_per_pack == 0) { + out[out_index] = output & 0xff; + out[out_index + 1] = (output & 0xff00) >> 8; + out[out_index + 2] = (output & 0xff0000) >> 16; + } + } else if constexpr (bits == 5) { + if (in_index % pack_factor == 0 && out_index % bytes_per_pack == 0) { + out[out_index] = output & 0xff; + out[out_index + 1] = (output & 0xff00) >> 8; + out[out_index + 2] = (output & 0xff0000) >> 16; + out[out_index + 3] = (output & 0xff000000) >> 24; + out[out_index + 4] = (output & 0xff00000000) >> 32; + } + } else { + if constexpr (writes_per_reduce > 0) { + if (out_index % writes_per_reduce == 0) { + out[out_index / writes_per_reduce] = output; + } + } + } +} + +template +__global__ void affine_dequantize( + const uint8_t* w, + const T* scales, + const T* biases, + T* out, + size_t size) { + auto block_size = cg::this_thread_block().dim_threads(); + auto block_idx = cg::this_thread_block().group_index(); + auto idx_in_block = cg::this_thread_block().thread_index(); + + auto tidx = block_idx.x * block_size.x + idx_in_block.x; + auto tidy = block_idx.y * block_size.y + idx_in_block.y; + + auto grid_dim = cg::this_grid().dim_threads(); + + constexpr int pack_factor = get_pack_factor(); + constexpr int bytes_per_pack = get_bytes_per_pack(); + + size_t offset = tidx + grid_dim.x * size_t(tidy); + size_t oindex = offset * pack_factor; + + if (oindex >= size) { + return; + } + + size_t gindex = oindex / group_size; + T scale = scales[gindex]; + T bias = biases[gindex]; + out += oindex; + + if constexpr (bits == 3) { + w += offset * bytes_per_pack; + out[0] = static_cast(w[0] & 0x7) * scale + bias; + out[1] = static_cast((w[0] & 0x38) >> 3) * scale + bias; + out[2] = (static_cast((w[0] & 0xc0) >> 6) + + static_cast((w[1] & 0x1) << 2)) * + scale + + bias; + out[3] = static_cast((w[1] & 0xe) >> 1) * scale + bias; + out[4] = static_cast((w[1] & 0x70) >> 4) * scale + bias; + out[5] = (static_cast((w[1] & 0x80) >> 7) + + static_cast((w[2] & 0x3) << 1)) * + scale + + bias; + out[6] = static_cast((w[2] & 0x1c) >> 2) * scale + bias; + out[7] = static_cast((w[2] & 0xe0) >> 5) * scale + bias; + } else if constexpr (bits == 5) { + w += offset * bytes_per_pack; + out[0] = static_cast(w[0] & 0x1f) * scale + bias; + out[1] = (static_cast((w[0] & 0xe0) >> 5) + + static_cast((w[1] & 0x3) << 3)) * + scale + + bias; + out[2] = static_cast((w[1] & 0x7c) >> 2) * scale + bias; + out[3] = (static_cast((w[1] & 0x80) >> 7) + + static_cast((w[2] & 0xf) << 1)) * + scale + + bias; + out[4] = (static_cast((w[2] & 0xf0) >> 4) + + static_cast((w[3] & 0x1) << 4)) * + scale + + bias; + out[5] = static_cast((w[3] & 0x3e) >> 1) * scale + bias; + out[6] = (static_cast((w[3] & 0xc0) >> 6) + + static_cast((w[4] & 0x7) << 2)) * + scale + + bias; + out[7] = static_cast((w[4] & 0xf8) >> 3) * scale + bias; + } else if constexpr (bits == 6) { + w += offset * bytes_per_pack; + out[0] = static_cast(w[0] & 0x3f) * scale + bias; + out[1] = (static_cast((w[0] >> 6) & 0x03) + + static_cast((w[1] & 0x0f) << 2)) * + scale + + bias; + out[2] = (static_cast((w[1] >> 4) & 0x0f) + + static_cast((w[2] & 0x03) << 4)) * + scale + + bias; + out[3] = static_cast((w[2] >> 2) & 0x3f) * scale + bias; + } else { + uint val = w[offset]; +#pragma clang loop unroll(full) + for (int i = 0; i < pack_factor; i++) { + uint8_t d; + if (bits == 2) { + d = (val >> (bits * i)) & 0x03; + } else if (bits == 4) { + d = (val >> (bits * i)) & 0x0f; + } else if (bits == 8) { + d = val; + } + out[i] = scale * static_cast(d) + bias; + } + } +} + +} // namespace cu +namespace { + +inline array ensure_row_contiguous( + const array& x, + cu::CommandEncoder& enc, + const Stream& s) { + if (!x.flags().row_contiguous) { + array x_copy(x.shape(), x.dtype(), nullptr, {}); + copy_gpu(x, x_copy, CopyType::General, s); + enc.add_temporary(x_copy); + return x_copy; + } else { + return x; + } +} + +} // namespace + +template +void dispatch_groups(int group_size, F&& f) { + switch (group_size) { + case 32: + f(std::integral_constant{}); + break; + case 64: + f(std::integral_constant{}); + break; + case 128: + f(std::integral_constant{}); + break; + } +} + +template +void dispatch_bits(int bits, F&& f) { + switch (bits) { + case 2: + f(std::integral_constant{}); + break; + case 3: + f(std::integral_constant{}); + break; + case 4: + f(std::integral_constant{}); + break; + case 5: + f(std::integral_constant{}); + break; + case 6: + f(std::integral_constant{}); + break; + case 8: + f(std::integral_constant{}); + break; + } +} + +void fast::AffineQuantize::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + auto& w_pre = inputs[0]; + auto& out = outputs[0]; + out.set_data(allocator::malloc(out.nbytes())); + + auto& s = stream(); + auto& d = cu::device(s.device); + auto& enc = d.get_command_encoder(s); + + auto w = ensure_row_contiguous(w_pre, enc, s); + enc.set_input_array(w); + if (dequantize_) { + auto scales = ensure_row_contiguous(inputs[1], enc, s); + auto biases = ensure_row_contiguous(inputs[2], enc, s); + enc.set_input_array(scales); + enc.set_input_array(biases); + enc.set_output_array(out); + } else { + auto& scales = outputs[1]; + auto& biases = outputs[2]; + scales.set_data(allocator::malloc(scales.nbytes())); + biases.set_data(allocator::malloc(biases.nbytes())); + enc.set_output_array(out); + enc.set_output_array(scales); + enc.set_output_array(biases); + } + + auto dtype = dequantize_ ? outputs[0].dtype() : inputs[0].dtype(); + + // Treat uint32 as uint8 in kernel + int uint8_per_uint32 = 4; + int packs_per_int = (bits_ == 3 || bits_ == 5) ? 8 + : bits_ == 6 ? 4 + : 8 / bits_; + int per_thread = dequantize_ ? packs_per_int : group_size_ / WARP_SIZE; + size_t size = + dequantize_ ? out.size() / packs_per_int : w.size() / per_thread; + + bool large = size > UINT_MAX; + auto grid_shape = w.shape(); + + if (dequantize_) { + grid_shape.back() *= uint8_per_uint32; + } else { + grid_shape.back() /= per_thread; + } + + dispatch_float_types(dtype, "affine_quantize", [&](auto type_tag) { + dispatch_groups(group_size_, [&](auto group_size) { + dispatch_bits(bits_, [&](auto bits) { + using DataType = cuda_type_t; + if (dequantize_) { + auto kernel = cu::affine_dequantize; + auto [num_blocks, block_dims] = + get_launch_args(kernel, size, grid_shape, w.strides(), large); + enc.add_kernel_node( + kernel, + num_blocks, + block_dims, + w.data(), + inputs[1].data(), + inputs[2].data(), + out.data(), + out.size()); + } else { + auto kernel = cu::affine_quantize; + auto [num_blocks, block_dims] = + get_launch_args(kernel, size, grid_shape, w.strides(), large); + enc.add_kernel_node( + kernel, + num_blocks, + block_dims, + w.data(), + out.data(), + outputs[1].data(), + outputs[2].data(), + w.size()); + } + }); + }); + }); +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/random.cu b/mlx/backend/cuda/random.cu new file mode 100644 index 000000000..7221af356 --- /dev/null +++ b/mlx/backend/cuda/random.cu @@ -0,0 +1,194 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/kernel_utils.cuh" +#include "mlx/primitives.h" + +#include +#include + +#include + +namespace mlx::core { + +namespace cu { + +namespace cg = cooperative_groups; + +__constant__ constexpr uint32_t rotations[2][4] = { + {13, 15, 26, 6}, + {17, 29, 16, 24}}; + +union rbits { + uint2 val; + uint8_t bytes[2][4]; +}; + +__device__ rbits threefry2x32_hash(uint2 key, uint2 count) { + uint32_t ks[] = {key.x, key.y, key.x ^ key.y ^ 0x1BD11BDA}; + + rbits v; + v.val.x = count.x + ks[0]; + v.val.y = count.y + ks[1]; + + for (int i = 0; i < 5; ++i) { + for (auto r : rotations[i % 2]) { + v.val.x += v.val.y; + v.val.y = (v.val.y << r) | (v.val.y >> (32 - r)); + v.val.y ^= v.val.x; + } + v.val.x += ks[(i + 1) % 3]; + v.val.y += ks[(i + 2) % 3] + i + 1; + } + + return v; +} + +__global__ void rbitsc( + const uint32_t* keys, + uint8_t* out, + dim3 grid_dims, + bool odd, + uint32_t bytes_per_key) { + auto grid = cg::this_grid(); + uint thread_index = grid.thread_rank(); + uint index_x = thread_index % grid_dims.x; + uint index_y = thread_index / grid_dims.x; + if (index_x >= grid_dims.x || index_y >= grid_dims.y) { + return; + } + + auto kidx = 2 * index_x; + auto key = uint2{keys[kidx], keys[kidx + 1]}; + auto half_size = grid_dims.y - odd; + out += index_x * bytes_per_key; + bool drop_last = odd && (index_y == half_size); + auto bits = threefry2x32_hash( + key, uint2{index_y, drop_last ? 0 : index_y + grid_dims.y}); + size_t idx = size_t(index_y) << 2; + for (int i = 0; i < 4; ++i) { + out[idx + i] = bits.bytes[0][i]; + } + if (!drop_last) { + idx = (drop_last ? 0 : size_t(index_y) + grid_dims.y) << 2; + if ((index_y + 1) == half_size && (bytes_per_key % 4) > 0) { + int edge_bytes = (bytes_per_key % 4); + for (int i = 0; i < edge_bytes; ++i) { + out[idx + i] = bits.bytes[1][i]; + } + } else { + for (int i = 0; i < 4; ++i) { + out[idx + i] = bits.bytes[1][i]; + } + } + } +} + +__global__ void rbits( + const uint32_t* keys, + uint8_t* out, + dim3 grid_dims, + bool odd, + uint32_t bytes_per_key, + int32_t ndim, + const __grid_constant__ Shape key_shape, + const __grid_constant__ Strides key_strides) { + auto grid = cg::this_grid(); + uint thread_index = grid.thread_rank(); + uint index_x = thread_index % grid_dims.x; + uint index_y = thread_index / grid_dims.x; + if (index_x >= grid_dims.x || index_y >= grid_dims.y) { + return; + } + + auto kidx = 2 * index_x; + auto k1_elem = elem_to_loc(kidx, key_shape.data(), key_strides.data(), ndim); + auto k2_elem = + elem_to_loc(kidx + 1, key_shape.data(), key_strides.data(), ndim); + auto key = uint2{keys[k1_elem], keys[k2_elem]}; + auto half_size = grid_dims.y - odd; + out += size_t(index_x) * bytes_per_key; + bool drop_last = odd && (index_y == half_size); + auto bits = threefry2x32_hash( + key, uint2{index_y, drop_last ? 0 : index_y + grid_dims.y}); + size_t idx = size_t(index_y) << 2; + for (int i = 0; i < 4; ++i) { + out[idx + i] = bits.bytes[0][i]; + } + if (!drop_last) { + idx = (drop_last ? 0 : size_t(index_y) + grid_dims.y) << 2; + if ((index_y + 1) == half_size && (bytes_per_key % 4) > 0) { + int edge_bytes = (bytes_per_key % 4); + for (int i = 0; i < edge_bytes; ++i) { + out[idx + i] = bits.bytes[1][i]; + } + } else { + for (int i = 0; i < 4; ++i) { + out[idx + i] = bits.bytes[1][i]; + } + } + } +} + +} // namespace cu + +void RandomBits::eval_gpu(const std::vector& inputs, array& out) { + nvtx3::scoped_range r("RandomBits::eval_gpu"); + assert(inputs.size() == 1); + + // keys has shape (N1, ..., NK, 2) + // out has shape (N1, ..., NK, M1, M2, ...) + auto& keys = inputs[0]; + uint32_t num_keys = keys.size() / 2; + + uint32_t elems_per_key = out.size() / num_keys; + uint32_t bytes_per_key = out.itemsize() * elems_per_key; + out.set_data(allocator::malloc(out.nbytes())); + if (out.size() == 0) { + return; + } + + uint32_t out_per_key = (bytes_per_key + 4 - 1) / 4; + uint32_t half_size = out_per_key / 2; + bool odd = out_per_key % 2; + + auto& s = stream(); + auto& encoder = cu::get_command_encoder(s); + encoder.set_input_array(keys); + encoder.set_output_array(out); + dim3 grid_dims{num_keys, half_size + odd}; + int64_t total = grid_dims.x * grid_dims.y; + int32_t threads_y = 1; + while ((total / threads_y) >= (1U << 31)) { + threads_y *= 2; + } + int32_t threads_x = cuda::ceil_div(total, threads_y); + auto [grid, block] = get_grid_and_block(threads_x, threads_y, 1); + auto& stream = encoder.stream(); + if (keys.flags().row_contiguous) { + encoder.add_kernel_node( + cu::rbitsc, + grid, + block, + keys.data(), + out.data(), + grid_dims, + odd, + bytes_per_key); + } else { + encoder.add_kernel_node( + cu::rbits, + grid, + block, + keys.data(), + out.data(), + grid_dims, + odd, + bytes_per_key, + keys.ndim(), + const_param(keys.shape()), + const_param(keys.strides())); + } +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/reduce.cu b/mlx/backend/cuda/reduce.cu new file mode 100644 index 000000000..8350eebb7 --- /dev/null +++ b/mlx/backend/cuda/reduce.cu @@ -0,0 +1,76 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/reduce/reduce.cuh" +#include "mlx/backend/gpu/copy.h" + +#include +#include +#include + +#include + +namespace mlx::core { + +void Reduce::eval_gpu(const std::vector& inputs, array& out) { + nvtx3::scoped_range r("Reduce::eval_gpu"); + assert(inputs.size() == 1); + array in = inputs[0]; + + // Make sure no identity reductions trickle down here. + assert(!axes_.empty()); + assert(out.size() != in.size()); + + auto& s = stream(); + auto& encoder = cu::get_command_encoder(s); + + if (in.size() == 0) { + init_reduce(encoder, in, out, reduce_type_); + return; + } + + // Reduce. + ReductionPlan plan = get_reduction_plan(in, axes_); + + // If it is a general reduce then copy the input to a contiguous array and + // recompute the plan. + // + // TODO: Instead of copying we can use elem-to-loc to deal with broadcasting + // like we do in Metal. When it comes to broadcasted reduction axes + // some can be ignored eg for min/max. + bool broadcasted = false; + for (int i = 0, j = 0; i < in.ndim() && !broadcasted; i++) { + if (j < axes_.size() && axes_[j] == i) { + j++; + } else { + broadcasted = in.strides(i) == 0; + } + } + if (plan.type == GeneralReduce || broadcasted || !in.flags().contiguous) { + array in_copy(in.shape(), in.dtype(), nullptr, {}); + copy_gpu(in, in_copy, CopyType::General, s); + encoder.add_temporary(in_copy); + in = in_copy; + plan = get_reduction_plan(in, axes_); + } + + if (plan.type == ContiguousAllReduce) { + all_reduce(encoder, in, out, reduce_type_); + return; + } + + if (plan.type == ContiguousReduce || plan.type == GeneralContiguousReduce) { + row_reduce(encoder, in, out, reduce_type_, axes_, plan); + return; + } + + if (plan.type == ContiguousStridedReduce || + plan.type == GeneralStridedReduce) { + col_reduce(encoder, in, out, reduce_type_, axes_, plan); + return; + } + + throw std::runtime_error("No plan reached in reduce."); +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/reduce/all_reduce.cu b/mlx/backend/cuda/reduce/all_reduce.cu new file mode 100644 index 000000000..166a11a79 --- /dev/null +++ b/mlx/backend/cuda/reduce/all_reduce.cu @@ -0,0 +1,157 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/reduce/reduce.cuh" + +#include +#include +#include + +namespace mlx::core { + +namespace cu { + +namespace cg = cooperative_groups; + +template +__global__ void all_reduce(T* in, U* out, size_t block_step, size_t size) { + // TODO: Process multiple "rows" in each thread + constexpr int M = 1; + + auto grid = cg::this_grid(); + auto block = cg::this_thread_block(); + auto warp = cg::tiled_partition(block); + + const U init = cu::ReduceInit::value(); + ReduceOp op; + + T vals[N]; + U accs[M]; + accs[0] = init; + + size_t start = grid.block_rank() * block_step; + size_t end = start + block_step; + size_t check = min(end, size); + + size_t i = start; + for (; i + block.size() * N <= check; i += block.size() * N) { + cub::LoadDirectBlockedVectorized(block.thread_rank(), in + i, vals); + for (int j = 0; j < N; j++) { + accs[0] = op(accs[0], cast_to(vals[j])); + } + } + + if (i < check) { + cub::LoadDirectBlocked( + block.thread_rank(), in + i, vals, check - i, cast_to(init)); + for (int i = 0; i < N; i++) { + accs[0] = op(accs[0], cast_to(vals[i])); + } + } + + __shared__ U shared_accumulators[32]; + block_reduce(block, warp, accs, shared_accumulators, op, init); + + if (block.thread_rank() == 0) { + out[grid.block_rank()] = accs[0]; + } +} + +} // namespace cu + +void all_reduce( + cu::CommandEncoder& encoder, + const array& in, + array& out, + Reduce::ReduceType reduce_type) { + constexpr int N_READS = 8; + + out.set_data(allocator::malloc(out.nbytes())); + + auto get_args = [](size_t size, int N) { + int threads = std::min(512UL, (size + N - 1) / N); + threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; + int reductions_per_step = threads * N; + size_t steps_needed = + (size + reductions_per_step - 1) / reductions_per_step; + + int blocks; + if (steps_needed < 32) { + blocks = 1; + } else if (steps_needed < 128) { + blocks = 32; + } else if (steps_needed < 512) { + blocks = 128; + } else if (steps_needed < 1024) { + blocks = 512; + } else { + blocks = 1024; + } + + size_t steps_per_block = (steps_needed + blocks - 1) / blocks; + size_t block_step = steps_per_block * reductions_per_step; + + return std::make_tuple(blocks, threads, block_step); + }; + + int blocks, threads; + size_t block_step; + size_t insize = in.size(); + Dtype dt = in.dtype(); + + // Cub doesn't like const pointers for load (sigh). + void* indata = const_cast(in.data()); + + // Large array so allocate an intermediate and accumulate there + std::tie(blocks, threads, block_step) = get_args(insize, N_READS); + encoder.set_input_array(in); + if (blocks > 1) { + array intermediate({blocks}, out.dtype(), nullptr, {}); + intermediate.set_data(allocator::malloc(intermediate.nbytes())); + encoder.add_temporary(intermediate); + encoder.set_output_array(intermediate); + dispatch_all_types(dt, [&](auto type_tag) { + dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) { + using OP = MLX_GET_TYPE(reduce_type_tag); + using T = cuda_type_t; + using U = typename cu::ReduceResult::type; + auto kernel = cu::all_reduce; + encoder.add_kernel_node( + kernel, + blocks, + threads, + static_cast(indata), + intermediate.data(), + block_step, + insize); + }); + }); + + // Set the input for the next step and recalculate the blocks + indata = intermediate.data(); + dt = intermediate.dtype(); + insize = intermediate.size(); + std::tie(blocks, threads, block_step) = get_args(insize, N_READS); + encoder.set_input_array(intermediate); + } + + encoder.set_output_array(out); + dispatch_all_types(dt, [&](auto type_tag) { + dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) { + using OP = MLX_GET_TYPE(reduce_type_tag); + using T = cuda_type_t; + using U = typename cu::ReduceResult::type; + auto kernel = cu::all_reduce; + encoder.add_kernel_node( + kernel, + blocks, + threads, + static_cast(indata), + out.data(), + block_step, + insize); + }); + }); +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/reduce/col_reduce.cu b/mlx/backend/cuda/reduce/col_reduce.cu new file mode 100644 index 000000000..fec5ca76b --- /dev/null +++ b/mlx/backend/cuda/reduce/col_reduce.cu @@ -0,0 +1,265 @@ +// Copyright © 2025 Apple Inc. + +#include + +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/reduce/reduce.cuh" + +#include +#include +#include + +namespace mlx::core { + +namespace cu { + +namespace cg = cooperative_groups; + +struct ColReduceArgs { + // The size of the contiguous column reduction. + size_t reduction_size; + int64_t reduction_stride; + + // Input shape and strides excluding the reduction axes. + Shape shape; + Strides strides; + int ndim; + + // Input shape and strides of the reduction axes (including last dimension). + Shape reduce_shape; + Strides reduce_strides; + int reduce_ndim; + + // The number of column we are reducing. Namely prod(reduce_shape). + size_t non_col_reductions; + + ColReduceArgs( + const array& in, + const ReductionPlan& plan, + const std::vector& axes) { + using ShapeVector = decltype(plan.shape); + using StridesVector = decltype(plan.strides); + + ShapeVector shape_vec; + StridesVector strides_vec; + + assert(!plan.shape.empty()); + reduction_size = plan.shape.back(); + reduction_stride = plan.strides.back(); + + int64_t stride_back = 1; + std::tie(shape_vec, strides_vec) = shapes_without_reduction_axes(in, axes); + while (!shape_vec.empty() && stride_back < reduction_stride) { + stride_back *= shape_vec.back(); + shape_vec.pop_back(); + strides_vec.pop_back(); + } + std::vector indices(shape_vec.size()); + std::iota(indices.begin(), indices.end(), 0); + std::sort(indices.begin(), indices.end(), [&](int left, int right) { + return strides_vec[left] > strides_vec[right]; + }); + ShapeVector sorted_shape; + StridesVector sorted_strides; + for (auto idx : indices) { + sorted_shape.push_back(shape_vec[idx]); + sorted_strides.push_back(strides_vec[idx]); + } + std::tie(shape_vec, strides_vec) = + collapse_contiguous_dims(sorted_shape, sorted_strides); + shape = const_param(shape_vec); + strides = const_param(strides_vec); + ndim = shape_vec.size(); + + reduce_shape = const_param(plan.shape); + reduce_strides = const_param(plan.strides); + reduce_ndim = plan.shape.size(); + + non_col_reductions = 1; + for (int i = 0; i < reduce_ndim - 1; i++) { + non_col_reductions *= reduce_shape[i]; + } + } +}; + +template < + typename T, + typename U, + typename Op, + int NDIM, + int BM, + int BN, + int N_READS = 4> +__global__ void +col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) { + auto grid = cg::this_grid(); + auto block = cg::this_thread_block(); + auto warp = cg::tiled_partition(block); + + constexpr int threads_per_row = BN / N_READS; + + // Compute the indices for the tile + size_t tile_idx = grid.block_rank(); + size_t tile_x = tile_idx % ((args.reduction_stride + BN - 1) / BN); + size_t tile_y = tile_idx / ((args.reduction_stride + BN - 1) / BN); + + // Compute the indices for the thread within the tile + short thread_x = block.thread_rank() % threads_per_row; + short thread_y = block.thread_rank() / threads_per_row; + + // Move the input pointer + in += elem_to_loc(tile_y, args.shape.data(), args.strides.data(), args.ndim) + + tile_x * BN; + + // Initialize the running totals + Op op; + U totals[N_READS]; + for (int i = 0; i < N_READS; i++) { + totals[i] = ReduceInit::value(); + } + + LoopedElemToLoc 2)> loop(args.reduce_ndim); + loop.next(thread_y, args.reduce_shape.data(), args.reduce_strides.data()); + size_t total = args.non_col_reductions * args.reduction_size; + if (tile_x * BN + BN <= args.reduction_stride) { + if (args.reduction_stride % N_READS == 0) { + for (size_t r = thread_y; r < total; r += BM) { + T vals[N_READS]; + cub::LoadDirectBlockedVectorized(thread_x, in + loop.location(), vals); + for (int i = 0; i < N_READS; i++) { + totals[i] = op(totals[i], cast_to(vals[i])); + } + loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data()); + } + } else { + for (size_t r = thread_y; r < total; r += BM) { + T vals[N_READS]; + cub::LoadDirectBlocked(thread_x, in + loop.location(), vals); + for (int i = 0; i < N_READS; i++) { + totals[i] = op(totals[i], cast_to(vals[i])); + } + loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data()); + } + } + } else { + for (size_t r = thread_y; r < total; r += BM) { + T vals[N_READS]; + cub::LoadDirectBlocked( + thread_x, + in + loop.location(), + vals, + args.reduction_stride - tile_x * BN, + cast_to(ReduceInit::value())); + for (int i = 0; i < N_READS; i++) { + totals[i] = op(totals[i], cast_to(vals[i])); + } + loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data()); + } + } + + // Do warp reduce for each output. + constexpr int n_outputs = BN / threads_per_row; + static_assert(BM == 32 && n_outputs == N_READS); + __shared__ U shared_vals[BM * BN]; + short s_idx = thread_y * BN + thread_x * N_READS; + for (int i = 0; i < N_READS; i++) { + shared_vals[s_idx + i] = totals[i]; + } + block.sync(); + s_idx = warp.thread_rank() * BN + warp.meta_group_rank() * n_outputs; + for (int i = 0; i < n_outputs; i++) { + totals[i] = cg::reduce(warp, shared_vals[s_idx + i], op); + } + + // Write result. + if (warp.thread_rank() == 0) { + cub::StoreDirectBlocked( + warp.meta_group_rank(), + out + tile_y * args.reduction_stride + tile_x * BN, + totals, + args.reduction_stride - tile_x * BN); + } +} + +} // namespace cu + +inline auto output_grid_for_col_reduce( + const array& out, + const cu::ColReduceArgs& args, + int bn) { + int gx, gy = 1; + size_t n_inner_blocks = cuda::ceil_div(args.reduction_stride, bn); + size_t n_outer_blocks = out.size() / args.reduction_stride; + size_t n_blocks = n_outer_blocks * n_inner_blocks; + while (n_blocks / gy > INT32_MAX) { + gy *= 2; + } + gx = cuda::ceil_div(n_blocks, gy); + + return dim3(gx, gy, 1); +} + +void col_reduce_looped( + cu::CommandEncoder& encoder, + const array& in, + array& out, + Reduce::ReduceType reduce_type, + const std::vector& axes, + const ReductionPlan& plan, + cu::ColReduceArgs args) { + // Allocate data for the output using in's layout to access them as + // contiguously as possible. + allocate_same_layout(out, in, axes); + + encoder.set_input_array(in); + encoder.set_output_array(out); + dispatch_all_types(in.dtype(), [&](auto type_tag) { + dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) { + dispatch_reduce_ndim(args.reduce_ndim, [&](auto reduce_ndim) { + using OP = MLX_GET_TYPE(reduce_type_tag); + using T = cuda_type_t; + using U = typename cu::ReduceResult::type; + // Cub doesn't like const pointers for vectorized loads. (sigh) + T* indata = const_cast(in.data()); + + constexpr int N_READS = 4; + constexpr int BM = 32; + constexpr int BN = 32; + dim3 grid = output_grid_for_col_reduce(out, args, BN); + int blocks = BM * BN / N_READS; + auto kernel = + cu::col_reduce_looped; + encoder.add_kernel_node( + kernel, grid, blocks, indata, out.data(), args); + }); + }); + }); +} + +void col_reduce( + cu::CommandEncoder& encoder, + const array& in, + array& out, + Reduce::ReduceType reduce_type, + const std::vector& axes, + const ReductionPlan& plan) { + // Current col reduce options + // + // - col_reduce_looped + // + // It is a general strided reduce. Each threadblock computes the output for + // a subrow of the fast moving axis. For instance 32 elements. + // + // Notes: As in row reduce we opt to read as much in order as possible and + // leave transpositions as they are (contrary to our Metal backend). + // + // Moreover we need different kernels for short rows and tuning + + // Make the args struct to help route to the best kernel + cu::ColReduceArgs args(in, plan, axes); + + // Fallback col reduce + col_reduce_looped(encoder, in, out, reduce_type, axes, plan, args); +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/reduce/init_reduce.cu b/mlx/backend/cuda/reduce/init_reduce.cu new file mode 100644 index 000000000..649d80190 --- /dev/null +++ b/mlx/backend/cuda/reduce/init_reduce.cu @@ -0,0 +1,49 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/reduce/reduce.cuh" + +#include + +namespace mlx::core { + +namespace cu { + +namespace cg = cooperative_groups; + +template +__global__ void init_reduce(U* out, size_t size) { + auto index = cg::this_grid().thread_rank(); + if (index < size) { + out[index] = ReduceInit::value(); + } +} + +} // namespace cu + +void init_reduce( + cu::CommandEncoder& encoder, + const array& in, + array& out, + Reduce::ReduceType reduce_type) { + // Allocate if needed + if (out.data_shared_ptr() == nullptr) { + out.set_data(allocator::malloc(out.nbytes())); + } + + encoder.set_output_array(out); + dispatch_all_types(in.dtype(), [&](auto type_tag) { + dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) { + using OP = MLX_GET_TYPE(reduce_type_tag); + using T = cuda_type_t; + using U = typename cu::ReduceResult::type; + auto kernel = cu::init_reduce; + dim3 grid = get_2d_grid_dims(out.shape(), out.strides()); + dim3 block(grid.x < 1024 ? grid.x : 1024, 1, 1); + grid.x = (grid.x + 1023) / 1024; + encoder.add_kernel_node(kernel, grid, block, out.data(), out.size()); + }); + }); +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/reduce/reduce.cuh b/mlx/backend/cuda/reduce/reduce.cuh new file mode 100644 index 000000000..02e495594 --- /dev/null +++ b/mlx/backend/cuda/reduce/reduce.cuh @@ -0,0 +1,71 @@ +// Copyright © 2025 Apple Inc. + +#include + +#include "mlx/backend/common/reduce.h" +#include "mlx/backend/cuda/kernel_utils.cuh" +#include "mlx/backend/cuda/reduce/reduce_ops.cuh" +#include "mlx/dtype_utils.h" +#include "mlx/primitives.h" + +namespace mlx::core { + +template +void dispatch_reduce_ndim(int ndim, F&& f) { + if (ndim == 1) { + f(std::integral_constant{}); + } else if (ndim == 2) { + f(std::integral_constant{}); + } else { + f(std::integral_constant{}); + } +} + +template +void dispatch_reduce_ops(Reduce::ReduceType reduce_type, F&& f) { + if (reduce_type == Reduce::ReduceType::And) { + f(type_identity{}); + } else if (reduce_type == Reduce::ReduceType::Or) { + f(type_identity{}); + } else if (reduce_type == Reduce::ReduceType::Sum) { + f(type_identity{}); + } else if (reduce_type == Reduce::ReduceType::Prod) { + f(type_identity{}); + } else if (reduce_type == Reduce::ReduceType::Max) { + f(type_identity{}); + } else if (reduce_type == Reduce::ReduceType::Min) { + f(type_identity{}); + } else { + throw std::invalid_argument("Unknown reduce type."); + } +} + +void all_reduce( + cu::CommandEncoder& encoder, + const array& in, + array& out, + Reduce::ReduceType reduce_type); + +void row_reduce( + cu::CommandEncoder& encoder, + const array& in, + array& out, + Reduce::ReduceType reduce_type, + const std::vector& axes, + const ReductionPlan& plan); + +void col_reduce( + cu::CommandEncoder& encoder, + const array& in, + array& out, + Reduce::ReduceType reduce_type, + const std::vector& axes, + const ReductionPlan& plan); + +void init_reduce( + cu::CommandEncoder& encoder, + const array& in, + array& out, + Reduce::ReduceType reduce_type); + +} // namespace mlx::core diff --git a/mlx/backend/cuda/reduce/reduce_ops.cuh b/mlx/backend/cuda/reduce/reduce_ops.cuh new file mode 100644 index 000000000..7f8cad0c4 --- /dev/null +++ b/mlx/backend/cuda/reduce/reduce_ops.cuh @@ -0,0 +1,211 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/backend/cuda/device/atomic_ops.cuh" +#include "mlx/backend/cuda/device/cast_op.cuh" +#include "mlx/backend/cuda/device/utils.cuh" +#include "mlx/backend/cuda/reduce/reduce_utils.cuh" + +namespace mlx::core::cu { + +// Reduce ops. +struct And { + __device__ __forceinline__ bool operator()(bool a, bool b) { + return a && b; + } + + __device__ void atomic_update(bool* x, bool y) { + atomic_reduce(x, y); + } +}; + +struct Or { + __device__ __forceinline__ bool operator()(bool a, bool b) { + return a || b; + } + + __device__ void atomic_update(bool* x, bool y) { + atomic_reduce(x, y); + } +}; + +struct Sum { + template + __device__ __forceinline__ T operator()(T a, T b) { + return a + b; + } + + template + __device__ void atomic_update(T* x, T y) { + atomic_reduce(x, y); + } + + __device__ void atomic_update(__nv_bfloat16* x, __nv_bfloat16 y) { + atomic_add(x, y); + } + + __device__ void atomic_update(int* x, int y) { + atomic_add(x, y); + } + + __device__ void atomic_update(float* x, float y) { + atomic_add(x, y); + } +}; + +struct Prod { + template + __device__ __forceinline__ T operator()(T a, T b) { + return a * b; + } + + template + __device__ void atomic_update(T* x, T y) { + atomic_reduce(x, y); + } +}; + +struct Min { + template + __device__ __forceinline__ T operator()(T a, T b) { + if constexpr (is_complex_v) { + if (isnan(a.real()) || isnan(a.imag())) { + return a; + } + if (isnan(b.real()) || isnan(b.imag())) { + return b; + } + } else if constexpr (!cuda::std::is_integral_v) { + if (isnan(a) || isnan(b)) { + return cuda::std::numeric_limits::quiet_NaN(); + } + } + return a < b ? a : b; + } + + template + __device__ void atomic_update(T* x, T y) { + atomic_reduce(x, y); + } +}; + +struct Max { + template + __device__ __forceinline__ T operator()(T a, T b) { + if constexpr (is_complex_v) { + if (isnan(a.real()) || isnan(a.imag())) { + return a; + } + if (isnan(b.real()) || isnan(b.imag())) { + return b; + } + } else if constexpr (!cuda::std::is_integral_v) { + if (isnan(a) || isnan(b)) { + return cuda::std::numeric_limits::quiet_NaN(); + } + } + return a > b ? a : b; + } + + template + __device__ void atomic_update(T* x, T y) { + atomic_reduce(x, y); + } +}; + +// Traits to get the result type of reduce op. +template +struct ReduceResult; + +template +struct ReduceResult { + using type = bool; +}; + +template +struct ReduceResult { + using type = bool; +}; + +template +struct ReduceResult { + using type = cuda::std::conditional_t< + (cuda::std::is_integral_v && sizeof(T) <= 4), + int32_t, + T>; +}; + +template +struct ReduceResult { + using type = cuda::std::conditional_t< + (cuda::std::is_integral_v && sizeof(T) <= 4), + int32_t, + T>; +}; + +template +struct ReduceResult { + using type = T; +}; + +template +struct ReduceResult { + using type = T; +}; + +// Traits to get the init value of reduce op. +template +struct ReduceInit; + +template +struct ReduceInit { + static constexpr __host__ __device__ bool value() { + return true; + } +}; + +template +struct ReduceInit { + static constexpr __host__ __device__ bool value() { + return false; + } +}; + +template +struct ReduceInit { + static constexpr __host__ __device__ auto value() { + if constexpr (is_complex_v) { + return T{0, 0}; + } else { + return cast_to::type>(0); + } + } +}; + +template +struct ReduceInit { + static constexpr __host__ __device__ auto value() { + if constexpr (is_complex_v) { + return T{1, 0}; + } else { + return cast_to::type>(1); + } + } +}; + +template +struct ReduceInit { + static constexpr __host__ __device__ T value() { + return Limits::max(); + } +}; + +template +struct ReduceInit { + static constexpr __host__ __device__ T value() { + return Limits::min(); + } +}; + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/reduce/reduce_utils.cuh b/mlx/backend/cuda/reduce/reduce_utils.cuh new file mode 100644 index 000000000..d993bacbb --- /dev/null +++ b/mlx/backend/cuda/reduce/reduce_utils.cuh @@ -0,0 +1,143 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include + +#include "mlx/backend/common/utils.h" +#include "mlx/backend/cuda/device/utils.cuh" + +#include +#include + +namespace mlx::core { + +namespace cu { + +namespace cg = cooperative_groups; + +template +struct uint_by_size; +template <> +struct uint_by_size<2> { + using type = uint16_t; +}; +template <> +struct uint_by_size<4> { + using type = uint32_t; +}; +template <> +struct uint_by_size<8> { + using type = unsigned long long int; +}; + +template +__device__ void atomic_reduce(T* x, T y) { + if constexpr (sizeof(T) == 1) { + using U = uint16_t; + U* x_int = (U*)((char*)x - ((size_t)x % 2)); + int shift = ((char*)x - (char*)x_int) * 8; + int mask = 0xff << shift; + U old_val, new_val; + do { + old_val = *x_int; + T result = Op{}(static_cast((old_val >> shift) & 0xff), y); + new_val = (old_val & ~mask) | (result << shift); + } while (atomicCAS(x_int, old_val, new_val) != old_val); + } else { + using U = typename uint_by_size::type; + U* x_int = (U*)(x); + U old_val, new_val; + do { + old_val = *x_int; + T result = Op{}(*((T*)&old_val), y); + new_val = *((U*)&result); + } while (atomicCAS(x_int, old_val, new_val) != old_val); + } +} + +template +inline __device__ void +block_reduce(Block block, Warp warp, T (&vals)[N], T* smem, Op op, T init) { + // First reduce in the current warp + for (int i = 0; i < N; i++) { + vals[i] = cg::reduce(warp, vals[i], op); + } + + // Reduce across warps + if (warp.meta_group_size() > 1) { + if (warp.thread_rank() == 0) { + for (int i = 0; i < N; i++) { + smem[warp.meta_group_rank() * N + i] = vals[i]; + } + } + block.sync(); + if (warp.thread_rank() < warp.meta_group_size()) { + for (int i = 0; i < N; i++) { + vals[i] = smem[warp.thread_rank() * N + i]; + } + } else { + for (int i = 0; i < N; i++) { + vals[i] = init; + } + } + for (int i = 0; i < N; i++) { + vals[i] = cg::reduce(warp, vals[i], op); + } + } +} + +} // namespace cu + +inline void allocate_same_layout( + array& out, + const array& in, + const std::vector& axes) { + if (in.flags().row_contiguous) { + out.set_data(allocator::malloc(out.nbytes())); + return; + } + + if (out.ndim() < in.ndim()) { + throw std::runtime_error( + "Reduction without keepdims only supported for row-contiguous inputs"); + } + + // Calculate the transpositions applied to in in order to apply them to out. + std::vector axis_order(in.ndim()); + std::iota(axis_order.begin(), axis_order.end(), 0); + std::sort(axis_order.begin(), axis_order.end(), [&](int left, int right) { + return in.strides(left) > in.strides(right); + }); + + // Transpose the shape and calculate the strides + Shape out_shape(in.ndim()); + Strides out_strides(in.ndim(), 1); + for (int i = 0; i < in.ndim(); i++) { + out_shape[i] = out.shape(axis_order[i]); + } + for (int i = in.ndim() - 2; i >= 0; i--) { + out_strides[i] = out_shape[i + 1] * out_strides[i + 1]; + } + + // Reverse the axis order to get the final strides + Strides final_strides(in.ndim()); + for (int i = 0; i < in.ndim(); i++) { + final_strides[axis_order[i]] = out_strides[i]; + } + + // Calculate the resulting contiguity and do the memory allocation + auto [data_size, rc, cc] = check_contiguity(out.shape(), final_strides); + auto fl = in.flags(); + fl.row_contiguous = rc; + fl.col_contiguous = cc; + fl.contiguous = true; + out.set_data( + allocator::malloc(out.nbytes()), + data_size, + final_strides, + fl, + allocator::free); +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/reduce/row_reduce.cu b/mlx/backend/cuda/reduce/row_reduce.cu new file mode 100644 index 000000000..61838ddd3 --- /dev/null +++ b/mlx/backend/cuda/reduce/row_reduce.cu @@ -0,0 +1,368 @@ +// Copyright © 2025 Apple Inc. + +#include + +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/reduce/reduce.cuh" + +#include +#include +#include +#include + +namespace mlx::core { + +namespace cu { + +namespace cg = cooperative_groups; + +struct RowReduceArgs { + // The size of the row being reduced, i.e. the size of last dimension. + int row_size; + + // Input shape and strides excluding the reduction axes. + Shape shape; + Strides strides; + int ndim; + + // Input shape and strides of the reduction axes excluding last dimension. + Shape reduce_shape; + Strides reduce_strides; + int reduce_ndim; + + // The number of rows we are reducing. Namely prod(reduce_shape). + size_t non_row_reductions; + + RowReduceArgs( + const array& in, + const ReductionPlan& plan, + const std::vector& axes) { + assert(!plan.shape.empty()); + row_size = plan.shape.back(); + + auto [shape_vec, strides_vec] = shapes_without_reduction_axes(in, axes); + std::tie(shape_vec, strides_vec) = + collapse_contiguous_dims(shape_vec, strides_vec); + shape = const_param(shape_vec); + strides = const_param(strides_vec); + ndim = shape_vec.size(); + + reduce_shape = const_param(plan.shape); + reduce_strides = const_param(plan.strides); + reduce_ndim = plan.shape.size() - 1; + + non_row_reductions = 1; + for (int i = 0; i < reduce_ndim; i++) { + non_row_reductions *= reduce_shape[i]; + } + } + + // Convert shape and strides as if in was contiguous + void sort_access_pattern(const array& in, const std::vector& axes) { + auto shape_vec = in.shape(); + auto strides_vec = in.strides(); + std::tie(shape_vec, strides_vec) = + shapes_without_reduction_axes(shape_vec, strides_vec, axes); + std::vector indices(shape_vec.size()); + std::iota(indices.begin(), indices.end(), 0); + std::sort(indices.begin(), indices.end(), [&](int left, int right) { + return strides_vec[left] > strides_vec[right]; + }); + decltype(shape_vec) sorted_shape; + decltype(strides_vec) sorted_strides; + for (auto idx : indices) { + sorted_shape.push_back(shape_vec[idx]); + sorted_strides.push_back(strides_vec[idx]); + } + std::tie(shape_vec, strides_vec) = + collapse_contiguous_dims(sorted_shape, sorted_strides); + shape = const_param(shape_vec); + strides = const_param(strides_vec); + ndim = shape_vec.size(); + } +}; + +template +__global__ void row_reduce_simple(T* in, U* out, size_t n_rows, int size) { + auto grid = cg::this_grid(); + auto block = cg::this_thread_block(); + auto warp = cg::tiled_partition(block); + + const U init = cu::ReduceInit::value(); + ReduceOp op; + + T vals[M][N]; + U accs[M]; + for (int i = 0; i < M; i++) { + accs[i] = init; + } + + const size_t start_row = + min(n_rows - M, static_cast(grid.block_rank() * M)); + const size_t full_blocks = size / (block.size() * N); + const size_t final_offset = full_blocks * (block.size() * N); + in += start_row * size; + out += start_row; + + if (size % N == 0) { + for (size_t r = 0; r < full_blocks; r++) { + for (int k = 0; k < M; k++) { + cub::LoadDirectBlockedVectorized( + block.thread_rank(), + in + k * size + r * (block.size() * N), + vals[k]); + for (int j = 0; j < N; j++) { + accs[k] = op(accs[k], cast_to(vals[k][j])); + } + } + } + } else { + for (size_t r = 0; r < full_blocks; r++) { + for (int k = 0; k < M; k++) { + cub::LoadDirectBlocked( + block.thread_rank(), + in + k * size + r * (block.size() * N), + vals[k]); + for (int j = 0; j < N; j++) { + accs[k] = op(accs[k], cast_to(vals[k][j])); + } + } + } + } + + if (final_offset < size) { + for (int k = 0; k < M; k++) { + cub::LoadDirectBlocked( + block.thread_rank(), + in + k * size + final_offset, + vals[k], + size, + cast_to(init)); + for (int j = 0; j < N; j++) { + accs[k] = op(accs[k], cast_to(vals[k][j])); + } + } + } + + __shared__ U shared_accumulators[32 * M]; + block_reduce(block, warp, accs, shared_accumulators, op, init); + + if (block.thread_rank() == 0) { + if (grid.block_rank() * M + M <= n_rows) { + for (int i = 0; i < M; i++) { + out[i] = accs[i]; + } + } else { + short offset = grid.block_rank() * M + M - n_rows; + for (int i = offset; i < M; i++) { + out[i] = accs[i]; + } + } + } +} + +template < + typename T, + typename U, + typename Op, + int NDIM, + int BLOCK_DIM, + int N_READS = 4> +__global__ void row_reduce_looped( + T* in, + U* out, + size_t out_size, + const __grid_constant__ RowReduceArgs args) { + auto grid = cg::this_grid(); + auto block = cg::this_thread_block(); + auto warp = cg::tiled_partition(block); + + size_t out_idx = grid.block_rank(); + + Op op; + + U total[1]; + U init = ReduceInit::value(); + total[0] = init; + LoopedElemToLoc 2)> loop(args.reduce_ndim); + size_t full_blocks = args.row_size / (BLOCK_DIM * N_READS); + size_t final_offset = full_blocks * BLOCK_DIM * N_READS; + + in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim); + + for (size_t n = 0; n < args.non_row_reductions; n++) { + for (size_t r = 0; r < full_blocks; r++) { + T vals[N_READS]; + cub::LoadDirectBlockedVectorized( + block.thread_rank(), + in + loop.location() + r * BLOCK_DIM * N_READS, + vals); + for (int i = 0; i < N_READS; i++) { + total[0] = op(total[0], cast_to(vals[i])); + } + } + if (final_offset < args.row_size) { + T vals[N_READS]; + cub::LoadDirectBlocked( + block.thread_rank(), + in + loop.location() + final_offset, + vals, + args.row_size - final_offset, + cast_to(init)); + for (int i = 0; i < N_READS; i++) { + total[0] = op(total[0], cast_to(vals[i])); + } + } + // TODO: Maybe block.sync() here? + loop.next(args.reduce_shape.data(), args.reduce_strides.data()); + } + + __shared__ U shared_accumulators[32]; + block_reduce(block, warp, total, shared_accumulators, op, init); + + if (block.thread_rank() == 0) { + out[out_idx] = total[0]; + } +} + +} // namespace cu + +void row_reduce_simple( + cu::CommandEncoder& encoder, + const array& in, + array& out, + Reduce::ReduceType reduce_type, + const std::vector& axes, + const ReductionPlan& plan) { + constexpr int N_READS = 8; + + // Allocate data for the output using in's layout to avoid elem_to_loc in the + // kernel. + allocate_same_layout(out, in, axes); + + // TODO: If out.size() < 1024 which will be a common case then write this in + // 2 passes. Something like 32 * out.size() and then do a warp reduce. + encoder.set_input_array(in); + encoder.set_output_array(out); + dispatch_all_types(in.dtype(), [&](auto type_tag) { + dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) { + using OP = MLX_GET_TYPE(reduce_type_tag); + using T = cuda_type_t; + using U = typename cu::ReduceResult::type; + + // Cub doesn't like const pointers for vectorized loads. (sigh) + T* indata = const_cast(in.data()); + + // Calculate the grid and block dims + size_t reductions = (plan.shape.back() + N_READS - 1) / N_READS; + dim3 grid = get_2d_grid_dims(out.shape(), out.strides()); + int threads = std::min(1024UL, reductions); + threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; + dim3 block(threads, 1, 1); + + // Pick the kernel + auto kernel = cu::row_reduce_simple; + if (grid.x >= 1024) { + grid.x = (grid.x + 1) / 2; + kernel = cu::row_reduce_simple; + } + + int size = plan.shape.back(); + encoder.add_kernel_node( + kernel, grid, block, indata, out.data(), out.size(), size); + }); + }); +} + +void row_reduce_looped( + cu::CommandEncoder& encoder, + const array& in, + array& out, + Reduce::ReduceType reduce_type, + const std::vector& axes, + const ReductionPlan& plan, + cu::RowReduceArgs args) { + constexpr int N_READS = 8; + + // Allocate data for the output using in's layout to access them as + // contiguously as possible. + allocate_same_layout(out, in, axes); + + encoder.set_input_array(in); + encoder.set_output_array(out); + dispatch_all_types(in.dtype(), [&](auto type_tag) { + dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) { + using OP = MLX_GET_TYPE(reduce_type_tag); + using T = cuda_type_t; + using U = typename cu::ReduceResult::type; + // Cub doesn't like const pointers for vectorized loads. (sigh) + T* indata = const_cast(in.data()); + + // Calculate the grid and block dims + args.sort_access_pattern(in, axes); + dim3 grid = get_2d_grid_dims(out.shape(), out.strides()); + size_t reductions = (args.row_size + N_READS - 1) / N_READS; + int threads = std::min(1024UL, reductions); + threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; + dim3 block(threads, 1, 1); + + // Pick the kernel + auto kernel = cu::row_reduce_looped; + dispatch_reduce_ndim(args.reduce_ndim, [&](auto reduce_ndim) { + dispatch_block_dim(threads, [&](auto threads_constant) { + kernel = cu::row_reduce_looped< + T, + U, + OP, + reduce_ndim.value, + threads_constant.value, + N_READS>; + block.x = threads_constant.value; + }); + }); + + encoder.add_kernel_node( + kernel, grid, block, indata, out.data(), out.size(), args); + }); + }); +} + +void row_reduce( + cu::CommandEncoder& encoder, + const array& in, + array& out, + Reduce::ReduceType reduce_type, + const std::vector& axes, + const ReductionPlan& plan) { + // Current row reduction options + // + // - row_reduce_simple + // + // That means that we are simply reducing across the fastest moving axis. + // We are reducing 1 or 2 rows per threadblock depending on the size of + // output. + // + // - row_reduce_looped + // + // It is a general row reduction. We are computing 1 output per + // threadblock. We read the fastest moving axis vectorized and loop over + // the rest of the axes. + // + // Notes: We opt to read as much in order as possible and leave + // transpositions as they are (contrary to our Metal backend). + + // Simple row reduce means that we have 1 axis that we are reducing over and + // it has stride 1. + if (plan.shape.size() == 1) { + row_reduce_simple(encoder, in, out, reduce_type, axes, plan); + return; + } + + // Make the args struct to help route to the best kernel + cu::RowReduceArgs args(in, plan, axes); + + // Fallback row reduce + row_reduce_looped(encoder, in, out, reduce_type, axes, plan, std::move(args)); +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/rms_norm.cu b/mlx/backend/cuda/rms_norm.cu new file mode 100644 index 000000000..964bd7d98 --- /dev/null +++ b/mlx/backend/cuda/rms_norm.cu @@ -0,0 +1,354 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/iterators/strided_iterator.cuh" +#include "mlx/backend/cuda/kernel_utils.cuh" +#include "mlx/backend/cuda/reduce/reduce.cuh" +#include "mlx/backend/gpu/copy.h" +#include "mlx/dtype_utils.h" +#include "mlx/fast_primitives.h" + +#include +#include +#include +#include +#include + +namespace mlx::core { + +namespace cu { + +namespace cg = cooperative_groups; + +inline __device__ float2 plus_f2(const float2& a, const float2& b) { + return {a.x + b.x, a.y + b.y}; +} + +// Similar to cub::BlockReduce, but result is broadcasted to every thread. +template +struct BlockBroadcastReduce { + static_assert(WARP_SIZE <= BLOCK_DIM && BLOCK_DIM <= WARP_SIZE * WARP_SIZE); + static_assert(BLOCK_DIM % WARP_SIZE == 0); + using TempStorage = T[BLOCK_DIM / WARP_SIZE]; + + cg::thread_block& block; + TempStorage& temp; + + template + __device__ T Reduce(const T& input, const Op& op, const T& init_value) { + auto warp = cg::tiled_partition(block); + T x = cg::reduce(warp, input, op); + if (warp.thread_rank() == 0) { + temp[warp.meta_group_rank()] = x; + } + block.sync(); + x = warp.thread_rank() < warp.meta_group_size() ? temp[warp.thread_rank()] + : init_value; + return cg::reduce(warp, x, op); + } + + __device__ T Sum(const T& input) { + return Reduce(input, cg::plus{}, T{}); + } +}; + +template +__global__ void rms_norm( + const T* x, + const T* w, + T* out, + float eps, + int32_t axis_size, + int64_t w_stride) { + auto grid = cg::this_grid(); + auto block = cg::this_thread_block(); + + using BlockReduceT = BlockBroadcastReduce; + __shared__ typename BlockReduceT::TempStorage temp; + + x += grid.block_rank() * axis_size; + out += grid.block_rank() * axis_size; + + // Normalizer. + float normalizer = 0; + for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) { + auto index = r * BLOCK_DIM + block.thread_rank(); + T xn[N_READS]; + cub::LoadDirectBlocked(index, x, xn, axis_size, cast_to(0)); + for (int i = 0; i < N_READS; ++i) { + float t = static_cast(xn[i]); + normalizer += t * t; + } + } + normalizer = BlockReduceT{block, temp}.Sum(normalizer); + normalizer = rsqrt(normalizer / axis_size + eps); + + // Outputs. + for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) { + auto index = r * BLOCK_DIM + block.thread_rank(); + T xn[N_READS]; + T wn[N_READS]; + cub::LoadDirectBlocked(index, x, xn, axis_size); + cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size); + for (int i = 0; i < N_READS; ++i) { + float norm = static_cast(xn[i]) * normalizer; + xn[i] = wn[i] * static_cast(norm); + } + cub::StoreDirectBlocked(index, out, xn, axis_size); + } +} + +template +__global__ void rms_norm_vjp( + const T* x, + const T* w, + const T* g, + T* gx, + T* gw, + float eps, + int32_t axis_size, + int64_t w_stride) { + auto grid = cg::this_grid(); + auto block = cg::this_thread_block(); + + using BlockReduceF = BlockBroadcastReduce; + using BlockReduceF2 = BlockBroadcastReduce; + __shared__ union { + typename BlockReduceF::TempStorage f; + typename BlockReduceF2::TempStorage f2; + } temp; + + x += grid.block_rank() * axis_size; + g += grid.block_rank() * axis_size; + gx += grid.block_rank() * axis_size; + gw += grid.block_rank() * axis_size; + + // Normalizer. + float2 factors = {}; + for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) { + T xn[N_READS]; + T wn[N_READS] = {}; + T gn[N_READS] = {}; + auto index = r * BLOCK_DIM + block.thread_rank(); + cub::LoadDirectBlocked(index, x, xn, axis_size, cast_to(0)); + cub::LoadDirectBlocked(index, g, gn, axis_size); + cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size); + for (int i = 0; i < N_READS; i++) { + float t = static_cast(xn[i]); + float wi = wn[i]; + float gi = gn[i]; + float wg = wi * gi; + factors = plus_f2(factors, {wg * t, t * t}); + } + } + factors = BlockReduceF2{block, temp.f2}.Reduce(factors, plus_f2, {}); + float meangwx = factors.x / axis_size; + float normalizer = rsqrt(factors.y / axis_size + eps); + float normalizer3 = normalizer * normalizer * normalizer; + + // Outputs. + for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) { + auto index = r * BLOCK_DIM + block.thread_rank(); + T xn[N_READS]; + T wn[N_READS]; + T gn[N_READS]; + cub::LoadDirectBlocked(index, x, xn, axis_size); + cub::LoadDirectBlocked(index, g, gn, axis_size); + cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size); + for (int i = 0; i < N_READS; i++) { + float xi = xn[i]; + float wi = wn[i]; + float gi = gn[i]; + xn[i] = static_cast(normalizer * wi * gi - xi * meangwx * normalizer3); + if constexpr (HAS_W) { + wn[i] = static_cast(gi * xi * normalizer); + } + } + cub::StoreDirectBlocked(index, gx, xn, axis_size); + if constexpr (HAS_W) { + cub::StoreDirectBlocked(index, gw, wn, axis_size); + } + } +} + +} // namespace cu + +namespace fast { + +bool RMSNorm::use_fallback(Stream s) { + return s.device == Device::cpu; +} + +// TODO: There are duplicate code with backend/metal/normalization.cpp +void RMSNorm::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + nvtx3::scoped_range r("RMSNorm::eval_gpu"); + auto& s = stream(); + auto& out = outputs[0]; + + // Make sure that the last dimension is contiguous. + auto set_output = [&s, &out](const array& x) { + bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1; + if (no_copy && x.ndim() > 1) { + auto s = x.strides()[x.ndim() - 2]; + no_copy &= (s == 0 || s == x.shape().back()); + } + if (no_copy) { + if (x.is_donatable()) { + out.copy_shared_buffer(x); + } else { + out.set_data( + allocator::malloc(x.data_size() * x.itemsize()), + x.data_size(), + x.strides(), + x.flags()); + } + return x; + } else { + auto x_copy = array(x.shape(), x.dtype(), nullptr, {}); + copy_gpu(x, x_copy, CopyType::General, s); + out.copy_shared_buffer(x_copy); + return x_copy; + } + }; + + const array x = set_output(inputs[0]); + const array& w = inputs[1]; + + int32_t axis_size = x.shape().back(); + int32_t n_rows = x.data_size() / axis_size; + int64_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0; + + auto& encoder = cu::get_command_encoder(s); + encoder.set_input_array(x); + encoder.set_input_array(w); + encoder.set_output_array(out); + dispatch_float_types(out.dtype(), "rms_norm", [&](auto type_tag) { + constexpr uint32_t N_READS = 4; + dispatch_block_dim(cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) { + using DataType = cuda_type_t; + auto kernel = cu::rms_norm; + encoder.add_kernel_node( + kernel, + n_rows, + block_dim(), + x.data(), + w.data(), + out.data(), + eps_, + axis_size, + w_stride); + }); + }); +} + +void RMSNormVJP::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + nvtx3::scoped_range r("RMSNormVJP::eval_gpu"); + auto& s = stream(); + auto& encoder = cu::get_command_encoder(s); + + // Ensure row contiguity. We could relax this step by checking that the array + // is contiguous (no broadcasts or holes) and that the input strides are the + // same as the cotangent strides but for now this is simpler. + auto check_input = [&s](const array& x, bool& copied) { + if (x.flags().row_contiguous) { + copied = false; + return x; + } + copied = true; + array x_copy(x.shape(), x.dtype(), nullptr, {}); + copy_gpu(x, x_copy, CopyType::General, s); + return x_copy; + }; + bool donate_x = inputs[0].is_donatable(); + bool donate_g = inputs[2].is_donatable(); + bool copied; + auto x = check_input(inputs[0], copied); + donate_x |= copied; + const array& w = inputs[1]; + bool g_copied; + auto g = check_input(inputs[2], g_copied); + donate_g |= g_copied; + array& gx = outputs[0]; + array& gw = outputs[1]; + + // Check whether we had a weight. + bool has_w = w.ndim() != 0; + + // Allocate space for the outputs. + bool g_in_gx = false; + if (donate_x) { + gx.copy_shared_buffer(x); + } else if (donate_g) { + gx.copy_shared_buffer(g); + g_in_gx = true; + } else { + gx.set_data(allocator::malloc(gx.nbytes())); + } + if (g_copied && !g_in_gx) { + encoder.add_temporary(g); + } + + int32_t axis_size = x.shape().back(); + int32_t n_rows = x.data_size() / axis_size; + int64_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0; + + // Allocate a temporary to store the gradients for w and allocate the output + // gradient accumulators. + array gw_temp = + (has_w) ? array({n_rows, x.shape().back()}, gw.dtype(), nullptr, {}) : w; + if (has_w) { + if (!g_in_gx && donate_g) { + gw_temp.copy_shared_buffer(g); + } else { + gw_temp.set_data(allocator::malloc(gw_temp.nbytes())); + encoder.add_temporary(gw_temp); + } + } + + encoder.set_input_array(x); + encoder.set_input_array(w); + encoder.set_input_array(g); + encoder.set_output_array(gx); + encoder.set_output_array(gw_temp); + dispatch_float_types(gx.dtype(), "rms_norm_vjp", [&](auto type_tag) { + dispatch_bool(has_w, [&](auto has_w_constant) { + constexpr int N_READS = 4; + dispatch_block_dim( + cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) { + using DataType = cuda_type_t; + constexpr int N_READS = 4; + auto kernel = cu::rms_norm_vjp< + DataType, + has_w_constant.value, + block_dim(), + N_READS>; + encoder.add_kernel_node( + kernel, + n_rows, + block_dim(), + x.data(), + w.data(), + g.data(), + gx.data(), + gw_temp.data(), + eps_, + axis_size, + w_stride); + }); + }); + }); + + if (has_w) { + ReductionPlan plan( + ReductionOpType::ContiguousStridedReduce, {n_rows}, {axis_size}); + col_reduce(encoder, gw_temp, gw, Reduce::ReduceType::Sum, {0}, plan); + } +} + +} // namespace fast + +} // namespace mlx::core diff --git a/mlx/backend/cuda/rope.cu b/mlx/backend/cuda/rope.cu new file mode 100644 index 000000000..517cddfe0 --- /dev/null +++ b/mlx/backend/cuda/rope.cu @@ -0,0 +1,401 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/kernel_utils.cuh" +#include "mlx/backend/gpu/copy.h" +#include "mlx/dtype_utils.h" +#include "mlx/fast_primitives.h" + +#include + +namespace mlx::core { + +namespace cu { + +template +__device__ void rope_single_impl( + const T* in, + T* out, + int32_t offset, + float inv_freq, + float scale, + int64_t stride, + uint2 pos, + uint2 dims) { + float L = scale * static_cast(offset); + + // Compute costheta, sintheta + float theta = L * inv_freq; + float costheta = cos(theta); + float sintheta = sin(theta); + + // Compute the input and output indices + uint index_1, index_2; + if (traditional) { + index_1 = 2 * pos.x + pos.y * stride; + index_2 = index_1 + 1; + } else { + index_1 = pos.x + pos.y * stride; + index_2 = index_1 + dims.x; + } + + // Read and write the output + float x1 = static_cast(in[index_1]); + float x2 = static_cast(in[index_2]); + float rx1; + float rx2; + if (forward) { + rx1 = x1 * costheta - x2 * sintheta; + rx2 = x1 * sintheta + x2 * costheta; + } else { + rx1 = x2 * sintheta + x1 * costheta; + rx2 = x2 * costheta - x1 * sintheta; + } + out[index_1] = static_cast(rx1); + out[index_2] = static_cast(rx2); +} + +template +__global__ void rope_single( + const T* in, + T* out, + const int32_t* offset, + float scale, + float base, + int64_t stride, + uint2 dims) { + uint2 pos = make_uint2( + blockIdx.x * blockDim.x + threadIdx.x, + blockIdx.y * blockDim.y + threadIdx.y); + if (pos.x >= dims.x || pos.y >= dims.y) { + return; + } + + float d = static_cast(pos.x) / static_cast(dims.x); + float inv_freq = exp2(-d * base); + rope_single_impl( + in, out, *offset, inv_freq, scale, stride, pos, dims); +} + +template +__global__ void rope_single_freqs( + const T* in, + T* out, + const int32_t* offset, + const float* freqs, + float scale, + int64_t stride, + uint2 dims, + int64_t freq_stride) { + uint2 pos = make_uint2( + blockIdx.x * blockDim.x + threadIdx.x, + blockIdx.y * blockDim.y + threadIdx.y); + if (pos.x >= dims.x || pos.y >= dims.y) { + return; + } + + float inv_freq = 1.0 / freqs[freq_stride * pos.x]; + rope_single_impl( + in, out, *offset, inv_freq, scale, stride, pos, dims); +} + +template +__device__ void rope_impl( + const T* in, + T* out, + int offset, + float inv_freq, + float scale, + const cuda::std::array strides, + const cuda::std::array out_strides, + int64_t n_batch, + uint3 pos, + uint3 dims) { + float L = scale * static_cast(pos.y + offset); + + // Compute costheta, sintheta + float theta = L * inv_freq; + float costheta = cos(theta); + float sintheta = sin(theta); + + // Compute the input and output indices + size_t in_index_1, in_index_2; + size_t out_index_1, out_index_2; + if (traditional) { + out_index_1 = 2 * pos.x * out_strides[2] + pos.y * out_strides[1] + + N * pos.z * out_strides[0]; + out_index_2 = out_index_1 + 1; + in_index_1 = + 2 * pos.x * strides[2] + pos.y * strides[1] + N * pos.z * strides[0]; + in_index_2 = in_index_1 + strides[2]; + } else { + out_index_1 = pos.x * out_strides[2] + pos.y * out_strides[1] + + N * pos.z * out_strides[0]; + out_index_2 = out_index_1 + dims.x * out_strides[2]; + in_index_1 = + pos.x * strides[2] + pos.y * strides[1] + N * pos.z * strides[0]; + in_index_2 = in_index_1 + dims.x * strides[2]; + } + for (int i = 0; i < N && pos.z * N + i < n_batch; ++i) { + // Read and write the output + float x1 = static_cast(in[in_index_1]); + float x2 = static_cast(in[in_index_2]); + float rx1; + float rx2; + if (forward) { + rx1 = x1 * costheta - x2 * sintheta; + rx2 = x1 * sintheta + x2 * costheta; + } else { + rx1 = x2 * sintheta + x1 * costheta; + rx2 = x2 * costheta - x1 * sintheta; + } + out[out_index_1] = static_cast(rx1); + out[out_index_2] = static_cast(rx2); + in_index_1 += strides[0]; + in_index_2 += strides[0]; + out_index_1 += out_strides[0]; + out_index_2 += out_strides[0]; + } +} + +template +__global__ void rope( + const T* in, + T* out, + const int32_t* offset, + float scale, + float base, + const __grid_constant__ cuda::std::array strides, + const __grid_constant__ cuda::std::array out_strides, + int64_t n_batch, + uint3 dims) { + uint3 pos = make_uint3( + blockIdx.x * blockDim.x + threadIdx.x, + blockIdx.y * blockDim.y + threadIdx.y, + blockIdx.z * blockDim.z + threadIdx.z); + if (pos.x >= dims.x || pos.y >= dims.y || pos.z >= dims.z) { + return; + } + + float d = static_cast(pos.x) / static_cast(dims.x); + float inv_freq = exp2(-d * base); + rope_impl( + in, + out, + *offset, + inv_freq, + scale, + strides, + out_strides, + n_batch, + pos, + dims); +} + +template +__global__ void rope_freqs( + const T* in, + T* out, + const int32_t* offset, + const float* freqs, + float scale, + float base, + const __grid_constant__ cuda::std::array strides, + const __grid_constant__ cuda::std::array out_strides, + int64_t n_batch, + uint3 dims, + int64_t freq_stride) { + uint3 pos = make_uint3( + blockIdx.x * blockDim.x + threadIdx.x, + blockIdx.y * blockDim.y + threadIdx.y, + blockIdx.z * blockDim.z + threadIdx.z); + if (pos.x >= dims.x || pos.y >= dims.y || pos.z >= dims.z) { + return; + } + + float inv_freq = 1.0 / freqs[freq_stride * pos.x]; + rope_impl( + in, + out, + *offset, + inv_freq, + scale, + strides, + out_strides, + n_batch, + pos, + dims); +} + +} // namespace cu + +namespace fast { + +bool RoPE::use_fallback(Stream s) { + return s.device == Device::cpu; +} + +void RoPE::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + nvtx3::scoped_range r("RoPE::eval_gpu"); + + auto& s = stream(); + auto& in = inputs[0]; + auto& offset = inputs[1]; + auto& out = outputs[0]; + + if (in.ndim() < 3) { + throw std::runtime_error("[RoPE] Input must have at least 3 dimensions"); + } + + cuda::std::array strides; + cuda::std::array out_strides; + bool donated = false; + int ndim = in.ndim(); + int dispatch_ndim = in.ndim(); + while (in.shape(-dispatch_ndim) == 1 && dispatch_ndim > 3) { + dispatch_ndim--; + } + size_t mat_size = in.shape(-2) * in.shape(-1); + + // We apply rope to less that the whole vector so copy to output and then + // apply in-place. + if (dims_ < in.shape(-1)) { + donated = true; + auto ctype = + (in.flags().row_contiguous) ? CopyType::Vector : CopyType::General; + copy_gpu(in, out, ctype, s); + strides[0] = mat_size; + strides[1] = out.strides()[ndim - 2]; + strides[2] = out.strides()[ndim - 1]; + } + + // Either copy or apply in-place + else if (in.flags().row_contiguous) { + if (in.is_donatable()) { + donated = true; + out.copy_shared_buffer(in); + } else { + out.set_data(allocator::malloc(out.nbytes())); + } + strides[0] = mat_size; + strides[1] = in.strides()[ndim - 2]; + strides[2] = in.strides()[ndim - 1]; + } else if (dispatch_ndim == 3) { + // Handle non-contiguous 3D inputs + out.set_data(allocator::malloc(out.nbytes())); + strides[0] = in.strides()[ndim - 3]; + strides[1] = in.strides()[ndim - 2]; + strides[2] = in.strides()[ndim - 1]; + } else { + // Copy non-contiguous > 3D inputs into the output and treat + // input as donated + donated = true; + copy_gpu(in, out, CopyType::General, s); + strides[0] = mat_size; + strides[1] = out.strides()[ndim - 2]; + strides[2] = out.strides()[ndim - 1]; + } + out_strides[0] = mat_size; + out_strides[1] = out.strides()[ndim - 2]; + out_strides[2] = out.strides()[ndim - 1]; + + // Some flags to help us dispatch below + bool single = in.flags().row_contiguous && (mat_size == in.shape(-1)); + bool with_freqs = inputs.size() == 3; + + auto& encoder = cu::get_command_encoder(s); + encoder.set_input_array(donated ? out : in); + encoder.set_input_array(offset); + if (with_freqs) { + encoder.set_input_array(inputs[2]); + } + encoder.set_output_array(out); + dispatch_float_types(out.dtype(), "rope", [&](auto type_tag) { + dispatch_bool(traditional_, [&](auto traditional) { + dispatch_bool(forward_, [&](auto forward) { + using DataType = cuda_type_t; + if (single && !with_freqs) { + auto kernel = + cu::rope_single; + uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size); + auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1); + encoder.add_kernel_node( + kernel, + grid, + block, + (donated ? out : in).data(), + out.data(), + offset.data(), + scale_, + std::log2(base_), + mat_size, + dims); + } else if (single) { + auto kernel = + cu::rope_single_freqs; + uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size); + auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1); + encoder.add_kernel_node( + kernel, + grid, + block, + (donated ? out : in).data(), + out.data(), + offset.data(), + inputs[2].data(), + scale_, + mat_size, + dims, + inputs[2].strides(0)); + } else if (with_freqs) { + auto kernel = + cu::rope_freqs; + uint3 dims = + make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size); + dims.z = (dims.z + 3) / 4; + auto [grid, block] = get_grid_and_block(dims.x, dims.y, dims.z); + encoder.add_kernel_node( + kernel, + grid, + block, + (donated ? out : in).data(), + out.data(), + offset.data(), + inputs[2].data(), + scale_, + std::log2(base_), + strides, + out_strides, + in.size() / mat_size, + dims, + inputs[2].strides(0)); + } else { + auto kernel = cu::rope; + uint3 dims = + make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size); + dims.z = (dims.z + 3) / 4; + auto [grid, block] = get_grid_and_block(dims.x, dims.y, dims.z); + encoder.add_kernel_node( + kernel, + grid, + block, + (donated ? out : in).data(), + out.data(), + offset.data(), + scale_, + std::log2(base_), + strides, + out_strides, + in.size() / mat_size, + dims); + } + }); + }); + }); +} + +} // namespace fast + +} // namespace mlx::core diff --git a/mlx/backend/cuda/scan.cu b/mlx/backend/cuda/scan.cu new file mode 100644 index 000000000..7a26ee161 --- /dev/null +++ b/mlx/backend/cuda/scan.cu @@ -0,0 +1,467 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/device/binary_ops.cuh" +#include "mlx/backend/cuda/kernel_utils.cuh" +#include "mlx/backend/cuda/reduce/reduce_ops.cuh" +#include "mlx/backend/gpu/copy.h" +#include "mlx/dtype_utils.h" +#include "mlx/primitives.h" + +#include +#include +#include + +#include + +namespace mlx::core { + +namespace cu { + +namespace cg = cooperative_groups; + +template +struct ScanResult { + using type = T; +}; + +template <> +struct ScanResult { + using type = int32_t; +}; + +template +struct ReduceInit { + static constexpr __host__ __device__ T value() { + return Limits::min(); + } +}; + +template +inline __device__ void +load_values(int index, const T* in, U (&values)[N_READS], int size, U init) { + int remaining = size - index * N_READS; + if constexpr (reverse) { + in += remaining - N_READS; + if (remaining < N_READS) { + for (int i = 0; i < N_READS; ++i) { + values[N_READS - i - 1] = + (N_READS - i - 1 < remaining) ? cast_to(in[i]) : init; + } + } else { + for (int i = 0; i < N_READS; ++i) { + values[N_READS - i - 1] = cast_to(in[i]); + } + } + } else { + in += index * N_READS; + if (remaining < N_READS) { + for (int i = 0; i < N_READS; ++i) { + values[i] = (i < remaining) ? cast_to(in[i]) : init; + } + } else { + for (int i = 0; i < N_READS; ++i) { + values[i] = cast_to(in[i]); + } + } + } +} + +template +inline __device__ void +store_values(int index, T* out, T (&values)[N_READS], int size) { + int start = index * N_READS + offset; + int remaining = size - start; + if constexpr (reverse) { + out += remaining - N_READS; + if (remaining < N_READS) { + for (int i = 0; i < N_READS; ++i) { + if (N_READS - i - 1 < remaining) { + out[i] = values[N_READS - i - 1]; + } + } + } else { + for (int i = 0; i < N_READS; ++i) { + out[i] = values[N_READS - i - 1]; + } + } + } else { + out += start; + if (remaining < N_READS) { + for (int i = 0; i < N_READS; ++i) { + if (i < remaining) { + out[i] = values[i]; + } + } + } else { + for (int i = 0; i < N_READS; ++i) { + out[i] = values[i]; + } + } + } +} + +template < + typename T, + typename U, + typename Op, + int N_READS, + bool inclusive, + bool reverse> +__global__ void contiguous_scan(const T* in, U* out, int32_t axis_size) { + auto grid = cg::this_grid(); + auto block = cg::this_thread_block(); + auto warp = cg::tiled_partition(block); + + in += grid.block_rank() * axis_size; + out += grid.block_rank() * axis_size; + + __shared__ U warp_sums[WARP_SIZE]; + + Op op; + U init = ReduceInit::value(); + U prefix = init; + + // Scan per block. + for (int r = 0; r < cuda::ceil_div(axis_size, block.size() * N_READS); ++r) { + int32_t index = r * block.size() + block.thread_rank(); + U values[N_READS]; + load_values(index, in, values, axis_size, init); + + // Compute an inclusive scan per thread. + for (int i = 1; i < N_READS; ++i) { + values[i] = op(values[i], values[i - 1]); + } + + // Compute exclusive scan of thread sums. + U prev_thread_sum = cg::exclusive_scan(warp, values[N_READS - 1], op); + if (warp.thread_rank() == 0) { + prev_thread_sum = init; + } + + // Write wrap's sum to shared memory. + if (warp.thread_rank() == WARP_SIZE - 1) { + warp_sums[warp.meta_group_rank()] = + op(prev_thread_sum, values[N_READS - 1]); + } + block.sync(); + + // Compute exclusive scan of warp sums. + if (warp.meta_group_rank() == 0) { + U prev_warp_sum = + cg::exclusive_scan(warp, warp_sums[warp.thread_rank()], op); + if (warp.thread_rank() == 0) { + prev_warp_sum = init; + } + warp_sums[warp.thread_rank()] = prev_warp_sum; + } + block.sync(); + + // Compute the output. + for (int i = 0; i < N_READS; ++i) { + values[i] = op(values[i], prefix); + values[i] = op(values[i], warp_sums[warp.meta_group_rank()]); + values[i] = op(values[i], prev_thread_sum); + } + + // Write the values. + if (inclusive) { + store_values(index, out, values, axis_size); + } else { + store_values(index, out, values, axis_size); + if (reverse) { + if (block.thread_rank() == 0 && index == 0) { + out[axis_size - 1] = init; + } + } else { + if (block.thread_rank() == 0 && index == 0) { + out[0] = init; + } + } + } + block.sync(); + + // Share the prefix. + if ((warp.meta_group_rank() == warp.meta_group_size() - 1) && + (warp.thread_rank() == WARP_SIZE - 1)) { + warp_sums[0] = values[N_READS - 1]; + } + block.sync(); + prefix = warp_sums[0]; + } +} + +template < + typename T, + typename U, + typename Op, + int N_READS, + int BM, + int BN, + bool inclusive, + bool reverse> +__global__ void strided_scan( + const T* in, + U* out, + int32_t axis_size, + int64_t stride, + int64_t stride_blocks) { + auto grid = cg::this_grid(); + auto block = cg::this_thread_block(); + auto warp = cg::tiled_partition(block); + + constexpr int BN_pad = WARP_SIZE + 16 / sizeof(U); + constexpr int n_warps = BN / N_READS; + constexpr int n_scans = BN / n_warps; + + __shared__ U read_buffer[BM * BN_pad]; + + Op op; + U init = ReduceInit::value(); + U values[n_scans]; + U prefix[n_scans]; + for (int i = 0; i < n_scans; ++i) { + prefix[i] = init; + } + + // Compute offsets. + int64_t offset = (grid.block_rank() / stride_blocks) * axis_size * stride; + int64_t global_index_x = (grid.block_rank() % stride_blocks) * BN; + uint read_offset_y = (block.thread_rank() * N_READS) / BN; + uint read_offset_x = (block.thread_rank() * N_READS) % BN; + uint scan_offset_y = warp.thread_rank(); + uint scan_offset_x = warp.meta_group_rank() * n_scans; + + uint stride_limit = stride - global_index_x; + in += offset + global_index_x + read_offset_x; + out += offset + global_index_x + read_offset_x; + U* read_into = read_buffer + read_offset_y * BN_pad + read_offset_x; + U* read_from = read_buffer + scan_offset_y * BN_pad + scan_offset_x; + + for (uint j = 0; j < axis_size; j += BM) { + // Calculate the indices for the current thread. + uint index_y = j + read_offset_y; + uint check_index_y = index_y; + if (reverse) { + index_y = axis_size - 1 - index_y; + } + + // Read in SM. + if (check_index_y < axis_size && (read_offset_x + N_READS) < stride_limit) { + for (int i = 0; i < N_READS; ++i) { + read_into[i] = in[index_y * stride + i]; + } + } else { + for (int i = 0; i < N_READS; ++i) { + if (check_index_y < axis_size && (read_offset_x + i) < stride_limit) { + read_into[i] = in[index_y * stride + i]; + } else { + read_into[i] = init; + } + } + } + block.sync(); + + // Read strided into registers. + for (int i = 0; i < n_scans; ++i) { + values[i] = read_from[i]; + } + + // Perform the scan. + for (int i = 0; i < n_scans; ++i) { + values[i] = cg::inclusive_scan(warp, values[i], op); + values[i] = op(values[i], prefix[i]); + prefix[i] = warp.shfl(values[i], WARP_SIZE - 1); + } + + // Write to SM. + for (int i = 0; i < n_scans; ++i) { + read_from[i] = values[i]; + } + block.sync(); + + // Write to device memory. + if (!inclusive) { + if (check_index_y == 0) { + if ((read_offset_x + N_READS) < stride_limit) { + for (int i = 0; i < N_READS; ++i) { + out[index_y * stride + i] = init; + } + } else { + for (int i = 0; i < N_READS; ++i) { + if ((read_offset_x + i) < stride_limit) { + out[index_y * stride + i] = init; + } + } + } + } + if (reverse) { + index_y -= 1; + check_index_y += 1; + } else { + index_y += 1; + check_index_y += 1; + } + } + if (check_index_y < axis_size && (read_offset_x + N_READS) < stride_limit) { + for (int i = 0; i < N_READS; ++i) { + out[index_y * stride + i] = read_into[i]; + } + } else { + for (int i = 0; i < N_READS; ++i) { + if (check_index_y < axis_size && (read_offset_x + i) < stride_limit) { + out[index_y * stride + i] = read_into[i]; + } + } + } + } +} + +} // namespace cu + +template +void dispatch_scan_ops(Scan::ReduceType scan_op, F&& f) { + if (scan_op == Scan::ReduceType::Max) { + f(type_identity{}); + } else if (scan_op == Scan::ReduceType::Min) { + f(type_identity{}); + } else if (scan_op == Scan::ReduceType::Sum) { + f(type_identity{}); + } else if (scan_op == Scan::ReduceType::Prod) { + f(type_identity{}); + } else if (scan_op == Scan::ReduceType::LogAddExp) { + f(type_identity{}); + } else { + throw std::invalid_argument("Unknown reduce type."); + } +} + +template +const char* op_to_string() { + if (cuda::std::is_same_v) { + return "Max"; + } else if (cuda::std::is_same_v) { + return "Min"; + } else if (cuda::std::is_same_v) { + return "Sum"; + } else if (cuda::std::is_same_v) { + return "Prod"; + } else if (cuda::std::is_same_v) { + return "LogAddExp"; + } else { + throw std::invalid_argument("Unknown op."); + } +} + +template +constexpr bool supports_scan_op() { + if constexpr (cuda::std::is_same_v) { + return is_inexact_v; + } else { + return true; + } +} + +void Scan::eval_gpu(const std::vector& inputs, array& out) { + nvtx3::scoped_range r("Scan::eval_gpu"); + assert(inputs.size() == 1); + auto in = inputs[0]; + auto& s = stream(); + + if (in.flags().contiguous && in.strides()[axis_] != 0) { + if (in.is_donatable() && in.itemsize() == out.itemsize()) { + out.copy_shared_buffer(in); + } else { + out.set_data( + allocator::malloc(in.data_size() * out.itemsize()), + in.data_size(), + in.strides(), + in.flags()); + } + } else { + array arr_copy(in.shape(), in.dtype(), nullptr, {}); + copy_gpu(in, arr_copy, CopyType::General, s); + in = std::move(arr_copy); + out.copy_shared_buffer(in); + } + + constexpr int N_READS = 4; + int32_t axis_size = in.shape(axis_); + bool contiguous = in.strides()[axis_] == 1; + + auto& encoder = cu::get_command_encoder(s); + encoder.set_input_array(in); + encoder.set_output_array(out); + + dispatch_all_types(in.dtype(), [&](auto type_tag) { + using T = cuda_type_t; + dispatch_scan_ops(reduce_type_, [&](auto scan_op_tag) { + using Op = MLX_GET_TYPE(scan_op_tag); + if constexpr (supports_scan_op) { + using U = typename cu::ScanResult::type; + dispatch_bool(inclusive_, [&](auto inclusive) { + dispatch_bool(reverse_, [&](auto reverse) { + if (contiguous) { + auto kernel = cu::contiguous_scan< + T, + U, + Op, + N_READS, + inclusive.value, + reverse.value>; + int block_dim = cuda::ceil_div(axis_size, N_READS); + block_dim = cuda::ceil_div(block_dim, WARP_SIZE) * WARP_SIZE; + block_dim = std::min(block_dim, WARP_SIZE * WARP_SIZE); + encoder.add_kernel_node( + kernel, + in.data_size() / axis_size, + block_dim, + in.data(), + out.data(), + axis_size); + } else { + constexpr int BM = WARP_SIZE; + constexpr int BN = WARP_SIZE; + auto kernel = cu::strided_scan< + T, + U, + Op, + N_READS, + BM, + BN, + inclusive.value, + reverse.value>; + int64_t stride = in.strides()[axis_]; + int64_t stride_blocks = cuda::ceil_div(stride, BN); + dim3 num_blocks = get_2d_grid_dims( + in.shape(), in.strides(), axis_size * stride); + if (num_blocks.x * stride_blocks <= UINT32_MAX) { + num_blocks.x *= stride_blocks; + } else { + num_blocks.y *= stride_blocks; + } + int block_dim = (BN / N_READS) * WARP_SIZE; + encoder.add_kernel_node( + kernel, + num_blocks, + block_dim, + in.data(), + out.data(), + axis_size, + stride, + stride_blocks); + } + }); + }); + } else { + throw std::runtime_error(fmt::format( + "Can not do scan op {} on inputs of {} with result of {}.", + op_to_string(), + dtype_to_string(in.dtype()), + dtype_to_string(out.dtype()))); + } + }); + }); +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/slicing.cpp b/mlx/backend/cuda/slicing.cpp new file mode 100644 index 000000000..af67fbbdd --- /dev/null +++ b/mlx/backend/cuda/slicing.cpp @@ -0,0 +1,41 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/common/slicing.h" +#include "mlx/backend/gpu/copy.h" +#include "mlx/backend/gpu/slicing.h" + +#include + +namespace mlx::core { + +void concatenate_gpu( + const std::vector& inputs, + array& out, + int axis, + const Stream& s) { + std::vector sizes; + sizes.push_back(0); + for (auto& p : inputs) { + sizes.push_back(p.shape(axis)); + } + std::partial_sum(sizes.cbegin(), sizes.cend(), sizes.begin()); + + out.set_data(allocator::malloc(out.nbytes())); + + auto strides = out.strides(); + auto flags = out.flags(); + flags.row_contiguous = false; + flags.col_contiguous = false; + flags.contiguous = false; + // TODO: Handle concurrent outputs: + // https://github.com/ml-explore/mlx/pull/2145#discussion_r2070753816 + for (int i = 0; i < inputs.size(); i++) { + array out_slice(inputs[i].shape(), out.dtype(), nullptr, {}); + size_t data_offset = strides[axis] * sizes[i]; + out_slice.copy_shared_buffer( + out, strides, flags, out_slice.size(), data_offset); + copy_gpu_inplace(inputs[i], out_slice, CopyType::GeneralGeneral, s); + } +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/softmax.cu b/mlx/backend/cuda/softmax.cu new file mode 100644 index 000000000..56f67d7f3 --- /dev/null +++ b/mlx/backend/cuda/softmax.cu @@ -0,0 +1,163 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/device/cast_op.cuh" +#include "mlx/backend/cuda/device/fp16_math.cuh" +#include "mlx/backend/cuda/kernel_utils.cuh" +#include "mlx/backend/gpu/copy.h" +#include "mlx/dtype_utils.h" +#include "mlx/primitives.h" + +#include +#include +#include +#include + +#include + +namespace mlx::core { + +namespace cu { + +namespace cg = cooperative_groups; + +template +inline __device__ T softmax_exp(T x) { + // Softmax doesn't need high precision exponential cause x is gonna be in + // (-oo, 0] anyway and subsequently it will be divided by sum(exp(x_i)). + return __expf(x); +} + +template +__global__ void softmax(const T* in, T* out, int axis_size) { + auto grid = cg::this_grid(); + auto block = cg::this_thread_block(); + auto warp = cg::tiled_partition(block); + + in += grid.block_rank() * axis_size; + out += grid.block_rank() * axis_size; + + cg::greater max_op; + cg::plus plus_op; + + // Thread reduce. + AccT prevmax; + AccT maxval = Limits::finite_min(); + AccT normalizer = cast_to(0); + for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); r++) { + AccT vals[N_READS]; + cub::LoadDirectBlocked( + r * BLOCK_DIM + block.thread_rank(), + make_cast_iterator(in), + vals, + axis_size, + Limits::min()); + prevmax = maxval; + maxval = max_op(maxval, cub::ThreadReduce(vals, max_op)); + // Online normalizer calculation for softmax: + // https://github.com/NVIDIA/online-softmax + normalizer = normalizer * softmax_exp(prevmax - maxval); + for (int i = 0; i < N_READS; i++) { + normalizer = normalizer + softmax_exp(vals[i] - maxval); + } + } + + // First warp reduce. + prevmax = maxval; + maxval = cg::reduce(warp, maxval, max_op); + normalizer = normalizer * softmax_exp(prevmax - maxval); + normalizer = cg::reduce(warp, normalizer, plus_op); + + __shared__ AccT local_max[WARP_SIZE]; + __shared__ AccT local_normalizer[WARP_SIZE]; + + // Write to shared memory and do second warp reduce. + prevmax = maxval; + if (warp.thread_rank() == 0) { + local_max[warp.meta_group_rank()] = maxval; + } + block.sync(); + maxval = warp.thread_rank() < warp.meta_group_size() + ? local_max[warp.thread_rank()] + : Limits::min(); + maxval = cg::reduce(warp, maxval, max_op); + normalizer = normalizer * softmax_exp(prevmax - maxval); + if (warp.thread_rank() == 0) { + local_normalizer[warp.meta_group_rank()] = normalizer; + } + block.sync(); + normalizer = warp.thread_rank() < warp.meta_group_size() + ? local_normalizer[warp.thread_rank()] + : AccT{}; + normalizer = cg::reduce(warp, normalizer, plus_op); + normalizer = 1 / normalizer; + + // Write output. + for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); r++) { + auto index = r * BLOCK_DIM + block.thread_rank(); + T vals[N_READS]; + cub::LoadDirectBlocked(index, in, vals, axis_size); + for (int i = 0; i < N_READS; i++) { + vals[i] = softmax_exp(static_cast(vals[i]) - maxval) * normalizer; + } + cub::StoreDirectBlocked(index, out, vals, axis_size); + } +} + +} // namespace cu + +void Softmax::eval_gpu(const std::vector& inputs, array& out) { + nvtx3::scoped_range r("Softmax::eval_gpu"); + assert(inputs.size() == 1); + auto& s = stream(); + + // Make sure that the last dimension is contiguous. + auto set_output = [&s, &out](const array& x) { + if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) { + if (x.is_donatable()) { + out.copy_shared_buffer(x); + } else { + out.set_data( + allocator::malloc(x.data_size() * x.itemsize()), + x.data_size(), + x.strides(), + x.flags()); + } + return x; + } else { + auto x_copy = array(x.shape(), x.dtype(), nullptr, {}); + copy_gpu(x, x_copy, CopyType::General, s); + out.copy_shared_buffer(x_copy); + return x_copy; + } + }; + + array in = set_output(inputs[0]); + bool precise = in.dtype() != float32 && precise_; + + int axis_size = in.shape().back(); + int n_rows = in.data_size() / axis_size; + + auto& encoder = cu::get_command_encoder(s); + encoder.set_input_array(in); + encoder.set_output_array(out); + dispatch_float_types(out.dtype(), "softmax", [&](auto type_tag) { + constexpr int N_READS = 4; + dispatch_block_dim(cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) { + using DataType = cuda_type_t; + auto kernel = cu::softmax; + if (precise) { + kernel = cu::softmax; + } + encoder.add_kernel_node( + kernel, + n_rows, + block_dim(), + in.data(), + out.data(), + axis_size); + }); + }); +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/sort.cu b/mlx/backend/cuda/sort.cu new file mode 100644 index 000000000..379c55706 --- /dev/null +++ b/mlx/backend/cuda/sort.cu @@ -0,0 +1,211 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/common/utils.h" +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/kernel_utils.cuh" +#include "mlx/backend/gpu/copy.h" +#include "mlx/dtype_utils.h" +#include "mlx/primitives.h" + +#include +#include +#include +#include + +#include +#include + +namespace mlx::core { + +namespace { + +template +struct ModOp { + T divisor; + __device__ T operator()(T x) { + return x % divisor; + } +}; + +// We can not use any op in eval, make an utility. +array swapaxes_in_eval(const array& in, int axis1, int axis2) { + std::vector axes(in.ndim()); + std::iota(axes.begin(), axes.end(), 0); + std::swap(axes[axis1], axes[axis2]); + // TODO: Share the code with Transpose::eval. + Shape shape(axes.size()); + Strides strides(in.ndim()); + for (size_t ax = 0; ax < axes.size(); ++ax) { + shape[ax] = in.shape()[axes[ax]]; + strides[ax] = in.strides()[axes[ax]]; + } + auto flags = in.flags(); + if (flags.contiguous) { + auto [_, row_contiguous, col_contiguous] = check_contiguity(shape, strides); + flags.row_contiguous = row_contiguous; + flags.col_contiguous = col_contiguous; + } + array out(shape, in.dtype(), nullptr, {}); + out.copy_shared_buffer(in, strides, flags, in.data_size()); + return out; +} + +struct OffsetTransform { + int nsort; + + int __device__ operator()(int i) { + return i * nsort; + } +}; + +void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) { + array out = out_; + auto& encoder = cu::get_command_encoder(s); + if (axis < 0) { + axis += in.ndim(); + } + int nsort = in.shape(axis); + int last_dim = in.ndim() - 1; + + // If we are not sorting the innermost dimension of a contiguous array, + // transpose and make a copy. + bool is_segmented_sort = in.flags().contiguous && in.strides()[axis] == 1; + if (!is_segmented_sort) { + array trans = swapaxes_in_eval(in, axis, last_dim); + in = array(trans.shape(), trans.dtype(), nullptr, {}); + copy_gpu(trans, in, CopyType::General, s); + encoder.add_temporary(in); + out = array(allocator::malloc(out.nbytes()), in.shape(), out.dtype()); + encoder.add_temporary(out); + } else { + out.set_data( + allocator::malloc(in.data_size() * out.itemsize()), + in.data_size(), + in.strides(), + in.flags()); + } + + encoder.set_input_array(in); + encoder.set_output_array(out); + dispatch_all_types(in.dtype(), [&](auto type_tag) { + using CTYPE = MLX_GET_TYPE(type_tag); + auto& stream = encoder.stream(); + if constexpr (!std::is_same_v) { + using Type = cuda_type_t; + auto offsets = thrust::make_transform_iterator( + thrust::make_counting_iterator(0), OffsetTransform{nsort}); + if (argsort) { + // Indices in the sorted dimension. + array indices(allocator::malloc(out.nbytes()), in.shape(), out.dtype()); + encoder.add_temporary(indices); + + // In argsort though we don't need the result of sorted values, the + // API requires us to provide an array to store it. + array discard(allocator::malloc(in.nbytes()), in.shape(), in.dtype()); + encoder.add_temporary(discard); + + size_t size; + CHECK_CUDA_ERROR(cub::DeviceSegmentedSort::StableSortPairs( + nullptr, + size, + in.data(), + discard.data(), + indices.data(), + out.data(), + in.data_size(), + in.data_size() / nsort, + offsets, + offsets + 1, + stream)); + + array temp(allocator::malloc(size), {static_cast(size)}, uint8); + encoder.add_temporary(temp); + + // Start capturing after allocations + auto capture = encoder.capture_context(); + thrust::transform( + cu::thrust_policy(stream), + thrust::counting_iterator(0), + thrust::counting_iterator(indices.data_size()), + thrust::device_pointer_cast(indices.data()), + ModOp{static_cast(nsort)}); + + CHECK_CUDA_ERROR(cub::DeviceSegmentedSort::StableSortPairs( + temp.data(), + size, + in.data(), + discard.data(), + indices.data(), + out.data(), + in.data_size(), + in.data_size() / nsort, + offsets, + offsets + 1, + stream)); + } else { + size_t size; + CHECK_CUDA_ERROR(cub::DeviceSegmentedSort::StableSortKeys( + nullptr, + size, + in.data(), + out.data(), + in.data_size(), + in.data_size() / nsort, + offsets, + offsets + 1, + stream)); + + array temp(allocator::malloc(size), {static_cast(size)}, uint8); + encoder.add_temporary(temp); + + // Start capturing after allocations + auto capture = encoder.capture_context(); + CHECK_CUDA_ERROR(cub::DeviceSegmentedSort::StableSortKeys( + temp.data(), + size, + in.data(), + out.data(), + in.data_size(), + in.data_size() / nsort, + offsets, + offsets + 1, + stream)); + } + } else { + throw std::runtime_error( + "CUDA backend does not support sorting complex numbers"); + } + }); + + if (!is_segmented_sort) { + // Swap the sorted axis back. + // TODO: Do in-place transpose instead of using a temporary out array. + copy_gpu(swapaxes_in_eval(out, axis, last_dim), out_, CopyType::General, s); + } +} + +} // namespace + +void ArgSort::eval_gpu(const std::vector& inputs, array& out) { + nvtx3::scoped_range r("ArgSort::eval_gpu"); + assert(inputs.size() == 1); + gpu_sort(stream(), inputs[0], out, axis_, true); +} + +void Sort::eval_gpu(const std::vector& inputs, array& out) { + nvtx3::scoped_range r("Sort::eval_gpu"); + assert(inputs.size() == 1); + gpu_sort(stream(), inputs[0], out, axis_, false); +} + +void ArgPartition::eval_gpu(const std::vector& inputs, array& out) { + nvtx3::scoped_range r("ArgPartition::eval_gpu"); + gpu_sort(stream(), inputs[0], out, axis_, true); +} + +void Partition::eval_gpu(const std::vector& inputs, array& out) { + nvtx3::scoped_range r("Partition::eval_gpu"); + gpu_sort(stream(), inputs[0], out, axis_, false); +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/ternary.cu b/mlx/backend/cuda/ternary.cu new file mode 100644 index 000000000..eb69442c2 --- /dev/null +++ b/mlx/backend/cuda/ternary.cu @@ -0,0 +1,212 @@ +// Copyright © 2025 Apple Inc. +#include "mlx/backend/common/ternary.h" +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/device/ternary_ops.cuh" +#include "mlx/backend/cuda/kernel_utils.cuh" +#include "mlx/dtype_utils.h" +#include "mlx/primitives.h" + +#include +#include + +namespace mlx::core { + +namespace cu { + +namespace cg = cooperative_groups; + +template +__global__ void +ternary_v(const bool* a, const T* b, const T* c, T* out, IdxT size) { + IdxT index = cg::this_grid().thread_rank(); + + if ((index + 1) * N_READS > size) { + for (IdxT i = index * N_READS; i < size; ++i) { + out[i] = Op{}(a[i], b[i], c[i]); + } + } else { + auto a_vec = load_vector(a, index); + auto b_vec = load_vector(b, index); + auto c_vec = load_vector(c, index); + + AlignedVector out_vec; +#pragma unroll + for (int i = 0; i < N_READS; ++i) { + out_vec.val[i] = Op{}(a_vec.val[i], b_vec.val[i], c_vec.val[i]); + } + + store_vector(out, index, out_vec); + } +} + +template +__global__ void ternary_g_nd( + const bool* a, + const T* b, + const T* c, + T* out, + IdxT size, + const __grid_constant__ cuda::std::array shape, + const __grid_constant__ cuda::std::array a_strides, + const __grid_constant__ cuda::std::array b_strides, + const __grid_constant__ cuda::std::array c_strides) { + IdxT index = cg::this_grid().thread_rank(); + if (index < size) { + auto [a_idx, b_idx, c_idx] = elem_to_loc_nd( + index, + shape.data(), + a_strides.data(), + b_strides.data(), + c_strides.data()); + out[index] = Op{}(a[a_idx], b[b_idx], c[c_idx]); + } +} + +template +__global__ void ternary_g( + const bool* a, + const T* b, + const T* c, + T* out, + IdxT size, + const __grid_constant__ Shape shape, + const __grid_constant__ Strides a_strides, + const __grid_constant__ Strides b_strides, + const __grid_constant__ Strides c_strides, + int ndim) { + IdxT index = cg::this_grid().thread_rank(); + if (index < size) { + auto [a_idx, b_idx, c_idx] = elem_to_loc_4d( + index, + shape.data(), + a_strides.data(), + b_strides.data(), + c_strides.data(), + ndim); + out[index] = Op{}(a[a_idx], b[b_idx], c[c_idx]); + } +} + +} // namespace cu + +template +void ternary_op_gpu_inplace( + const std::vector& inputs, + array& out, + const Stream& s) { + const auto& a = inputs[0]; + const auto& b = inputs[1]; + const auto& c = inputs[2]; + if (out.size() == 0) { + return; + } + + auto& encoder = cu::get_command_encoder(s); + encoder.set_input_array(a); + encoder.set_input_array(b); + encoder.set_input_array(c); + encoder.set_output_array(out); + dispatch_all_types(out.dtype(), [&](auto type_tag) { + using DType = cuda_type_t; + + auto topt = get_ternary_op_type(a, b, c); + if (topt == TernaryOpType::General) { + dispatch_bool( + a.data_size() > INT32_MAX || b.data_size() > INT32_MAX || + c.data_size() > INT32_MAX || out.data_size() > INT32_MAX, + [&](auto large) { + using IdxT = std::conditional_t; + Shape shape; + std::vector strides; + std::tie(shape, strides) = collapse_contiguous_dims(a, b, c, out); + auto& a_strides = strides[0]; + auto& b_strides = strides[1]; + auto& c_strides = strides[2]; + int ndim = shape.size(); + if (ndim <= 3) { + dispatch_1_2_3(ndim, [&](auto dims_constant) { + auto kernel = + cu::ternary_g_nd; + auto [num_blocks, block_dims] = + get_launch_args(kernel, out, large()); + encoder.add_kernel_node( + kernel, + num_blocks, + block_dims, + a.data(), + b.data(), + c.data(), + out.data(), + out.size(), + const_param(shape), + const_param(a_strides), + const_param(b_strides), + const_param(c_strides)); + }); + } else { + auto kernel = cu::ternary_g; + auto [num_blocks, block_dims] = + get_launch_args(kernel, out, large()); + encoder.add_kernel_node( + kernel, + num_blocks, + block_dims, + a.data(), + b.data(), + c.data(), + out.data(), + out.data_size(), + const_param(shape), + const_param(a_strides), + const_param(b_strides), + const_param(c_strides), + ndim); + } + }); + } else { + dispatch_bool(out.data_size() > UINT32_MAX, [&](auto large) { + using IdxT = std::conditional_t; + // TODO: Choose optimized value based on type size. + constexpr int N_READS = 4; + auto kernel = cu::ternary_v; + auto [num_blocks, block_dims] = get_launch_args( + kernel, + out.data_size(), + out.shape(), + out.strides(), + large(), + N_READS); + encoder.add_kernel_node( + kernel, + num_blocks, + block_dims, + a.data(), + b.data(), + c.data(), + out.data(), + out.data_size()); + }); + } + }); +} + +template +void ternary_op_gpu( + const std::vector& inputs, + array& out, + const Stream& s) { + auto& a = inputs[0]; + auto& b = inputs[1]; + auto& c = inputs[2]; + auto topt = get_ternary_op_type(a, b, c); + set_ternary_op_output_data(a, b, c, out, topt); + ternary_op_gpu_inplace(inputs, out, s); +} + +void Select::eval_gpu(const std::vector& inputs, array& out) { + nvtx3::scoped_range r("select::eval_gpu"); + auto& s = out.primitive().stream(); + ternary_op_gpu(inputs, out, s); +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/unary.cu b/mlx/backend/cuda/unary.cu new file mode 100644 index 000000000..ddb32d05e --- /dev/null +++ b/mlx/backend/cuda/unary.cu @@ -0,0 +1,262 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/common/unary.h" +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/device/unary_ops.cuh" +#include "mlx/backend/cuda/iterators/general_iterator.cuh" +#include "mlx/backend/cuda/kernel_utils.cuh" +#include "mlx/dtype_utils.h" +#include "mlx/primitives.h" + +#include +#include + +namespace mlx::core { + +namespace cu { + +namespace cg = cooperative_groups; + +template +__global__ void unary_v(const In* in, Out* out, IdxT size) { + IdxT index = cg::this_grid().thread_rank(); + + if ((index + 1) * N_READS > size) { + for (IdxT i = index * N_READS; i < size; ++i) { + out[i] = Op{}(in[i]); + } + } else { + auto in_vec = load_vector(in, index); + + AlignedVector out_vec; +#pragma unroll + for (int i = 0; i < N_READS; ++i) { + out_vec.val[i] = Op{}(in_vec.val[i]); + } + + store_vector(out, index, out_vec); + } +} + +template +__global__ void unary_g( + const In* in, + Out* out, + IdxT size, + const __grid_constant__ Shape shape, + const __grid_constant__ Strides strides, + int ndim) { + IdxT index = cg::this_grid().thread_rank(); + if (index < size) { + auto idx = elem_to_loc_4d(index, shape.data(), strides.data(), ndim); + out[index] = Op{}(in[idx]); + } +} + +template +constexpr bool supports_unary_op() { + if (std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v) { + return std::is_same_v; + } + if (std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v) { + return std::is_same_v && is_floating_v; + } + if (std::is_same_v) { + return std::is_same_v && std::is_integral_v && + !std::is_same_v; + } + if (std::is_same_v || std::is_same_v) { + return std::is_same_v && !mlx::core::is_complex_v; + } + if (std::is_same_v) { + return std::is_same_v && mlx::core::is_complex_v; + } + if (std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v) { + return std::is_same_v && is_inexact_v; + } + if (std::is_same_v || std::is_same_v) { + return mlx::core::is_complex_v && std::is_same_v; + } + if (std::is_same_v) { + return std::is_same_v && std::is_same_v; + } + return false; +} + +} // namespace cu + +template +void unary_op_gpu_inplace( + const std::vector& inputs, + array& out, + const char* op, + const Stream& s) { + auto& in = inputs[0]; + if (in.size() == 0) { + return; + } + bool contig = in.flags().contiguous; + bool large; + if (!contig) { + large = in.data_size() > INT32_MAX || out.size() > INT32_MAX; + } else { + large = in.data_size() > UINT32_MAX; + } + + auto& encoder = cu::get_command_encoder(s); + encoder.set_input_array(in); + encoder.set_output_array(out); + dispatch_all_types(in.dtype(), [&](auto in_type_tag) { + dispatch_all_types(out.dtype(), [&](auto out_type_tag) { + using CTYPE_IN = MLX_GET_TYPE(in_type_tag); + using CTYPE_OUT = MLX_GET_TYPE(out_type_tag); + if constexpr (cu::supports_unary_op()) { + dispatch_bool(large, [&](auto large) { + using InType = cuda_type_t; + using OutType = cuda_type_t; + if (contig) { + using IdxT = std::conditional_t; + // TODO: Choose optimized value based on type size. + constexpr int N_READS = 4; + auto kernel = cu::unary_v; + auto [num_blocks, block_dims] = get_launch_args( + kernel, + out.data_size(), + out.shape(), + out.strides(), + large, + N_READS); + encoder.add_kernel_node( + kernel, + num_blocks, + block_dims, + in.data(), + out.data(), + out.data_size()); + } else { + using IdxT = std::conditional_t; + auto [shape, strides] = collapse_contiguous_dims(in); + auto kernel = cu::unary_g; + auto [num_blocks, block_dims] = get_launch_args(kernel, out, large); + encoder.add_kernel_node( + kernel, + num_blocks, + block_dims, + in.data(), + out.data(), + out.data_size(), + const_param(shape), + const_param(strides), + shape.size()); + } + }); + } else { + throw std::runtime_error(fmt::format( + "Can not do unary op {} on input of {} with output of {}.", + op, + dtype_to_string(in.dtype()), + dtype_to_string(out.dtype()))); + } + }); + }); +} + +template +void unary_op_gpu( + const std::vector& inputs, + array& out, + const char* op, + const Stream& s) { + set_unary_output_data(inputs[0], out); + unary_op_gpu_inplace(inputs, out, op, s); +} + +#define UNARY_GPU(func) \ + void func::eval_gpu(const std::vector& inputs, array& out) { \ + nvtx3::scoped_range r(#func "::eval_gpu"); \ + auto& s = out.primitive().stream(); \ + unary_op_gpu(inputs, out, name(), s); \ + } + +UNARY_GPU(Abs) +UNARY_GPU(ArcCos) +UNARY_GPU(ArcCosh) +UNARY_GPU(ArcSin) +UNARY_GPU(ArcSinh) +UNARY_GPU(ArcTan) +UNARY_GPU(ArcTanh) +UNARY_GPU(BitwiseInvert) +UNARY_GPU(Ceil) +UNARY_GPU(Conjugate) +UNARY_GPU(Cos) +UNARY_GPU(Cosh) +UNARY_GPU(Erf) +UNARY_GPU(ErfInv) +UNARY_GPU(Exp) +UNARY_GPU(Expm1) +UNARY_GPU(Floor) +UNARY_GPU(Imag) +UNARY_GPU(Log1p) +UNARY_GPU(LogicalNot) +UNARY_GPU(Negative) +UNARY_GPU(Real) +UNARY_GPU(Sigmoid) +UNARY_GPU(Sign) +UNARY_GPU(Sin) +UNARY_GPU(Sinh) +UNARY_GPU(Square) +UNARY_GPU(Tan) +UNARY_GPU(Tanh) + +void Log::eval_gpu(const std::vector& inputs, array& out) { + nvtx3::scoped_range r("Log::eval_gpu"); + auto& s = out.primitive().stream(); + switch (base_) { + case Base::e: + unary_op_gpu(inputs, out, name(), s); + break; + case Base::two: + unary_op_gpu(inputs, out, name(), s); + break; + case Base::ten: + unary_op_gpu(inputs, out, name(), s); + break; + } +} + +void Round::eval_gpu(const std::vector& inputs, array& out) { + nvtx3::scoped_range r("Round::eval_gpu"); + assert(inputs.size() == 1); + const auto& in = inputs[0]; + auto& s = out.primitive().stream(); + if (issubdtype(in.dtype(), inexact)) { + unary_op_gpu(inputs, out, name(), s); + } else { + // No-op integer types + out.copy_shared_buffer(in); + } +} + +void Sqrt::eval_gpu(const std::vector& inputs, array& out) { + nvtx3::scoped_range r("Sort::eval_gpu"); + auto& s = out.primitive().stream(); + if (recip_) { + unary_op_gpu(inputs, out, "Rsqrt", s); + } else { + unary_op_gpu(inputs, out, "Sqrt", s); + } +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/utils.cpp b/mlx/backend/cuda/utils.cpp new file mode 100644 index 000000000..1c12fa4df --- /dev/null +++ b/mlx/backend/cuda/utils.cpp @@ -0,0 +1,70 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/utils.h" +#include "mlx/backend/cuda/device.h" +#include "mlx/dtype_utils.h" + +#include + +namespace mlx::core { + +CudaStream::CudaStream(cu::Device& device) { + device.make_current(); + CHECK_CUDA_ERROR(cudaStreamCreateWithFlags(&stream_, cudaStreamNonBlocking)); +} + +CudaStream::~CudaStream() { + CHECK_CUDA_ERROR(cudaStreamDestroy(stream_)); +} + +void check_cuda_error(const char* name, cudaError_t err) { + if (err != cudaSuccess) { + throw std::runtime_error( + fmt::format("{} failed: {}", name, cudaGetErrorString(err))); + } +} + +void check_cuda_error(const char* name, CUresult err) { + if (err != CUDA_SUCCESS) { + const char* err_str = "Unknown error"; + cuGetErrorString(err, &err_str); + throw std::runtime_error(fmt::format("{} failed: {}", name, err_str)); + } +} + +const char* dtype_to_cuda_type(const Dtype& dtype) { + switch (dtype) { + case bool_: + return "bool"; + case int8: + return "int8_t"; + case int16: + return "int16_t"; + case int32: + return "int32_t"; + case int64: + return "int64_t"; + case uint8: + return "uint8_t"; + case uint16: + return "uint16_t"; + case uint32: + return "uint32_t"; + case uint64: + return "uint64_t"; + case float16: + return "__half"; + case bfloat16: + return "__nv_bfloat16"; + case float32: + return "float"; + case float64: + return "double"; + case complex64: + return "complex64_t"; + default: + return "unknown"; + } +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/utils.h b/mlx/backend/cuda/utils.h new file mode 100644 index 000000000..bfb02c5b6 --- /dev/null +++ b/mlx/backend/cuda/utils.h @@ -0,0 +1,45 @@ +// Copyright © 2025 Apple Inc. + +// This file include utilies that are used by C++ code (i.e. .cpp files). + +#pragma once + +#include +#include + +namespace mlx::core { + +namespace cu { +class Device; +} + +struct Dtype; + +// Cuda stream managed with RAII. +class CudaStream { + public: + explicit CudaStream(cu::Device& device); + ~CudaStream(); + + CudaStream(const CudaStream&) = delete; + CudaStream& operator=(const CudaStream&) = delete; + + operator cudaStream_t() const { + return stream_; + } + + private: + cudaStream_t stream_; +}; + +// Throw exception if the cuda API does not succeed. +void check_cuda_error(const char* name, cudaError_t err); +void check_cuda_error(const char* name, CUresult err); + +// The macro version that prints the command that failed. +#define CHECK_CUDA_ERROR(cmd) check_cuda_error(#cmd, (cmd)) + +// Convert Dtype to CUDA C++ types. +const char* dtype_to_cuda_type(const Dtype& dtype); + +} // namespace mlx::core diff --git a/mlx/backend/cuda/worker.cpp b/mlx/backend/cuda/worker.cpp new file mode 100644 index 000000000..3b35c830b --- /dev/null +++ b/mlx/backend/cuda/worker.cpp @@ -0,0 +1,92 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/worker.h" +#include "mlx/backend/cuda/allocator.h" +#include "mlx/backend/cuda/device.h" + +namespace mlx::core::cu { + +Worker::Worker() + : signal_stream_(device(mlx::core::Device::gpu)), + worker_(&Worker::thread_fn, this) {} + +Worker::~Worker() { + { + std::lock_guard lock(worker_mutex_); + stop_ = true; + } + worker_event_.signal(batch_ + 1); + worker_.join(); +} + +void Worker::add_task(std::function task) { + pending_tasks_.push_back(std::move(task)); +} + +void Worker::consume_in_this_thread() { + for (auto& task : pending_tasks_) { + task(); + } + pending_tasks_.clear(); +} + +void Worker::end_batch() { + batch_++; + { + std::lock_guard lock(worker_mutex_); + worker_tasks_[batch_] = std::move(pending_tasks_); + } + uncommited_batches_++; +} + +void Worker::commit() { + if (uncommited_batches_ == 0) { + return; + } + uncommited_batches_ = 0; + worker_event_.signal(batch_); +} + +void Worker::commit(cudaStream_t stream) { + if (uncommited_batches_ == 0) { + return; + } + uncommited_batches_ = 0; + // Signal the |worker_event_| in |signal_stream_| after the kernels in + // |stream_| finish running. + signal_event_.record(stream); + signal_event_.wait(signal_stream_); + worker_event_.signal(signal_stream_, batch_); +} + +void Worker::thread_fn() { + // The worker thread is safe to free buffers. + allocator().register_this_thread(); + + while (!stop_) { + uint64_t batch = worker_event_.value(); + Tasks tasks; + { + std::lock_guard lock(worker_mutex_); + // Move tasks in signaled batches. + auto end = worker_tasks_.upper_bound(batch); + for (auto it = worker_tasks_.begin(); it != end; ++it) { + if (tasks.empty()) { + tasks = std::move(it->second); + } else { + std::move( + it->second.begin(), it->second.end(), std::back_inserter(tasks)); + } + } + worker_tasks_.erase(worker_tasks_.begin(), end); + } + // Make sure tasks are cleared before the next wait + for (int i = 0; i < tasks.size(); ++i) { + auto task = std::move(tasks[i]); + task(); + } + worker_event_.wait(batch + 1); + } +} + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/worker.h b/mlx/backend/cuda/worker.h new file mode 100644 index 000000000..d28e22e95 --- /dev/null +++ b/mlx/backend/cuda/worker.h @@ -0,0 +1,68 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/backend/cuda/event.h" +#include "mlx/backend/cuda/utils.h" + +#include +#include +#include +#include + +namespace mlx::core::cu { + +// Run tasks in worker thread, synchronized with cuda stream. +class Worker { + public: + Worker(); + ~Worker(); + + Worker(const Worker&) = delete; + Worker& operator=(const Worker&) = delete; + + // Add a pending |task| that will run when consumed or commited. + void add_task(std::function task); + + // Run pending tasks immediately in current thread. + void consume_in_this_thread(); + + // Put pending tasks in a batch. + void end_batch(); + + // Inform worker thread to run current batches now. + void commit(); + + // Inform worker thread to run current batches after kernels in |stream| + // finish running. + void commit(cudaStream_t stream); + + // Return how many batches have been added but not committed yet. + size_t uncommited_batches() const { + return uncommited_batches_; + } + + private: + void thread_fn(); + + uint64_t batch_{0}; + size_t uncommited_batches_{0}; + + // Cuda stream and event for signaling kernel completion. + CudaStream signal_stream_; + CudaEvent signal_event_; + + // Worker thread. + SharedEvent worker_event_; + std::thread worker_; + std::mutex worker_mutex_; + bool stop_{false}; + + // Tasks are put in |pending_tasks_| first, and then moved to + // |worker_tasks_| when end_batch() is called. + using Tasks = std::vector>; + Tasks pending_tasks_; + std::map worker_tasks_; +}; + +} // namespace mlx::core::cu diff --git a/mlx/backend/gpu/CMakeLists.txt b/mlx/backend/gpu/CMakeLists.txt new file mode 100644 index 000000000..0396ae03a --- /dev/null +++ b/mlx/backend/gpu/CMakeLists.txt @@ -0,0 +1,5 @@ +target_sources( + mlx + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp) diff --git a/mlx/backend/gpu/available.h b/mlx/backend/gpu/available.h new file mode 100644 index 000000000..476c7acf2 --- /dev/null +++ b/mlx/backend/gpu/available.h @@ -0,0 +1,9 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +namespace mlx::core::gpu { + +bool is_available(); + +} // namespace mlx::core::gpu diff --git a/mlx/backend/gpu/copy.cpp b/mlx/backend/gpu/copy.cpp new file mode 100644 index 000000000..6127ac921 --- /dev/null +++ b/mlx/backend/gpu/copy.cpp @@ -0,0 +1,49 @@ +// Copyright © 2023-2024 Apple Inc. + +#include "mlx/backend/gpu/copy.h" +#include "mlx/primitives.h" + +#include + +namespace mlx::core { + +void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) { + bool donated = set_copy_output_data(in, out, ctype); + if (donated && in.dtype() == out.dtype()) { + // If the output has the same type as the input then there is nothing to + // copy, just use the buffer. + return; + } + if (ctype == CopyType::GeneralGeneral) { + ctype = CopyType::General; + } + copy_gpu_inplace(in, out, ctype, s); +} + +void copy_gpu(const array& in, array& out, CopyType ctype) { + copy_gpu(in, out, ctype, out.primitive().stream()); +} + +void copy_gpu_inplace( + const array& in, + array& out, + CopyType ctype, + const Stream& s) { + assert(in.shape() == out.shape()); + return copy_gpu_inplace( + in, out, in.shape(), in.strides(), out.strides(), 0, 0, ctype, s); +} + +void copy_gpu_inplace( + const array& in, + array& out, + const Strides& i_strides, + int64_t i_offset, + CopyType ctype, + const Stream& s) { + assert(in.shape() == out.shape()); + return copy_gpu_inplace( + in, out, in.shape(), i_strides, out.strides(), i_offset, 0, ctype, s); +} + +} // namespace mlx::core diff --git a/mlx/backend/metal/copy.h b/mlx/backend/gpu/copy.h similarity index 98% rename from mlx/backend/metal/copy.h rename to mlx/backend/gpu/copy.h index 37c60df42..020f579e4 100644 --- a/mlx/backend/metal/copy.h +++ b/mlx/backend/gpu/copy.h @@ -5,6 +5,8 @@ #include "mlx/backend/common/copy.h" #include "mlx/stream.h" +#include + namespace mlx::core { // Generic copy inplace diff --git a/mlx/backend/metal/metal_impl.h b/mlx/backend/gpu/eval.h similarity index 63% rename from mlx/backend/metal/metal_impl.h rename to mlx/backend/gpu/eval.h index 9ca8d2f80..f646c2ec9 100644 --- a/mlx/backend/metal/metal_impl.h +++ b/mlx/backend/gpu/eval.h @@ -8,14 +8,11 @@ #include "mlx/array.h" #include "mlx/stream.h" -namespace mlx::core::metal { +namespace mlx::core::gpu { void new_stream(Stream stream); - -std::unique_ptr> new_scoped_memory_pool(); - void eval(array& arr); void finalize(Stream s); void synchronize(Stream s); -} // namespace mlx::core::metal +} // namespace mlx::core::gpu diff --git a/mlx/backend/gpu/primitives.cpp b/mlx/backend/gpu/primitives.cpp new file mode 100644 index 000000000..1adb85918 --- /dev/null +++ b/mlx/backend/gpu/primitives.cpp @@ -0,0 +1,261 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/primitives.h" +#include "mlx/backend/common/slicing.h" +#include "mlx/backend/common/utils.h" +#include "mlx/backend/gpu/copy.h" +#include "mlx/backend/gpu/slicing.h" + +#if defined(MLX_USE_CUDA) +#include +#endif + +#include + +#if defined(MLX_USE_CUDA) +#define MLX_PROFILER_RANGE(message) nvtx3::scoped_range r(message) +#else +#define MLX_PROFILER_RANGE(message) +#endif + +namespace mlx::core { + +namespace { + +void reshape(const array& in, array& out, Stream s) { + auto [copy_necessary, out_strides] = prepare_reshape(in, out); + if (copy_necessary) { + out.set_data(allocator::malloc(out.nbytes())); + copy_gpu_inplace( + in, + out, + in.shape(), + in.strides(), + make_contiguous_strides(in.shape()), + 0, + 0, + CopyType::General, + s); + } else { + shared_buffer_reshape(in, out_strides, out); + } +} + +} // namespace + +void AsStrided::eval_gpu(const std::vector& inputs, array& out) { + MLX_PROFILER_RANGE("AsStrided::eval_gpu"); + eval(inputs, out); +} + +void AsType::eval_gpu(const std::vector& inputs, array& out) { + MLX_PROFILER_RANGE("AsType::eval_gpu"); + CopyType ctype = + inputs[0].flags().contiguous ? CopyType::Vector : CopyType::General; + copy_gpu(inputs[0], out, ctype); +} + +void Broadcast::eval_gpu(const std::vector& inputs, array& out) { + MLX_PROFILER_RANGE("Broadcast::eval_gpu"); + eval(inputs, out); +} + +void BroadcastAxes::eval_gpu(const std::vector& inputs, array& out) { + MLX_PROFILER_RANGE("BroadcastAxes::eval_gpu"); + eval(inputs, out); +} + +void Concatenate::eval_gpu(const std::vector& inputs, array& out) { + MLX_PROFILER_RANGE("Concatenate::eval_gpu"); + concatenate_gpu(inputs, out, axis_, stream()); +} + +void Contiguous::eval_gpu(const std::vector& inputs, array& out) { + MLX_PROFILER_RANGE("Contiguous::eval_gpu"); + assert(inputs.size() == 1); + auto& in = inputs[0]; + constexpr size_t extra_bytes = 16384; + if (in.buffer_size() <= out.nbytes() + extra_bytes && + (in.flags().row_contiguous || + (allow_col_major_ && in.flags().col_contiguous))) { + out.copy_shared_buffer(in); + } else { + copy_gpu(in, out, CopyType::General); + } +} + +void Copy::eval_gpu(const std::vector& inputs, array& out) { + MLX_PROFILER_RANGE("Copy::eval_gpu"); + eval(inputs, out); +} + +void CustomTransforms::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + MLX_PROFILER_RANGE("CustomTransforms::eval_gpu"); + eval(inputs, outputs); +} + +void Depends::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + MLX_PROFILER_RANGE("Depends::eval_gpu"); + eval(inputs, outputs); +} + +void ExpandDims::eval_gpu(const std::vector& inputs, array& out) { + MLX_PROFILER_RANGE("ExpandDims::eval_gpu"); + eval(inputs, out); +} + +void Full::eval_gpu(const std::vector& inputs, array& out) { + MLX_PROFILER_RANGE("Full::eval_gpu"); + auto in = inputs[0]; + CopyType ctype; + if (in.data_size() == 1) { + ctype = CopyType::Scalar; + } else if (in.flags().contiguous) { + ctype = CopyType::Vector; + } else { + ctype = CopyType::General; + } + copy_gpu(in, out, ctype); +} + +void Flatten::eval_gpu(const std::vector& inputs, array& out) { + MLX_PROFILER_RANGE("Flatten::eval_gpu"); + reshape(inputs[0], out, stream()); +} + +void NumberOfElements::eval_gpu(const std::vector& inputs, array& out) { + MLX_PROFILER_RANGE("NumberOfElements::eval_gpu"); + eval(inputs, out); +} + +void Pad::eval_gpu(const std::vector& inputs, array& out) { + // Inputs must be base input array and scalar val array + assert(inputs.size() == 2); + auto& in = inputs[0]; + auto& val = inputs[1]; + + // Padding value must be a scalar + assert(val.size() == 1); + + // Padding value, input and output must be of the same type + assert(val.dtype() == in.dtype() && in.dtype() == out.dtype()); + + pad_gpu(in, val, out, axes_, low_pad_size_, stream()); +} + +void Reshape::eval_gpu(const std::vector& inputs, array& out) { + MLX_PROFILER_RANGE("Reshape::eval_gpu"); + reshape(inputs[0], out, stream()); +} + +void Split::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + MLX_PROFILER_RANGE("Split::eval_gpu"); + eval(inputs, outputs); +} + +void Slice::eval_gpu(const std::vector& inputs, array& out) { + MLX_PROFILER_RANGE("Slice::eval_gpu"); + assert(inputs.size() == 1); + if (out.size() == 0) { + out.set_data(nullptr); + return; + } + + auto& in = inputs[0]; + slice_gpu(in, out, start_indices_, strides_, stream()); +} + +void SliceUpdate::eval_gpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 2); + if (out.size() == 0) { + out.set_data(nullptr); + return; + } + + auto& in = inputs[0]; + auto& upd = inputs[1]; + + if (upd.size() == 0) { + out.copy_shared_buffer(in); + return; + } + + auto ctype = in.flags().contiguous && in.size() == in.data_size() + ? CopyType::Vector + : CopyType::General; + copy_gpu(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, stream()); + auto [data_offset, out_strides] = + prepare_slice(out, start_indices_, strides_); + + // Do copy + copy_gpu_inplace( + /* const array& src = */ upd, + /* array& dst = */ out, + /* const Shape& data_shape = */ upd.shape(), + /* const Strides& i_strides = */ upd.strides(), + /* const Strides& o_strides = */ out_strides, + /* int64_t i_offset = */ 0, + /* int64_t o_offset = */ data_offset, + /* CopyType ctype = */ CopyType::GeneralGeneral, + /* const Stream& s = */ stream()); +} + +void Squeeze::eval_gpu(const std::vector& inputs, array& out) { + MLX_PROFILER_RANGE("Squeeze::eval_gpu"); + eval(inputs, out); +} + +void StopGradient::eval_gpu(const std::vector& inputs, array& out) { + MLX_PROFILER_RANGE("StopGradient::eval_gpu"); + eval(inputs, out); +} + +void Transpose::eval_gpu(const std::vector& inputs, array& out) { + MLX_PROFILER_RANGE("Transpose::eval_gpu"); + eval(inputs, out); +} + +void Unflatten::eval_gpu(const std::vector& inputs, array& out) { + MLX_PROFILER_RANGE("Unflatten::eval_gpu"); + reshape(inputs[0], out, stream()); +} + +void View::eval_gpu(const std::vector& inputs, array& out) { + MLX_PROFILER_RANGE("View::eval_gpu"); + auto& in = inputs[0]; + auto ibytes = size_of(in.dtype()); + auto obytes = size_of(out.dtype()); + // Conditions for buffer copying (disjunction): + // - type size is the same + // - type size is smaller and the last axis is contiguous + // - the entire array is row contiguous + if (ibytes == obytes || (obytes < ibytes && in.strides().back() == 1) || + in.flags().row_contiguous) { + auto strides = in.strides(); + for (int i = 0; i < static_cast(strides.size()) - 1; ++i) { + strides[i] *= ibytes; + strides[i] /= obytes; + } + out.copy_shared_buffer( + in, strides, in.flags(), in.data_size() * ibytes / obytes); + } else { + auto tmp = array(in.shape(), in.dtype(), nullptr, {}); + tmp.set_data(allocator::malloc(tmp.nbytes())); + copy_gpu_inplace(in, tmp, CopyType::General, stream()); + + auto flags = out.flags(); + flags.contiguous = true; + flags.row_contiguous = true; + auto max_dim = std::max_element(out.shape().begin(), out.shape().end()); + flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim; + out.copy_shared_buffer(tmp, out.strides(), flags, out.size()); + } +} + +} // namespace mlx::core diff --git a/mlx/backend/gpu/slicing.cpp b/mlx/backend/gpu/slicing.cpp new file mode 100644 index 000000000..fde2a01cd --- /dev/null +++ b/mlx/backend/gpu/slicing.cpp @@ -0,0 +1,44 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/common/slicing.h" +#include "mlx/backend/gpu/copy.h" +#include "mlx/backend/gpu/slicing.h" + +namespace mlx::core { + +void slice_gpu( + const array& in, + array& out, + const Shape& start_indices, + const Shape& strides, + const Stream& s) { + slice(in, out, start_indices, strides); +} + +void pad_gpu( + const array& in, + const array& val, + array& out, + const std::vector& axes, + const Shape& low_pad_size, + const Stream& s) { + // Fill output with val + fill_gpu(val, out, s); + + // Find offset for start of input values + size_t data_offset = 0; + for (int i = 0; i < axes.size(); i++) { + auto ax = axes[i] < 0 ? out.ndim() + axes[i] : axes[i]; + data_offset += out.strides()[ax] * low_pad_size[i]; + } + + // Extract slice from output where input will be pasted + array out_slice(in.shape(), out.dtype(), nullptr, {}); + out_slice.copy_shared_buffer( + out, out.strides(), out.flags(), out_slice.size(), data_offset); + + // Copy input values into the slice + copy_gpu_inplace(in, out_slice, CopyType::GeneralGeneral, s); +} + +} // namespace mlx::core diff --git a/mlx/backend/metal/slicing.h b/mlx/backend/gpu/slicing.h similarity index 100% rename from mlx/backend/metal/slicing.h rename to mlx/backend/gpu/slicing.h diff --git a/mlx/backend/metal/CMakeLists.txt b/mlx/backend/metal/CMakeLists.txt index 332c560f8..ccdd83202 100644 --- a/mlx/backend/metal/CMakeLists.txt +++ b/mlx/backend/metal/CMakeLists.txt @@ -63,6 +63,7 @@ if(MLX_METAL_JIT) make_jit_source(steel/gemm/kernels/steel_gemm_masked kernels/steel/defines.h) make_jit_source(steel/gemm/kernels/steel_gemm_gather) make_jit_source(steel/gemm/kernels/steel_gemm_splitk) + make_jit_source(steel/gemm/kernels/steel_gemm_segmented) make_jit_source( steel/conv/conv kernels/steel/utils.h @@ -93,6 +94,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp ${CMAKE_CURRENT_SOURCE_DIR}/event.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp ${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp diff --git a/mlx/backend/metal/allocator.cpp b/mlx/backend/metal/allocator.cpp index 0a69dd261..dd6189732 100644 --- a/mlx/backend/metal/allocator.cpp +++ b/mlx/backend/metal/allocator.cpp @@ -1,7 +1,6 @@ // Copyright © 2023-2024 Apple Inc. #include "mlx/backend/metal/allocator.h" #include "mlx/backend/metal/metal.h" -#include "mlx/backend/metal/metal_impl.h" #include "mlx/backend/metal/resident.h" #include "mlx/memory.h" @@ -31,141 +30,18 @@ void* Buffer::raw_ptr() { namespace metal { -namespace { - -BufferCache::BufferCache(ResidencySet& residency_set) - : head_(nullptr), - tail_(nullptr), - pool_size_(0), - residency_set_(residency_set) {} - -BufferCache::~BufferCache() { - auto pool = metal::new_scoped_memory_pool(); - clear(); -} - -int BufferCache::clear() { - int n_release = 0; - for (auto& [size, holder] : buffer_pool_) { - if (holder->buf) { - if (!holder->buf->heap()) { - residency_set_.erase(holder->buf); - } - holder->buf->release(); - n_release++; - } - delete holder; - } - buffer_pool_.clear(); - pool_size_ = 0; - head_ = nullptr; - tail_ = nullptr; - return n_release; -} - -MTL::Buffer* BufferCache::reuse_from_cache(size_t size) { - // Find the closest buffer in pool - MTL::Buffer* pbuf = nullptr; - - auto it = buffer_pool_.lower_bound(size); - - // Make sure we use most of the available memory - while (!pbuf && it != buffer_pool_.end() && - it->first < std::min(2 * size, size + 2 * vm_page_size)) { - // Collect from the cache - pbuf = it->second->buf; - - // Remove from cache - remove_from_list(it->second); - delete it->second; - it = buffer_pool_.erase(it); - } - - if (pbuf) { - pool_size_ -= pbuf->length(); - } - - return pbuf; -} - -void BufferCache::recycle_to_cache(MTL::Buffer* buf) { - // Add to cache - if (buf) { - BufferHolder* bh = new BufferHolder(buf); - add_at_head(bh); - pool_size_ += buf->length(); - buffer_pool_.insert({buf->length(), bh}); - } -} - -int BufferCache::release_cached_buffers(size_t min_bytes_to_free) { - if (min_bytes_to_free >= 0.9 * pool_size_) { - return clear(); - } else { - int n_release = 0; - size_t total_bytes_freed = 0; - - while (tail_ && (total_bytes_freed < min_bytes_to_free)) { - if (tail_->buf) { - total_bytes_freed += tail_->buf->length(); - if (!tail_->buf->heap()) { - residency_set_.erase(tail_->buf); - } - tail_->buf->release(); - tail_->buf = nullptr; - n_release++; - } - remove_from_list(tail_); - } - pool_size_ -= total_bytes_freed; - return n_release; - } -} - -void BufferCache::add_at_head(BufferCache::BufferHolder* to_add) { - if (!to_add) - return; - - if (!head_) { - head_ = to_add; - tail_ = to_add; - } else { - head_->prev = to_add; - to_add->next = head_; - head_ = to_add; - } -} - -void BufferCache::remove_from_list(BufferCache::BufferHolder* to_remove) { - if (!to_remove) { - return; - } - - // If in the middle - if (to_remove->prev && to_remove->next) { - to_remove->prev->next = to_remove->next; - to_remove->next->prev = to_remove->prev; - } else if (to_remove->prev && to_remove == tail_) { // If tail - tail_ = to_remove->prev; - tail_->next = nullptr; - } else if (to_remove == head_ && to_remove->next) { // If head - head_ = to_remove->next; - head_->prev = nullptr; - } else if (to_remove == head_ && to_remove == tail_) { // If only element - head_ = nullptr; - tail_ = nullptr; - } - - to_remove->prev = nullptr; - to_remove->next = nullptr; -} - -} // namespace - MetalAllocator::MetalAllocator() : device_(device(mlx::core::Device::gpu).mtl_device()), residency_set_(device_), - buffer_cache_(residency_set_) { + buffer_cache_( + vm_page_size, + [](MTL::Buffer* buf) { return buf->length(); }, + [this](MTL::Buffer* buf) { + if (!buf->heap()) { + residency_set_.erase(buf); + } + buf->release(); + }) { auto pool = metal::new_scoped_memory_pool(); auto memsize = std::get(device_info().at("memory_size")); auto max_rec_size = @@ -194,6 +70,7 @@ MetalAllocator::~MetalAllocator() { if (heap_) { heap_->release(); } + buffer_cache_.clear(); } size_t MetalAllocator::set_cache_limit(size_t limit) { diff --git a/mlx/backend/metal/allocator.h b/mlx/backend/metal/allocator.h index 227b09e91..691317916 100644 --- a/mlx/backend/metal/allocator.h +++ b/mlx/backend/metal/allocator.h @@ -7,6 +7,7 @@ #include #include "mlx/allocator.h" +#include "mlx/backend/common/buffer_cache.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/resident.h" @@ -14,43 +15,6 @@ namespace mlx::core::metal { using allocator::Buffer; -namespace { - -class BufferCache { - public: - BufferCache(ResidencySet& residency_set); - ~BufferCache(); - - MTL::Buffer* reuse_from_cache(size_t size); - void recycle_to_cache(MTL::Buffer* buf); - int release_cached_buffers(size_t min_bytes_to_free); - size_t cache_size() { - return pool_size_; - } - int clear(); - - private: - struct BufferHolder { - public: - BufferHolder(MTL::Buffer* buf_) : buf(buf_), prev(nullptr), next(nullptr) {} - - BufferHolder* prev; - BufferHolder* next; - MTL::Buffer* buf; - }; - - void add_at_head(BufferHolder* to_add); - void remove_from_list(BufferHolder* to_remove); - - std::multimap buffer_pool_; - BufferHolder* head_; - BufferHolder* tail_; - size_t pool_size_; - ResidencySet& residency_set_; -}; - -} // namespace - class MetalAllocator : public allocator::Allocator { /** Allocator for Metal GPUs. */ public: @@ -90,7 +54,7 @@ class MetalAllocator : public allocator::Allocator { friend MetalAllocator& allocator(); // Caching allocator - BufferCache buffer_cache_; + BufferCache buffer_cache_; ResidencySet residency_set_; diff --git a/mlx/backend/metal/binary.cpp b/mlx/backend/metal/binary.cpp index f80f8c3e4..8c0e8c333 100644 --- a/mlx/backend/metal/binary.cpp +++ b/mlx/backend/metal/binary.cpp @@ -7,20 +7,20 @@ #define BINARY_GPU(func) \ void func::eval_gpu(const std::vector& inputs, array& out) { \ - binary_op_gpu(inputs, out, get_primitive_string(this)); \ + binary_op_gpu(inputs, out, name()); \ } #define BINARY_GPU_MULTI(func) \ void func::eval_gpu( \ const std::vector& inputs, std::vector& outputs) { \ - binary_op_gpu(inputs, outputs, get_primitive_string(this)); \ + binary_op_gpu(inputs, outputs, name()); \ } namespace mlx::core { std::string get_kernel_name( BinaryOpType bopt, - const std::string& op, + const char* op, const array& a, bool large, int ndim, @@ -31,13 +31,13 @@ std::string get_kernel_name( kname = "ss"; break; case BinaryOpType::ScalarVector: - kname = (large ? "sv2" : "sv"); + kname = "sv"; break; case BinaryOpType::VectorScalar: - kname = (large ? "vs2" : "vs"); + kname = "vs"; break; case BinaryOpType::VectorVector: - kname = (large ? "vv2" : "vv"); + kname = "vv"; break; case BinaryOpType::General: kname = "g"; @@ -51,6 +51,13 @@ std::string get_kernel_name( } break; } + if (bopt != BinaryOpType::General && bopt != BinaryOpType::ScalarScalar) { + if (large) { + kname += "2"; + } else if (work_per_thread > 1) { + kname += "n"; + } + } concatenate(kname, "_", op, type_to_name(a)); return kname; } @@ -58,7 +65,7 @@ std::string get_kernel_name( void binary_op_gpu_inplace( const std::vector& inputs, std::vector& outputs, - const std::string& op, + const char* op, const Stream& s) { auto& a = inputs[0]; auto& b = inputs[1]; @@ -90,7 +97,7 @@ void binary_op_gpu_inplace( work_per_thread = large ? 4 : 2; } else { large = out.data_size() > UINT32_MAX; - work_per_thread = 1; + work_per_thread = get_work_per_thread(a.dtype(), out.data_size()); } std::string kernel_name = get_kernel_name(bopt, op, a, large, shape.size(), work_per_thread); @@ -137,13 +144,20 @@ void binary_op_gpu_inplace( compute_encoder.dispatch_threads(grid_dims, group_dims); } else { // Launch a 1D or 2D grid of threads - size_t nthreads = out.data_size(); + size_t nthreads = ceildiv(out.data_size(), work_per_thread); if (thread_group_size > nthreads) { thread_group_size = nthreads; } + MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); - MTL::Size grid_dims = large ? get_2d_grid_dims(out.shape(), out.strides()) - : MTL::Size(nthreads, 1, 1); + MTL::Size grid_dims; + if (large) { + compute_encoder.set_bytes(out.data_size(), arg_idx++); + grid_dims = get_2d_grid_dims(out.shape(), out.strides(), work_per_thread); + } else { + compute_encoder.set_bytes(out.data_size(), arg_idx++); + grid_dims = MTL::Size(nthreads, 1, 1); + } compute_encoder.dispatch_threads(grid_dims, group_dims); } } @@ -151,7 +165,7 @@ void binary_op_gpu_inplace( void binary_op_gpu( const std::vector& inputs, std::vector& outputs, - const std::string& op, + const char* op, const Stream& s) { assert(inputs.size() == 2); auto& a = inputs[0]; @@ -165,7 +179,7 @@ void binary_op_gpu( void binary_op_gpu( const std::vector& inputs, std::vector& outputs, - const std::string& op) { + const char* op) { auto& s = outputs[0].primitive().stream(); binary_op_gpu(inputs, outputs, op, s); } @@ -173,7 +187,7 @@ void binary_op_gpu( void binary_op_gpu_inplace( const std::vector& inputs, array& out, - const std::string& op, + const char* op, const Stream& s) { std::vector outputs = {out}; binary_op_gpu_inplace(inputs, outputs, op, s); @@ -182,7 +196,7 @@ void binary_op_gpu_inplace( void binary_op_gpu( const std::vector& inputs, array& out, - const std::string& op, + const char* op, const Stream& s) { assert(inputs.size() == 2); auto& a = inputs[0]; @@ -195,7 +209,7 @@ void binary_op_gpu( void binary_op_gpu( const std::vector& inputs, array& out, - const std::string& op) { + const char* op) { auto& s = out.primitive().stream(); binary_op_gpu(inputs, out, op, s); } @@ -223,19 +237,19 @@ BINARY_GPU(Subtract) void BitwiseBinary::eval_gpu(const std::vector& inputs, array& out) { switch (op_) { case BitwiseBinary::And: - binary_op_gpu(inputs, out, get_primitive_string(this)); + binary_op_gpu(inputs, out, name()); break; case BitwiseBinary::Or: - binary_op_gpu(inputs, out, get_primitive_string(this)); + binary_op_gpu(inputs, out, name()); break; case BitwiseBinary::Xor: - binary_op_gpu(inputs, out, get_primitive_string(this)); + binary_op_gpu(inputs, out, name()); break; case BitwiseBinary::LeftShift: - binary_op_gpu(inputs, out, get_primitive_string(this)); + binary_op_gpu(inputs, out, name()); break; case BitwiseBinary::RightShift: - binary_op_gpu(inputs, out, get_primitive_string(this)); + binary_op_gpu(inputs, out, name()); break; } } diff --git a/mlx/backend/metal/binary.h b/mlx/backend/metal/binary.h index 8552c1e07..0341a2f83 100644 --- a/mlx/backend/metal/binary.h +++ b/mlx/backend/metal/binary.h @@ -9,25 +9,25 @@ namespace mlx::core { void binary_op_gpu( const std::vector& inputs, std::vector& outputs, - const std::string& op, + const char* op, const Stream& s); void binary_op_gpu( const std::vector& inputs, array& out, - const std::string& op, + const char* op, const Stream& s); void binary_op_gpu_inplace( const std::vector& inputs, std::vector& outputs, - const std::string& op, + const char* op, const Stream& s); void binary_op_gpu_inplace( const std::vector& inputs, array& out, - const std::string& op, + const char* op, const Stream& s); } // namespace mlx::core diff --git a/mlx/backend/metal/compiled.cpp b/mlx/backend/metal/compiled.cpp index 154273233..eb51ab750 100644 --- a/mlx/backend/metal/compiled.cpp +++ b/mlx/backend/metal/compiled.cpp @@ -11,8 +11,6 @@ #include "mlx/primitives.h" #include "mlx/utils.h" -using namespace fmt::literals; - namespace mlx::core { inline void build_kernel( @@ -21,21 +19,12 @@ inline void build_kernel( const std::vector& inputs, const std::vector& outputs, const std::vector& tape, - const std::unordered_set& constant_ids, + const std::function& is_constant, bool contiguous, int ndim, bool dynamic_dims, bool use_big_index = false, int work_per_thread = 1) { - // All outputs should have the exact same shape and will be row contiguous - auto output_shape = outputs[0].shape(); - auto output_strides = outputs[0].strides(); - - // Constants are scalars that are captured by value and cannot change - auto is_constant = [&constant_ids](const array& x) { - return constant_ids.find(x.id()) != constant_ids.end(); - }; - NodeNamer namer; bool add_indices = false; int cnt = 0; @@ -45,14 +34,15 @@ inline void build_kernel( "[[host_name(\"{0}\")]]\n[[kernel]] void {0}(\n", kernel_name); // Add the input arguments - for (auto& x : inputs) { - auto& xname = namer.get_name(x); - + for (size_t i = 0; i < inputs.size(); ++i) { // Skip constants from the input list - if (is_constant(x)) { + if (is_constant(i)) { continue; } + const auto& x = inputs[i]; + auto& xname = namer.get_name(x); + // Scalars and contiguous need no strides if (!is_scalar(x) && !contiguous) { add_indices = true; @@ -64,6 +54,7 @@ inline void build_kernel( cnt++); } + std::string idx_type = use_big_index ? "int64_t" : "uint"; if (add_indices) { os += fmt::format( " constant const int64_t* in_strides [[buffer({0})]],\n", cnt++); @@ -79,10 +70,11 @@ inline void build_kernel( } // Add output strides and shape to extract the indices. if (!contiguous) { - os += fmt::format( - " constant const int64_t* output_strides [[buffer({0})]],\n", cnt++); os += fmt::format( " constant const int* output_shape [[buffer({0})]],\n", cnt++); + } else { + os += fmt::format( + " constant const {0}& size [[buffer({1})]],\n", idx_type, cnt++); } if (dynamic_dims) { os += fmt::format(" constant const int& ndim [[buffer({0})]],\n", cnt++); @@ -92,13 +84,14 @@ inline void build_kernel( os += " uint3 pos [[thread_position_in_grid]],\n"; os += " uint3 grid [[threads_per_grid]]) {\n"; - std::string idx_type = use_big_index ? "int64_t" : "uint"; + os += fmt::format(" constexpr int N_ = {0};\n", work_per_thread); if (contiguous && use_big_index) { // This is only used for contiguous kernels which don't have // a third grid dimension - os += " int64_t index = pos.x + grid.x * int64_t(pos.y);\n"; + os += " int64_t index = N_ * (pos.x + grid.x * int64_t(pos.y));\n"; + } else if (contiguous) { + os += " uint index = N_ * pos.x;\n"; } else if (work_per_thread > 1) { - os += fmt::format(" constexpr int N_ = {0};\n", work_per_thread); os += fmt::format( " int xshape = output_shape[{0}];\n", dynamic_dims ? "ndim - 1" : std::to_string(ndim - 1)); @@ -110,6 +103,9 @@ inline void build_kernel( " {0} index = pos.x + grid.x * (pos.y + {0}(grid.y) * pos.z);\n", idx_type); } + if (work_per_thread > 1 && contiguous) { + os += " for (int i = 0; i < N_ && index < size; ++i) {\n"; + } // Read constant / contiguous inputs in tmps std::vector nc_inputs; @@ -117,7 +113,7 @@ inline void build_kernel( auto& x = inputs[i]; auto& xname = namer.get_name(x); - if (is_constant(x)) { + if (is_constant(i)) { auto type_str = get_type_string(x.dtype()); std::ostringstream ss; print_constant(ss, x); @@ -193,7 +189,7 @@ inline void build_kernel( } // Open per-thread loop - if (work_per_thread > 1) { + if (work_per_thread > 1 && !contiguous) { os += " for (int i = 0; i < N_ && (int(N_ * pos.x) + i) < xshape; ++i) {\n"; } @@ -216,9 +212,7 @@ inline void build_kernel( get_type_string(x.dtype()), namer.get_name(x.inputs()[0])); } else { - std::ostringstream ss; - x.primitive().print(ss); - os += ss.str(); + os += x.primitive().name(); os += "()("; for (int i = 0; i < x.inputs().size() - 1; i++) { os += fmt::format("tmp_{0}, ", namer.get_name(x.inputs()[i])); @@ -263,15 +257,11 @@ inline void build_kernel( void Compiled::eval_gpu( const std::vector& inputs, std::vector& outputs) { - // Make the name for the kernel library - if (kernel_lib_.empty()) { - kernel_lib_ = build_lib_name(inputs_, outputs_, tape_, constant_ids_); - } - // Get the kernel if someone else built it already auto& s = stream(); auto& d = metal::device(s.device); auto lib = d.get_library(kernel_lib_, [&]() { + int work_per_thread = get_work_per_thread(outputs_[0].dtype()); std::string kernel = metal::utils(); concatenate( kernel, metal::unary_ops(), metal::binary_ops(), metal::ternary_ops()); @@ -281,21 +271,38 @@ void Compiled::eval_gpu( inputs_, outputs_, tape_, - constant_ids_, + is_constant_, /* contiguous = */ true, /* ndim = */ 0, - /* dynamic_dims = */ false); + /* dynamic_dims = */ false, + /* use_big_index = */ false, + /* work_per_thread = */ 1); + if (work_per_thread > 1) { + build_kernel( + kernel, + kernel_lib_ + "_contiguous_n", + inputs_, + outputs_, + tape_, + is_constant_, + /* contiguous = */ true, + /* ndim = */ 0, + /* dynamic_dims = */ false, + /* use_big_index = */ false, + /* work_per_thread = */ work_per_thread); + } build_kernel( kernel, kernel_lib_ + "_contiguous_large", inputs_, outputs_, tape_, - constant_ids_, + is_constant_, /* contiguous = */ true, /* ndim = */ 0, /* dynamic_dims = */ false, - /* use_big_index = */ true); + /* use_big_index = */ true, + /* work_per_thread = */ work_per_thread); for (int i = 1; i < 8; i++) { build_kernel( kernel, @@ -303,7 +310,7 @@ void Compiled::eval_gpu( inputs_, outputs_, tape_, - constant_ids_, + is_constant_, /* contiguous = */ false, /* ndim = */ i, /* dynamic_dims = */ false, @@ -316,7 +323,7 @@ void Compiled::eval_gpu( inputs_, outputs_, tape_, - constant_ids_, + is_constant_, /* contiguous = */ false, /* ndim = */ i, /* dynamic_dims = */ false, @@ -330,7 +337,7 @@ void Compiled::eval_gpu( inputs_, outputs_, tape_, - constant_ids_, + is_constant_, /* contiguous = */ false, /* ndim = */ 0, /* dynamic_dims = */ true, @@ -342,7 +349,7 @@ void Compiled::eval_gpu( inputs_, outputs_, tape_, - constant_ids_, + is_constant_, /* contiguous = */ false, /* ndim = */ 0, /* dynamic_dims = */ true, @@ -351,81 +358,32 @@ void Compiled::eval_gpu( return kernel; }); - // Figure out which kernel we are using - auto& output_shape = outputs[0].shape(); - auto contiguous = compiled_check_contiguity(inputs, output_shape); - // Collapse contiguous dims to route to a faster kernel if possible. Also // handle all broadcasting. - std::vector initial_strides; - initial_strides.push_back(outputs[0].strides()); - Shape shape; - std::vector strides; - if (!contiguous) { - for (int i = 0; i < inputs.size(); i++) { - // Skip constants. - if (constant_ids_.find(inputs_[i].id()) != constant_ids_.end()) { - continue; - } - auto& x = inputs[i]; + auto [contiguous, shape, strides] = + compiled_collapse_contiguous_dims(inputs, outputs[0], is_constant_); - // Skip scalar inputs. - if (is_scalar(x)) { - continue; - } - - // Broadcast the inputs to the output shape. - Strides xstrides; - int j = 0; - for (; j < output_shape.size() - x.ndim(); j++) { - if (output_shape[j] == 1) { - xstrides.push_back(outputs[0].strides()[j]); - } else { - xstrides.push_back(0); - } - } - for (int i = 0; i < x.ndim(); i++, j++) { - if (x.shape(i) == 1) { - if (output_shape[j] == 1) { - xstrides.push_back(outputs[0].strides()[j]); - } else { - xstrides.push_back(0); - } - } else { - xstrides.push_back(x.strides()[i]); - } - } - initial_strides.push_back(std::move(xstrides)); - } - std::tie(shape, strides) = - collapse_contiguous_dims(output_shape, initial_strides, INT32_MAX); - } - - bool large; - if (contiguous) { - size_t max_size = 0; - for (auto& in : inputs) { - max_size = std::max(max_size, in.data_size()); - } - large = (max_size > UINT32_MAX); - } else { - size_t max_size = 0; - for (auto& o : outputs) { - max_size = std::max(max_size, o.size()); - } - large = (max_size > UINT32_MAX); - } + // Whether to use large index. + bool large = compiled_use_large_index(inputs, outputs, contiguous); // Get the kernel from the lib int ndim = shape.size(); bool dynamic = ndim >= 8; auto kernel_name = kernel_lib_ + (contiguous ? "_contiguous" : "_strided_"); + int work_per_thread = 1; if (!contiguous) { if (dynamic) { kernel_name += "dynamic"; } else { kernel_name += std::to_string(shape.size()); } + work_per_thread = ndim > 3 ? (large ? 4 : 2) : 1; + } else { + work_per_thread = + get_work_per_thread(outputs[0].dtype(), outputs[0].data_size()); + if (work_per_thread > 1 && !large) { + kernel_name += "_n"; + } } if (large) { kernel_name += "_large"; @@ -439,7 +397,7 @@ void Compiled::eval_gpu( int stride_idx = 1; // idx 0 is the output strides Strides in_strides; for (int i = 0; i < inputs.size(); i++) { - if (constant_ids_.find(inputs_[i].id()) != constant_ids_.end()) { + if (is_constant_(i)) { continue; } auto& x = inputs[i]; @@ -456,8 +414,7 @@ void Compiled::eval_gpu( compute_encoder.set_vector_bytes(in_strides, cnt++); } - compiled_allocate_outputs( - inputs, outputs, inputs_, constant_ids_, contiguous); + compiled_allocate_outputs(inputs, outputs, is_constant_, contiguous); // Put the outputs in for (auto& x : outputs) { @@ -466,8 +423,14 @@ void Compiled::eval_gpu( // Put the output shape and strides in if (!contiguous) { - compute_encoder.set_vector_bytes(strides[0], cnt++); compute_encoder.set_vector_bytes(shape, cnt++); + } else { + auto size = outputs[0].data_size(); + if (large) { + compute_encoder.set_bytes(size, cnt++); + } else { + compute_encoder.set_bytes(size, cnt++); + } } // Put the number of dims in if it is dynamic @@ -477,19 +440,18 @@ void Compiled::eval_gpu( // Launch the kernel if (contiguous) { - size_t nthreads = outputs[0].data_size(); + size_t nthreads = ceildiv(outputs[0].data_size(), work_per_thread); MTL::Size group_dims( std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1); - MTL::Size grid_dims = large - ? get_2d_grid_dims(outputs[0].shape(), outputs[0].strides()) + ? get_2d_grid_dims( + outputs[0].shape(), outputs[0].strides(), work_per_thread) : MTL::Size(nthreads, 1, 1); compute_encoder.dispatch_threads(grid_dims, group_dims); } else { size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1; size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1; size_t rest = outputs[0].size() / (dim0 * dim1); - int work_per_thread = ndim > 3 ? (large ? 4 : 2) : 1; dim0 = (dim0 + work_per_thread - 1) / work_per_thread; NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); int pow2; diff --git a/mlx/backend/metal/conv.cpp b/mlx/backend/metal/conv.cpp index 9075ea4c5..9eb6a6385 100644 --- a/mlx/backend/metal/conv.cpp +++ b/mlx/backend/metal/conv.cpp @@ -1,11 +1,10 @@ // Copyright © 2023-2024 Apple Inc. - #include #include #include #include -#include "mlx/backend/metal/copy.h" +#include "mlx/backend/gpu/copy.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/kernels/defines.h" @@ -156,103 +155,26 @@ void explicit_gemm_conv_group_ND_gpu( // Perform gemm std::vector copies = {in_unfolded, wt_transpose}; return steel_matmul_regular( - s, - d, - /* a = */ in_unfolded, - /* b = */ wt_transpose, - /* c = */ out, - /* M = */ implicit_M, - /* N = */ implicit_N, - /* K = */ implicit_K, - /* batch_size_out = */ groups, - /* a_cols = */ implicit_K * groups, - /* b_cols = */ implicit_K, - /* out_cols = */ implicit_N * groups, - /* a_transposed = */ false, - /* b_transposed = */ true, - /* batch_shape = */ {1}, - /* batch_strides = */ {0}, - /* A_batch_strides = */ size_t(implicit_K), - /* B_batch_strides = */ size_t(implicit_N) * implicit_K, - /* matrix_stride_out = */ size_t(implicit_N), - /*copies = */ copies); -} - -void conv_1D_gpu( - const Stream& s, - metal::Device& d, - const array& in, - const array& wt, - array out, - const std::vector& padding, - const std::vector& wt_strides, - const std::vector& wt_dilation, - const std::vector& in_dilation, - int groups, - bool flip) { - // Make conv params - MLXConvParams<1> conv_params{ - /* const int N = */ static_cast(in.shape(0)), - /* const int C = */ static_cast(in.shape(2)), - /* const int O = */ static_cast(wt.shape(0)), - /* const int iS[NDIM] = */ {static_cast(in.shape(1))}, - /* const int wS[NDIM] = */ {static_cast(wt.shape(1))}, - /* const int oS[NDIM] = */ {static_cast(out.shape(1))}, - /* const int str[NDIM] = */ {wt_strides[0]}, - /* const int pad[NDIM] = */ {padding[0]}, - /* const int kdil[NDIM] = */ {wt_dilation[0]}, - /* const int idil[NDIM] = */ {in_dilation[0]}, - /* const size_t in_strides[NDIM + 2] = */ - {in.strides()[0], in.strides()[1], in.strides()[2]}, - /* const size_t wt_strides[NDIM + 2] = */ - {wt.strides()[0], wt.strides()[1], wt.strides()[2]}, - /* const size_t out_strides[NDIM + 2] = */ - {out.strides()[0], out.strides()[1], out.strides()[2]}, - /* const int groups = */ groups, - /* const bool flip = */ flip}; - - // Direct to explicit gemm conv - if (groups > 1) { - return explicit_gemm_conv_group_ND_gpu(s, d, in, wt, out, conv_params); - } else { - return explicit_gemm_conv_ND_gpu(s, d, in, wt, out, conv_params); - } -} - -void slow_conv_2D_gpu( - const Stream& s, - metal::Device& d, - const array& in, - const array& wt, - array out, - const MLXConvParams<2>& conv_params) { - int bm = 16, bn = 8; - int tm = 4, tn = 4; - - std::ostringstream kname; - kname << "naive_conv_2d_" << type_to_name(out) << "_bm" << bm << "_bn" << bn - << "_tm" << tm << "_tn" << tn; - - // Encode and dispatch kernel - auto& compute_encoder = d.get_command_encoder(s.index); - auto kernel = d.get_kernel(kname.str()); - compute_encoder.set_compute_pipeline_state(kernel); - - size_t n_pixels = conv_params.oS[0] * conv_params.oS[1]; - - size_t grid_dim_x = (n_pixels + (tm * bm) - 1) / (tm * bm); - size_t grid_dim_y = (conv_params.O + (tn * bn) - 1) / (tn * bn); - size_t grid_dim_z = conv_params.N; - - MTL::Size group_dims = MTL::Size(bm, bn, 1); - MTL::Size grid_dims = MTL::Size(grid_dim_x, grid_dim_y, grid_dim_z); - - compute_encoder.set_input_array(in, 0); - compute_encoder.set_input_array(wt, 1); - compute_encoder.set_output_array(out, 2); - - compute_encoder.set_bytes(conv_params, 3); - compute_encoder.dispatch_threadgroups(grid_dims, group_dims); + /* const Stream& s = */ s, + /* Device& d = */ d, + /* const array& a = */ in_unfolded, + /* const array& b = */ wt_transpose, + /* array& c = */ out, + /* int M = */ implicit_M, + /* int N = */ implicit_N, + /* int K = */ implicit_K, + /* int batch_size_out = */ groups, + /* int lda = */ implicit_K * groups, + /* int ldb = */ implicit_K, + /* int ldd = */ implicit_N * groups, + /* bool transpose_a = */ false, + /* bool transpose_b = */ true, + /* std::vector& copies = */ copies, + /* Shape batch_shape = */ {1}, + /* Strides batch_strides = */ {0}, + /* int64_t A_batch_strides = */ int64_t(implicit_K), + /* int64_t B_batch_strides = */ int64_t(implicit_N) * implicit_K, + /* int64_t matrix_stride_out = */ int64_t(implicit_N)); } void implicit_gemm_conv_2D_gpu( @@ -469,6 +391,7 @@ void implicit_gemm_conv_2D_general_gpu( // Get channel iteration info int channel_k_iters = ((conv_params.C + bk - 1) / bk); int gemm_k_iters = channel_k_iters; + bool align_C = conv_params.C % bk == 0; // Fix host side helper params int sign = (conv_params.flip ? -1 : 1); @@ -497,14 +420,33 @@ void implicit_gemm_conv_2D_general_gpu( /* const int swizzle_log = */ swizzle_log}; // Determine kernel - std::ostringstream kname; - kname << "implicit_gemm_conv_2d_general_" << type_to_name(out) << "_bm" << bm - << "_bn" << bn << "_bk" << bk << "_wm" << wm << "_wn" << wn; + std::string kname; + kname.reserve(64); + concatenate( + kname, + "implicit_gemm_conv_2d_general_", + type_to_name(out), + "_bm", + bm, + "_bn", + bn, + "_bk", + bk, + "_wm", + wm, + "_wn", + wn); + std::string hash_name; + hash_name.reserve(64); + concatenate(hash_name, kname, "_alC_", align_C); + metal::MTLFCList func_consts = { + {&align_C, MTL::DataType::DataTypeBool, 200}, + }; // Encode and dispatch kernel auto& compute_encoder = d.get_command_encoder(s.index); - auto kernel = - get_steel_conv_general_kernel(d, kname.str(), out, bm, bn, bk, wm, wn); + auto kernel = get_steel_conv_general_kernel( + d, kname, hash_name, func_consts, out, bm, bn, bk, wm, wn); compute_encoder.set_compute_pipeline_state(kernel); // Deduce grid launch dimensions @@ -755,7 +697,7 @@ void depthwise_conv_2D_gpu( std::string hash_name = kname.str(); auto& compute_encoder = d.get_command_encoder(s.index); - auto kernel = d.get_kernel(base_name, "mlx", hash_name, func_consts); + auto kernel = d.get_kernel(base_name, hash_name, func_consts); compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_input_array(in, 0); @@ -771,6 +713,143 @@ void depthwise_conv_2D_gpu( compute_encoder.dispatch_threadgroups(grid_dims, group_dims); } +void dispatch_conv_2D_gpu( + const Stream& s, + metal::Device& d, + const array& in, + const array& wt, + array out, + const MLXConvParams<2>& conv_params, + std::vector& copies) { + bool is_stride_one = conv_params.str[0] == 1 && conv_params.str[1] == 1; + bool is_kdil_one = conv_params.kdil[0] == 1 && conv_params.kdil[1] == 1; + bool is_idil_one = conv_params.idil[0] == 1 && conv_params.idil[1] == 1; + + if (is_idil_one && conv_params.groups > 1) { + const int C_per_group = conv_params.C / conv_params.groups; + const int O_per_group = conv_params.O / conv_params.groups; + + if (C_per_group == 1 && O_per_group == 1 && is_kdil_one && + conv_params.wS[0] <= 7 && conv_params.wS[1] <= 7 && + conv_params.str[0] <= 2 && conv_params.str[1] <= 2 && + conv_params.oS[0] % 8 == 0 && conv_params.oS[1] % 8 == 0 && + conv_params.wt_strides[1] == conv_params.wS[1] && + conv_params.C % 16 == 0 && conv_params.C == conv_params.O) { + return depthwise_conv_2D_gpu(s, d, in, wt, out, conv_params); + } + + if ((C_per_group <= 4 || C_per_group % 16 == 0) && + (O_per_group <= 16 || O_per_group % 16 == 0)) { + return implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params); + } else { + return explicit_gemm_conv_group_ND_gpu(s, d, in, wt, out, conv_params); + } + } + + // Direct to winograd conv + bool inp_large = + (conv_params.N * conv_params.iS[0] * conv_params.iS[1]) >= 4096; + bool channels_large = (conv_params.C + conv_params.O) >= 256; + bool out_large = + (conv_params.N * conv_params.oS[0] * conv_params.oS[1]) >= 256; + if (!conv_params.flip && is_stride_one && is_kdil_one && is_idil_one && + conv_params.wS[0] == 3 && conv_params.wS[1] == 3 && + conv_params.C % 32 == 0 && conv_params.O % 32 == 0 && inp_large && + channels_large) { + return winograd_conv_2D_gpu(s, d, in, wt, out, conv_params, copies); + } + + // Direct to implicit gemm conv + if (is_idil_one && (conv_params.C <= 4 || conv_params.C % 16 == 0) && + (conv_params.O <= 16 || conv_params.O % 16 == 0)) { + return implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params); + } + + else if ((conv_params.C % 16 == 0 && conv_params.O % 16 == 0) || out_large) { + return implicit_gemm_conv_2D_general_gpu(s, d, in, wt, out, conv_params); + } + + // Direct to explicit gemm conv + else { + return explicit_gemm_conv_ND_gpu(s, d, in, wt, out, conv_params); + } +} + +void conv_1D_gpu( + const Stream& s, + metal::Device& d, + const array& in, + const array& wt, + array out, + const std::vector& padding, + const std::vector& wt_strides, + const std::vector& wt_dilation, + const std::vector& in_dilation, + int groups, + bool flip, + std::vector& copies) { + bool is_idil_one = in_dilation[0] == 1; + int C = in.shape(2); + int O = wt.shape(0); + const int C_per_group = in.shape(2) / groups; + const int O_per_group = wt.shape(0) / groups; + + // Direct to implicit gemm conv + if (is_idil_one && (C_per_group <= 4 || C_per_group % 16 == 0) && + (O_per_group <= 16 || O_per_group % 16 == 0)) { + MLXConvParams<2> conv_params{ + /* const int N = */ static_cast(in.shape(0)), + /* const int C = */ C, + /* const int O = */ O, + /* const int iS[NDIM] = */ {static_cast(in.shape(1)), 1}, + /* const int wS[NDIM] = */ {static_cast(wt.shape(1)), 1}, + /* const int oS[NDIM] = */ {static_cast(out.shape(1)), 1}, + /* const int str[NDIM] = */ {wt_strides[0], 1}, + /* const int pad[NDIM] = */ {padding[0], 0}, + /* const int kdil[NDIM] = */ {wt_dilation[0], 1}, + /* const int idil[NDIM] = */ {in_dilation[0], 1}, + /* const size_t in_strides[NDIM + 2] = */ + {in.strides()[0], in.strides()[1], 0, in.strides()[2]}, + /* const size_t wt_strides[NDIM + 2] = */ + {wt.strides()[0], wt.strides()[1], 0, wt.strides()[2]}, + /* const size_t out_strides[NDIM + 2] = */ + {out.strides()[0], out.strides()[1], 0, out.strides()[2]}, + /* const int groups = */ groups, + /* const bool flip = */ flip}; + + dispatch_conv_2D_gpu(s, d, in, wt, out, conv_params, copies); + return; + } + + // Make conv params + MLXConvParams<1> conv_params{ + /* const int N = */ static_cast(in.shape(0)), + /* const int C = */ static_cast(in.shape(2)), + /* const int O = */ static_cast(wt.shape(0)), + /* const int iS[NDIM] = */ {static_cast(in.shape(1))}, + /* const int wS[NDIM] = */ {static_cast(wt.shape(1))}, + /* const int oS[NDIM] = */ {static_cast(out.shape(1))}, + /* const int str[NDIM] = */ {wt_strides[0]}, + /* const int pad[NDIM] = */ {padding[0]}, + /* const int kdil[NDIM] = */ {wt_dilation[0]}, + /* const int idil[NDIM] = */ {in_dilation[0]}, + /* const size_t in_strides[NDIM + 2] = */ + {in.strides()[0], in.strides()[1], in.strides()[2]}, + /* const size_t wt_strides[NDIM + 2] = */ + {wt.strides()[0], wt.strides()[1], wt.strides()[2]}, + /* const size_t out_strides[NDIM + 2] = */ + {out.strides()[0], out.strides()[1], out.strides()[2]}, + /* const int groups = */ groups, + /* const bool flip = */ flip}; + + // Direct to explicit gemm conv + if (groups > 1) { + return explicit_gemm_conv_group_ND_gpu(s, d, in, wt, out, conv_params); + } else { + return explicit_gemm_conv_ND_gpu(s, d, in, wt, out, conv_params); + } +} + void conv_2D_gpu( const Stream& s, metal::Device& d, @@ -808,57 +887,7 @@ void conv_2D_gpu( /* const int groups = */ groups, /* const bool flip = */ flip, }; - - bool is_stride_one = conv_params.str[0] == 1 && conv_params.str[1] == 1; - bool is_kdil_one = conv_params.kdil[0] == 1 && conv_params.kdil[1] == 1; - bool is_idil_one = conv_params.idil[0] == 1 && conv_params.idil[1] == 1; - - if (is_idil_one && groups > 1) { - const int C_per_group = conv_params.C / groups; - const int O_per_group = conv_params.O / groups; - - if (C_per_group == 1 && O_per_group == 1 && is_kdil_one && - conv_params.wS[0] <= 7 && conv_params.wS[1] <= 7 && - conv_params.str[0] <= 2 && conv_params.str[1] <= 2 && - conv_params.oS[0] % 8 == 0 && conv_params.oS[1] % 8 == 0 && - conv_params.wt_strides[1] == conv_params.wS[1] && - conv_params.C % 16 == 0 && conv_params.C == conv_params.O) { - return depthwise_conv_2D_gpu(s, d, in, wt, out, conv_params); - } - - if ((C_per_group <= 4 || C_per_group % 16 == 0) && - (O_per_group <= 16 || O_per_group % 16 == 0)) { - return implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params); - } else { - return explicit_gemm_conv_group_ND_gpu(s, d, in, wt, out, conv_params); - } - } - - // Direct to winograd conv - bool inp_large = - (conv_params.N * conv_params.iS[0] * conv_params.iS[1]) >= 1ul << 12; - bool channels_large = (conv_params.C + conv_params.O) >= 256; - if (!flip && is_stride_one && is_kdil_one && is_idil_one && - conv_params.wS[0] == 3 && conv_params.wS[1] == 3 && - conv_params.C % 32 == 0 && conv_params.O % 32 == 0 && inp_large && - channels_large) { - return winograd_conv_2D_gpu(s, d, in, wt, out, conv_params, copies); - } - - // Direct to implicit gemm conv - if (is_idil_one && (conv_params.C <= 4 || conv_params.C % 16 == 0) && - (conv_params.O <= 16 || conv_params.O % 16 == 0)) { - return implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params); - } - - else if (conv_params.C % 16 == 0 && conv_params.O % 16 == 0) { - return implicit_gemm_conv_2D_general_gpu(s, d, in, wt, out, conv_params); - } - - // Direct to explicit gemm conv - else { - return explicit_gemm_conv_ND_gpu(s, d, in, wt, out, conv_params); - } + dispatch_conv_2D_gpu(s, d, in, wt, out, conv_params, copies); } void conv_3D_gpu( @@ -952,7 +981,7 @@ void Convolution::eval_gpu(const std::vector& inputs, array& out) { in, wt, out, - padding_, + padding_lo_, kernel_strides_, kernel_dilation_, input_dilation_, @@ -967,7 +996,7 @@ void Convolution::eval_gpu(const std::vector& inputs, array& out) { in, wt, out, - padding_, + padding_lo_, kernel_strides_, kernel_dilation_, input_dilation_, @@ -983,12 +1012,13 @@ void Convolution::eval_gpu(const std::vector& inputs, array& out) { in, wt, out, - padding_, + padding_lo_, kernel_strides_, kernel_dilation_, input_dilation_, groups_, - flip_); + flip_, + copies); } // Throw error else { diff --git a/mlx/backend/metal/copy.cpp b/mlx/backend/metal/copy.cpp index 3399201de..915fc69fd 100644 --- a/mlx/backend/metal/copy.cpp +++ b/mlx/backend/metal/copy.cpp @@ -1,35 +1,15 @@ // Copyright © 2023-2024 Apple Inc. -#include - +#include "mlx/backend/gpu/copy.h" #include "mlx/backend/common/utils.h" -#include "mlx/backend/metal/copy.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/utils.h" -#include "mlx/primitives.h" namespace mlx::core { constexpr int MAX_COPY_SPECIALIZED_DIMS = 3; -void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) { - bool donated = set_copy_output_data(in, out, ctype); - if (donated && in.dtype() == out.dtype()) { - // If the output has the same type as the input then there is nothing to - // copy, just use the buffer. - return; - } - if (ctype == CopyType::GeneralGeneral) { - ctype = CopyType::General; - } - copy_gpu_inplace(in, out, ctype, s); -} - -void copy_gpu(const array& in, array& out, CopyType ctype) { - copy_gpu(in, out, ctype, out.primitive().stream()); -} - void copy_gpu_inplace( const array& in, array& out, @@ -75,10 +55,10 @@ void copy_gpu_inplace( std::string kernel_name; switch (ctype) { case CopyType::Scalar: - kernel_name = (large ? "s2" : "s"); + kernel_name = large ? "s2" : "s"; break; case CopyType::Vector: - kernel_name = (large ? "v2" : "v"); + kernel_name = large ? "v2" : "v"; break; case CopyType::General: kernel_name = "g"; @@ -104,6 +84,11 @@ void copy_gpu_inplace( "[Copy::eval_gpu] Dynamic output offset requires GeneralGeneral copy"); } } + } else { + work_per_thread = get_work_per_thread(out.dtype(), out.data_size()); + if (!large && work_per_thread > 1) { + kernel_name += "n"; + } } concatenate(kernel_name, "_copy", type_to_name(in), type_to_name(out)); auto kernel = dynamic ? get_dynamic_copy_kernel(d, kernel_name, in, out) @@ -165,48 +150,33 @@ void copy_gpu_inplace( MTL::Size grid_dims = MTL::Size(dim0, dim1, rest); compute_encoder.dispatch_threads(grid_dims, group_dims); } else { - size_t nthreads = out.data_size(); + size_t nthreads = ceildiv(out.data_size(), work_per_thread); if (thread_group_size > nthreads) { thread_group_size = nthreads; } MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); - MTL::Size grid_dims = large ? get_2d_grid_dims(out.shape(), out.strides()) - : MTL::Size(nthreads, 1, 1); + MTL::Size grid_dims; + if (large) { + compute_encoder.set_bytes(out.data_size(), 2); + grid_dims = get_2d_grid_dims(out.shape(), out.strides(), work_per_thread); + } else { + compute_encoder.set_bytes(out.data_size(), 2); + grid_dims = MTL::Size(nthreads, 1, 1); + } compute_encoder.dispatch_threads(grid_dims, group_dims); } } -void copy_gpu_inplace( - const array& in, - array& out, - CopyType ctype, - const Stream& s) { - assert(in.shape() == out.shape()); - return copy_gpu_inplace( - in, out, in.shape(), in.strides(), out.strides(), 0, 0, ctype, s); -} - -void copy_gpu_inplace( - const array& in, - array& out, - const Strides& i_strides, - int64_t i_offset, - CopyType ctype, - const Stream& s) { - assert(in.shape() == out.shape()); - return copy_gpu_inplace( - in, out, in.shape(), i_strides, out.strides(), i_offset, 0, ctype, s); -} - void fill_gpu(const array& val, array& out, const Stream& s) { if (out.size() == 0) { return; } out.set_data(allocator::malloc(out.nbytes())); bool large = out.data_size() > UINT32_MAX; + int work_per_thread = get_work_per_thread(out.dtype(), out.data_size()); auto& d = metal::device(s.device); - std::string kernel_name = std::string(large ? "s2" : "s") + "_copy" + - type_to_name(val) + type_to_name(out); + std::string kernel_name = large ? "s2" : (work_per_thread > 1 ? "sn" : "s"); + concatenate(kernel_name, "_copy", type_to_name(val), type_to_name(out)); auto kernel = get_copy_kernel(d, kernel_name, val, out); auto& compute_encoder = d.get_command_encoder(s.index); compute_encoder.set_compute_pipeline_state(kernel); @@ -215,13 +185,19 @@ void fill_gpu(const array& val, array& out, const Stream& s) { compute_encoder.set_output_array(out, 1); auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); - size_t nthreads = out.data_size(); + size_t nthreads = ceildiv(out.data_size(), work_per_thread); if (thread_group_size > nthreads) { thread_group_size = nthreads; } MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); - MTL::Size grid_dims = large ? get_2d_grid_dims(out.shape(), out.strides()) - : MTL::Size(nthreads, 1, 1); + MTL::Size grid_dims; + if (large) { + compute_encoder.set_bytes(out.data_size(), 2); + grid_dims = get_2d_grid_dims(out.shape(), out.strides(), work_per_thread); + } else { + compute_encoder.set_bytes(out.data_size(), 2); + grid_dims = MTL::Size(nthreads, 1, 1); + } compute_encoder.dispatch_threads(grid_dims, group_dims); } diff --git a/mlx/backend/metal/custom_kernel.cpp b/mlx/backend/metal/custom_kernel.cpp index 8a672289a..161503a0e 100644 --- a/mlx/backend/metal/custom_kernel.cpp +++ b/mlx/backend/metal/custom_kernel.cpp @@ -1,12 +1,326 @@ // Copyright © 2024 Apple Inc. -#include "mlx/backend/metal/copy.h" +#include +#include + +#include "mlx/backend/common/compiled.h" +#include "mlx/backend/gpu/copy.h" #include "mlx/backend/metal/jit/includes.h" #include "mlx/backend/metal/utils.h" +#include "mlx/fast.h" #include "mlx/fast_primitives.h" +#include "mlx/utils.h" namespace mlx::core::fast { +struct CustomKernelCache { + std::unordered_map libraries; +}; + +static CustomKernelCache& cache() { + static CustomKernelCache cache_; + return cache_; +}; + +std::string write_signature( + std::string func_name, + const std::string& header, + const std::string& source, + const std::vector& input_names, + const std::vector& inputs, + const std::vector& output_names, + const std::vector& output_dtypes, + const std::vector>& template_args, + const std::vector& attributes, + const std::vector& shape_infos, + bool atomic_outputs) { + std::string kernel_source; + kernel_source.reserve(header.size() + source.size() + 16384); + kernel_source += header; + // Auto-generate a function signature based on `template_args` + // and the dtype/shape of the arrays passed as `inputs`. + if (!template_args.empty()) { + kernel_source += "template <"; + int i = 0; + for (const auto& [name, arg] : template_args) { + std::string param_type; + if (std::holds_alternative(arg)) { + param_type = "int"; + } else if (std::holds_alternative(arg)) { + param_type = "bool"; + } else if (std::holds_alternative(arg)) { + param_type = "typename"; + } + if (i > 0) { + kernel_source += ", "; + } + kernel_source += param_type; + kernel_source += " "; + kernel_source += name; + i++; + } + kernel_source += ">\n"; + } + kernel_source += "[[kernel]] void "; + kernel_source += func_name; + kernel_source += "(\n"; + + int index = 0; + constexpr int max_constant_array_size = 8; + // Add inputs + for (int i = 0; i < inputs.size(); ++i) { + const auto& name = input_names[i]; + const auto& arr = inputs[i]; + auto dtype = get_type_string(arr.dtype()); + std::string location = + arr.size() < max_constant_array_size ? "constant" : "device"; + std::string ref = arr.ndim() == 0 ? "&" : "*"; + kernel_source += " const "; + kernel_source += location; + kernel_source += " "; + kernel_source += dtype; + kernel_source += ref; + kernel_source += " "; + kernel_source += name; + kernel_source += " [[buffer("; + kernel_source += std::to_string(index); + kernel_source += ")]],\n"; + index++; + // Add input shape, strides and ndim if present in the source + if (arr.ndim() > 0) { + if (shape_infos[i].shape) { + kernel_source += + (" const constant int* " + name + "_shape [[buffer(" + + std::to_string(index) + ")]],\n"); + index++; + } + if (shape_infos[i].strides) { + kernel_source += + (" const constant int64_t* " + name + "_strides [[buffer(" + + std::to_string(index) + ")]],\n"); + index++; + } + if (shape_infos[i].ndim) { + kernel_source += + (" const constant int& " + name + "_ndim [[buffer(" + + std::to_string(index) + ")]],\n"); + index++; + } + } + } + // Add outputs + for (int i = 0; i < output_names.size(); ++i) { + const auto& name = output_names[i]; + const auto& dtype = output_dtypes[i]; + kernel_source += " device "; + auto type_string = get_type_string(dtype); + if (atomic_outputs) { + kernel_source += "atomic<"; + } + kernel_source += type_string; + if (atomic_outputs) { + kernel_source += ">"; + } + kernel_source += "* "; + kernel_source += name; + kernel_source += " [[buffer("; + kernel_source += std::to_string(index); + kernel_source += ")]]"; + if (index < inputs.size() + output_names.size() - 1 || + attributes.size() > 0) { + kernel_source += ",\n"; + } else { + kernel_source += ") {\n"; + } + index++; + } + + index = 0; + for (const auto& attr : attributes) { + kernel_source += attr; + if (index < attributes.size() - 1) { + kernel_source += ",\n"; + } else { + kernel_source += ") {\n"; + } + index++; + } + kernel_source += source; + kernel_source += "\n}\n"; + return kernel_source; +} + +std::string write_template( + const std::vector>& template_args) { + std::ostringstream template_def; + template_def << "<"; + int i = 0; + for (const auto& [name, arg] : template_args) { + if (i > 0) { + template_def << ", "; + } + if (std::holds_alternative(arg)) { + template_def << std::get(arg); + } else if (std::holds_alternative(arg)) { + template_def << std::get(arg); + } else if (std::holds_alternative(arg)) { + template_def << get_type_string(std::get(arg)); + } + i++; + } + template_def << ">"; + return template_def.str(); +} + +MetalKernelFunction metal_kernel( + const std::string& name, + const std::vector& input_names, + const std::vector& output_names, + const std::string& source, + const std::string& header /* = "" */, + bool ensure_row_contiguous /* = true */, + bool atomic_outputs /* = false */) { + if (output_names.empty()) { + throw std::invalid_argument( + "[metal_kernel] Must specify at least one output."); + } + std::vector shape_infos; + for (auto& n : input_names) { + CustomKernelShapeInfo shape_info; + shape_info.shape = source.find(n + "_shape") != std::string::npos; + shape_info.strides = source.find(n + "_strides") != std::string::npos; + shape_info.ndim = source.find(n + "_ndim") != std::string::npos; + shape_infos.push_back(shape_info); + } + const std::vector> metal_attributes = { + {"dispatch_quadgroups_per_threadgroup", "uint"}, + {"dispatch_simdgroups_per_threadgroup", "uint"}, + {"dispatch_threads_per_threadgroup", "uint3"}, + {"grid_origin", "uint3"}, + {"grid_size", "uint3"}, + {"quadgroup_index_in_threadgroup", "uint"}, + {"quadgroups_per_threadgroup", "uint"}, + {"simdgroup_index_in_threadgroup", "uint"}, + {"simdgroups_per_threadgroup", "uint"}, + {"thread_execution_width", "uint"}, + {"thread_index_in_quadgroup", "uint"}, + {"thread_index_in_simdgroup", "uint"}, + {"thread_index_in_threadgroup", "uint"}, + {"thread_position_in_grid", "uint3"}, + {"thread_position_in_threadgroup", "uint3"}, + {"threadgroup_position_in_grid", "uint3"}, + {"threadgroups_per_grid", "uint3"}, + {"threads_per_grid", "uint3"}, + {"threads_per_simdgroup", "uint"}, + {"threads_per_threadgroup", "uint3"}, + }; + + std::vector attributes; + for (const auto& [attr, dtype] : metal_attributes) { + if (source.find(attr) != std::string::npos) { + attributes.push_back(" " + dtype + " " + attr + " [[" + attr + "]]"); + } + } + + return [=, + shape_infos = std::move(shape_infos), + attributes = std::move(attributes)]( + const std::vector& inputs, + const std::vector& output_shapes, + const std::vector& output_dtypes, + std::tuple grid, + std::tuple threadgroup, + const std::vector>& + template_args = {}, + std::optional init_value = std::nullopt, + bool verbose = false, + StreamOrDevice s_ = {}) { + if (inputs.size() != input_names.size()) { + std::ostringstream msg; + msg << "[metal_kernel] Expected `inputs` to have size " + << input_names.size() << " but got size " << inputs.size() << "." + << std::endl; + throw std::invalid_argument(msg.str()); + } + if (output_shapes.size() != output_names.size()) { + std::ostringstream msg; + msg << "[metal_kernel] Expected `output_shapes` to have size " + << output_names.size() << " but got size " << output_shapes.size() + << "." << std::endl; + throw std::invalid_argument(msg.str()); + } + if (output_dtypes.size() != output_names.size()) { + std::ostringstream msg; + msg << "[metal_kernel] Expected `output_dtypes` to have size " + << output_names.size() << " but got size " << output_dtypes.size() + << "." << std::endl; + throw std::invalid_argument(msg.str()); + } + + auto s = to_stream(s_); + if (s.device != Device::gpu) { + throw std::invalid_argument("[metal_kernel] Only supports the GPU."); + } + + std::string kernel_name = "custom_kernel_" + name; + std::string template_def = ""; + if (!template_args.empty()) { + std::regex disallowed_chars("\\<|\\>|(, )"); + template_def = write_template(template_args); + auto template_hash = + std::regex_replace(template_def, disallowed_chars, "_"); + template_hash.pop_back(); + kernel_name += "_"; + kernel_name += template_hash; + } + + std::string kernel_source = write_signature( + kernel_name, + header, + source, + input_names, + inputs, + output_names, + output_dtypes, + template_args, + attributes, + shape_infos, + atomic_outputs); + + if (!template_args.empty()) { + template_def = kernel_name + template_def; + kernel_source += "\ntemplate [[host_name(\""; + kernel_source += kernel_name; + kernel_source += "\")]] [[kernel]] decltype("; + kernel_source += template_def; + kernel_source += ") "; + kernel_source += template_def; + kernel_source += ";\n"; + } + + if (verbose) { + std::cout << "Generated source code for `" << name << "`:" << std::endl + << "```" << std::endl + << kernel_source << std::endl + << "```" << std::endl; + } + + return array::make_arrays( + std::move(output_shapes), + std::move(output_dtypes), + std::make_shared( + s, + std::move(kernel_name), + std::move(kernel_source), + grid, + threadgroup, + shape_infos, + ensure_row_contiguous, + init_value), + std::move(inputs)); + }; +} + void CustomKernel::eval_gpu( const std::vector& inputs, std::vector& outputs) { @@ -39,9 +353,23 @@ void CustomKernel::eval_gpu( } auto& d = metal::device(s.device); - const auto& lib_name = name_; - auto lib = - d.get_library(lib_name, [this] { return metal::utils() + source_; }); + + { + // Clear kernels from the device library cache if needed + auto& kernel_cache = cache(); + if (auto it = kernel_cache.libraries.find(name_); + it != kernel_cache.libraries.end()) { + if (it->second != source_) { + auto& d = metal::device(s.device); + d.clear_library(name_); + it->second = source_; + } + } else { + kernel_cache.libraries.emplace(name_, source_); + } + } + + auto lib = d.get_library(name_, [this] { return metal::utils() + source_; }); auto kernel = d.get_kernel(name_, lib); auto& compute_encoder = d.get_command_encoder(s.index); compute_encoder.set_compute_pipeline_state(kernel); @@ -73,6 +401,16 @@ void CustomKernel::eval_gpu( } const auto [tx, ty, tz] = threadgroup_; + auto tg_size = tx * ty * tz; + auto max_tg_size = kernel->maxTotalThreadsPerThreadgroup(); + if (tg_size > max_tg_size) { + std::ostringstream msg; + msg << "Thread group size (" << tg_size << ") is greater than " + << " the maximum allowed threads per threadgroup (" << max_tg_size + << ")."; + throw std::invalid_argument(msg.str()); + } + const auto [gx, gy, gz] = grid_; MTL::Size group_dims = MTL::Size(std::min(tx, gx), std::min(ty, gy), std::min(tz, gz)); diff --git a/mlx/backend/metal/device.cpp b/mlx/backend/metal/device.cpp index 95aeb1cc9..e22d9da2d 100644 --- a/mlx/backend/metal/device.cpp +++ b/mlx/backend/metal/device.cpp @@ -3,15 +3,13 @@ #include #include -#include - #define NS_PRIVATE_IMPLEMENTATION #define CA_PRIVATE_IMPLEMENTATION #define MTL_PRIVATE_IMPLEMENTATION +#include "mlx/backend/common/utils.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/metal.h" -#include "mlx/backend/metal/metal_impl.h" #include "mlx/backend/metal/utils.h" #include "mlx/utils.h" @@ -66,8 +64,8 @@ MTL::Library* try_load_bundle( if (bundle != nullptr) { std::string resource_path = std::string(bundle->resourceURL()->fileSystemRepresentation()) + "/" + - lib_name + ".metallib" auto [lib, error] = - load_library_from_path(device, resource_path.c_str()); + lib_name + ".metallib"; + auto [lib, error] = load_library_from_path(device, resource_path.c_str()); if (lib) { return lib; } @@ -79,12 +77,13 @@ MTL::Library* try_load_bundle( // Firstly, search for the metallib in the same path as this binary std::pair load_colocated_library( MTL::Device* device, - const std::string& lib_name) { - std::string lib_path = get_colocated_mtllib_path(lib_name); - if (lib_path.size() != 0) { - return load_library_from_path(device, lib_path.c_str()); + const std::string& relative_path) { + auto path = current_binary_dir() / relative_path; + if (!path.has_extension()) { + path.replace_extension(".metallib"); } - return {nullptr, nullptr}; + + return load_library_from_path(device, path.c_str()); } std::pair load_swiftpm_library( @@ -99,7 +98,7 @@ std::pair load_swiftpm_library( auto bundles = NS::Bundle::allBundles(); for (int i = 0, c = (int)bundles->count(); i < c; i++) { auto bundle = reinterpret_cast(bundles->object(i)); - library = try_load_bundle(device, bundle->resourceURL()); + library = try_load_bundle(device, bundle->resourceURL(), lib_name); if (library != nullptr) { return {library, nullptr}; } @@ -109,33 +108,34 @@ std::pair load_swiftpm_library( } MTL::Library* load_default_library(MTL::Device* device) { - NS::Error *error1, *error2, *error3; + NS::Error* error[4]; MTL::Library* lib; // First try the colocated mlx.metallib - std::tie(lib, error1) = load_colocated_library(device, "mlx"); + std::tie(lib, error[0]) = load_colocated_library(device, "mlx"); + if (lib) { + return lib; + } + + std::tie(lib, error[1]) = load_colocated_library(device, "Resources/mlx"); if (lib) { return lib; } // Then try default.metallib in a SwiftPM bundle if we have one - std::tie(lib, error2) = load_swiftpm_library(device, "default"); + std::tie(lib, error[2]) = load_swiftpm_library(device, "default"); if (lib) { return lib; } // Finally try default_mtllib_path - std::tie(lib, error3) = load_library_from_path(device, default_mtllib_path); + std::tie(lib, error[3]) = load_library_from_path(device, default_mtllib_path); if (!lib) { std::ostringstream msg; msg << "Failed to load the default metallib. "; - if (error1 != nullptr) { - msg << error1->localizedDescription()->utf8String() << " "; - } - if (error2 != nullptr) { - msg << error2->localizedDescription()->utf8String() << " "; - } - if (error3 != nullptr) { - msg << error3->localizedDescription()->utf8String() << " "; + for (int i = 0; i < 4; i++) { + if (error[i] != nullptr) { + msg << error[i]->localizedDescription()->utf8String() << " "; + } } throw std::runtime_error(msg.str()); } @@ -156,6 +156,7 @@ MTL::Library* load_library( << error->localizedDescription()->utf8String(); throw std::runtime_error(msg.str()); } + return lib; } // We have been given a path so try to load from lib_path / lib_name.metallib @@ -168,6 +169,7 @@ MTL::Library* load_library( << "> with error " << error->localizedDescription()->utf8String(); throw std::runtime_error(msg.str()); } + return lib; } // Try to load the colocated library @@ -188,8 +190,8 @@ MTL::Library* load_library( std::ostringstream msg; msg << "Failed to load the metallib " << lib_name << ".metallib. " - << "We attempted to load it from <" << get_colocated_mtllib_path(lib_name) - << ">"; + << "We attempted to load it from <" << current_binary_dir() << "/" + << lib_name << ".metallib" << ">"; #ifdef SWIFTPM_BUNDLE msg << " and from the Swift PM bundle."; #endif @@ -286,8 +288,11 @@ void CommandEncoder::barrier() { Device::Device() { auto pool = new_scoped_memory_pool(); device_ = load_device(); - library_map_ = {{"mlx", load_default_library(device_)}}; + default_library_ = load_default_library(device_); arch_ = std::string(device_->architecture()->name()->utf8String()); + int ag_tens = arch_[arch_.size() - 3] - '0'; + int ag_ones = arch_[arch_.size() - 2] - '0'; + arch_gen_ = ag_tens * 10 + ag_ones; auto arch = arch_.back(); switch (arch) { case 'p': // phone @@ -317,11 +322,11 @@ Device::Device() { Device::~Device() { auto pool = new_scoped_memory_pool(); - for (auto& k : kernel_map_) { - k.second->release(); - } - for (auto& l : library_map_) { - l.second->release(); + for (auto& [l, kernel_map] : library_kernels_) { + l->release(); + for (auto& [_, k] : kernel_map) { + k->release(); + } } stream_map_.clear(); device_->release(); @@ -465,13 +470,24 @@ CommandEncoder& Device::get_command_encoder(int index) { return *stream.encoder; } -void Device::register_library( - const std::string& lib_name, - const std::string& lib_path) { - if (auto it = library_map_.find(lib_name); it == library_map_.end()) { - auto new_lib = load_library(device_, lib_name, lib_path.c_str()); - library_map_.insert({lib_name, new_lib}); +MTL::Library* Device::get_library( + const std::string& name, + const std::string& path /* = "" */) { + { + std::shared_lock rlock(library_mtx_); + if (auto it = library_map_.find(name); it != library_map_.end()) { + return it->second; + } } + + std::unique_lock wlock(library_mtx_); + if (auto it = library_map_.find(name); it != library_map_.end()) { + return it->second; + } + + auto new_lib = load_library(device_, name, path.c_str()); + library_map_.insert({name, new_lib}); + return new_lib; } MTL::Library* Device::build_library_(const std::string& source_string) { @@ -640,6 +656,19 @@ MTL::Library* Device::get_library( return mtl_lib; } +void Device::clear_library(const std::string& name) { + std::unique_lock wlock(library_mtx_); + if (auto it = library_map_.find(name); it != library_map_.end()) { + auto kernel_map_it = library_kernels_.find(it->second); + for (auto& [_, kernel] : kernel_map_it->second) { + kernel->release(); + } + library_kernels_.erase(kernel_map_it); + it->second->release(); + library_map_.erase(it); + } +} + MTL::LinkedFunctions* Device::get_linked_functions_( const std::vector& funcs) { if (funcs.empty()) { @@ -670,6 +699,7 @@ MTL::ComputePipelineState* Device::get_kernel_( std::unique_lock wlock(kernel_mtx_); // Try loading again to avoid loading twice + auto& kernel_map_ = library_kernels_[mtl_lib]; if (auto it = kernel_map_.find(hash_name); it != kernel_map_.end()) { return it->second; } @@ -704,6 +734,7 @@ MTL::ComputePipelineState* Device::get_kernel( std::shared_lock lock(kernel_mtx_); // Look for cached kernel + auto& kernel_map_ = library_kernels_[mtl_lib]; if (auto it = kernel_map_.find(kname); it != kernel_map_.end()) { return it->second; } @@ -713,23 +744,11 @@ MTL::ComputePipelineState* Device::get_kernel( MTL::ComputePipelineState* Device::get_kernel( const std::string& base_name, - const std::string& lib_name /* = "mlx" */, const std::string& hash_name /* = "" */, const MTLFCList& func_consts /* = {} */, const std::vector& linked_functions /* = {} */) { - const auto& kname = hash_name.size() == 0 ? base_name : hash_name; - { - // Multiple readers allowed - std::shared_lock lock(kernel_mtx_); - - // Look for cached kernel - if (auto it = kernel_map_.find(kname); it != kernel_map_.end()) { - return it->second; - } - } - // Search for cached metal lib - MTL::Library* mtl_lib = get_library_(lib_name); - return get_kernel_(base_name, mtl_lib, kname, func_consts, linked_functions); + return get_kernel( + base_name, default_library_, hash_name, func_consts, linked_functions); } void Device::set_residency_set(const MTL::ResidencySet* residency_set) { @@ -760,42 +779,4 @@ std::unique_ptr> new_scoped_memory_pool() { NS::AutoreleasePool::alloc()->init(), dtor); } -void new_stream(Stream stream) { - if (stream.device == mlx::core::Device::gpu) { - device(stream.device).new_queue(stream.index); - } -} - -const std::unordered_map>& -device_info() { - auto init_device_info = []() - -> std::unordered_map> { - auto pool = new_scoped_memory_pool(); - auto raw_device = device(default_device()).mtl_device(); - auto name = std::string(raw_device->name()->utf8String()); - auto arch = std::string(raw_device->architecture()->name()->utf8String()); - - size_t memsize = 0; - size_t length = sizeof(memsize); - sysctlbyname("hw.memsize", &memsize, &length, NULL, 0); - - size_t rsrc_limit = 0; - sysctlbyname("iogpu.rsrc_limit", &rsrc_limit, &length, NULL, 0); - if (rsrc_limit == 0) { - rsrc_limit = 499000; - } - - return { - {"device_name", name}, - {"architecture", arch}, - {"max_buffer_length", raw_device->maxBufferLength()}, - {"max_recommended_working_set_size", - raw_device->recommendedMaxWorkingSetSize()}, - {"memory_size", memsize}, - {"resource_limit", rsrc_limit}}; - }; - static auto device_info_ = init_device_info(); - return device_info_; -} - } // namespace mlx::core::metal diff --git a/mlx/backend/metal/device.h b/mlx/backend/metal/device.h index bb0e93147..52595e6e6 100644 --- a/mlx/backend/metal/device.h +++ b/mlx/backend/metal/device.h @@ -3,8 +3,6 @@ #pragma once #include -#include -#include #include #include #include @@ -15,26 +13,8 @@ #include "mlx/array.h" #include "mlx/device.h" -namespace fs = std::filesystem; - namespace mlx::core::metal { -// Note, this function must be left inline in a header so that it is not -// dynamically linked. -inline std::string get_colocated_mtllib_path(const std::string& lib_name) { - Dl_info info; - std::string mtllib_path; - std::string lib_ext = lib_name + ".metallib"; - - int success = dladdr((void*)get_colocated_mtllib_path, &info); - if (success) { - auto mtllib = fs::path(info.dli_fname).remove_filename() / lib_ext; - mtllib_path = mtllib.c_str(); - } - - return mtllib_path; -} - using MTLFCList = std::vector>; @@ -99,6 +79,10 @@ struct CommandEncoder { return enc_->setBytes(&v, sizeof(T), idx); } + void set_threadgroup_memory_length(size_t length, int idx) { + enc_->setThreadgroupMemoryLength(length, idx); + } + ConcurrentContext start_concurrent() { return ConcurrentContext(*this); } @@ -177,6 +161,10 @@ class Device { return arch_; } + int get_architecture_gen() const { + return arch_gen_; + } + void new_queue(int index); MTL::CommandQueue* get_queue(Stream stream); @@ -187,14 +175,16 @@ class Device { CommandEncoder& get_command_encoder(int index); void end_encoding(int index); - void register_library( - const std::string& lib_name, - const std::string& lib_path = ""); + MTL::Library* get_library( + const std::string& name, + const std::string& path = ""); MTL::Library* get_library( const std::string& name, const std::function& builder); + void clear_library(const std::string& name); + MTL::ComputePipelineState* get_kernel( const std::string& base_name, MTL::Library* mtl_lib, @@ -204,7 +194,6 @@ class Device { MTL::ComputePipelineState* get_kernel( const std::string& base_name, - const std::string& lib_name = "mlx", const std::string& hash_name = "", const MTLFCList& func_consts = {}, const std::vector& linked_functions = {}); @@ -258,16 +247,22 @@ class Device { std::unordered_map stream_map_; std::shared_mutex kernel_mtx_; - std::unordered_map kernel_map_; - std::shared_mutex library_mtx_; std::unordered_map library_map_; + MTL::Library* default_library_; + std::unordered_map< + MTL::Library*, + std::unordered_map> + library_kernels_; const MTL::ResidencySet* residency_set_{nullptr}; std::string arch_; + int arch_gen_; int max_ops_per_buffer_; int max_mb_per_buffer_; }; Device& device(mlx::core::Device); +std::unique_ptr> new_scoped_memory_pool(); + } // namespace mlx::core::metal diff --git a/mlx/backend/metal/distributed.cpp b/mlx/backend/metal/distributed.cpp index 82e8fff7d..a800d2e0f 100644 --- a/mlx/backend/metal/distributed.cpp +++ b/mlx/backend/metal/distributed.cpp @@ -4,7 +4,7 @@ #include "mlx/allocator.h" #include "mlx/backend/common/utils.h" -#include "mlx/backend/metal/copy.h" +#include "mlx/backend/gpu/copy.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/utils.h" #include "mlx/distributed/ops.h" diff --git a/mlx/backend/metal/eval.cpp b/mlx/backend/metal/eval.cpp new file mode 100644 index 000000000..49783200a --- /dev/null +++ b/mlx/backend/metal/eval.cpp @@ -0,0 +1,102 @@ +// Copyright © 2023-2024 Apple Inc. +#include + +#include "mlx/backend/gpu/available.h" +#include "mlx/backend/gpu/eval.h" +#include "mlx/backend/metal/device.h" +#include "mlx/backend/metal/utils.h" +#include "mlx/primitives.h" +#include "mlx/scheduler.h" + +namespace mlx::core::gpu { + +bool is_available() { + return true; +} + +void new_stream(Stream stream) { + if (stream.device == mlx::core::Device::gpu) { + metal::device(stream.device).new_queue(stream.index); + } +} + +inline void check_error(MTL::CommandBuffer* cbuf) { + if (cbuf->status() == MTL::CommandBufferStatusError) { + std::ostringstream msg; + msg << "[METAL] Command buffer execution failed: " + << cbuf->error()->localizedDescription()->utf8String(); + throw std::runtime_error(msg.str()); + } +} + +void eval(array& arr) { + auto pool = metal::new_scoped_memory_pool(); + auto s = arr.primitive().stream(); + auto& d = metal::device(s.device); + auto command_buffer = d.get_command_buffer(s.index); + + auto outputs = arr.outputs(); + { + // If the array is a tracer hold a reference + // to its inputs so they don't get donated + std::vector inputs; + if (arr.is_tracer()) { + inputs = arr.inputs(); + } + + debug_set_primitive_buffer_label(command_buffer, arr.primitive()); + arr.primitive().eval_gpu(arr.inputs(), outputs); + } + std::unordered_set> buffers; + for (auto& in : arr.inputs()) { + buffers.insert(in.data_shared_ptr()); + } + for (auto& s : arr.siblings()) { + buffers.insert(s.data_shared_ptr()); + } + // Remove the output if it was donated to by an input + if (auto it = buffers.find(arr.data_shared_ptr()); it != buffers.end()) { + buffers.erase(it); + } + + if (d.command_buffer_needs_commit(s.index)) { + d.end_encoding(s.index); + scheduler::notify_new_task(s); + command_buffer->addCompletedHandler( + [s, buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) { + scheduler::notify_task_completion(s); + check_error(cbuf); + }); + d.commit_command_buffer(s.index); + d.get_command_buffer(s.index); + } else { + command_buffer->addCompletedHandler( + [s, buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) { + check_error(cbuf); + }); + } +} + +void finalize(Stream s) { + auto pool = metal::new_scoped_memory_pool(); + auto& d = metal::device(s.device); + auto cb = d.get_command_buffer(s.index); + d.end_encoding(s.index); + cb->addCompletedHandler([s](MTL::CommandBuffer* cbuf) { check_error(cbuf); }); + d.commit_command_buffer(s.index); + d.get_command_buffer(s.index); +} + +void synchronize(Stream s) { + auto pool = metal::new_scoped_memory_pool(); + auto& d = metal::device(s.device); + auto cb = d.get_command_buffer(s.index); + cb->retain(); + d.end_encoding(s.index); + d.commit_command_buffer(s.index); + cb->waitUntilCompleted(); + check_error(cb); + cb->release(); +} + +} // namespace mlx::core::gpu diff --git a/mlx/backend/metal/event.cpp b/mlx/backend/metal/event.cpp index 246d6bcc5..eb7f1b58a 100644 --- a/mlx/backend/metal/event.cpp +++ b/mlx/backend/metal/event.cpp @@ -2,7 +2,6 @@ #include "mlx/event.h" #include "mlx/backend/metal/device.h" -#include "mlx/backend/metal/metal_impl.h" #include "mlx/scheduler.h" namespace mlx::core { diff --git a/mlx/backend/metal/fence.cpp b/mlx/backend/metal/fence.cpp index e784d34ae..5abdf7309 100644 --- a/mlx/backend/metal/fence.cpp +++ b/mlx/backend/metal/fence.cpp @@ -1,7 +1,6 @@ // Copyright © 2024 Apple Inc. #include "mlx/fence.h" #include "mlx/backend/metal/device.h" -#include "mlx/backend/metal/metal_impl.h" #include "mlx/scheduler.h" #include "mlx/utils.h" @@ -139,7 +138,7 @@ void Fence::update(Stream stream, const array& x) { compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_input_array(x, 0); compute_encoder.set_bytes(nthreads, 1); - compute_encoder.dispatch_threadgroups(group_dims, grid_dims); + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); // Barrier on previous kernels compute_encoder.barrier(); diff --git a/mlx/backend/metal/fft.cpp b/mlx/backend/metal/fft.cpp index 153c62c02..1e23160a6 100644 --- a/mlx/backend/metal/fft.cpp +++ b/mlx/backend/metal/fft.cpp @@ -7,10 +7,10 @@ #include "mlx/3rdparty/pocketfft.h" #include "mlx/backend/common/utils.h" +#include "mlx/backend/gpu/copy.h" +#include "mlx/backend/gpu/slicing.h" #include "mlx/backend/metal/binary.h" -#include "mlx/backend/metal/copy.h" #include "mlx/backend/metal/kernels.h" -#include "mlx/backend/metal/slicing.h" #include "mlx/backend/metal/unary.h" #include "mlx/backend/metal/utils.h" #include "mlx/utils.h" @@ -632,7 +632,7 @@ void fft_op( func_consts.push_back(make_int(&rader_m, 3)); // The overall number of FFTs we're going to compute for this input - int size = out.dtype() == float32 ? out.size() : in.size(); + size_t size = out.dtype() == float32 ? out.size() : in.size(); if (real && inverse && four_step_params.required) { size = out.size(); } @@ -659,8 +659,6 @@ void fft_op( // We can perform 2 RFFTs at once so the batch size is halved. batch_size = (batch_size + 2 - 1) / 2; } - int out_buffer_size = out.size(); - auto& compute_encoder = d.get_command_encoder(s.index); auto in_type_str = in.dtype() == float32 ? "float" : "float2"; auto out_type_str = out.dtype() == float32 ? "float" : "float2"; diff --git a/mlx/backend/metal/hadamard.cpp b/mlx/backend/metal/hadamard.cpp index a7dfc5f17..65a877151 100644 --- a/mlx/backend/metal/hadamard.cpp +++ b/mlx/backend/metal/hadamard.cpp @@ -1,11 +1,9 @@ // Copyright © 2024 Apple Inc. -#include - -#include "mlx/backend/common/compiled.h" #include "mlx/backend/common/hadamard.h" +#include "mlx/backend/common/compiled.h" #include "mlx/backend/common/utils.h" -#include "mlx/backend/metal/copy.h" +#include "mlx/backend/gpu/copy.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/jit/includes.h" #include "mlx/backend/metal/kernels.h" @@ -15,7 +13,6 @@ namespace mlx::core { constexpr int MAX_HADAMARD_THREADS_PER_GROUP = 256; -constexpr int MAX_HADAMARD_BYTES = 32768; // 32KB std::string gen_hadamard_codelet(int m) { // Generate a O(m^2) hadamard codelet for a given M @@ -60,121 +57,142 @@ std::string gen_hadamard_codelet(int m) { return source.str(); } -void Hadamard::eval_gpu(const std::vector& inputs, array& out) { - auto& s = stream(); +void hadamard_mn_contiguous( + const array& x, + array& y, + int m, + int n1, + int n2, + float scale, + metal::Device& d, + const Stream& s) { + int n = n1 * n2; + int read_width_n1 = n1 == 2 ? 2 : 4; + int read_width_n2 = n2 == 2 ? 2 : 4; + int read_width_m = (n == 2 || m == 28) ? 2 : 4; + int max_radix_1 = std::min(n1, 16); + int max_radix_2 = std::min(n2, 16); + float scale_n1 = 1.0; + float scale_n2 = (m == 1) ? scale : 1.0; + float scale_m = scale; - auto& in = inputs[0]; + // n2 is a row contiguous power of 2 hadamard transform + MTL::Size group_dims_n2(n2 / max_radix_2, 1, 1); + MTL::Size grid_dims_n2(n2 / max_radix_2, x.size() / n2, 1); - std::vector copies; - // Only support the last axis for now - int axis = in.ndim() - 1; - auto check_input = [&copies, &s](const array& x) { - // TODO(alexbarron) pass strides to kernel to relax this constraint - bool no_copy = x.flags().row_contiguous; - if (no_copy) { - return x; - } else { - copies.push_back(array(x.shape(), x.dtype(), nullptr, {})); - copy_gpu(x, copies.back(), CopyType::General, s); - return copies.back(); + // n1 is a strided power of 2 hadamard transform with stride n2 + MTL::Size group_dims_n1(n1 / max_radix_1, 1, 1); + MTL::Size grid_dims_n1(n1 / max_radix_1, x.size() / n, n2); + + // m is a strided hadamard transform with stride n = n1 * n2 + MTL::Size group_dims_m( + std::min(n / read_width_m, MAX_HADAMARD_THREADS_PER_GROUP), 1, 1); + MTL::Size grid_dims_m( + group_dims_m.width, x.size() / m / read_width_m / group_dims_m.width, 1); + + // Make the kernel + std::string kname; + kname.reserve(32); + concatenate(kname, "hadamard_", n * m, "_", type_to_name(x)); + auto lib = d.get_library(kname, [&]() { + std::string kernel; + concatenate( + kernel, + metal::utils(), + gen_hadamard_codelet(m), + metal::hadamard(), + get_template_definition( + "n2" + kname, + "hadamard_n", + get_type_string(x.dtype()), + n2, + max_radix_2, + read_width_n2)); + if (n1 > 1) { + kernel += get_template_definition( + "n1" + kname, + "hadamard_n", + get_type_string(x.dtype()), + n1, + max_radix_1, + read_width_n1, + n2); } - }; - const array& in_contiguous = check_input(in); - - if (in_contiguous.is_donatable()) { - out.copy_shared_buffer(in_contiguous); - } else { - out.set_data(allocator::malloc(out.nbytes())); - } - - int n, m; - std::tie(n, m) = decompose_hadamard(in.shape(axis)); - - if (n * (int)size_of(in.dtype()) > MAX_HADAMARD_BYTES) { - throw std::invalid_argument( - "[hadamard] For n = m*2^k, 2^k > 8192 for FP32 or 2^k > 16384 for FP16/BF16 NYI"); - } - - int max_radix = std::min(n, 16); - // Use read_width 2 for m = 28 to avoid register spilling - int read_width = (n == 2 || m == 28) ? 2 : 4; - - std::ostringstream kname; - kname << "hadamard_" << n * m << "_" << type_to_name(out); - auto kernel_name = kname.str(); - auto& d = metal::device(s.device); - const auto& lib_name = kernel_name; - auto lib = d.get_library(lib_name, [&]() { - std::ostringstream kernel_source; - auto codelet = gen_hadamard_codelet(m); - kernel_source << metal::utils() << codelet << metal::hadamard(); - kernel_source << get_template_definition( - "n" + kernel_name, - "hadamard_n", - get_type_string(in.dtype()), - n, - max_radix, - read_width); - kernel_source << get_template_definition( - "m" + kernel_name, - "hadamard_m", - get_type_string(in.dtype()), - n, - m, - read_width); - return kernel_source.str(); + if (m > 1) { + kernel += get_template_definition( + "m" + kname, + "hadamard_m", + get_type_string(x.dtype()), + n, + m, + read_width_m); + } + return kernel; }); - int batch_size = in.size() / n; - int threads_per = n / max_radix; - - auto& compute_encoder = d.get_command_encoder(s.index); - - auto launch_hadamard = [&](const array& in, - array& out, - const std::string& kernel_name, - float scale) { - auto kernel = d.get_kernel(kernel_name, lib); - assert(threads_per <= kernel->maxTotalThreadsPerThreadgroup()); - + // Launch the strided transform for n1 + if (n1 > 1) { + auto& compute_encoder = d.get_command_encoder(s.index); + auto kernel = d.get_kernel("n1" + kname, lib); compute_encoder.set_compute_pipeline_state(kernel); - compute_encoder.set_input_array(in, 0); - compute_encoder.set_output_array(out, 1); - compute_encoder.set_bytes(scale, 2); - - MTL::Size group_dims = MTL::Size(1, threads_per, 1); - MTL::Size grid_dims = MTL::Size(batch_size, threads_per, 1); - compute_encoder.dispatch_threads(grid_dims, group_dims); - }; - - if (m > 1) { - // When m is greater than 1, we decompose the - // computation into two uploads to the GPU: - // - // e.g. len(x) = 12*4 = 48, m = 12, n = 4 - // - // y = h48 @ x - // - // Upload 1: - // tmp = a.reshape(12, 4) @ h4 - // - // Upload 2: - // y = h12 @ tmp - array temp(in.shape(), in.dtype(), nullptr, {}); - temp.set_data(allocator::malloc(temp.nbytes())); - copies.push_back(temp); - - launch_hadamard(in_contiguous, temp, "n" + kernel_name, 1.0); - - // Metal sometimes reports 256 max threads per group for hadamard_m kernel - threads_per = std::min(n / read_width, MAX_HADAMARD_THREADS_PER_GROUP); - batch_size = in.size() / m / read_width / threads_per; - launch_hadamard(temp, out, "m" + kernel_name, scale_); - } else { - launch_hadamard(in_contiguous, out, "n" + kernel_name, scale_); + compute_encoder.set_input_array(x, 0); + compute_encoder.set_output_array(y, 1); + compute_encoder.set_bytes(scale_n1, 2); + compute_encoder.dispatch_threads(grid_dims_n1, group_dims_n1); } - d.add_temporaries(std::move(copies), s.index); + // Launch the transform for n2 + auto& compute_encoder = d.get_command_encoder(s.index); + auto kernel = d.get_kernel("n2" + kname, lib); + compute_encoder.set_compute_pipeline_state(kernel); + compute_encoder.set_input_array(n1 > 1 ? y : x, 0); + compute_encoder.set_output_array(y, 1); + compute_encoder.set_bytes(scale_n2, 2); + compute_encoder.dispatch_threads(grid_dims_n2, group_dims_n2); + + // Launch the strided transform for m + if (m > 1) { + auto kernel = d.get_kernel("m" + kname, lib); + compute_encoder.set_compute_pipeline_state(kernel); + compute_encoder.set_input_array(y, 0); + compute_encoder.set_output_array(y, 1); + compute_encoder.set_bytes(scale_m, 2); + compute_encoder.dispatch_threads(grid_dims_m, group_dims_m); + } +} + +void Hadamard::eval_gpu(const std::vector& inputs, array& out) { + auto& s = stream(); + auto& d = metal::device(s.device); + auto& in = inputs[0]; + + // Split the hadamard transform so that all of them work on vectors smaller + // than 8192 elements. + // + // We decompose it in the following way: + // + // n = m * n1 * n2 = m * 2^k1 * 2^k2 + // + // where m is in (1, 12, 20, 28) and n1 and n2 <= 8192 + auto [n, m] = decompose_hadamard(in.shape().back()); + int n1 = 1, n2 = n; + if (n > 8192) { + for (n2 = 2; n2 * n2 < n; n2 *= 2) { + } + n1 = n / n2; + } + + if (in.flags().row_contiguous) { + if (in.is_donatable()) { + out.copy_shared_buffer(in); + } else { + out.set_data(allocator::malloc(out.nbytes())); + } + hadamard_mn_contiguous(in, out, m, n1, n2, scale_, d, s); + } else { + copy_gpu(in, out, CopyType::General, s); + hadamard_mn_contiguous(out, out, m, n1, n2, scale_, d, s); + } } } // namespace mlx::core diff --git a/mlx/backend/metal/indexing.cpp b/mlx/backend/metal/indexing.cpp index d2a263051..13ce88a62 100644 --- a/mlx/backend/metal/indexing.cpp +++ b/mlx/backend/metal/indexing.cpp @@ -2,7 +2,8 @@ #include #include "mlx/backend/common/compiled.h" -#include "mlx/backend/metal/copy.h" +#include "mlx/backend/common/utils.h" +#include "mlx/backend/gpu/copy.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/jit/includes.h" #include "mlx/backend/metal/jit/indexing.h" @@ -458,17 +459,9 @@ void GatherAxis::eval_gpu(const std::vector& inputs, array& out) { compute_encoder.set_output_array(out, 2); // Set source info - auto shape = idx.shape(); - shape.erase(shape.begin() + axis_); - compute_encoder.set_vector_bytes(shape, 3); - - auto strides = src.strides(); - strides.erase(strides.begin() + axis_); - compute_encoder.set_vector_bytes(strides, 4); - - strides = idx.strides(); - strides.erase(strides.begin() + axis_); - compute_encoder.set_vector_bytes(strides, 5); + compute_encoder.set_vector_bytes(remove_index(idx.shape(), axis_), 3); + compute_encoder.set_vector_bytes(remove_index(src.strides(), axis_), 4); + compute_encoder.set_vector_bytes(remove_index(idx.strides(), axis_), 5); compute_encoder.set_bytes(ndim - 1, 6); compute_encoder.set_bytes(axis_, 7); compute_encoder.set_bytes(src.shape(axis_), 8); @@ -582,17 +575,17 @@ void ScatterAxis::eval_gpu(const std::vector& inputs, array& out) { compute_encoder.set_output_array(out, 2); // Set source info - auto shape = idx.shape(); - shape.erase(shape.begin() + axis_); - compute_encoder.set_vector_bytes(shape, 3); - - auto strides = upd.strides(); - strides.erase(strides.begin() + axis_); - compute_encoder.set_vector_bytes(strides, 4); - - strides = idx.strides(); - strides.erase(strides.begin() + axis_); - compute_encoder.set_vector_bytes(strides, 5); + if (ndim > 1) { + compute_encoder.set_vector_bytes(remove_index(idx.shape(), axis_), 3); + compute_encoder.set_vector_bytes(remove_index(upd.strides(), axis_), 4); + compute_encoder.set_vector_bytes(remove_index(idx.strides(), axis_), 5); + } else { + // The following will be ignored in the kernel but we still have to set + // some value so that metal validation passes. + compute_encoder.set_vector_bytes(idx.shape(), 3); + compute_encoder.set_vector_bytes(upd.strides(), 4); + compute_encoder.set_vector_bytes(idx.strides(), 5); + } compute_encoder.set_bytes(ndim - 1, 6); compute_encoder.set_bytes(axis_, 7); compute_encoder.set_bytes(out.shape(axis_), 8); diff --git a/mlx/backend/metal/jit/includes.h b/mlx/backend/metal/jit/includes.h index 27ae22d05..b380a8374 100644 --- a/mlx/backend/metal/jit/includes.h +++ b/mlx/backend/metal/jit/includes.h @@ -34,6 +34,7 @@ const char* steel_gemm_fused(); const char* steel_gemm_masked(); const char* steel_gemm_splitk(); const char* steel_gemm_gather(); +const char* steel_gemm_segmented(); const char* conv(); const char* steel_conv(); const char* steel_conv_general(); diff --git a/mlx/backend/metal/jit_kernels.cpp b/mlx/backend/metal/jit_kernels.cpp index 5206c9b54..6ae72e0aa 100644 --- a/mlx/backend/metal/jit_kernels.cpp +++ b/mlx/backend/metal/jit_kernels.cpp @@ -8,12 +8,6 @@ using namespace fmt::literals; namespace mlx::core { -std::string op_name(const array& arr) { - std::ostringstream op_t; - arr.primitive().print(op_t); - return op_t.str(); -} - MTL::ComputePipelineState* get_arange_kernel( metal::Device& d, const std::string& kernel_name, @@ -33,7 +27,7 @@ MTL::ComputePipelineState* get_unary_kernel( const std::string& kernel_name, Dtype in_type, Dtype out_type, - const std::string op) { + const char* op) { std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1); auto lib = d.get_library(lib_name, [&]() { auto in_t = get_type_string(in_type); @@ -41,7 +35,11 @@ MTL::ComputePipelineState* get_unary_kernel( std::string kernel_source = metal::utils(); concatenate(kernel_source, metal::unary_ops(), metal::unary()); kernel_source += - get_template_definition("v_" + lib_name, "unary_v", in_t, out_t, op); + get_template_definition("v_" + lib_name, "unary_v", in_t, out_t, op, 1); + if (get_work_per_thread(in_type) > 1) { + kernel_source += + get_template_definition("vn_" + lib_name, "unary_v", in_t, out_t, op); + } kernel_source += get_template_definition("v2_" + lib_name, "unary_v2", in_t, out_t, op); kernel_source += get_template_definition( @@ -54,16 +52,13 @@ MTL::ComputePipelineState* get_unary_kernel( } void append_binary_kernels( - const std::string lib_name, + const std::string& lib_name, Dtype in_type, Dtype out_type, - const std::string op, + const char* op, std::string& kernel_source) { - const std::array, 10> kernel_types = {{ + const std::array, 7> kernel_types = {{ {"ss", "binary_ss"}, - {"vs", "binary_vs"}, - {"sv", "binary_sv"}, - {"vv", "binary_vv"}, {"vs2", "binary_vs2"}, {"sv2", "binary_sv2"}, {"vv2", "binary_vv2"}, @@ -78,6 +73,22 @@ void append_binary_kernels( kernel_source += get_template_definition(name + "_" + lib_name, func, in_t, out_t, op); } + kernel_source += get_template_definition( + "vs_" + lib_name, "binary_vs", in_t, out_t, op, 1); + kernel_source += get_template_definition( + "sv_" + lib_name, "binary_sv", in_t, out_t, op, 1); + kernel_source += get_template_definition( + "vv_" + lib_name, "binary_vv", in_t, out_t, op, 1); + + if (get_work_per_thread(in_type) > 1) { + kernel_source += get_template_definition( + "vsn_" + lib_name, "binary_vs", in_t, out_t, op); + kernel_source += get_template_definition( + "svn_" + lib_name, "binary_sv", in_t, out_t, op); + kernel_source += get_template_definition( + "vvn_" + lib_name, "binary_vv", in_t, out_t, op); + } + kernel_source += get_template_definition( "g1_" + lib_name, "binary_g_nd1", in_t, out_t, op, "int"); kernel_source += get_template_definition( @@ -95,7 +106,7 @@ MTL::ComputePipelineState* get_binary_kernel( const std::string& kernel_name, Dtype in_type, Dtype out_type, - const std::string op) { + const char* op) { std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1); auto lib = d.get_library(lib_name, [&]() { std::string kernel_source; @@ -112,7 +123,7 @@ MTL::ComputePipelineState* get_binary_two_kernel( const std::string& kernel_name, Dtype in_type, Dtype out_type, - const std::string op) { + const char* op) { std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1); auto lib = d.get_library(lib_name, [&]() { std::string kernel_source = metal::utils(); @@ -127,14 +138,13 @@ MTL::ComputePipelineState* get_ternary_kernel( metal::Device& d, const std::string& kernel_name, Dtype type, - const std::string op) { + const char* op) { std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1); auto lib = d.get_library(lib_name, [&]() { auto t_str = get_type_string(type); std::string kernel_source = metal::utils(); concatenate(kernel_source, metal::ternary_ops(), metal::ternary()); - const std::array, 5> kernel_types = {{ - {"v", "ternary_v"}, + const std::array, 4> kernel_types = {{ {"v2", "ternary_v2"}, {"g1large", "ternary_g_nd1"}, {"g2large", "ternary_g_nd2"}, @@ -144,6 +154,13 @@ MTL::ComputePipelineState* get_ternary_kernel( kernel_source += get_template_definition(name + "_" + lib_name, func, t_str, op); } + if (get_work_per_thread(type) > 1) { + kernel_source += + get_template_definition("vn_" + lib_name, "ternary_v", t_str, op); + } + + kernel_source += + get_template_definition("v_" + lib_name, "ternary_v", t_str, op, 1); kernel_source += get_template_definition( "g1_" + lib_name, "ternary_g_nd1", t_str, op, "int"); kernel_source += get_template_definition( @@ -170,15 +187,22 @@ MTL::ComputePipelineState* get_copy_kernel( kernel_source += metal::copy(); auto in_type = get_type_string(in.dtype()); auto out_type = get_type_string(out.dtype()); - kernel_source += - get_template_definition("s_" + lib_name, "copy_s", in_type, out_type); + kernel_source += get_template_definition( + "s_" + lib_name, "copy_s", in_type, out_type, 1); kernel_source += get_template_definition("s2_" + lib_name, "copy_s2", in_type, out_type); - kernel_source += - get_template_definition("v_" + lib_name, "copy_v", in_type, out_type); + kernel_source += get_template_definition( + "v_" + lib_name, "copy_v", in_type, out_type, 1); kernel_source += get_template_definition("v2_" + lib_name, "copy_v2", in_type, out_type); + if (get_work_per_thread(out.dtype()) > 1) { + kernel_source += get_template_definition( + "sn_" + lib_name, "copy_s", in_type, out_type); + kernel_source += get_template_definition( + "vn_" + lib_name, "copy_v", in_type, out_type); + } + kernel_source += get_template_definition( "g1_" + lib_name, "copy_g_nd1", in_type, out_type, "int"); kernel_source += get_template_definition( @@ -622,6 +646,43 @@ MTL::ComputePipelineState* get_steel_gemm_gather_kernel( return d.get_kernel(kernel_name, lib, hash_name, func_consts); } +MTL::ComputePipelineState* get_steel_gemm_segmented_kernel( + metal::Device& d, + const std::string& kernel_name, + const std::string& hash_name, + const metal::MTLFCList& func_consts, + const array& out, + bool transpose_a, + bool transpose_b, + int bm, + int bn, + int bk, + int wm, + int wn) { + const auto& lib_name = kernel_name; + auto lib = d.get_library(lib_name, [&]() { + std::string kernel_source; + concatenate( + kernel_source, + metal::utils(), + metal::gemm(), + metal::steel_gemm_segmented(), + get_template_definition( + lib_name, + "segmented_mm", + get_type_string(out.dtype()), + bm, + bn, + bk, + wm, + wn, + transpose_a, + transpose_b)); + return kernel_source; + }); + return d.get_kernel(kernel_name, lib, hash_name, func_consts); +} + MTL::ComputePipelineState* get_gemv_masked_kernel( metal::Device& d, const std::string& kernel_name, @@ -697,6 +758,8 @@ MTL::ComputePipelineState* get_steel_conv_kernel( MTL::ComputePipelineState* get_steel_conv_general_kernel( metal::Device& d, const std::string& kernel_name, + const std::string& hash_name, + const metal::MTLFCList& func_consts, const array& out, int bm, int bn, @@ -719,7 +782,7 @@ MTL::ComputePipelineState* get_steel_conv_general_kernel( wn); return kernel_source.str(); }); - return d.get_kernel(kernel_name, lib); + return d.get_kernel(kernel_name, lib, hash_name, func_consts); } MTL::ComputePipelineState* get_fft_kernel( diff --git a/mlx/backend/metal/kernels.h b/mlx/backend/metal/kernels.h index 6d8864385..ca29ca52e 100644 --- a/mlx/backend/metal/kernels.h +++ b/mlx/backend/metal/kernels.h @@ -19,27 +19,27 @@ MTL::ComputePipelineState* get_unary_kernel( const std::string& kernel_name, Dtype in_type, Dtype out_type, - const std::string op); + const char* op); MTL::ComputePipelineState* get_binary_kernel( metal::Device& d, const std::string& kernel_name, Dtype in_type, Dtype out_type, - const std::string op); + const char* op); MTL::ComputePipelineState* get_binary_two_kernel( metal::Device& d, const std::string& kernel_name, Dtype in_type, Dtype out_type, - const std::string op); + const char* op); MTL::ComputePipelineState* get_ternary_kernel( metal::Device& d, const std::string& kernel_name, Dtype type, - const std::string op); + const char* op); MTL::ComputePipelineState* get_copy_kernel( metal::Device& d, @@ -175,6 +175,20 @@ MTL::ComputePipelineState* get_steel_gemm_gather_kernel( int wn, bool rhs); +MTL::ComputePipelineState* get_steel_gemm_segmented_kernel( + metal::Device& d, + const std::string& kernel_name, + const std::string& hash_name, + const metal::MTLFCList& func_consts, + const array& out, + bool transpose_a, + bool transpose_b, + int bm, + int bn, + int bk, + int wm, + int wn); + MTL::ComputePipelineState* get_steel_conv_kernel( metal::Device& d, const std::string& kernel_name, @@ -205,6 +219,8 @@ MTL::ComputePipelineState* get_gemv_masked_kernel( MTL::ComputePipelineState* get_steel_conv_general_kernel( metal::Device& d, const std::string& kernel_name, + const std::string& hash_name, + const metal::MTLFCList& func_consts, const array& out, int bm, int bn, @@ -241,8 +257,10 @@ MTL::ComputePipelineState* get_gather_qmm_kernel( // Create a GPU kernel template definition for JIT compilation template -std::string -get_template_definition(std::string name, std::string func, Args... args) { +std::string get_template_definition( + std::string_view name, + std::string_view func, + Args... args) { std::ostringstream s; s << func << "<"; bool first = true; diff --git a/mlx/backend/metal/kernels/CMakeLists.txt b/mlx/backend/metal/kernels/CMakeLists.txt index 3ee88ca46..4069d8c21 100644 --- a/mlx/backend/metal/kernels/CMakeLists.txt +++ b/mlx/backend/metal/kernels/CMakeLists.txt @@ -71,6 +71,7 @@ set(STEEL_HEADERS steel/gemm/kernels/steel_gemm_fused.h steel/gemm/kernels/steel_gemm_gather.h steel/gemm/kernels/steel_gemm_masked.h + steel/gemm/kernels/steel_gemm_segmented.h steel/gemm/kernels/steel_gemm_splitk.h steel/utils/type_traits.h steel/utils/integral_constant.h) @@ -120,6 +121,7 @@ if(NOT MLX_METAL_JIT) build_kernel(steel/gemm/kernels/steel_gemm_gather ${STEEL_HEADERS}) build_kernel(steel/gemm/kernels/steel_gemm_masked ${STEEL_HEADERS}) build_kernel(steel/gemm/kernels/steel_gemm_splitk ${STEEL_HEADERS}) + build_kernel(steel/gemm/kernels/steel_gemm_segmented ${STEEL_HEADERS}) build_kernel(gemv_masked steel/utils.h) endif() diff --git a/mlx/backend/metal/kernels/arg_reduce.metal b/mlx/backend/metal/kernels/arg_reduce.metal index 7f1075ad9..4a83d8e57 100644 --- a/mlx/backend/metal/kernels/arg_reduce.metal +++ b/mlx/backend/metal/kernels/arg_reduce.metal @@ -80,9 +80,10 @@ template const constant size_t& ndim [[buffer(5)]], const constant int64_t& axis_stride [[buffer(6)]], const constant size_t& axis_size [[buffer(7)]], - uint gid [[thread_position_in_grid]], - uint lid [[thread_position_in_threadgroup]], - uint lsize [[threads_per_threadgroup]], + uint3 gid [[thread_position_in_grid]], + uint3 gsize [[threads_per_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint3 lsize [[threads_per_threadgroup]], uint simd_size [[threads_per_simdgroup]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]]) { @@ -104,17 +105,18 @@ template // Compute the input/output index. There is one beginning and one output for // the whole threadgroup. - auto in_idx = elem_to_loc(gid / lsize, shape, in_strides, ndim); - auto out_idx = elem_to_loc(gid / lsize, shape, out_strides, ndim); + int64_t row_idx = gid.y + static_cast(gsize.y) * gid.z; + auto in_idx = elem_to_loc(row_idx, shape, in_strides, ndim); + auto out_idx = elem_to_loc(row_idx, shape, out_strides, ndim); IndexValPair best{0, Op::init}; threadgroup IndexValPair local_data[32]; // Loop over the reduction axis in lsize*N_READS buckets - for (uint r = 0; r < ceildiv(axis_size, N_READS * lsize); r++) { + for (uint r = 0; r < ceildiv(axis_size, N_READS * lsize.x); r++) { // Read the current value - uint32_t current_index = r * lsize * N_READS + lid * N_READS; + uint32_t current_index = r * lsize.x * N_READS + lid.x * N_READS; uint32_t offset = current_index; const device T* current_in = in + in_idx + current_index * axis_stride; T vals[N_READS]; @@ -144,7 +146,7 @@ template } // Read the appropriate value from local data and perform one simd reduction - uint simd_groups = ceildiv(lsize, simd_size); + uint simd_groups = ceildiv(lsize.x, simd_size); if (simd_lane_id < simd_groups) { best = local_data[simd_lane_id]; } @@ -154,7 +156,7 @@ template } // Finally write the output - if (lid == 0) { + if (lid.x == 0) { out[out_idx] = best.index; } } diff --git a/mlx/backend/metal/kernels/binary.h b/mlx/backend/metal/kernels/binary.h index 91a02c818..f1df88535 100644 --- a/mlx/backend/metal/kernels/binary.h +++ b/mlx/backend/metal/kernels/binary.h @@ -9,64 +9,121 @@ template c[index] = Op()(a[0], b[0]); } -template +template ::n> [[kernel]] void binary_sv( device const T* a, device const T* b, device U* c, + constant uint& size, uint index [[thread_position_in_grid]]) { - c[index] = Op()(a[0], b[index]); + index *= N; + if (N > 1 && index + N > size) { + for (int i = 0; index + i < size; ++i) { + c[index + i] = Op()(a[0], b[index + i]); + } + } else { + for (int i = 0; i < N; ++i) { + c[index + i] = Op()(a[0], b[index + i]); + } + } } -template +template ::n> [[kernel]] void binary_vs( device const T* a, device const T* b, device U* c, + constant uint& size, uint index [[thread_position_in_grid]]) { - c[index] = Op()(a[index], b[0]); + index *= N; + if (N > 1 && index + N > size) { + for (int i = 0; index + i < size; ++i) { + c[index + i] = Op()(a[index + i], b[0]); + } + } else { + for (int i = 0; i < N; ++i) { + c[index + i] = Op()(a[index + i], b[0]); + } + } } -template +template ::n> [[kernel]] void binary_vv( device const T* a, device const T* b, device U* c, + constant uint& size, uint index [[thread_position_in_grid]]) { - c[index] = Op()(a[index], b[index]); + index *= N; + if (N > 1 && index + N > size) { + for (int i = 0; index + i < size; ++i) { + c[index + i] = Op()(a[index + i], b[index + i]); + } + } else { + for (int i = 0; i < N; ++i) { + c[index + i] = Op()(a[index + i], b[index + i]); + } + } } -template +template ::n> [[kernel]] void binary_sv2( device const T* a, device const T* b, device U* c, + constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - int64_t offset = index.x + grid_dim.x * int64_t(index.y); - c[offset] = Op()(a[0], b[offset]); + int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); + if (N > 1 && offset + N > size) { + for (int i = 0; offset + i < size; ++i) { + c[offset + i] = Op()(a[0], b[offset + i]); + } + } else { + for (int i = 0; i < N; ++i) { + c[offset + i] = Op()(a[0], b[offset + i]); + } + } } -template +template ::n> [[kernel]] void binary_vs2( device const T* a, device const T* b, device U* c, + constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - int64_t offset = index.x + grid_dim.x * int64_t(index.y); - c[offset] = Op()(a[offset], b[0]); + int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); + if (N > 1 && offset + N > size) { + for (int i = 0; offset + i < size; ++i) { + c[offset + i] = Op()(a[offset + i], b[0]); + } + } else { + for (int i = 0; i < N; ++i) { + c[offset + i] = Op()(a[offset + i], b[0]); + } + } } -template +template ::n> [[kernel]] void binary_vv2( device const T* a, device const T* b, device U* c, + constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - int64_t offset = index.x + grid_dim.x * int64_t(index.y); - c[offset] = Op()(a[offset], b[offset]); + int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); + if (N > 1 && offset + N > size) { + for (int i = 0; offset + i < size; ++i) { + c[offset + i] = Op()(a[offset + i], b[offset + i]); + } + } else { + for (int i = 0; i < N; ++i) { + c[offset + i] = Op()(a[offset + i], b[offset + i]); + } + } } template diff --git a/mlx/backend/metal/kernels/binary.metal b/mlx/backend/metal/kernels/binary.metal index 3ef8e6269..17ed13c57 100644 --- a/mlx/backend/metal/kernels/binary.metal +++ b/mlx/backend/metal/kernels/binary.metal @@ -9,11 +9,16 @@ #include "mlx/backend/metal/kernels/binary_ops.h" #include "mlx/backend/metal/kernels/binary.h" -#define instantiate_binary_all(op, tname, itype, otype) \ +#define instantiate_binary_work_per_thread(op, tname, itype, otype) \ + instantiate_kernel("svn_" #op #tname, binary_sv, itype, otype, op) \ + instantiate_kernel("vsn_" #op #tname, binary_vs, itype, otype, op) \ + instantiate_kernel("vvn_" #op #tname, binary_vv, itype, otype, op) \ + +#define instantiate_binary_base(op, tname, itype, otype) \ instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \ - instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op) \ - instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op) \ - instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op) \ + instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op, 1) \ + instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op, 1) \ + instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op, 1) \ instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \ instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \ instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \ @@ -26,15 +31,19 @@ instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op, int) \ instantiate_kernel("g3large_" #op #tname, binary_g_nd3, itype, otype, op) -#define instantiate_binary_integer(op) \ - instantiate_binary_all(op, uint8, uint8_t, uint8_t) \ - instantiate_binary_all(op, uint16, uint16_t, uint16_t) \ - instantiate_binary_all(op, uint32, uint32_t, uint32_t) \ - instantiate_binary_all(op, uint64, uint64_t, uint64_t) \ - instantiate_binary_all(op, int8, int8_t, int8_t) \ - instantiate_binary_all(op, int16, int16_t, int16_t) \ - instantiate_binary_all(op, int32, int32_t, int32_t) \ - instantiate_binary_all(op, int64, int64_t, int64_t) +#define instantiate_binary_all(op, tname, itype, otype) \ + instantiate_binary_base(op, tname, itype, otype) \ + instantiate_binary_work_per_thread(op, tname, itype, otype) + +#define instantiate_binary_integer(op) \ + instantiate_binary_all(op, uint8, uint8_t, uint8_t) \ + instantiate_binary_all(op, uint16, uint16_t, uint16_t) \ + instantiate_binary_all(op, uint32, uint32_t, uint32_t) \ + instantiate_binary_base(op, uint64, uint64_t, uint64_t) \ + instantiate_binary_all(op, int8, int8_t, int8_t) \ + instantiate_binary_all(op, int16, int16_t, int16_t) \ + instantiate_binary_all(op, int32, int32_t, int32_t) \ + instantiate_binary_base(op, int64, int64_t, int64_t) #define instantiate_binary_float(op) \ instantiate_binary_all(op, float16, half, half) \ @@ -44,7 +53,7 @@ #define instantiate_binary_types(op) \ instantiate_binary_all(op, bool_, bool, bool) \ instantiate_binary_integer(op) \ - instantiate_binary_all(op, complex64, complex64_t, complex64_t) \ + instantiate_binary_base(op, complex64, complex64_t, complex64_t)\ instantiate_binary_float(op) #define instantiate_binary_types_bool(op) \ @@ -52,15 +61,15 @@ instantiate_binary_all(op, uint8, uint8_t, bool) \ instantiate_binary_all(op, uint16, uint16_t, bool) \ instantiate_binary_all(op, uint32, uint32_t, bool) \ - instantiate_binary_all(op, uint64, uint64_t, bool) \ + instantiate_binary_base(op, uint64, uint64_t, bool) \ instantiate_binary_all(op, int8, int8_t, bool) \ instantiate_binary_all(op, int16, int16_t, bool) \ instantiate_binary_all(op, int32, int32_t, bool) \ - instantiate_binary_all(op, int64, int64_t, bool) \ + instantiate_binary_base(op, int64, int64_t, bool) \ instantiate_binary_all(op, float16, half, bool) \ instantiate_binary_all(op, float32, float, bool) \ instantiate_binary_all(op, bfloat16, bfloat16_t, bool) \ - instantiate_binary_all(op, complex64, complex64_t, bool) + instantiate_binary_base(op, complex64, complex64_t, bool) instantiate_binary_types(Add) instantiate_binary_types(Divide) @@ -71,6 +80,7 @@ instantiate_binary_types_bool(Less) instantiate_binary_types_bool(LessEqual) instantiate_binary_types_bool(NotEqual) instantiate_binary_float(LogAddExp) +instantiate_binary_base(LogAddExp, complex64, complex64_t, complex64_t) instantiate_binary_types(Maximum) instantiate_binary_types(Minimum) instantiate_binary_types(Multiply) @@ -83,7 +93,7 @@ instantiate_binary_float(ArcTan2) instantiate_binary_all(NaNEqual, float16, half, bool) instantiate_binary_all(NaNEqual, float32, float, bool) instantiate_binary_all(NaNEqual, bfloat16, bfloat16_t, bool) -instantiate_binary_all(NaNEqual, complex64, complex64_t, bool) +instantiate_binary_base(NaNEqual, complex64, complex64_t, bool) instantiate_binary_all(LogicalOr, bool_, bool, bool) instantiate_binary_all(LogicalAnd, bool_, bool, bool) diff --git a/mlx/backend/metal/kernels/binary_ops.h b/mlx/backend/metal/kernels/binary_ops.h index 8f961c2cf..f4deb860e 100644 --- a/mlx/backend/metal/kernels/binary_ops.h +++ b/mlx/backend/metal/kernels/binary_ops.h @@ -130,6 +130,24 @@ struct LogAddExp { ? maxval : (maxval + log1p(metal::exp(minval - maxval))); }; + + complex64_t operator()(complex64_t x, complex64_t y) { + if (metal::isnan(x.real) || metal::isnan(x.imag) || metal::isnan(y.real) || + metal::isnan(y.imag)) { + return metal::numeric_limits::quiet_NaN(); + } + constexpr float inf = metal::numeric_limits::infinity(); + complex64_t maxval = x > y ? x : y; + complex64_t minval = x < y ? x : y; + if (minval.real == -inf || maxval.real == inf) + return maxval; + float m = metal::exp(minval.real - maxval.real); + complex64_t dexp{ + m * metal::cos(minval.imag - maxval.imag), + m * metal::sin(minval.imag - maxval.imag), + }; + return maxval + log1p(dexp); + } }; struct Maximum { @@ -217,6 +235,13 @@ struct Power { template <> complex64_t operator()(complex64_t x, complex64_t y) { + if (x.real == 0 && x.imag == 0) { + if (metal::isnan(y.real) || metal::isnan(y.imag)) { + auto nan = metal::numeric_limits::quiet_NaN(); + return {nan, nan}; + } + return {0.0, 0.0}; + } auto x_theta = metal::atan2(x.imag, x.real); auto x_ln_r = 0.5 * metal::log(x.real * x.real + x.imag * x.imag); auto mag = metal::exp(y.real * x_ln_r - y.imag * x_theta); diff --git a/mlx/backend/metal/kernels/binary_two.h b/mlx/backend/metal/kernels/binary_two.h index 8f6b3392d..4455e4ca9 100644 --- a/mlx/backend/metal/kernels/binary_two.h +++ b/mlx/backend/metal/kernels/binary_two.h @@ -12,82 +12,151 @@ template d[index] = out[1]; } -template +template ::n> [[kernel]] void binary_sv( device const T* a, device const T* b, device U* c, device U* d, + constant uint& size, uint index [[thread_position_in_grid]]) { - auto out = Op()(a[0], b[index]); - c[index] = out[0]; - d[index] = out[1]; + index *= N; + if (N > 1 && index + N > size) { + for (int i = 0; index + i < size; ++i) { + auto out = Op()(a[0], b[index + i]); + c[index + i] = out[0]; + d[index + i] = out[1]; + } + } else { + for (int i = 0; i < N; ++i) { + auto out = Op()(a[0], b[index + i]); + c[index + i] = out[0]; + d[index + i] = out[1]; + } + } } -template +template ::n> [[kernel]] void binary_vs( device const T* a, device const T* b, device U* c, device U* d, + constant uint& size, uint index [[thread_position_in_grid]]) { - auto out = Op()(a[index], b[0]); - c[index] = out[0]; - d[index] = out[1]; + index *= N; + if (N > 1 && index + N > size) { + for (int i = 0; index + i < size; ++i) { + auto out = Op()(a[index + i], b[0]); + c[index + i] = out[0]; + d[index + i] = out[1]; + } + } else { + for (int i = 0; i < N; ++i) { + auto out = Op()(a[index + i], b[0]); + c[index + i] = out[0]; + d[index + i] = out[1]; + } + } } -template +template ::n> [[kernel]] void binary_vv( device const T* a, device const T* b, device U* c, device U* d, + constant uint& size, uint index [[thread_position_in_grid]]) { - auto out = Op()(a[index], b[index]); - c[index] = out[0]; - d[index] = out[1]; + index *= N; + if (N > 1 && index + N > size) { + for (int i = 0; index + i < size; ++i) { + auto out = Op()(a[index + i], b[index + i]); + c[index + i] = out[0]; + d[index + i] = out[1]; + } + } else { + for (int i = 0; i < N; ++i) { + auto out = Op()(a[index + i], b[index + i]); + c[index + i] = out[0]; + d[index + i] = out[1]; + } + } } -template +template ::n> [[kernel]] void binary_sv2( device const T* a, device const T* b, device U* c, device U* d, + constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - auto offset = index.x + grid_dim.x * int64_t(index.y); - auto out = Op()(a[0], b[offset]); - c[offset] = out[0]; - d[offset] = out[1]; + int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); + if (N > 1 && offset + N > size) { + for (int i = 0; offset + i < size; ++i) { + auto out = Op()(a[0], b[offset + i]); + c[offset + i] = out[0]; + d[offset + i] = out[1]; + } + } else { + for (int i = 0; i < N; ++i) { + auto out = Op()(a[0], b[offset + i]); + c[offset + i] = out[0]; + d[offset + i] = out[1]; + } + } } -template +template ::n> [[kernel]] void binary_vs2( device const T* a, device const T* b, device U* c, device U* d, + constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - auto offset = index.x + grid_dim.x * int64_t(index.y); - auto out = Op()(a[offset], b[0]); - c[offset] = out[0]; - d[offset] = out[1]; + int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); + if (N > 1 && offset + N > size) { + for (int i = 0; offset + i < size; ++i) { + auto out = Op()(a[offset + i], b[0]); + c[offset + i] = out[0]; + d[offset + i] = out[1]; + } + } else { + for (int i = 0; i < N; ++i) { + auto out = Op()(a[offset + i], b[0]); + c[offset + i] = out[0]; + d[offset + i] = out[1]; + } + } } -template +template ::n> [[kernel]] void binary_vv2( device const T* a, device const T* b, device U* c, device U* d, + constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - auto offset = index.x + grid_dim.x * int64_t(index.y); - auto out = Op()(a[offset], b[offset]); - c[offset] = out[0]; - d[offset] = out[1]; + int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); + if (N > 1 && offset + N > size) { + for (int i = 0; offset + i < size; ++i) { + auto out = Op()(a[offset + i], b[offset + i]); + c[offset + i] = out[0]; + d[offset + i] = out[1]; + } + } else { + for (int i = 0; i < N; ++i) { + auto out = Op()(a[offset + i], b[offset + i]); + c[offset + i] = out[0]; + d[offset + i] = out[1]; + } + } } template diff --git a/mlx/backend/metal/kernels/binary_two.metal b/mlx/backend/metal/kernels/binary_two.metal index 984a28320..c7d3ecdf0 100644 --- a/mlx/backend/metal/kernels/binary_two.metal +++ b/mlx/backend/metal/kernels/binary_two.metal @@ -7,11 +7,16 @@ #include "mlx/backend/metal/kernels/binary_ops.h" #include "mlx/backend/metal/kernels/binary_two.h" -#define instantiate_binary_all(op, tname, itype, otype) \ +#define instantiate_binary_work_per_thread(op, tname, itype, otype) \ + instantiate_kernel("svn_" #op #tname, binary_sv, itype, otype, op) \ + instantiate_kernel("vsn_" #op #tname, binary_vs, itype, otype, op) \ + instantiate_kernel("vvn_" #op #tname, binary_vv, itype, otype, op) + +#define instantiate_binary_base(op, tname, itype, otype) \ instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \ - instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op) \ - instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op) \ - instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op) \ + instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op, 1) \ + instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op, 1) \ + instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op, 1) \ instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \ instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \ instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \ @@ -24,22 +29,26 @@ instantiate_kernel("g2large_" #op #tname, binary_g_nd2, itype, otype, op) \ instantiate_kernel("g3large_" #op #tname, binary_g_nd3, itype, otype, op) +#define instantiate_binary_all(op, tname, itype, otype) \ + instantiate_binary_base(op, tname, itype, otype) \ + instantiate_binary_work_per_thread(op, tname, itype, otype) + #define instantiate_binary_float(op) \ instantiate_binary_all(op, float16, half, half) \ instantiate_binary_all(op, float32, float, float) \ instantiate_binary_all(op, bfloat16, bfloat16_t, bfloat16_t) -#define instantiate_binary_types(op) \ - instantiate_binary_all(op, bool_, bool, bool) \ - instantiate_binary_all(op, uint8, uint8_t, uint8_t) \ - instantiate_binary_all(op, uint16, uint16_t, uint16_t) \ - instantiate_binary_all(op, uint32, uint32_t, uint32_t) \ - instantiate_binary_all(op, uint64, uint64_t, uint64_t) \ - instantiate_binary_all(op, int8, int8_t, int8_t) \ - instantiate_binary_all(op, int16, int16_t, int16_t) \ - instantiate_binary_all(op, int32, int32_t, int32_t) \ - instantiate_binary_all(op, int64, int64_t, int64_t) \ - instantiate_binary_all(op, complex64, complex64_t, complex64_t) \ +#define instantiate_binary_types(op) \ + instantiate_binary_all(op, bool_, bool, bool) \ + instantiate_binary_all(op, uint8, uint8_t, uint8_t) \ + instantiate_binary_all(op, uint16, uint16_t, uint16_t) \ + instantiate_binary_all(op, uint32, uint32_t, uint32_t) \ + instantiate_binary_base(op, uint64, uint64_t, uint64_t) \ + instantiate_binary_all(op, int8, int8_t, int8_t) \ + instantiate_binary_all(op, int16, int16_t, int16_t) \ + instantiate_binary_all(op, int32, int32_t, int32_t) \ + instantiate_binary_base(op, int64, int64_t, int64_t) \ + instantiate_binary_base(op, complex64, complex64_t, complex64_t) \ instantiate_binary_float(op) instantiate_binary_types(DivMod) // clang-format on diff --git a/mlx/backend/metal/kernels/cexpf.h b/mlx/backend/metal/kernels/cexpf.h new file mode 100644 index 000000000..b45fe6a2f --- /dev/null +++ b/mlx/backend/metal/kernels/cexpf.h @@ -0,0 +1,134 @@ +// Copyright © 2025 Apple Inc. +// Copyright © 2008-2013 NVIDIA Corporation +// Copyright © 2013 Filipe RNC Maia +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Forked from +// https://github.com/NVIDIA/cccl/blob/main/thrust/thrust/detail/complex/cexpf.h + +// TODO: We should use thrust::exp but the thrust header in old CUDA versions +// can not be used in JIT. + +#pragma once + +#include + +using ieee_float_shape_type = union { + float value; + uint32_t word; +}; + +inline void get_float_word(thread uint32_t& i, float d) { + ieee_float_shape_type gf_u; + gf_u.value = (d); + (i) = gf_u.word; +} + +inline void get_float_word(thread int32_t& i, float d) { + ieee_float_shape_type gf_u; + gf_u.value = (d); + (i) = gf_u.word; +} + +inline void set_float_word(thread float& d, uint32_t i) { + ieee_float_shape_type sf_u; + sf_u.word = (i); + (d) = sf_u.value; +} + +inline float frexp_expf(float x, thread int* expt) { + const uint32_t k = 235; + const float kln2 = 162.88958740F; + + float exp_x; + uint32_t hx; + + exp_x = metal::exp(x - kln2); + get_float_word(hx, exp_x); + *expt = (hx >> 23) - (0x7f + 127) + k; + set_float_word(exp_x, (hx & 0x7fffff) | ((0x7f + 127) << 23)); + return exp_x; +} + +inline complex64_t ldexp_cexpf(complex64_t z, int expt) { + float x, y, exp_x, scale1, scale2; + int ex_expt, half_expt; + + x = z.real; + y = z.imag; + exp_x = frexp_expf(x, &ex_expt); + expt += ex_expt; + + half_expt = expt / 2; + set_float_word(scale1, (0x7f + half_expt) << 23); + half_expt = expt - half_expt; + set_float_word(scale2, (0x7f + half_expt) << 23); + + return complex64_t{ + metal::cos(y) * exp_x * scale1 * scale2, + metal::sin(y) * exp_x * scale1 * scale2}; +} + +inline complex64_t cexpf(const thread complex64_t& z) { + float x, y, exp_x; + uint32_t hx, hy; + + const uint32_t exp_ovfl = 0x42b17218, cexp_ovfl = 0x43400074; + + x = z.real; + y = z.imag; + + get_float_word(hy, y); + hy &= 0x7fffffff; + + /* cexp(x + I 0) = exp(x) + I 0 */ + if (hy == 0) { + return complex64_t{metal::exp(x), y}; + } + get_float_word(hx, x); + /* cexp(0 + I y) = cos(y) + I sin(y) */ + if ((hx & 0x7fffffff) == 0) { + return complex64_t{metal::cos(y), metal::sin(y)}; + } + if (hy >= 0x7f800000) { + if ((hx & 0x7fffffff) != 0x7f800000) { + /* cexp(finite|NaN +- I Inf|NaN) = NaN + I NaN */ + return complex64_t{y - y, y - y}; + } else if (hx & 0x80000000) { + /* cexp(-Inf +- I Inf|NaN) = 0 + I 0 */ + return complex64_t{0.0, 0.0}; + } else { + /* cexp(+Inf +- I Inf|NaN) = Inf + I NaN */ + return complex64_t{x, y - y}; + } + } + + if (hx >= exp_ovfl && hx <= cexp_ovfl) { + /* + * x is between 88.7 and 192, so we must scale to avoid + * overflow in expf(x). + */ + return ldexp_cexpf(z, 0); + } else { + /* + * Cases covered here: + * - x < exp_ovfl and exp(x) won't overflow (common case) + * - x > cexp_ovfl, so exp(x) * s overflows for all s > 0 + * - x = +-Inf (generated by exp()) + * - x = NaN (spurious inexact exception from y) + */ + exp_x = metal::exp(x); + return complex64_t{exp_x * metal::cos(y), exp_x * metal::sin(y)}; + } +} diff --git a/mlx/backend/metal/kernels/complex.h b/mlx/backend/metal/kernels/complex.h index fe8ec5c0f..c88002cb3 100644 --- a/mlx/backend/metal/kernels/complex.h +++ b/mlx/backend/metal/kernels/complex.h @@ -104,10 +104,22 @@ constexpr bool operator==(complex64_t a, complex64_t b) { constexpr complex64_t operator+(complex64_t a, complex64_t b) { return {a.real + b.real, a.imag + b.imag}; } +constexpr complex64_t operator+(float a, complex64_t b) { + return {a + b.real, b.imag}; +} +constexpr complex64_t operator+(complex64_t a, float b) { + return {a.real + b, a.imag}; +} constexpr complex64_t operator-(complex64_t a, complex64_t b) { return {a.real - b.real, a.imag - b.imag}; } +constexpr complex64_t operator-(float a, complex64_t b) { + return {a - b.real, -b.imag}; +} +constexpr complex64_t operator-(complex64_t a, float b) { + return {a.real - b, a.imag}; +} constexpr complex64_t operator*(complex64_t a, complex64_t b) { return {a.real * b.real - a.imag * b.imag, a.real * b.imag + a.imag * b.real}; @@ -120,6 +132,13 @@ constexpr complex64_t operator/(complex64_t a, complex64_t b) { return {x / denom, y / denom}; } +constexpr complex64_t operator/(float a, complex64_t b) { + auto denom = b.real * b.real + b.imag * b.imag; + auto x = a * b.real; + auto y = -a * b.imag; + return {x / denom, y / denom}; +} + constexpr complex64_t operator%(complex64_t a, complex64_t b) { auto real = a.real - (b.real * static_cast(a.real / b.real)); auto imag = a.imag - (b.imag * static_cast(a.imag / b.imag)); diff --git a/mlx/backend/metal/kernels/copy.h b/mlx/backend/metal/kernels/copy.h index b1367cf4f..cf22347ee 100644 --- a/mlx/backend/metal/kernels/copy.h +++ b/mlx/backend/metal/kernels/copy.h @@ -1,39 +1,77 @@ // Copyright © 2024 Apple Inc. -template +template ::n> [[kernel]] void copy_s( device const T* src [[buffer(0)]], device U* dst [[buffer(1)]], + constant uint& size, uint index [[thread_position_in_grid]]) { - dst[index] = static_cast(src[0]); + index *= N; + if (N > 1 && index + N > size) { + for (int i = 0; index + i < size; ++i) { + dst[index + i] = static_cast(src[0]); + } + } else { + for (int i = 0; i < N; ++i) { + dst[index + i] = static_cast(src[0]); + } + } } -template +template ::n> [[kernel]] void copy_v( device const T* src [[buffer(0)]], device U* dst [[buffer(1)]], + constant uint& size, uint index [[thread_position_in_grid]]) { - dst[index] = static_cast(src[index]); + index *= N; + if (N > 1 && index + N > size) { + for (int i = 0; index + i < size; ++i) { + dst[index + i] = static_cast(src[index + i]); + } + } else { + for (int i = 0; i < N; ++i) { + dst[index + i] = static_cast(src[index + i]); + } + } } -template +template ::n> [[kernel]] void copy_s2( device const T* src [[buffer(0)]], device U* dst [[buffer(1)]], + constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - auto offset = index.x + grid_dim.x * int64_t(index.y); - dst[offset] = static_cast(src[0]); + int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); + if (N > 1 && offset + N > size) { + for (int i = 0; offset + i < size; ++i) { + dst[offset + i] = static_cast(src[0]); + } + } else { + for (int i = 0; i < N; ++i) { + dst[offset + i] = static_cast(src[0]); + } + } } -template +template ::n> [[kernel]] void copy_v2( device const T* src [[buffer(0)]], device U* dst [[buffer(1)]], + constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - auto offset = index.x + grid_dim.x * int64_t(index.y); - dst[offset] = static_cast(src[offset]); + int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); + if (N > 1 && offset + N > size) { + for (int i = 0; offset + i < size; ++i) { + dst[offset + i] = static_cast(src[offset + i]); + } + } else { + for (int i = 0; i < N; ++i) { + dst[offset + i] = static_cast(src[offset + i]); + } + } } template diff --git a/mlx/backend/metal/kernels/copy.metal b/mlx/backend/metal/kernels/copy.metal index bbf268158..fcf8884f8 100644 --- a/mlx/backend/metal/kernels/copy.metal +++ b/mlx/backend/metal/kernels/copy.metal @@ -4,9 +4,13 @@ #include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/copy.h" -#define instantiate_copy_all(tname, itype, otype) \ - instantiate_kernel("s_copy" #tname, copy_s, itype, otype) \ - instantiate_kernel("v_copy" #tname, copy_v, itype, otype) \ +#define instantiate_copy_work_per_thread(tname, itype, otype) \ + instantiate_kernel("sn_copy" #tname, copy_s, itype, otype) \ + instantiate_kernel("vn_copy" #tname, copy_v, itype, otype) + +#define instantiate_copy_base(tname, itype, otype) \ + instantiate_kernel("s_copy" #tname, copy_s, itype, otype, 1) \ + instantiate_kernel("v_copy" #tname, copy_v, itype, otype, 1) \ instantiate_kernel("s2_copy" #tname, copy_s2, itype, otype) \ instantiate_kernel("v2_copy" #tname, copy_v2, itype, otype) \ instantiate_kernel("g1_copy" #tname, copy_g_nd1, itype, otype, int) \ @@ -18,6 +22,10 @@ instantiate_kernel("g3large_copy" #tname, copy_g_nd3, itype, otype) \ instantiate_kernel("gn4large_copy" #tname, copy_g, itype, otype, 4) +#define instantiate_copy_all(tname, itype, otype) \ + instantiate_copy_base(tname, itype, otype) \ + instantiate_copy_work_per_thread(tname, itype, otype) + #define instantiate_copy_same(tname, type) \ instantiate_kernel("gg1_copy" #tname, copy_gg_nd1, type, type, int) \ instantiate_kernel("gg2_copy" #tname, copy_gg_nd2, type, type, int) \ @@ -42,15 +50,15 @@ instantiate_copy_all(itname ##uint8, itype, uint8_t) \ instantiate_copy_all(itname ##uint16, itype, uint16_t) \ instantiate_copy_all(itname ##uint32, itype, uint32_t) \ - instantiate_copy_all(itname ##uint64, itype, uint64_t) \ + instantiate_copy_base(itname ##uint64, itype, uint64_t) \ instantiate_copy_all(itname ##int8, itype, int8_t) \ instantiate_copy_all(itname ##int16, itype, int16_t) \ instantiate_copy_all(itname ##int32, itype, int32_t) \ - instantiate_copy_all(itname ##int64, itype, int64_t) \ + instantiate_copy_base(itname ##int64, itype, int64_t) \ instantiate_copy_all(itname ##float16, itype, half) \ instantiate_copy_all(itname ##float32, itype, float) \ instantiate_copy_all(itname ##bfloat16, itype, bfloat16_t) \ - instantiate_copy_all(itname ##complex64, itype, complex64_t) + instantiate_copy_base(itname ##complex64, itype, complex64_t) instantiate_copy_itype(bool_, bool) instantiate_copy_itype(uint8, uint8_t) diff --git a/mlx/backend/metal/kernels/fft/readwrite.h b/mlx/backend/metal/kernels/fft/readwrite.h index ab699e136..0dc62992e 100644 --- a/mlx/backend/metal/kernels/fft/readwrite.h +++ b/mlx/backend/metal/kernels/fft/readwrite.h @@ -10,7 +10,7 @@ For many sizes, GPU FFTs are memory bandwidth bound so read/write performance is important. Where possible, we read 128 bits sequentially in each thread, -coalesced with accesses from adajcent threads for optimal performance. +coalesced with accesses from adjacent threads for optimal performance. We implement specialized reading/writing for: - FFT @@ -98,7 +98,7 @@ struct ReadWriter { } METAL_FUNC void load() const { - int batch_idx = elem.x * grid.y * n; + size_t batch_idx = size_t(elem.x * grid.y) * n; short tg_idx = elem.y * grid.z + elem.z; short max_index = grid.y * n - 2; @@ -121,7 +121,7 @@ struct ReadWriter { } METAL_FUNC void write() const { - int batch_idx = elem.x * grid.y * n; + size_t batch_idx = size_t(elem.x * grid.y) * n; short tg_idx = elem.y * grid.z + elem.z; short max_index = grid.y * n - 2; @@ -144,7 +144,7 @@ struct ReadWriter { // Padded IO for Bluestein's algorithm METAL_FUNC void load_padded(int length, const device float2* w_k) const { - int batch_idx = elem.x * grid.y * length + elem.y * length; + size_t batch_idx = size_t(elem.x * grid.y) * length + elem.y * length; int fft_idx = elem.z; int m = grid.z; @@ -161,7 +161,7 @@ struct ReadWriter { } METAL_FUNC void write_padded(int length, const device float2* w_k) const { - int batch_idx = elem.x * grid.y * length + elem.y * length; + size_t batch_idx = size_t(elem.x * grid.y) * length + elem.y * length; int fft_idx = elem.z; int m = grid.z; float2 inv_factor = {1.0f / n, -1.0f / n}; @@ -261,7 +261,7 @@ METAL_FUNC bool ReadWriter::out_of_bounds() const { template <> METAL_FUNC void ReadWriter::load() const { - int batch_idx = elem.x * grid.y * n * 2 + elem.y * n * 2; + size_t batch_idx = size_t(elem.x * grid.y) * n * 2 + elem.y * n * 2; threadgroup float2* seq_buf = buf + elem.y * n; // No out of bounds accesses on odd batch sizes @@ -283,7 +283,8 @@ template <> METAL_FUNC void ReadWriter::write() const { short n_over_2 = (n / 2) + 1; - int batch_idx = elem.x * grid.y * n_over_2 * 2 + elem.y * n_over_2 * 2; + size_t batch_idx = + size_t(elem.x * grid.y) * n_over_2 * 2 + elem.y * n_over_2 * 2; threadgroup float2* seq_buf = buf + elem.y * n; int grid_index = elem.x * grid.y + elem.y; @@ -317,7 +318,7 @@ template <> METAL_FUNC void ReadWriter::load_padded( int length, const device float2* w_k) const { - int batch_idx = elem.x * grid.y * length * 2 + elem.y * length * 2; + size_t batch_idx = size_t(elem.x * grid.y) * length * 2 + elem.y * length * 2; threadgroup float2* seq_buf = buf + elem.y * n; // No out of bounds accesses on odd batch sizes @@ -345,8 +346,8 @@ METAL_FUNC void ReadWriter::write_padded( int length, const device float2* w_k) const { int length_over_2 = (length / 2) + 1; - int batch_idx = - elem.x * grid.y * length_over_2 * 2 + elem.y * length_over_2 * 2; + size_t batch_idx = + size_t(elem.x * grid.y) * length_over_2 * 2 + elem.y * length_over_2 * 2; threadgroup float2* seq_buf = buf + elem.y * n + length - 1; int grid_index = elem.x * grid.y + elem.y; @@ -397,7 +398,8 @@ METAL_FUNC bool ReadWriter::out_of_bounds() const { template <> METAL_FUNC void ReadWriter::load() const { short n_over_2 = (n / 2) + 1; - int batch_idx = elem.x * grid.y * n_over_2 * 2 + elem.y * n_over_2 * 2; + size_t batch_idx = + size_t(elem.x * grid.y) * n_over_2 * 2 + elem.y * n_over_2 * 2; threadgroup float2* seq_buf = buf + elem.y * n; // No out of bounds accesses on odd batch sizes @@ -458,8 +460,8 @@ METAL_FUNC void ReadWriter::load_padded( int n_over_2 = (n / 2) + 1; int length_over_2 = (length / 2) + 1; - int batch_idx = - elem.x * grid.y * length_over_2 * 2 + elem.y * length_over_2 * 2; + size_t batch_idx = + size_t(elem.x * grid.y) * length_over_2 * 2 + elem.y * length_over_2 * 2; threadgroup float2* seq_buf = buf + elem.y * n; // No out of bounds accesses on odd batch sizes @@ -503,7 +505,7 @@ template <> METAL_FUNC void ReadWriter::write_padded( int length, const device float2* w_k) const { - int batch_idx = elem.x * grid.y * length * 2 + elem.y * length * 2; + size_t batch_idx = size_t(elem.x * grid.y) * length * 2 + elem.y * length * 2; threadgroup float2* seq_buf = buf + elem.y * n + length - 1; int grid_index = elem.x * grid.y + elem.y; diff --git a/mlx/backend/metal/kernels/hadamard.h b/mlx/backend/metal/kernels/hadamard.h index 93e2fb8a8..9f2311c10 100644 --- a/mlx/backend/metal/kernels/hadamard.h +++ b/mlx/backend/metal/kernels/hadamard.h @@ -26,7 +26,7 @@ METAL_FUNC void radix_func(thread float* x) { } } -template +template [[kernel]] void hadamard_n( const device T* in [[buffer(0)]], device T* out [[buffer(1)]], @@ -46,18 +46,25 @@ template constexpr short logFinal = logN % logR; constexpr short final_radix = 1 << (logFinal); - int batch_idx = elem.x * N; - short i = elem.y; + int batch_idx = elem.y * N * stride + elem.z; + short i = elem.x; threadgroup T buf[N]; // Read values from device - STEEL_PRAGMA_UNROLL - for (short j = 0; j < max_radix / read_width; j++) { - short index = j * read_width * num_threads + i * read_width; + if (stride == 1) { STEEL_PRAGMA_UNROLL - for (short r = 0; r < read_width; r++) { - buf[index + r] = in[batch_idx + index + r]; + for (short j = 0; j < max_radix / read_width; j++) { + short index = j * read_width * num_threads + i * read_width; + STEEL_PRAGMA_UNROLL + for (short r = 0; r < read_width; r++) { + buf[index + r] = in[batch_idx + index + r]; + } + } + } else { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < max_radix; j++) { + buf[j * num_threads + i] = in[batch_idx + (j * num_threads + i) * stride]; } } @@ -113,12 +120,20 @@ template } // Write values to device - STEEL_PRAGMA_UNROLL - for (short j = 0; j < max_radix / read_width; j++) { - short index = j * read_width * num_threads + i * read_width; + if (stride == 1) { STEEL_PRAGMA_UNROLL - for (short r = 0; r < read_width; r++) { - out[batch_idx + index + r] = T(buf[index + r] * scale); + for (short j = 0; j < max_radix / read_width; j++) { + short index = j * read_width * num_threads + i * read_width; + STEEL_PRAGMA_UNROLL + for (short r = 0; r < read_width; r++) { + out[batch_idx + index + r] = T(buf[index + r] * scale); + } + } + } else { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < max_radix; j++) { + out[batch_idx + (j * num_threads + i) * stride] = + buf[j * num_threads + i]; } } } diff --git a/mlx/backend/metal/kernels/layer_norm.metal b/mlx/backend/metal/kernels/layer_norm.metal index 51570e48d..ea77b53dc 100644 --- a/mlx/backend/metal/kernels/layer_norm.metal +++ b/mlx/backend/metal/kernels/layer_norm.metal @@ -9,7 +9,42 @@ using namespace metal; constant bool has_w [[function_constant(20)]]; -template +template +inline void initialize_buffer( + threadgroup float* xs, + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]]) { + if (simd_group_id == 0) { + for (int i = 0; i < N; i++) { + xs[N * simd_lane_id + i] = 0; + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); +} + +template +inline void threadgroup_sum( + thread float* x, + threadgroup float* xs, + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]]) { + for (int i = 0; i < N; i++) { + x[i] = simd_sum(x[i]); + } + threadgroup_barrier(mem_flags::mem_threadgroup); + if (simd_lane_id == 0) { + for (int i = 0; i < N; i++) { + xs[N * simd_group_id + i] = x[i]; + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + for (int i = 0; i < N; i++) { + x[i] = xs[N * simd_lane_id + i]; + x[i] = simd_sum(x[i]); + } +} + +template [[kernel]] void layer_norm_single_row( const device T* x, const device T* w, @@ -23,90 +58,71 @@ template uint lid [[thread_position_in_threadgroup]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]]) { - float sumx = 0; - float sumx2 = 0; - float thread_x[N_READS]; - constexpr int SIMD_SIZE = 32; - threadgroup float local_sumx[SIMD_SIZE]; - threadgroup float local_sumx2[SIMD_SIZE]; - threadgroup float local_mean[1]; - threadgroup float local_normalizer[1]; + // Initialize the registers and threadgroup memory + float thread_x[N_READS] = {0}; + threadgroup float local_buffer[SIMD_SIZE] = {0}; + initialize_buffer(local_buffer, simd_lane_id, simd_group_id); + // Advance the pointers x += gid * size_t(axis_size) + lid * N_READS; w += w_stride * lid * N_READS; b += b_stride * lid * N_READS; + out += gid * size_t(axis_size) + lid * N_READS; - if (lid * N_READS + N_READS <= axis_size) { + // Compute some variables for reading writing etc + const bool safe = lid * N_READS + N_READS <= axis_size; + const int n = axis_size - lid * N_READS; + + // Read the inputs + if (safe) { for (int i = 0; i < N_READS; i++) { thread_x[i] = x[i]; - sumx2 += thread_x[i] * thread_x[i]; - sumx += thread_x[i]; } } else { - for (int i = 0; i < N_READS; i++) { - if ((lid * N_READS + i) < axis_size) { - thread_x[i] = x[i]; - sumx2 += thread_x[i] * thread_x[i]; - sumx += thread_x[i]; - } + for (int i = 0; i < n; i++) { + thread_x[i] = x[i]; } } - sumx = simd_sum(sumx); - sumx2 = simd_sum(sumx2); - - // Initialize shared memory - if (simd_group_id == 0) { - local_sumx[simd_lane_id] = 0; - local_sumx2[simd_lane_id] = 0; + // Compute the mean + float mean = 0; + for (int i = 0; i < N_READS; i++) { + mean += thread_x[i]; } - threadgroup_barrier(mem_flags::mem_threadgroup); + threadgroup_sum(&mean, local_buffer, simd_lane_id, simd_group_id); + mean /= axis_size; - // Write simd accumulations into shared memory - if (simd_lane_id == 0) { - local_sumx[simd_group_id] = sumx; - local_sumx2[simd_group_id] = sumx2; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Accumulate over simd groups - if (simd_group_id == 0) { - sumx = simd_sum(local_sumx[simd_lane_id]); - sumx2 = simd_sum(local_sumx2[simd_lane_id]); - if (simd_lane_id == 0) { - float mean = sumx / axis_size; - float variance = sumx2 / axis_size - mean * mean; - - local_mean[0] = mean; - local_normalizer[0] = metal::precise::rsqrt(variance + eps); + // Compute the normalizer + float normalizer = 0; + if (!safe) { + for (int i = n; i < N_READS; i++) { + thread_x[i] = mean; } } - threadgroup_barrier(mem_flags::mem_threadgroup); - - float mean = local_mean[0]; - float normalizer = local_normalizer[0]; + for (int i = 0; i < N_READS; i++) { + thread_x[i] -= mean; + normalizer += thread_x[i] * thread_x[i]; + } + threadgroup_sum(&normalizer, local_buffer, simd_lane_id, simd_group_id); + normalizer = metal::precise::rsqrt(normalizer / axis_size + eps); // Write the outputs - out += gid * size_t(axis_size) + lid * N_READS; - if (lid * N_READS + N_READS <= axis_size) { + if (safe) { for (int i = 0; i < N_READS; i++) { - thread_x[i] = (thread_x[i] - mean) * normalizer; + thread_x[i] *= normalizer; out[i] = w[w_stride * i] * static_cast(thread_x[i]) + b[b_stride * i]; } } else { - for (int i = 0; i < N_READS; i++) { - if ((lid * N_READS + i) < axis_size) { - thread_x[i] = (thread_x[i] - mean) * normalizer; - out[i] = - w[w_stride * i] * static_cast(thread_x[i]) + b[b_stride * i]; - } + for (int i = 0; i < n; i++) { + thread_x[i] *= normalizer; + out[i] = w[w_stride * i] * static_cast(thread_x[i]) + b[b_stride * i]; } } } -template +template [[kernel]] void layer_norm_looped( const device T* x, const device T* w, @@ -121,71 +137,52 @@ template uint lsize [[threads_per_threadgroup]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]]) { - float sumx = 0; - float sumx2 = 0; - constexpr int SIMD_SIZE = 32; - threadgroup float local_sumx[SIMD_SIZE]; - threadgroup float local_sumx2[SIMD_SIZE]; - threadgroup float local_mean[1]; - threadgroup float local_normalizer[1]; + threadgroup float local_buffer[SIMD_SIZE]; + initialize_buffer(local_buffer, simd_lane_id, simd_group_id); x += gid * size_t(axis_size) + lid * N_READS; w += w_stride * lid * N_READS; b += b_stride * lid * N_READS; + // Compute the mean + float mean = 0; for (uint r = 0; r < axis_size; r += lsize * N_READS) { if (r + lid * N_READS + N_READS <= axis_size) { for (int i = 0; i < N_READS; i++) { - float xi = x[i + r]; - sumx2 += xi * xi; - sumx += xi; + mean += x[i + r]; } } else { for (int i = 0; i < N_READS; i++) { if ((r + lid * N_READS + i) < axis_size) { - float xi = x[i + r]; - sumx2 += xi * xi; - sumx += xi; + mean += x[i + r]; } } } } + threadgroup_sum(&mean, local_buffer, simd_lane_id, simd_group_id); + mean /= axis_size; - sumx = simd_sum(sumx); - sumx2 = simd_sum(sumx2); - - // Initialize shared memory - if (simd_group_id == 0) { - local_sumx[simd_lane_id] = 0; - local_sumx2[simd_lane_id] = 0; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Write simd accumulations into shared memory - if (simd_lane_id == 0) { - local_sumx[simd_group_id] = sumx; - local_sumx2[simd_group_id] = sumx2; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Accumulate over simd groups - if (simd_group_id == 0) { - sumx = simd_sum(local_sumx[simd_lane_id]); - sumx2 = simd_sum(local_sumx2[simd_lane_id]); - if (simd_lane_id == 0) { - float mean = sumx / axis_size; - float variance = sumx2 / axis_size - mean * mean; - - local_mean[0] = mean; - local_normalizer[0] = metal::precise::rsqrt(variance + eps); + // Compute the normalizer + float normalizer = 0; + for (uint r = 0; r < axis_size; r += lsize * N_READS) { + if (r + lid * N_READS + N_READS <= axis_size) { + for (int i = 0; i < N_READS; i++) { + float t = x[i + r] - mean; + normalizer += t * t; + } + } else { + for (int i = 0; i < N_READS; i++) { + if ((r + lid * N_READS + i) < axis_size) { + float t = x[i + r] - mean; + normalizer += t * t; + } + } } } - threadgroup_barrier(mem_flags::mem_threadgroup); - - float mean = local_mean[0]; - float normalizer = local_normalizer[0]; + threadgroup_sum(&normalizer, local_buffer, simd_lane_id, simd_group_id); + normalizer = metal::precise::rsqrt(normalizer / axis_size + eps); // Write the outputs out += gid * size_t(axis_size) + lid * N_READS; @@ -208,7 +205,7 @@ template } } -template +template [[kernel]] void vjp_layer_norm_single_row( const device T* x, const device T* w, @@ -222,133 +219,96 @@ template uint lid [[thread_position_in_threadgroup]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]]) { + constexpr int SIMD_SIZE = 32; + // Advance the input pointers x += gid * size_t(axis_size) + lid * N_READS; g += gid * size_t(axis_size) + lid * N_READS; w += w_stride * lid * N_READS; - // Allocate registers for the computation and accumulators - float thread_x[N_READS]; - float thread_w[N_READS]; - float thread_g[N_READS]; - float sumx = 0; - float sumx2 = 0; - float sumwg = 0; - float sumwgx = 0; + // Initialize the registers and threadgroup memory + float thread_x[N_READS] = {0}; + float thread_w[N_READS] = {0}; + float thread_g[N_READS] = {0}; + threadgroup float local_buffer[3 * SIMD_SIZE]; + initialize_buffer<3>(local_buffer, simd_lane_id, simd_group_id); - constexpr int SIMD_SIZE = 32; + // Compute some variables for reading writing etc + const bool safe = lid * N_READS + N_READS <= axis_size; + const int n = axis_size - lid * N_READS; - threadgroup float local_sumx[SIMD_SIZE]; - threadgroup float local_sumx2[SIMD_SIZE]; - threadgroup float local_sumwg[SIMD_SIZE]; - threadgroup float local_sumwgx[SIMD_SIZE]; - threadgroup float local_mean[1]; - threadgroup float local_normalizer[1]; - threadgroup float local_meanwg[1]; - threadgroup float local_meanwgx[1]; - - if (lid * N_READS + N_READS <= axis_size) { + // Read the inputs + if (safe) { for (int i = 0; i < N_READS; i++) { thread_x[i] = x[i]; - thread_w[i] = w[i * w_stride]; thread_g[i] = g[i]; - float wg = thread_w[i] * thread_g[i]; - sumx += thread_x[i]; - sumx2 += thread_x[i] * thread_x[i]; - sumwg += wg; - sumwgx += wg * thread_x[i]; + thread_w[i] = w[i * w_stride]; } } else { - for (int i = 0; i < N_READS; i++) { - if ((lid * N_READS + i) < axis_size) { - thread_x[i] = x[i]; - thread_w[i] = w[i * w_stride]; - thread_g[i] = g[i]; - float wg = thread_w[i] * thread_g[i]; - sumx += thread_x[i]; - sumx2 += thread_x[i] * thread_x[i]; - sumwg += wg; - sumwgx += wg * thread_x[i]; - } + for (int i = 0; i < n; i++) { + thread_x[i] = x[i]; + thread_g[i] = g[i]; + thread_w[i] = w[i * w_stride]; } } - sumx = simd_sum(sumx); - sumx2 = simd_sum(sumx2); - sumwg = simd_sum(sumwg); - sumwgx = simd_sum(sumwgx); - - // Initialize shared memory - if (simd_group_id == 0) { - local_sumx[simd_lane_id] = 0; - local_sumx2[simd_lane_id] = 0; - local_sumwg[simd_lane_id] = 0; - local_sumwgx[simd_lane_id] = 0; + // Compute the mean + float mean = 0; + for (int i = 0; i < N_READS; i++) { + mean += thread_x[i]; } - threadgroup_barrier(mem_flags::mem_threadgroup); + threadgroup_sum(&mean, local_buffer, simd_lane_id, simd_group_id); + mean /= axis_size; - // Write simd accumulations into shared memory - if (simd_lane_id == 0) { - local_sumx[simd_group_id] = sumx; - local_sumx2[simd_group_id] = sumx2; - local_sumwg[simd_group_id] = sumwg; - local_sumwgx[simd_group_id] = sumwgx; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Accumulate over simd groups - if (simd_group_id == 0) { - sumx = simd_sum(local_sumx[simd_lane_id]); - sumx2 = simd_sum(local_sumx2[simd_lane_id]); - sumwg = simd_sum(local_sumwg[simd_lane_id]); - sumwgx = simd_sum(local_sumwgx[simd_lane_id]); - if (simd_lane_id == 0) { - float mean = sumx / axis_size; - float variance = sumx2 / axis_size - mean * mean; - - local_mean[0] = mean; - local_normalizer[0] = metal::precise::rsqrt(variance + eps); - local_meanwg[0] = sumwg / axis_size; - local_meanwgx[0] = sumwgx / axis_size; + // Compute the neccesary scaling factors using the mean + if (!safe) { + for (int i = n; i < N_READS; i++) { + thread_x[i] = mean; } } - threadgroup_barrier(mem_flags::mem_threadgroup); - - float mean = local_mean[0]; - float normalizer = local_normalizer[0]; - float meanwg = local_meanwg[0]; - float meanwgxc = local_meanwgx[0] - meanwg * mean; - float normalizer2 = normalizer * normalizer; + float factors[3] = {0}; + constexpr int meanwg = 0; + constexpr int meanwgxc = 1; + constexpr int normalizer2 = 2; + for (int i = 0; i < N_READS; i++) { + thread_x[i] -= mean; + factors[meanwg] += thread_w[i] * thread_g[i]; + factors[meanwgxc] += thread_w[i] * thread_g[i] * thread_x[i]; + factors[normalizer2] += thread_x[i] * thread_x[i]; + } + threadgroup_sum<3>(factors, local_buffer, simd_lane_id, simd_group_id); + factors[meanwg] /= axis_size; + factors[meanwgxc] /= axis_size; + factors[normalizer2] = 1 / (factors[normalizer2] / axis_size + eps); + float normalizer = metal::precise::sqrt(factors[normalizer2]); // Write the outputs gx += gid * size_t(axis_size) + lid * N_READS; gw += gid * size_t(axis_size) + lid * N_READS; - if (lid * N_READS + N_READS <= axis_size) { + if (safe) { for (int i = 0; i < N_READS; i++) { - thread_x[i] = (thread_x[i] - mean) * normalizer; + thread_x[i] *= normalizer; gx[i] = static_cast( - normalizer * (thread_w[i] * thread_g[i] - meanwg) - - thread_x[i] * meanwgxc * normalizer2); + normalizer * (thread_w[i] * thread_g[i] - factors[meanwg]) - + thread_x[i] * factors[meanwgxc] * factors[normalizer2]); if (has_w) { gw[i] = static_cast(thread_g[i] * thread_x[i]); } } } else { - for (int i = 0; i < N_READS; i++) { - if ((lid * N_READS + i) < axis_size) { - thread_x[i] = (thread_x[i] - mean) * normalizer; - gx[i] = static_cast( - normalizer * (thread_w[i] * thread_g[i] - meanwg) - - thread_x[i] * meanwgxc * normalizer2); - if (has_w) { - gw[i] = static_cast(thread_g[i] * thread_x[i]); - } + for (int i = 0; i < n; i++) { + thread_x[i] *= normalizer; + gx[i] = static_cast( + normalizer * (thread_w[i] * thread_g[i] - factors[meanwg]) - + thread_x[i] * factors[meanwgxc] * factors[normalizer2]); + if (has_w) { + gw[i] = static_cast(thread_g[i] * thread_x[i]); } } } } -template +template [[kernel]] void vjp_layer_norm_looped( const device T* x, const device T* w, @@ -363,102 +323,69 @@ template uint lsize [[threads_per_threadgroup]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]]) { + constexpr int SIMD_SIZE = 32; + // Advance the input pointers x += gid * size_t(axis_size) + lid * N_READS; g += gid * size_t(axis_size) + lid * N_READS; w += w_stride * lid * N_READS; - // Allocate registers for the accumulators - float sumx = 0; - float sumx2 = 0; - float sumwg = 0; - float sumwgx = 0; - - constexpr int SIMD_SIZE = 32; - - threadgroup float local_sumx[SIMD_SIZE]; - threadgroup float local_sumx2[SIMD_SIZE]; - threadgroup float local_sumwg[SIMD_SIZE]; - threadgroup float local_sumwgx[SIMD_SIZE]; - threadgroup float local_mean[1]; - threadgroup float local_normalizer[1]; - threadgroup float local_meanwg[1]; - threadgroup float local_meanwgx[1]; + threadgroup float local_buffer[3 * SIMD_SIZE]; + initialize_buffer<3>(local_buffer, simd_lane_id, simd_group_id); + // Compute the mean + float mean = 0; for (uint r = 0; r < axis_size; r += lsize * N_READS) { if (r + lid * N_READS + N_READS <= axis_size) { for (int i = 0; i < N_READS; i++) { - float xi = x[i + r]; - float wi = w[(i + r) * w_stride]; - float gi = g[i + r]; - float wg = wi * gi; - sumx += xi; - sumx2 += xi * xi; - sumwg += wg; - sumwgx += wg * xi; + mean += x[i + r]; } } else { for (int i = 0; i < N_READS; i++) { if ((r + lid * N_READS + i) < axis_size) { - float xi = x[i + r]; - float wi = w[(i + r) * w_stride]; - float gi = g[i + r]; - float wg = wi * gi; - sumx += xi; - sumx2 += xi * xi; - sumwg += wg; - sumwgx += wg * xi; + mean += x[i + r]; } } } } + threadgroup_sum(&mean, local_buffer, simd_lane_id, simd_group_id); + mean /= axis_size; - sumx = simd_sum(sumx); - sumx2 = simd_sum(sumx2); - sumwg = simd_sum(sumwg); - sumwgx = simd_sum(sumwgx); - - // Initialize shared memory - if (simd_group_id == 0) { - local_sumx[simd_lane_id] = 0; - local_sumx2[simd_lane_id] = 0; - local_sumwg[simd_lane_id] = 0; - local_sumwgx[simd_lane_id] = 0; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Write simd accumulations into shared memory - if (simd_lane_id == 0) { - local_sumx[simd_group_id] = sumx; - local_sumx2[simd_group_id] = sumx2; - local_sumwg[simd_group_id] = sumwg; - local_sumwgx[simd_group_id] = sumwgx; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Accumulate over simd groups - if (simd_group_id == 0) { - sumx = simd_sum(local_sumx[simd_lane_id]); - sumx2 = simd_sum(local_sumx2[simd_lane_id]); - sumwg = simd_sum(local_sumwg[simd_lane_id]); - sumwgx = simd_sum(local_sumwgx[simd_lane_id]); - if (simd_lane_id == 0) { - float mean = sumx / axis_size; - float variance = sumx2 / axis_size - mean * mean; - - local_mean[0] = mean; - local_normalizer[0] = metal::precise::rsqrt(variance + eps); - local_meanwg[0] = sumwg / axis_size; - local_meanwgx[0] = sumwgx / axis_size; + // Compute the neccesary scaling factors using the mean + float factors[3] = {0}; + constexpr int meanwg = 0; + constexpr int meanwgxc = 1; + constexpr int normalizer2 = 2; + for (uint r = 0; r < axis_size; r += lsize * N_READS) { + if (r + lid * N_READS + N_READS <= axis_size) { + for (int i = 0; i < N_READS; i++) { + float t = x[i + r] - mean; + float wi = w[(i + r) * w_stride]; + float gi = g[i + r]; + float wg = wi * gi; + factors[meanwg] += wg; + factors[meanwgxc] += wg * t; + factors[normalizer2] += t * t; + } + } else { + for (int i = 0; i < N_READS; i++) { + if ((r + lid * N_READS + i) < axis_size) { + float t = x[i + r] - mean; + float wi = w[(i + r) * w_stride]; + float gi = g[i + r]; + float wg = wi * gi; + factors[meanwg] += wg; + factors[meanwgxc] += wg * t; + factors[normalizer2] += t * t; + } + } } } - threadgroup_barrier(mem_flags::mem_threadgroup); - - float mean = local_mean[0]; - float normalizer = local_normalizer[0]; - float meanwg = local_meanwg[0]; - float meanwgxc = local_meanwgx[0] - meanwg * mean; - float normalizer2 = normalizer * normalizer; + threadgroup_sum<3>(factors, local_buffer, simd_lane_id, simd_group_id); + factors[meanwg] /= axis_size; + factors[meanwgxc] /= axis_size; + factors[normalizer2] = 1 / (factors[normalizer2] / axis_size + eps); + float normalizer = metal::precise::sqrt(factors[normalizer2]); // Write the outputs gx += gid * size_t(axis_size) + lid * N_READS; @@ -470,7 +397,8 @@ template float wi = w[(i + r) * w_stride]; float gi = g[i + r]; gx[i + r] = static_cast( - normalizer * (wi * gi - meanwg) - xi * meanwgxc * normalizer2); + normalizer * (wi * gi - factors[meanwg]) - + xi * factors[meanwgxc] * factors[normalizer2]); if (has_w) { gw[i + r] = static_cast(gi * xi); } @@ -482,7 +410,8 @@ template float wi = w[(i + r) * w_stride]; float gi = g[i + r]; gx[i + r] = static_cast( - normalizer * (wi * gi - meanwg) - xi * meanwgxc * normalizer2); + normalizer * (wi * gi - factors[meanwg]) - + xi * factors[meanwgxc] * factors[normalizer2]); if (has_w) { gw[i + r] = static_cast(gi * xi); } diff --git a/mlx/backend/metal/kernels/logsumexp.h b/mlx/backend/metal/kernels/logsumexp.h index b6898e31e..c746050b3 100644 --- a/mlx/backend/metal/kernels/logsumexp.h +++ b/mlx/backend/metal/kernels/logsumexp.h @@ -103,8 +103,8 @@ template } } else { for (int i = 0; i < N_READS; i++) { - vals[i] = (offset + i < axis_size) ? AccT(in[offset + i]) - : Limits::finite_min; + vals[i] = + (offset + i < axis_size) ? AccT(in[offset + i]) : Limits::min; } } prevmax = maxval; @@ -134,10 +134,7 @@ template threadgroup_barrier(mem_flags::mem_threadgroup); normalizer = simd_sum(local_normalizer[simd_lane_id]); - if (simd_group_id == 0) { - normalizer = simd_sum(local_normalizer[simd_lane_id]); - if (simd_lane_id == 0) { - out[gid] = isinf(maxval) ? T(maxval) : T(log(normalizer) + maxval); - } + if (lid == 0) { + out[gid] = isinf(maxval) ? T(maxval) : T(log(normalizer) + maxval); } } diff --git a/mlx/backend/metal/kernels/quantized.h b/mlx/backend/metal/kernels/quantized.h index b2b0d8d8f..0a40cec00 100644 --- a/mlx/backend/metal/kernels/quantized.h +++ b/mlx/backend/metal/kernels/quantized.h @@ -14,11 +14,23 @@ using namespace metal; MLX_MTL_CONST int SIMD_SIZE = 32; MLX_MTL_CONST int QUAD_SIZE = 4; +template +inline constexpr short get_pack_factor() { + return (bits == 3 || bits == 5) ? 8 : (bits == 6 ? 4 : wsize / bits); +} + +template +inline constexpr short get_bytes_per_pack() { + constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; + return power_of_2_bits ? (wsize / 8) : (bits == 5 ? 5 : 3); +} + template inline U load_vector(const device T* x, thread U* x_thread) { static_assert( - bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8, - "Template undefined for bits not in {2, 3, 4, 6, 8}"); + bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || + bits == 8, + "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); U sum = 0; @@ -57,6 +69,21 @@ inline U load_vector(const device T* x, thread U* x_thread) { } } + else if (bits == 5) { + for (int i = 0; i < values_per_thread; i += 8) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] + + x[i + 6] + x[i + 7]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 32.0f; + x_thread[i + 2] = x[i + 2] / 4.0f; + x_thread[i + 3] = x[i + 3] / 128.0f; + x_thread[i + 4] = x[i + 4] / 16.0f; + x_thread[i + 5] = x[i + 5] / 2.0f; + x_thread[i + 6] = x[i + 6] / 64.0f; + x_thread[i + 7] = x[i + 7] / 8.0f; + } + } + else if (bits == 6) { for (int i = 0; i < values_per_thread; i += 4) { sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; @@ -80,8 +107,9 @@ inline U load_vector(const device T* x, thread U* x_thread) { template inline U load_vector_safe(const device T* x, thread U* x_thread, int N) { static_assert( - bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8, - "Template undefined for bits not in {2, 3, 4, 6, 8}"); + bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || + bits == 8, + "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); U sum = 0; @@ -121,6 +149,21 @@ inline U load_vector_safe(const device T* x, thread U* x_thread, int N) { } } + else if (bits == 5) { + for (int i = 0; i < N; i += 8) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] + + x[i + 6] + x[i + 7]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 32.0f; + x_thread[i + 2] = x[i + 2] / 4.0f; + x_thread[i + 3] = x[i + 3] / 128.0f; + x_thread[i + 4] = x[i + 4] / 16.0f; + x_thread[i + 5] = x[i + 5] / 2.0f; + x_thread[i + 6] = x[i + 6] / 64.0f; + x_thread[i + 7] = x[i + 7] / 8.0f; + } + } + else if (bits == 6) { for (int i = 0; i < N; i += 4) { sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; @@ -153,8 +196,9 @@ inline U qdot( U bias, U sum) { static_assert( - bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8, - "Template undefined for bits not in {2, 3, 4, 6, 8}"); + bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || + bits == 8, + "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); U accum = 0; @@ -199,6 +243,26 @@ inline U qdot( } } + else if (bits == 5) { + for (int i = 0; i < (values_per_thread / 8); i++) { + x_thread += 8 * i; + w += 5 * i; + + accum += (w[0] & 0x1f) * x_thread[0]; + accum += (w[0] & 0xe0) * x_thread[1]; + accum += (w[1] & 0x3) * (x_thread[1] * 256.0f); + accum += (w[1] & 0x7c) * x_thread[2]; + accum += (w[1] & 0x80) * x_thread[3]; + accum += (w[2] & 0xf) * (x_thread[3] * 256.0f); + accum += (w[2] & 0xf0) * x_thread[4]; + accum += (w[3] & 0x1) * (x_thread[4] * 256.0f); + accum += (w[3] & 0x3e) * x_thread[5]; + accum += (w[3] & 0xc0) * x_thread[6]; + accum += (w[4] & 0x7) * (x_thread[6] * 256.0f); + accum += (w[4] & 0xf8) * x_thread[7]; + } + } + else if (bits == 6) { for (int i = 0; i < (values_per_thread / 4); i++) { x_thread += 4 * i; @@ -234,8 +298,9 @@ inline U qdot_safe( U sum, int N) { static_assert( - bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8, - "Template undefined for bits not in {2, 3, 4, 6, 8}"); + bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || + bits == 8, + "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); U accum = 0; @@ -280,6 +345,26 @@ inline U qdot_safe( } } + else if (bits == 5) { + for (int i = 0; i < (N / 8); i++) { + x_thread += 8 * i; + w += 5 * i; + + accum += (w[0] & 0x1f) * x_thread[0]; + accum += (w[0] & 0xe0) * x_thread[1]; + accum += (w[1] & 0x3) * (x_thread[1] * 256.0f); + accum += (w[1] & 0x7c) * x_thread[2]; + accum += (w[1] & 0x80) * x_thread[3]; + accum += (w[2] & 0xf) * (x_thread[3] * 256.0f); + accum += (w[2] & 0xf0) * x_thread[4]; + accum += (w[3] & 0x1) * (x_thread[4] * 256.0f); + accum += (w[3] & 0x3e) * x_thread[5]; + accum += (w[3] & 0xc0) * x_thread[6]; + accum += (w[4] & 0x7) * (x_thread[6] * 256.0f); + accum += (w[4] & 0xf8) * x_thread[7]; + } + } + else if (bits == 6) { for (int i = 0; i < (N / 4); i++) { x_thread += 4 * i; @@ -310,8 +395,9 @@ template inline void qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* result) { static_assert( - bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8, - "Template undefined for bits not in {2, 3, 4, 6, 8}"); + bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || + bits == 8, + "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); if (bits == 2) { U s[4] = {scale, scale / 4.0f, scale / 16.0f, scale / 64.0f}; @@ -348,8 +434,31 @@ qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* result) { result[2 * i] += x * (s[0] * (w[i] & 0x0f) + bias); result[2 * i + 1] += x * (s[1] * (w[i] & 0xf0) + bias); } + } - } else if (bits == 6) { + else if (bits == 5) { + for (int i = 0; i < (values_per_thread / 8); i++) { + uint8_t w0 = w[5 * i]; + uint8_t w1 = w[5 * i + 1]; + uint8_t w2 = w[5 * i + 2]; + uint8_t w3 = w[5 * i + 3]; + uint8_t w4 = w[5 * i + 4]; + result[8 * i] += x * ((w0 & 0x1f) * scale + bias); + result[8 * i + 1] += + x * ((((w0 & 0xe0) >> 5) + ((w1 & 0x3) << 3)) * scale + bias); + result[8 * i + 2] += x * (((w1 & 0x7c) >> 2) * scale + bias); + result[8 * i + 3] += + x * ((((w1 & 0x80) >> 7) + ((w2 & 0xf) << 1)) * scale + bias); + result[8 * i + 4] += + x * ((((w2 & 0xf0) >> 4) + ((w3 & 0x1) << 4)) * scale + bias); + result[8 * i + 5] += x * (((w3 & 0x3e) >> 1) * scale + bias); + result[8 * i + 6] += + x * ((((w3 & 0xc0) >> 6) + ((w4 & 0x7) << 2)) * scale + bias); + result[8 * i + 7] += x * (((w4 & 0xf8) >> 3) * scale + bias); + } + } + + else if (bits == 6) { for (int i = 0; i < (values_per_thread / 4); i++) { uint8_t w0 = w[3 * i]; uint8_t w1 = w[3 * i + 1]; @@ -375,8 +484,9 @@ template inline void dequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) { static_assert( - bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8, - "Template undefined for bits not in {2, 3, 4, 6, 8}"); + bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || + bits == 8, + "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); if (bits == 2) { U s[4] = { @@ -416,11 +526,26 @@ dequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) { } } + else if (bits == 5) { + for (int i = 0; i < (N / 8); i++) { + w_local += 8 * i; + w += 5 * i; + + w_local[0] = (w[0] & 0x1f) * scale + bias; + w_local[1] = (((w[0] & 0xe0) >> 5) + ((w[1] & 0x3) << 3)) * scale + bias; + w_local[2] = ((w[1] & 0x7c) >> 2) * scale + bias; + w_local[3] = (((w[1] & 0x80) >> 7) + ((w[2] & 0xf) << 1)) * scale + bias; + w_local[4] = (((w[2] & 0xf0) >> 4) + ((w[3] & 0x1) << 4)) * scale + bias; + w_local[5] = ((w[3] & 0x3e) >> 1) * scale + bias; + w_local[6] = (((w[3] & 0xc0) >> 6) + ((w[4] & 0x7) << 2)) * scale + bias; + w_local[7] = ((w[4] & 0xf8) >> 3) * scale + bias; + } + } + else if (bits == 6) { for (int i = 0; i < (N / 4); i++) { w_local += 4 * i; w += 3 * i; - w_local[0] = (w[0] & 0x3f) * scale + bias; w_local[1] = (((w[0] >> 6) & 0x03) + ((w[1] & 0x0f) << 2)) * scale + bias; w_local[2] = (((w[1] >> 4) & 0x0f) + ((w[2] & 0x03) << 4)) * scale + bias; @@ -452,11 +577,12 @@ struct QuantizedBlockLoader { group_size % BCOLS == 0, "The group size should be divisible by the columns"); static_assert( - bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8, - "Template undefined for bits not in {2, 3, 4, 6, 8}"); + bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || + bits == 8, + "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); - MLX_MTL_CONST short pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits; - MLX_MTL_CONST short bytes_per_pack = (bits == 3 || bits == 6) ? 3 : 1; + MLX_MTL_CONST short pack_factor = get_pack_factor(); + MLX_MTL_CONST short bytes_per_pack = get_bytes_per_pack(); MLX_MTL_CONST short BCOLS_PACKED = BCOLS / pack_factor; MLX_MTL_CONST short n_reads = (BCOLS_PACKED * BROWS < tgp_size) ? 1 : (BCOLS_PACKED * BROWS) / tgp_size; @@ -517,14 +643,14 @@ struct QuantizedBlockLoader { return; } - if (reduction_dim == 1 && bi >= src_tile_dim.y) { + if (reduction_dim == 1 && bi >= src_tile_dim.x) { for (int i = 0; i < n_reads * pack_factor; i++) { dst[i] = T(0); } return; } - if (reduction_dim == 0 && bi >= src_tile_dim.x) { + if (reduction_dim == 0 && bi >= src_tile_dim.y) { for (int i = 0; i < n_reads * pack_factor; i++) { dst[i] = T(0); } @@ -632,12 +758,11 @@ METAL_FUNC void qmv_fast_impl( uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { - constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; constexpr int packs_per_thread = bits == 2 ? 1 : 2; constexpr int num_simdgroups = 2; constexpr int results_per_simdgroup = 4; - constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits; - constexpr int bytes_per_pack = power_of_2_bits ? 4 : 3; + constexpr int pack_factor = get_pack_factor(); + constexpr int bytes_per_pack = get_bytes_per_pack(); constexpr int values_per_thread = pack_factor * packs_per_thread; constexpr int block_size = values_per_thread * SIMD_SIZE; constexpr int scale_step_per_thread = group_size / values_per_thread; @@ -700,12 +825,12 @@ METAL_FUNC void qmv_impl( uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { - constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; constexpr int num_simdgroups = 2; constexpr int results_per_simdgroup = 4; constexpr int packs_per_thread = 1; - constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits; - constexpr int bytes_per_pack = power_of_2_bits ? 4 : 3; + constexpr int pack_factor = get_pack_factor(); + constexpr int bytes_per_pack = get_bytes_per_pack(); + constexpr int values_per_thread = pack_factor * packs_per_thread; constexpr int block_size = values_per_thread * SIMD_SIZE; constexpr int scale_step_per_thread = group_size / values_per_thread; @@ -857,8 +982,9 @@ METAL_FUNC void qvm_impl( uint simd_lid [[thread_index_in_simdgroup]]) { constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; constexpr int num_simdgroups = 2; - constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits; - constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3; + constexpr int pack_factor = get_pack_factor(); + constexpr int bytes_per_pack = get_bytes_per_pack(); + constexpr int tn = 32 / pack_factor; constexpr int block_size = SIMD_SIZE; @@ -981,9 +1107,10 @@ METAL_FUNC void qmm_t_impl( constexpr int WM = 2; constexpr int WN = 2; - constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits; + constexpr int pack_factor = get_pack_factor(); + constexpr int bytes_per_pack = get_bytes_per_pack(); + constexpr int BK_padded = (BK + 16 / sizeof(T)); - constexpr int bytes_per_pack = (bits == 3 || bits == 6) ? 3 : 1; // Instantiate the appropriate BlockMMA and Loader using mma_t = mlx::steel:: @@ -1008,11 +1135,11 @@ METAL_FUNC void qmm_t_impl( auto wl = (const device uint8_t*)w; - x += y_row * K; + x += y_row * static_cast(K); wl += y_col * K_w; scales += y_col * K_g; biases += y_col * K_g; - y += y_row * N + y_col; + y += y_row * static_cast(N) + y_col; // Make the x loader and mma operation const short num_els = min(BM, M - y_row); @@ -1106,11 +1233,11 @@ METAL_FUNC void qmm_n_impl( constexpr int WM = 2; constexpr int WN = 2; - constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits; + constexpr int pack_factor = get_pack_factor(); + constexpr int bytes_per_pack = get_bytes_per_pack(); + constexpr int BK_padded = (BK + 16 / sizeof(T)); constexpr int BN_padded = (BN + 16 / sizeof(T)); - constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; - constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3; // Instantiate the appropriate BlockMMA and Loader using mma_t = mlx::steel:: @@ -1132,11 +1259,11 @@ METAL_FUNC void qmm_n_impl( // Set the block const int y_row = tid.y * BM; const int y_col = tid.x * BN; - x += y_row * K; + x += y_row * static_cast(K); wl += y_col * bytes_per_pack / pack_factor; scales += y_col / group_size; biases += y_col / group_size; - y += y_row * N + y_col; + y += y_row * static_cast(N) + y_col; // Make the x loader and mma operation const short num_els = min(BM, M - y_row); @@ -2120,11 +2247,10 @@ template < uint3 tid [[threadgroup_position_in_grid]], uint simd_group_id [[simdgroup_index_in_threadgroup]], uint simd_lane_id [[thread_index_in_simdgroup]]) { - constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits; + constexpr int pack_factor = get_pack_factor(); + constexpr int bytes_per_pack = get_bytes_per_pack(); constexpr int BK_padded = (BK + 16 / sizeof(T)); constexpr int BN_padded = (BN + 16 / sizeof(T)); - constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; - constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3; using mma_t = mlx::steel::BlockMMA< T, @@ -2305,13 +2431,13 @@ template constexpr float eps = 1e-7; constexpr int simd_size = 32; constexpr float n_bins = (1 << bits) - 1; - constexpr int packs_per_int = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits; + constexpr int pack_factor = get_pack_factor(); + constexpr int bytes_per_pack = get_bytes_per_pack(); constexpr int values_per_reduce = group_size / simd_size; - constexpr int writes_per_reduce = packs_per_int / values_per_reduce; + constexpr int writes_per_reduce = pack_factor / values_per_reduce; constexpr int writes_per_pack = - writes_per_reduce > 1 ? 1 : values_per_reduce / packs_per_int; + writes_per_reduce > 1 ? 1 : values_per_reduce / pack_factor; constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; - constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3; static_assert( group_size % simd_size == 0, @@ -2354,8 +2480,8 @@ template biases[gindex] = static_cast(bias); } - // We accumulate 3 bytes worth for 3/6 bit so we need a uint32_t - uint32_t output = 0; + using OutType = metal::conditional_t; + OutType output = 0; #pragma clang loop unroll(full) for (int i = 0; i < values_per_reduce; i++) { @@ -2363,27 +2489,35 @@ template if (bits == 8) { output = val; } else { - output += val << (bits * (i % packs_per_int)); + output |= val << (bits * (i % pack_factor)); } - if (packs_per_int < values_per_reduce && - i % packs_per_int == packs_per_int - 1) { - out[out_index + i / packs_per_int] = output; + if (pack_factor < values_per_reduce && i % pack_factor == pack_factor - 1) { + out[out_index + i / pack_factor] = output; output = 0; } else { #pragma clang loop unroll(full) for (int j = 1; j < writes_per_reduce; j++) { uint8_t sval = simd_shuffle_down(val, j); - output += sval << (bits * (j * values_per_reduce + i)); + output |= static_cast(sval) + << (bits * (j * values_per_reduce + i)); } } } if (bits == 3 || bits == 6) { - if (in_index % packs_per_int == 0 && out_index % bytes_per_pack == 0) { + if (in_index % pack_factor == 0 && out_index % bytes_per_pack == 0) { out[out_index] = output & 0xff; out[out_index + 1] = (output & 0xff00) >> 8; out[out_index + 2] = (output & 0xff0000) >> 16; } + } else if (bits == 5) { + if (in_index % pack_factor == 0 && out_index % bytes_per_pack == 0) { + out[out_index] = output & 0xff; + out[out_index + 1] = (output & 0xff00) >> 8; + out[out_index + 2] = (output & 0xff0000) >> 16; + out[out_index + 3] = (output & 0xff000000) >> 24; + out[out_index + 4] = (output & 0xff00000000) >> 32; + } } else { if (writes_per_reduce > 0 && out_index % writes_per_reduce == 0) { out[out_index / writes_per_reduce] = output; @@ -2399,12 +2533,11 @@ template device T* out [[buffer(3)]], uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - constexpr int packs_per_int = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits; - constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; - constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3; + constexpr int pack_factor = get_pack_factor(); + constexpr int bytes_per_pack = get_bytes_per_pack(); size_t offset = index.x + grid_dim.x * size_t(index.y); - size_t oindex = offset * packs_per_int; + size_t oindex = offset * pack_factor; size_t gindex = oindex / group_size; T scale = scales[gindex]; T bias = biases[gindex]; @@ -2421,7 +2554,16 @@ template out[5] = (((w[1] & 0x80) >> 7) + ((w[2] & 0x3) << 1)) * scale + bias; out[6] = ((w[2] & 0x1c) >> 2) * scale + bias; out[7] = ((w[2] & 0xe0) >> 5) * scale + bias; - + } else if (bits == 5) { + w += offset * bytes_per_pack; + out[0] = (w[0] & 0x1f) * scale + bias; + out[1] = (((w[0] & 0xe0) >> 5) + ((w[1] & 0x3) << 3)) * scale + bias; + out[2] = ((w[1] & 0x7c) >> 2) * scale + bias; + out[3] = (((w[1] & 0x80) >> 7) + ((w[2] & 0xf) << 1)) * scale + bias; + out[4] = (((w[2] & 0xf0) >> 4) + ((w[3] & 0x1) << 4)) * scale + bias; + out[5] = ((w[3] & 0x3e) >> 1) * scale + bias; + out[6] = (((w[3] & 0xc0) >> 6) + ((w[4] & 0x7) << 2)) * scale + bias; + out[7] = ((w[4] & 0xf8) >> 3) * scale + bias; } else if (bits == 6) { w += offset * bytes_per_pack; out[0] = (w[0] & 0x3f) * scale + bias; @@ -2431,7 +2573,7 @@ template } else { uint val = w[offset]; #pragma clang loop unroll(full) - for (int i = 0; i < packs_per_int; i++) { + for (int i = 0; i < pack_factor; i++) { uint8_t d; if (bits == 2) { d = (val >> (bits * i)) & 0x03; diff --git a/mlx/backend/metal/kernels/quantized.metal b/mlx/backend/metal/kernels/quantized.metal index 11cd8421b..de83cb657 100644 --- a/mlx/backend/metal/kernels/quantized.metal +++ b/mlx/backend/metal/kernels/quantized.metal @@ -136,6 +136,7 @@ instantiate_quantized_groups(2) \ instantiate_quantized_groups(3) \ instantiate_quantized_groups(4) \ + instantiate_quantized_groups(5) \ instantiate_quantized_groups(6) \ instantiate_quantized_groups(8) diff --git a/mlx/backend/metal/kernels/reduction/ops.h b/mlx/backend/metal/kernels/reduction/ops.h index 68ed11986..11d8e83ac 100644 --- a/mlx/backend/metal/kernels/reduction/ops.h +++ b/mlx/backend/metal/kernels/reduction/ops.h @@ -164,7 +164,15 @@ struct Min { DEFINE_SIMD_REDUCE() template - T simd_reduce_impl(T val) { + metal::enable_if_t, T> simd_reduce_impl(T val) { + return simd_min(val); + } + + template + metal::enable_if_t, T> simd_reduce_impl(T val) { + if (simd_any(val != val)) { + return static_cast(NAN); + } return simd_min(val); } @@ -176,17 +184,52 @@ struct Min { } // Operator - U operator()(U a, U b) { + template + metal::enable_if_t, T> operator()(T a, T b) { return a < b ? a : b; } -}; + template + metal::enable_if_t, T> operator()(T a, T b) { + if (metal::isnan(a) || metal::isnan(b)) { + return static_cast(NAN); + } else { + return a < b ? a : b; + } + } + + template <> + complex64_t operator()(complex64_t a, complex64_t b) { + bool real_is_nan = metal::isnan(a.real) || metal::isnan(b.real); + bool imag_is_nan = metal::isnan(a.imag) || metal::isnan(b.imag); + + if (!real_is_nan && !imag_is_nan) { + return a < b ? a : b; + } else if (real_is_nan && !imag_is_nan) { + return complex64_t( + static_cast(NAN), a.imag < b.imag ? a.imag : b.imag); + } else if (!real_is_nan && imag_is_nan) { + return complex64_t( + a.real < b.real ? a.real : b.real, static_cast(NAN)); + } else { + return complex64_t(static_cast(NAN), static_cast(NAN)); + } + }; +}; template struct Max { DEFINE_SIMD_REDUCE() template - T simd_reduce_impl(T val) { + metal::enable_if_t, T> simd_reduce_impl(T val) { + return simd_max(val); + } + + template + metal::enable_if_t, T> simd_reduce_impl(T val) { + if (simd_any(val != val)) { + return static_cast(NAN); + } return simd_max(val); } @@ -198,7 +241,35 @@ struct Max { } // Operator - U operator()(U a, U b) { + template + metal::enable_if_t, T> operator()(T a, T b) { return a > b ? a : b; } + + template + metal::enable_if_t, T> operator()(T a, T b) { + if (metal::isnan(a) || metal::isnan(b)) { + return static_cast(NAN); + } else { + return a > b ? a : b; + } + } + + template <> + complex64_t operator()(complex64_t a, complex64_t b) { + bool real_is_nan = metal::isnan(a.real) || metal::isnan(b.real); + bool imag_is_nan = metal::isnan(a.imag) || metal::isnan(b.imag); + + if (!real_is_nan && !imag_is_nan) { + return a > b ? a : b; + } else if (real_is_nan && !imag_is_nan) { + return complex64_t( + static_cast(NAN), a.imag > b.imag ? a.imag : b.imag); + } else if (!real_is_nan && imag_is_nan) { + return complex64_t( + a.real > b.real ? a.real : b.real, static_cast(NAN)); + } else { + return complex64_t(static_cast(NAN), static_cast(NAN)); + } + } }; diff --git a/mlx/backend/metal/kernels/reduction/reduce_row.h b/mlx/backend/metal/kernels/reduction/reduce_row.h index c8973429f..936d75bb5 100644 --- a/mlx/backend/metal/kernels/reduction/reduce_row.h +++ b/mlx/backend/metal/kernels/reduction/reduce_row.h @@ -224,7 +224,7 @@ template < if ((non_row_reductions < 32 && row_size <= 8) || non_row_reductions <= 8) { // Simple loop over non_row_reductions and reduce the row in the thread. - IdxT out_idx = tid.x + tsize.y * IdxT(tid.y); + IdxT out_idx = tid.x + tsize.x * IdxT(tid.y); in += elem_to_loc(out_idx, shape, strides, ndim); for (uint r = 0; r < non_row_reductions; r++) { diff --git a/mlx/backend/metal/kernels/scan.metal b/mlx/backend/metal/kernels/scan.metal index 8fcd7f61b..f38f8757e 100644 --- a/mlx/backend/metal/kernels/scan.metal +++ b/mlx/backend/metal/kernels/scan.metal @@ -104,4 +104,5 @@ instantiate_scan_helper(min_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumMi instantiate_scan_helper(min_complex64_complex64, complex64_t, complex64_t, CumMin, 2) instantiate_scan_helper(logaddexp_float16_float16, half, half, CumLogaddexp, 4) instantiate_scan_helper(logaddexp_float32_float32, float, float, CumLogaddexp, 4) -instantiate_scan_helper(logaddexp_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumLogaddexp, 4) // clang-format on +instantiate_scan_helper(logaddexp_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumLogaddexp, 4) +instantiate_scan_helper(logaddexp_complex64_complex64, complex64_t, complex64_t, CumLogaddexp, 2) // clang-format on diff --git a/mlx/backend/metal/kernels/sdpa_vector.h b/mlx/backend/metal/kernels/sdpa_vector.h index c4c0f6456..8258e9c14 100644 --- a/mlx/backend/metal/kernels/sdpa_vector.h +++ b/mlx/backend/metal/kernels/sdpa_vector.h @@ -56,9 +56,9 @@ template const int head_idx = tid.x; const int q_seq_idx = tid.y; const int kv_head_idx = head_idx / gqa_factor; - const int o_offset = tpg.x * q_seq_idx + head_idx; + const int o_offset = head_idx * tpg.y + q_seq_idx; const int q_offset = - query_transposed ? o_offset : head_idx * tpg.y + q_seq_idx; + query_transposed ? tpg.x * q_seq_idx + head_idx : o_offset; queries += q_offset * D + simd_lid * qk_per_thread; keys += kv_head_idx * k_head_stride + simd_gid * k_seq_stride + simd_lid * qk_per_thread; @@ -213,9 +213,9 @@ template const int block_idx = tid.z; const int head_idx = tid.x; const int q_seq_idx = tid.y; - const int o_offset = tpg.x * q_seq_idx + head_idx; + const int o_offset = head_idx * tpg.y + q_seq_idx; const int q_offset = - query_transposed ? o_offset : head_idx * tpg.y + q_seq_idx; + query_transposed ? tpg.x * q_seq_idx + head_idx : o_offset; const int kv_head_idx = head_idx / gqa_factor; queries += q_offset * D + simd_lid * qk_per_thread; @@ -358,8 +358,8 @@ template // Adjust positions const int head_idx = tid.x; const int q_seq_idx = tid.y; - const int n_heads = tpg.x; - const int q_offset = n_heads * q_seq_idx + head_idx; + const int q_offset = head_idx * tpg.y + q_seq_idx; + ; partials += q_offset * blocks * D + simd_gid * D + simd_lid * elem_per_thread; sums += q_offset * blocks; maxs += q_offset * blocks; diff --git a/mlx/backend/metal/kernels/softmax.h b/mlx/backend/metal/kernels/softmax.h index b36b73bd8..6ea4ac732 100644 --- a/mlx/backend/metal/kernels/softmax.h +++ b/mlx/backend/metal/kernels/softmax.h @@ -128,8 +128,8 @@ template } } else { for (int i = 0; i < N_READS; i++) { - vals[i] = (offset + i < axis_size) ? AccT(in[offset + i]) - : Limits::finite_min; + vals[i] = + (offset + i < axis_size) ? AccT(in[offset + i]) : Limits::min; } } prevmax = maxval; diff --git a/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h b/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h index 2e27ea06f..34d5bf58a 100644 --- a/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +++ b/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h @@ -95,7 +95,7 @@ template < Q += tidl.z * params->Q_strides[0] + // Batch tidl.y * params->Q_strides[1] + // Head - tidl.x * BQ * params->Q_strides[2]; // Seqeunce + tidl.x * BQ * params->Q_strides[2]; // Sequence ulong kv_head_idx = int(tid.y) / params->gqa_factor; K += tidl.z * params->K_strides[0] + // Batch @@ -106,7 +106,7 @@ template < O += tidl.z * params->O_strides[0] + // Batch tidl.y * params->O_strides[1] + // Head - tidl.x * BQ * params->O_strides[2]; // Seqeunce + tidl.x * BQ * params->O_strides[2]; // Sequence if (has_mask) { mask += tidl.z * mask_params->M_strides[0] + // Batch diff --git a/mlx/backend/metal/kernels/steel/attn/loader.h b/mlx/backend/metal/kernels/steel/attn/loader.h index 2849c00f1..7ec798146 100644 --- a/mlx/backend/metal/kernels/steel/attn/loader.h +++ b/mlx/backend/metal/kernels/steel/attn/loader.h @@ -113,7 +113,7 @@ struct BlockLoader { tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)]; } - // Zero out uneeded values + // Zero out unneeded values STEEL_PRAGMA_UNROLL for (short j = 0; j < vec_size; j++) { tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0); @@ -240,7 +240,7 @@ struct BlockLoaderT { tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)]; } - // Zero out uneeded values + // Zero out unneeded values STEEL_PRAGMA_UNROLL for (short j = 0; j < vec_size; j++) { tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0); diff --git a/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h b/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h index e4b662cd3..9afebd307 100644 --- a/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +++ b/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h @@ -2,6 +2,8 @@ #include "mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h" +constant bool align_C [[function_constant(200)]]; + template < typename T, int BM, @@ -118,30 +120,65 @@ implicit_gemm_conv_2d_general( // Prepare threadgroup mma operation mma_t mma_op(simd_gid, simd_lid); - int gemm_k_iterations = - base_wh_size * base_ww_size * gemm_params->gemm_k_iterations; + if (align_C) { + int gemm_k_iterations = + base_wh_size * base_ww_size * gemm_params->gemm_k_iterations; - for (int k = 0; k < gemm_k_iterations; k++) { - threadgroup_barrier(mem_flags::mem_threadgroup); - // Load elements into threadgroup - loader_a.load_unsafe(); - loader_b.load_unsafe(); + for (int k = 0; k < gemm_k_iterations; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + // Load elements into threadgroup + loader_a.load_unsafe(); + loader_b.load_unsafe(); - threadgroup_barrier(mem_flags::mem_threadgroup); + threadgroup_barrier(mem_flags::mem_threadgroup); - // Multiply and accumulate threadgroup elements - mma_op.mma(As, Bs); + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); - // Prepare for next iteration - loader_a.next(); - loader_b.next(); + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + } + + else { + for (int k = 1; k < gemm_params->gemm_k_iterations; k++) { + for (int j = 0; j < base_wh_size * base_ww_size; j++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + // Load elements into threadgroup + loader_a.load_unsafe(); + loader_b.load_unsafe(); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + } + const short remaining_k = params->C % BK; + for (int j = 0; j < base_wh_size * base_ww_size; j++) { + // Load elements into threadgroup + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_a.load_safe(remaining_k); + loader_b.load_safe(remaining_k); + threadgroup_barrier(mem_flags::mem_threadgroup); + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } } threadgroup_barrier(mem_flags::mem_none); // Store results to device memory { - // Adjust for simdgroup and thread locatio + // Adjust for simdgroup and thread location int offset_m = c_row + mma_op.sm; int offset_n = c_col + mma_op.sn; C += offset_n; diff --git a/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h b/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h index dad496e81..d52642b73 100644 --- a/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +++ b/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h @@ -381,6 +381,7 @@ struct Conv2DWeightBlockLoader { const constant MLXConvParams<2>* params; int weight_hw; + int weight_step; const int read_n; const bool do_read; @@ -402,6 +403,7 @@ struct Conv2DWeightBlockLoader { src(src_ + bi * src_ld + bj), params(params_), weight_hw(0), + weight_step(params->C / params->groups), read_n(offsets.y + bi), do_read(read_n + n_rows * TROWS <= gemm_params_->N) {} @@ -435,15 +437,15 @@ struct Conv2DWeightBlockLoader { /* Iteration helper */ METAL_FUNC void next() { if (++weight_hw < (params->wS[1] * params->wS[0])) { - src += params->wt_strides[2]; + src += weight_step; return; } weight_hw = 0; - src += BK - (params->wS[1] * params->wS[0] - 1) * params->wt_strides[2]; + src += BK - (params->wS[1] * params->wS[0] - 1) * weight_step; } }; } // namespace steel -} // namespace mlx \ No newline at end of file +} // namespace mlx diff --git a/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h b/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h index 56027916e..b0b98d21a 100644 --- a/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +++ b/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h @@ -272,7 +272,7 @@ struct Conv2DWeightBlockLoaderSmallChannels { return; } - const device T* curr_src = src + weight_hw * params->wt_strides[2]; + const device T* curr_src = src + weight_hw * (params->C / params->groups); if (BN != 8 || do_read) { STEEL_PRAGMA_UNROLL @@ -316,4 +316,4 @@ struct Conv2DWeightBlockLoaderSmallChannels { }; } // namespace steel -} // namespace mlx \ No newline at end of file +} // namespace mlx diff --git a/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h b/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h index 72335e698..9b7ddc2ee 100644 --- a/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +++ b/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h @@ -137,6 +137,52 @@ struct Conv2DInputBlockLoaderGeneral { } } + METAL_FUNC void load_safe(const short remaining_k) const { + STEEL_PRAGMA_UNROLL + for (short i = 0, is = 0; i < n_rows; ++i, is += TROWS) { + // Find bounds + int n = read_n[i]; + + int h_flip = params->flip ? params->wS[0] - weight_h - 1 : weight_h; + int w_flip = params->flip ? params->wS[1] - weight_w - 1 : weight_w; + + int ih_dil = read_ih[i] + h_flip * params->kdil[0]; + int iw_dil = read_iw[i] + w_flip * params->kdil[1]; + + int ih = ih_dil / params->idil[0]; + int iw = iw_dil / params->idil[1]; + + size_t offset = ih * params->in_strides[1] + iw * params->in_strides[2]; + + // Read from input if in bounds + if ((n < params->N) && (ih_dil >= 0 && ih < params->iS[0]) && + (iw_dil >= 0 && iw < params->iS[1])) { + if (bj + vec_size <= remaining_k) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; ++j) { + dst[is * dst_ld + j] = (src[i])[offset + j]; + } + } else { + for (short j = 0; j < vec_size; ++j) { + if (bj + j < remaining_k) { + dst[is * dst_ld + j] = (src[i])[offset + j]; + } else { + dst[is * dst_ld + j] = T(0); + } + } + } + } + + // Zero pad otherwise + else { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; ++j) { + dst[is * dst_ld + j] = T(0); + } + } + } + } + /* Iteration helper */ METAL_FUNC void next() { weight_w += jump_params->f_wgt_jump_w; @@ -262,6 +308,55 @@ struct Conv2DWeightBlockLoaderGeneral { } } + METAL_FUNC void load_safe(const short remaining_k) const { + const device T* curr_src = src + weight_h * params->wt_strides[1] + + weight_w * params->wt_strides[2]; + + if ((start_row + BN <= params->O)) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BN; i += TROWS) { + if (bj + vec_size <= remaining_k) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = curr_src[i * src_ld + j]; + } + } else { + for (short j = 0; j < vec_size; j++) { + if (bj + j < remaining_k) { + dst[i * dst_ld + j] = curr_src[i * src_ld + j]; + } else { + dst[i * dst_ld + j] = T(0); + } + } + } + } + } else { + for (short i = 0; i < BN; i += TROWS) { + if ((start_row + i) < params->O) { + if (bj + vec_size <= remaining_k) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = curr_src[i * src_ld + j]; + } + } else { + for (short j = 0; j < vec_size; j++) { + if (bj + j < remaining_k) { + dst[i * dst_ld + j] = curr_src[i * src_ld + j]; + } else { + dst[i * dst_ld + j] = T(0); + } + } + } + } else { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = T(0); + } + } + } + } + } + /* Iteration helper */ METAL_FUNC void next() { weight_w += jump_params->f_wgt_jump_w; diff --git a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h index add495d93..85830872d 100644 --- a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +++ b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h @@ -33,8 +33,8 @@ template < device T* D [[buffer(3)]], const constant GEMMParams* params [[buffer(4)]], const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]], - const constant int* batch_shape [[buffer(6)]], - const constant int64_t* batch_strides [[buffer(7)]], + const constant int* batch_shape [[buffer(6), function_constant(has_batch)]], + const constant int64_t* batch_strides [[buffer(7), function_constant(has_batch)]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]], uint3 tid [[threadgroup_position_in_grid]], diff --git a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h new file mode 100644 index 000000000..b915eb343 --- /dev/null +++ b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h @@ -0,0 +1,266 @@ +// Copyright © 2025 Apple Inc. + +using namespace mlx::steel; + +constant bool segments_contiguous [[function_constant(199)]]; +constant bool align_M [[function_constant(200)]]; +constant bool align_N [[function_constant(201)]]; + +template < + typename T, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose_a, + bool transpose_b, + typename AccumType = float> +[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void segmented_mm( + const device T* A [[buffer(0)]], + const device T* B [[buffer(1)]], + const device uint32_t* segments [[buffer(2)]], + device T* C [[buffer(3)]], + const constant GEMMParams* params [[buffer(4)]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint3 tid [[threadgroup_position_in_grid]]) { + using gemm_kernel = GEMMKernel< + T, + T, + BM, + BN, + BK, + WM, + WN, + transpose_a, + transpose_b, + true, + true, + AccumType>; + + using loader_a_t = typename gemm_kernel::loader_a_t; + using loader_b_t = typename gemm_kernel::loader_b_t; + using mma_t = typename gemm_kernel::mma_t; + + if (params->tiles_n <= static_cast(tid.x) || + params->tiles_m <= static_cast(tid.y)) { + return; + } + + // Prepare threadgroup memory + threadgroup T As[gemm_kernel::tgp_mem_size_a]; + threadgroup T Bs[gemm_kernel::tgp_mem_size_b]; + + // Find the block in A, B, C + const int c_row = tid.y * BM; + const int c_col = tid.x * BN; + const size_t c_row_long = size_t(c_row); + const size_t c_col_long = size_t(c_col); + + // Prepare threadgroup bounds + const short tgp_bm = align_M ? BM : short(min(BM, params->M - c_row)); + const short tgp_bn = align_N ? BN : short(min(BN, params->N - c_col)); + + // Move the pointers to the output tile + A += transpose_a ? c_row_long : c_row_long * params->lda; + B += transpose_b ? c_col_long * params->ldb : c_col_long; + C += c_row_long * params->ldd + c_col_long; + + // Move the pointers to the start of the segment + uint32_t k_start, k_end; + if (segments_contiguous) { + k_start = segments[2 * tid.z]; + k_end = segments[2 * tid.z + 1]; + } else { + // We accept either contiguous (above) or weird strides where the beginning + // of the next one is the previous one. Basically the last two strides are + // both 1! + k_start = segments[tid.z]; + k_end = segments[tid.z + 1]; + } + A += transpose_a ? k_start * params->lda : k_start; + B += transpose_b ? k_start : k_start * params->ldb; + C += tid.z * params->batch_stride_d; + + // Prepare threadgroup mma operation + thread mma_t mma_op(simd_group_id, simd_lane_id); + + // Prepare threadgroup loading operations + thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id); + thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id); + + // Matrix level alignment so only check K + if (align_M && align_N) { + uint32_t k = k_start + BK; + for (; k <= k_end; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Load elements into threadgroup + loader_a.load_unsafe(); + loader_b.load_unsafe(); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + short k_remain = BK - short(k - k_end); + const short2 tile_dims_A = + transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm); + const short2 tile_dims_B = + transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain); + if (k_remain > 0) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_a.load_safe(tile_dims_A); + loader_b.load_safe(tile_dims_B); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(As, Bs); + } + mma_op.store_result(C, params->ldd); + } else { + // Tile aligned do the same as above + if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) { + uint32_t k = k_start + BK; + for (; k <= k_end; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Load elements into threadgroup + loader_a.load_unsafe(); + loader_b.load_unsafe(); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + short k_remain = BK - short(k - k_end); + const short2 tile_dims_A = + transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm); + const short2 tile_dims_B = + transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain); + if (k_remain > 0) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_a.load_safe(tile_dims_A); + loader_b.load_safe(tile_dims_B); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(As, Bs); + } + mma_op.store_result(C, params->ldd); + } + + // Tile partially aligned check rows + else if (align_N || tgp_bn == BN) { + uint32_t k = k_start + BK; + for (; k <= k_end; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Load elements into threadgroup + loader_a.load_safe( + transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm)); + loader_b.load_unsafe(); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + short k_remain = BK - short(k - k_end); + const short2 tile_dims_A = + transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm); + const short2 tile_dims_B = + transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain); + if (k_remain > 0) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_a.load_safe(tile_dims_A); + loader_b.load_safe(tile_dims_B); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(As, Bs); + } + mma_op.store_result_safe(C, params->ldd, short2(tgp_bn, tgp_bm)); + } + + // Tile partially aligned check cols + else if (align_M || tgp_bm == BM) { + uint32_t k = k_start + BK; + for (; k <= k_end; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Load elements into threadgroup + loader_a.load_unsafe(); + loader_b.load_safe( + transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK)); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + short k_remain = BK - short(k - k_end); + const short2 tile_dims_A = + transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm); + const short2 tile_dims_B = + transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain); + if (k_remain > 0) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_a.load_safe(tile_dims_A); + loader_b.load_safe(tile_dims_B); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(As, Bs); + } + mma_op.store_result_safe(C, params->ldd, short2(tgp_bn, tgp_bm)); + } + + // Nothing aligned so check both rows and cols + else { + uint32_t k = k_start + BK; + for (; k <= k_end; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Load elements into threadgroup + loader_a.load_safe( + transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm)); + loader_b.load_safe( + transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK)); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + short k_remain = BK - short(k - k_end); + const short2 tile_dims_A = + transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm); + const short2 tile_dims_B = + transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain); + if (k_remain > 0) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_a.load_safe(tile_dims_A); + loader_b.load_safe(tile_dims_B); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(As, Bs); + } + mma_op.store_result_safe(C, params->ldd, short2(tgp_bn, tgp_bm)); + } + } +} diff --git a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.metal b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.metal new file mode 100644 index 000000000..a7515c359 --- /dev/null +++ b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.metal @@ -0,0 +1,43 @@ +// Copyright © 2024 Apple Inc. + +// clang-format off +#include "mlx/backend/metal/kernels/utils.h" +#include "mlx/backend/metal/kernels/steel/gemm/gemm.h" +#include "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h" + +#define instantiate_segmented_mm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_kernel( \ + "steel_segmented_mm_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn \ + "_bk" #bk "_wm" #wm "_wn" #wn, \ + segmented_mm, \ + itype, \ + bm, \ + bn, \ + bk, \ + wm, \ + wn, \ + trans_a, \ + trans_b, \ + float) + +#define instantiate_segmented_mm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_segmented_mm(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_segmented_mm(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_segmented_mm(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_segmented_mm(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn) + +#define instantiate_segmented_mm_shapes_helper(iname, itype, oname, otype) \ + instantiate_segmented_mm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) \ + instantiate_segmented_mm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 1, 2) \ + instantiate_segmented_mm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \ + instantiate_segmented_mm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 1, 2) \ + instantiate_segmented_mm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) +// clang-format on + +instantiate_segmented_mm_shapes_helper(float16, half, float16, half); +instantiate_segmented_mm_shapes_helper( + bfloat16, + bfloat16_t, + bfloat16, + bfloat16_t); +instantiate_segmented_mm_shapes_helper(float32, float, float32, float); diff --git a/mlx/backend/metal/kernels/steel/gemm/loader.h b/mlx/backend/metal/kernels/steel/gemm/loader.h index 3f084d8ec..d421b2d1f 100644 --- a/mlx/backend/metal/kernels/steel/gemm/loader.h +++ b/mlx/backend/metal/kernels/steel/gemm/loader.h @@ -113,7 +113,7 @@ struct BlockLoader { tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)]; } - // Zero out uneeded values + // Zero out unneeded values STEEL_PRAGMA_UNROLL for (short j = 0; j < vec_size; j++) { tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0); diff --git a/mlx/backend/metal/kernels/ternary.h b/mlx/backend/metal/kernels/ternary.h index 4b3adcc80..570f5e4d6 100644 --- a/mlx/backend/metal/kernels/ternary.h +++ b/mlx/backend/metal/kernels/ternary.h @@ -1,25 +1,44 @@ // Copyright © 2024 Apple Inc. -template +template ::n> [[kernel]] void ternary_v( device const bool* a, device const T* b, device const T* c, device T* d, + constant uint& size, uint index [[thread_position_in_grid]]) { - d[index] = Op()(a[index], b[index], c[index]); + index *= N; + if (N > 1 && index + N > size) { + for (int i = 0; index + i < size; ++i) { + d[index + i] = Op()(a[index + i], b[index + i], c[index + i]); + } + } else { + for (int i = 0; i < N; ++i) { + d[index + i] = Op()(a[index + i], b[index + i], c[index + i]); + } + } } -template +template ::n> [[kernel]] void ternary_v2( device const bool* a, device const T* b, device const T* c, device T* d, + constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - auto offset = index.x + grid_dim.x * int64_t(index.y); - d[offset] = Op()(a[offset], b[offset], c[offset]); + int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); + if (N > 1 && offset + N > size) { + for (int i = 0; offset + i < size; ++i) { + d[offset + i] = Op()(a[offset + i], b[offset + i], c[offset + i]); + } + } else { + for (int i = 0; i < N; ++i) { + d[offset + i] = Op()(a[offset + i], b[offset + i], c[offset + i]); + } + } } template diff --git a/mlx/backend/metal/kernels/ternary.metal b/mlx/backend/metal/kernels/ternary.metal index cceb53061..6da258b6f 100644 --- a/mlx/backend/metal/kernels/ternary.metal +++ b/mlx/backend/metal/kernels/ternary.metal @@ -8,8 +8,8 @@ #include "mlx/backend/metal/kernels/ternary_ops.h" #include "mlx/backend/metal/kernels/ternary.h" -#define instantiate_ternary_all(op, tname, type) \ - instantiate_kernel("v_" #op #tname, ternary_v, type, op) \ +#define instantiate_ternary_base(op, tname, type) \ + instantiate_kernel("v_" #op #tname, ternary_v, type, op, 1) \ instantiate_kernel("v2_" #op #tname, ternary_v2, type, op) \ instantiate_kernel("gn2_" #op #tname, ternary_g, type, op, 2, int) \ instantiate_kernel("g1_" #op #tname, ternary_g_nd1, type, op, int) \ @@ -20,19 +20,23 @@ instantiate_kernel("g3large_" #op #tname, ternary_g_nd3, type, op) \ instantiate_kernel("gn4large_" #op #tname, ternary_g, type, op, 4) \ +#define instantiate_ternary_all(op, tname, type) \ + instantiate_kernel("vn_" #op #tname, ternary_v, type, op) \ + instantiate_ternary_base(op, tname, type) + #define instantiate_ternary_types(op) \ instantiate_ternary_all(op, bool_, bool) \ instantiate_ternary_all(op, uint8, uint8_t) \ instantiate_ternary_all(op, uint16, uint16_t) \ instantiate_ternary_all(op, uint32, uint32_t) \ - instantiate_ternary_all(op, uint64, uint64_t) \ + instantiate_ternary_base(op, uint64, uint64_t) \ instantiate_ternary_all(op, int8, int8_t) \ instantiate_ternary_all(op, int16, int16_t) \ instantiate_ternary_all(op, int32, int32_t) \ - instantiate_ternary_all(op, int64, int64_t) \ + instantiate_ternary_base(op, int64, int64_t) \ instantiate_ternary_all(op, float16, half) \ instantiate_ternary_all(op, float32, float) \ instantiate_ternary_all(op, bfloat16, bfloat16_t) \ - instantiate_ternary_all(op, complex64, complex64_t) // clang-format on + instantiate_ternary_base(op, complex64, complex64_t) // clang-format on instantiate_ternary_types(Select) diff --git a/mlx/backend/metal/kernels/unary.h b/mlx/backend/metal/kernels/unary.h index 69828599f..649ba7f2c 100644 --- a/mlx/backend/metal/kernels/unary.h +++ b/mlx/backend/metal/kernels/unary.h @@ -1,21 +1,40 @@ // Copyright © 2024 Apple Inc. -template +template ::n> [[kernel]] void unary_v( device const T* in, device U* out, + constant uint& size, uint index [[thread_position_in_grid]]) { - out[index] = Op()(in[index]); + index *= N; + if (N > 1 && index + N > size) { + for (int i = 0; index + i < size; ++i) { + out[index + i] = Op()(in[index + i]); + } + } else { + for (int i = 0; i < N; ++i) { + out[index + i] = Op()(in[index + i]); + } + } } -template +template ::n> [[kernel]] void unary_v2( device const T* in, device U* out, + constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - auto offset = index.x + grid_dim.x * int64_t(index.y); - out[offset] = Op()(in[offset]); + int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); + if (N > 1 && offset + N > size) { + for (int i = 0; offset + i < size; ++i) { + out[offset + i] = Op()(in[offset + i]); + } + } else { + for (int i = 0; i < N; ++i) { + out[offset + i] = Op()(in[offset + i]); + } + } } template < diff --git a/mlx/backend/metal/kernels/unary.metal b/mlx/backend/metal/kernels/unary.metal index 2209b0665..160ef4af1 100644 --- a/mlx/backend/metal/kernels/unary.metal +++ b/mlx/backend/metal/kernels/unary.metal @@ -5,31 +5,41 @@ #include "mlx/backend/metal/kernels/unary_ops.h" #include "mlx/backend/metal/kernels/unary.h" -#define instantiate_unary_all(op, in_tname, out_tname, in_type, out_type) \ - instantiate_kernel("v_" #op #in_tname #out_tname, unary_v, in_type, out_type, op) \ - instantiate_kernel("v2_" #op #in_tname #out_tname, unary_v2, in_type, out_type, op) \ - instantiate_kernel( \ - "gn1_" #op #in_tname #out_tname, unary_g, in_type, out_type, op, 1, int) \ - instantiate_kernel( \ +#define instantiate_unary_work_per_thread(op, in_tname, out_tname, in_type, out_type) \ + instantiate_kernel("vn_" #op #in_tname #out_tname, unary_v, in_type, out_type, op) + +#define instantiate_unary_base(op, in_tname, out_tname, in_type, out_type) \ + instantiate_kernel("v_" #op #in_tname #out_tname, unary_v, in_type, out_type, op, 1) \ + instantiate_kernel("v2_" #op #in_tname #out_tname, unary_v2, in_type, out_type, op) \ + instantiate_kernel( \ + "gn1_" #op #in_tname #out_tname, unary_g, in_type, out_type, op, 1, int) \ + instantiate_kernel( \ "gn4large_" #op #in_tname #out_tname, unary_g, in_type, out_type, op, 4) +#define instantiate_unary_all(op, in_tname, out_tname, in_type, out_type) \ + instantiate_unary_base(op, in_tname, out_tname, in_type, out_type) \ + instantiate_unary_work_per_thread(op, in_tname, out_tname, in_type, out_type) + #define instantiate_unary_all_same(op, tname, type) \ instantiate_unary_all(op, tname, tname, type, type) +#define instantiate_unary_base_same(op, tname, type) \ + instantiate_unary_base(op, tname, tname, type, type) + #define instantiate_unary_float(op) \ instantiate_unary_all_same(op, float16, half) \ instantiate_unary_all_same(op, float32, float) \ instantiate_unary_all_same(op, bfloat16, bfloat16_t) -#define instantiate_unary_int(op) \ - instantiate_unary_all_same(op, uint8, uint8_t) \ - instantiate_unary_all_same(op, uint16, uint16_t) \ - instantiate_unary_all_same(op, uint32, uint32_t) \ - instantiate_unary_all_same(op, uint64, uint64_t) \ - instantiate_unary_all_same(op, int8, int8_t) \ - instantiate_unary_all_same(op, int16, int16_t) \ - instantiate_unary_all_same(op, int32, int32_t) \ - instantiate_unary_all_same(op, int64, int64_t) +#define instantiate_unary_int(op) \ + instantiate_unary_all_same(op, uint8, uint8_t) \ + instantiate_unary_all_same(op, uint16, uint16_t) \ + instantiate_unary_all_same(op, uint32, uint32_t) \ + instantiate_unary_base_same(op, uint64, uint64_t) \ + instantiate_unary_all_same(op, int8, int8_t) \ + instantiate_unary_all_same(op, int16, int16_t) \ + instantiate_unary_all_same(op, int32, int32_t) \ + instantiate_unary_base_same(op, int64, int64_t) #define instantiate_unary_types(op) \ instantiate_unary_all_same(op, bool_, bool) \ @@ -68,22 +78,29 @@ instantiate_unary_float(Tanh) instantiate_unary_float(Round) instantiate_unary_int(BitwiseInvert) -instantiate_unary_all_same(Abs, complex64, complex64_t) -instantiate_unary_all_same(Conjugate, complex64, complex64_t) -instantiate_unary_all_same(Cos, complex64, complex64_t) -instantiate_unary_all_same(Cosh, complex64, complex64_t) -instantiate_unary_all_same(Exp, complex64, complex64_t) -instantiate_unary_all_same(Log, complex64, complex64_t) -instantiate_unary_all_same(Log2, complex64, complex64_t) -instantiate_unary_all_same(Log10, complex64, complex64_t) -instantiate_unary_all_same(Negative, complex64, complex64_t) -instantiate_unary_all_same(Sign, complex64, complex64_t) -instantiate_unary_all_same(Sin, complex64, complex64_t) -instantiate_unary_all_same(Sinh, complex64, complex64_t) -instantiate_unary_all_same(Tan, complex64, complex64_t) -instantiate_unary_all_same(Tanh, complex64, complex64_t) -instantiate_unary_all_same(Round, complex64, complex64_t) -instantiate_unary_all(Real, complex64, float32, complex64_t, float) -instantiate_unary_all(Imag, complex64, float32, complex64_t, float) +instantiate_unary_base_same(Abs, complex64, complex64_t) +instantiate_unary_base_same(ArcCos, complex64, complex64_t) +instantiate_unary_base_same(ArcSin, complex64, complex64_t) +instantiate_unary_base_same(ArcTan, complex64, complex64_t) +instantiate_unary_base_same(Conjugate, complex64, complex64_t) +instantiate_unary_base_same(Cos, complex64, complex64_t) +instantiate_unary_base_same(Cosh, complex64, complex64_t) +instantiate_unary_base_same(Exp, complex64, complex64_t) +instantiate_unary_base_same(Log, complex64, complex64_t) +instantiate_unary_base_same(Log1p, complex64, complex64_t) +instantiate_unary_base_same(Log2, complex64, complex64_t) +instantiate_unary_base_same(Log10, complex64, complex64_t) +instantiate_unary_base_same(Negative, complex64, complex64_t) +instantiate_unary_base_same(Sign, complex64, complex64_t) +instantiate_unary_base_same(Sin, complex64, complex64_t) +instantiate_unary_base_same(Sinh, complex64, complex64_t) +instantiate_unary_base_same(Square, complex64, complex64_t) +instantiate_unary_base_same(Sqrt, complex64, complex64_t) +instantiate_unary_base_same(Rsqrt, complex64, complex64_t) +instantiate_unary_base_same(Tan, complex64, complex64_t) +instantiate_unary_base_same(Tanh, complex64, complex64_t) +instantiate_unary_base_same(Round, complex64, complex64_t) +instantiate_unary_base(Real, complex64, float32, complex64_t, float) +instantiate_unary_base(Imag, complex64, float32, complex64_t, float) instantiate_unary_all_same(LogicalNot, bool_, bool) // clang-format on diff --git a/mlx/backend/metal/kernels/unary_ops.h b/mlx/backend/metal/kernels/unary_ops.h index 52e126b40..b34bc44ba 100644 --- a/mlx/backend/metal/kernels/unary_ops.h +++ b/mlx/backend/metal/kernels/unary_ops.h @@ -5,6 +5,7 @@ #include #include +#include "mlx/backend/metal/kernels/cexpf.h" #include "mlx/backend/metal/kernels/erf.h" #include "mlx/backend/metal/kernels/expm1f.h" @@ -17,27 +18,21 @@ struct Abs { T operator()(T x) { return metal::abs(x); }; - template <> uint8_t operator()(uint8_t x) { return x; }; - template <> uint16_t operator()(uint16_t x) { return x; }; - template <> uint32_t operator()(uint32_t x) { return x; }; - template <> uint64_t operator()(uint64_t x) { return x; }; - template <> bool operator()(bool x) { return x; }; - template <> complex64_t operator()(complex64_t x) { return {metal::precise::sqrt(x.real * x.real + x.imag * x.imag), 0}; }; @@ -48,6 +43,8 @@ struct ArcCos { T operator()(T x) { return metal::precise::acos(x); }; + + complex64_t operator()(complex64_t x); }; struct ArcCosh { @@ -62,6 +59,8 @@ struct ArcSin { T operator()(T x) { return metal::precise::asin(x); }; + + complex64_t operator()(complex64_t x); }; struct ArcSinh { @@ -76,6 +75,8 @@ struct ArcTan { T operator()(T x) { return metal::precise::atan(x); }; + + complex64_t operator()(complex64_t x); }; struct ArcTanh { @@ -97,39 +98,30 @@ struct Ceil { T operator()(T x) { return metal::ceil(x); }; - template <> int8_t operator()(int8_t x) { return x; }; - template <> int16_t operator()(int16_t x) { return x; }; - template <> int32_t operator()(int32_t x) { return x; }; - template <> int64_t operator()(int64_t x) { return x; }; - template <> uint8_t operator()(uint8_t x) { return x; }; - template <> uint16_t operator()(uint16_t x) { return x; }; - template <> uint32_t operator()(uint32_t x) { return x; }; - template <> uint64_t operator()(uint64_t x) { return x; }; - template <> bool operator()(bool x) { return x; }; @@ -141,7 +133,6 @@ struct Cos { return metal::precise::cos(x); }; - template <> complex64_t operator()(complex64_t x) { return { metal::precise::cos(x.real) * metal::precise::cosh(x.imag), @@ -155,7 +146,6 @@ struct Cosh { return metal::precise::cosh(x); }; - template <> complex64_t operator()(complex64_t x) { return { metal::precise::cosh(x.real) * metal::precise::cos(x.imag), @@ -188,10 +178,8 @@ struct Exp { T operator()(T x) { return metal::precise::exp(x); }; - template <> complex64_t operator()(complex64_t x) { - auto m = metal::precise::exp(x.real); - return {m * metal::precise::cos(x.imag), m * metal::precise::sin(x.imag)}; + return cexpf(x); } }; @@ -207,39 +195,30 @@ struct Floor { T operator()(T x) { return metal::floor(x); }; - template <> int8_t operator()(int8_t x) { return x; }; - template <> int16_t operator()(int16_t x) { return x; }; - template <> int32_t operator()(int32_t x) { return x; }; - template <> int64_t operator()(int64_t x) { return x; }; - template <> uint8_t operator()(uint8_t x) { return x; }; - template <> uint16_t operator()(uint16_t x) { return x; }; - template <> uint32_t operator()(uint32_t x) { return x; }; - template <> uint64_t operator()(uint64_t x) { return x; }; - template <> bool operator()(bool x) { return x; }; @@ -258,7 +237,6 @@ struct Log { return metal::precise::log(x); }; - template <> complex64_t operator()(complex64_t x) { auto r = metal::precise::log(Abs{}(x).real); auto i = metal::precise::atan2(x.imag, x.real); @@ -272,7 +250,6 @@ struct Log2 { return metal::precise::log2(x); }; - template <> complex64_t operator()(complex64_t x) { auto y = Log{}(x); return {y.real / M_LN2_F, y.imag / M_LN2_F}; @@ -285,7 +262,6 @@ struct Log10 { return metal::precise::log10(x); }; - template <> complex64_t operator()(complex64_t x) { auto y = Log{}(x); return {y.real / M_LN10_F, y.imag / M_LN10_F}; @@ -325,7 +301,6 @@ struct Round { T operator()(T x) { return metal::rint(x); }; - template <> complex64_t operator()(complex64_t x) { return {metal::rint(x.real), metal::rint(x.imag)}; }; @@ -344,11 +319,9 @@ struct Sign { T operator()(T x) { return (x > T(0)) - (x < T(0)); }; - template <> uint32_t operator()(uint32_t x) { return x != 0; }; - template <> complex64_t operator()(complex64_t x) { if (x == complex64_t(0)) { return x; @@ -364,7 +337,6 @@ struct Sin { return metal::precise::sin(x); }; - template <> complex64_t operator()(complex64_t x) { return { metal::precise::sin(x.real) * metal::precise::cosh(x.imag), @@ -378,7 +350,6 @@ struct Sinh { return metal::precise::sinh(x); }; - template <> complex64_t operator()(complex64_t x) { return { metal::precise::sinh(x.real) * metal::precise::cos(x.imag), @@ -398,6 +369,17 @@ struct Sqrt { T operator()(T x) { return metal::precise::sqrt(x); }; + + complex64_t operator()(complex64_t x) { + if (x.real == 0.0 && x.imag == 0.0) { + return {0.0, 0.0}; + } + auto r = Abs{}(x).real; + auto a = metal::precise::sqrt((r + x.real) / 2.0); + auto b_abs = metal::precise::sqrt((r - x.real) / 2.0); + auto b = metal::copysign(b_abs, x.imag); + return {a, b}; + } }; struct Rsqrt { @@ -405,6 +387,10 @@ struct Rsqrt { T operator()(T x) { return metal::precise::rsqrt(x); }; + + complex64_t operator()(complex64_t x) { + return 1.0 / Sqrt{}(x); + } }; struct Tan { @@ -413,7 +399,6 @@ struct Tan { return metal::precise::tan(x); }; - template <> complex64_t operator()(complex64_t x) { float tan_a = metal::precise::tan(x.real); float tanh_b = metal::precise::tanh(x.imag); @@ -429,7 +414,6 @@ struct Tanh { return metal::precise::tanh(x); }; - template <> complex64_t operator()(complex64_t x) { float tanh_a = metal::precise::tanh(x.real); float tan_b = metal::precise::tan(x.imag); @@ -438,3 +422,21 @@ struct Tanh { return {(tanh_a + tan_b * t1) / denom, (tan_b - tanh_a * t1) / denom}; }; }; + +complex64_t ArcCos::operator()(complex64_t x) { + auto i = complex64_t{0.0, 1.0}; + auto y = Log{}(x + i * Sqrt{}(1.0 - x * x)); + return {y.imag, -y.real}; +}; + +complex64_t ArcSin::operator()(complex64_t x) { + auto i = complex64_t{0.0, 1.0}; + auto y = Log{}(i * x + Sqrt{}(1.0 - x * x)); + return {y.imag, -y.real}; +}; + +complex64_t ArcTan::operator()(complex64_t x) { + auto i = complex64_t{0.0, 1.0}; + auto ix = i * x; + return (1.0 / complex64_t{0.0, 2.0}) * Log{}((1.0 + ix) / (1.0 - ix)); +}; diff --git a/mlx/backend/metal/kernels/utils.h b/mlx/backend/metal/kernels/utils.h index b31cd20d6..c30d186b8 100644 --- a/mlx/backend/metal/kernels/utils.h +++ b/mlx/backend/metal/kernels/utils.h @@ -15,6 +15,14 @@ typedef half float16_t; +// Work per thread values for different types. The values here are expected to +// match get_work_per_thread in mlx/backend/metal/utils.h +template +struct WorkPerThread { + static_assert(sizeof(U) <= 8, "Type too large"); + static constexpr int constant n = 8 / sizeof(U); +}; + /////////////////////////////////////////////////////////////////////////////// // Type limits utils /////////////////////////////////////////////////////////////////////////////// @@ -328,6 +336,23 @@ inline bfloat16_t log1p(bfloat16_t x) { return bfloat16_t(x * (metal::log(xp1) / (xp1 - 1.0f))); } +inline complex64_t log1p(complex64_t in) { + float x = in.real; + float y = in.imag; + float zabs = metal::precise::sqrt(x * x + y * y); + float theta = metal::atan2(y, x + 1); + if (zabs < 0.5f) { + float r = x * (2 + x) + y * y; + if (r == 0) { // handle underflow + return {x, theta}; + } + return {0.5f * log1p(r), theta}; + } else { + auto z0 = metal::sqrt((x + 1) * (x + 1) + y * y); + return {metal::log(z0), theta}; + } +} + /////////////////////////////////////////////////////////////////////////////// // SIMD shuffle ops /////////////////////////////////////////////////////////////////////////////// diff --git a/mlx/backend/metal/logsumexp.cpp b/mlx/backend/metal/logsumexp.cpp index 4901190e1..e53bc58d9 100644 --- a/mlx/backend/metal/logsumexp.cpp +++ b/mlx/backend/metal/logsumexp.cpp @@ -1,7 +1,7 @@ // Copyright © 2023-2024 Apple Inc. #include -#include "mlx/backend/metal/copy.h" +#include "mlx/backend/gpu/copy.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/utils.h" diff --git a/mlx/backend/metal/matmul.cpp b/mlx/backend/metal/matmul.cpp index f55d20c9f..55b8be3a9 100644 --- a/mlx/backend/metal/matmul.cpp +++ b/mlx/backend/metal/matmul.cpp @@ -6,8 +6,8 @@ #include #include "mlx/backend/common/broadcasting.h" -#include "mlx/backend/common/utils.h" -#include "mlx/backend/metal/copy.h" +#include "mlx/backend/common/matmul.h" +#include "mlx/backend/gpu/copy.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/kernels/defines.h" @@ -21,69 +21,6 @@ namespace mlx::core { namespace { -inline auto collapse_batches(const array& a, const array& b) { - // Get and check the shape for the batched dims - Shape A_bshape{a.shape().begin(), a.shape().end() - 2}; - Shape B_bshape{b.shape().begin(), b.shape().end() - 2}; - if (A_bshape != B_bshape) { - std::ostringstream msg; - msg << "[matmul] Got matrices with incorrectly broadcasted shapes: " << "A " - << a.shape() << ", B " << b.shape() << "."; - throw std::runtime_error(msg.str()); - } - - Strides A_bstride{a.strides().begin(), a.strides().end() - 2}; - Strides B_bstride{b.strides().begin(), b.strides().end() - 2}; - - auto [batch_shape, batch_strides] = - collapse_contiguous_dims(A_bshape, std::vector{A_bstride, B_bstride}); - - auto A_batch_stride = batch_strides[0]; - auto B_batch_stride = batch_strides[1]; - - if (batch_shape.empty()) { - batch_shape.push_back(1); - A_batch_stride.push_back(0); - B_batch_stride.push_back(0); - } - - return std::make_tuple(batch_shape, A_batch_stride, B_batch_stride); -} - -inline auto collapse_batches(const array& a, const array& b, const array& c) { - // Get and check the shape for the batched dims - Shape A_bshape{a.shape().begin(), a.shape().end() - 2}; - Shape B_bshape{b.shape().begin(), b.shape().end() - 2}; - Shape C_bshape{c.shape().begin(), c.shape().end() - 2}; - if (A_bshape != B_bshape || A_bshape != C_bshape) { - std::ostringstream msg; - msg << "[addmm] Got matrices with incorrectly broadcasted shapes: " << "A " - << a.shape() << ", B " << b.shape() << ", B " << c.shape() << "."; - throw std::runtime_error(msg.str()); - } - - Strides A_bstride{a.strides().begin(), a.strides().end() - 2}; - Strides B_bstride{b.strides().begin(), b.strides().end() - 2}; - Strides C_bstride{c.strides().begin(), c.strides().end() - 2}; - - auto [batch_shape, batch_strides] = collapse_contiguous_dims( - A_bshape, std::vector{A_bstride, B_bstride, C_bstride}); - - auto A_batch_stride = batch_strides[0]; - auto B_batch_stride = batch_strides[1]; - auto C_batch_stride = batch_strides[2]; - - if (batch_shape.empty()) { - batch_shape.push_back(1); - A_batch_stride.push_back(0); - B_batch_stride.push_back(0); - C_batch_stride.push_back(0); - } - - return std::make_tuple( - batch_shape, A_batch_stride, B_batch_stride, C_batch_stride); -} - std::tuple check_transpose( std::vector& copies, const Stream& s, @@ -227,11 +164,17 @@ ensure_batch_contiguous(const array& x, metal::Device& d, const Stream& s) { wn = 2; \ } -void steel_matmul_regular( +/////////////////////////////////////////////////////////////////////////////// +// Regular steel matmul dispatch +/////////////////////////////////////////////////////////////////////////////// + +template +void steel_matmul_regular_axpby( const Stream& s, metal::Device& d, const array& a, const array& b, + const array& c, array& out, int M, int N, @@ -242,12 +185,15 @@ void steel_matmul_regular( int ldd, bool transpose_a, bool transpose_b, + std::vector& copies, Shape batch_shape, Strides batch_strides, int64_t A_batch_stride, int64_t B_batch_stride, int64_t matrix_stride_out, - std::vector& copies) { + int64_t C_batch_stride /* = 0*/, + float alpha /* = 1.0f */, + float beta /* = 0.0f */) { using namespace mlx::steel; // Determine dispatch kernel @@ -259,16 +205,21 @@ void steel_matmul_regular( // Prepare kernel name std::ostringstream kname; - kname << "steel_gemm_fused_" << (transpose_a ? 't' : 'n') - << (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_" - << type_to_name(out) << "_bm" << bm << "_bn" << bn << "_bk" << bk - << "_wm" << wm << "_wn" << wn; + + // clang-format off + kname << "steel_gemm_fused_" + << (transpose_a ? 't' : 'n') + << (transpose_b ? 't' : 'n') + << "_" << type_to_name(a) + << "_" << type_to_name(out) + << "_bm" << bm << "_bn" << bn << "_bk" << bk + << "_wm" << wm << "_wn" << wn; // clang-format on std::string base_name = kname.str(); const bool has_batch = (batch_shape.size() > 1); - const bool use_out_source = false; - const bool do_axpby = false; + const bool use_out_source = CHECK_AB && (alpha != 0.0f || beta != 1.0f); + const bool do_axpby = use_out_source && (alpha != 1.0f || beta != 1.0f); const bool align_M = (M % bm) == 0; const bool align_N = (N % bn) == 0; const bool align_K = (K % bk) == 0; @@ -295,18 +246,18 @@ void steel_matmul_regular( // Encode and dispatch kernel auto& compute_encoder = d.get_command_encoder(s.index); auto kernel = get_steel_gemm_fused_kernel( - d, - base_name, - hash_name, - func_consts, - out, - transpose_a, - transpose_b, - bm, - bn, - bk, - wm, - wn); + /* metal::Device& d = */ d, + /* const std::string& kernel_name = */ base_name, + /* const std::string& hash_name = */ hash_name, + /* const metal::MTLFCList& func_consts = */ func_consts, + /* const array& out = */ out, + /* bool transpose_a = */ transpose_a, + /* bool transpose_b = */ transpose_b, + /* int bm = */ bm, + /* int bn = */ bn, + /* int bk = */ bk, + /* int wm = */ wm, + /* int wn = */ wn); compute_encoder.set_compute_pipeline_state(kernel); @@ -349,8 +300,25 @@ void steel_matmul_regular( compute_encoder.set_bytes(params, 4); - compute_encoder.set_vector_bytes(batch_shape, 6); - compute_encoder.set_vector_bytes(batch_strides, 7); + if (has_batch) { + compute_encoder.set_vector_bytes(batch_shape, 6); + compute_encoder.set_vector_bytes(batch_strides, 7); + } + + if (use_out_source) { + int ldc = c.strides()[c.ndim() - 2]; + int fdc = c.strides()[c.ndim() - 1]; + + GEMMAddMMParams params{ + /* const int ldc = */ ldc, + /* const int fdc = */ fdc, + /* const int64_t batch_stride_c = */ C_batch_stride, + /* const float alpha = */ alpha, + /* const float beta = */ beta}; + + compute_encoder.set_input_array(c, 2); + compute_encoder.set_bytes(params, 5); + } compute_encoder.dispatch_threadgroups(grid_dims, group_dims); @@ -358,7 +326,437 @@ void steel_matmul_regular( d.add_temporaries(std::move(copies), s.index); } -void steel_matmul( +/////////////////////////////////////////////////////////////////////////////// +// Split k steel matmul +/////////////////////////////////////////////////////////////////////////////// + +template +void steel_gemm_splitk_axpby( + const Stream& s, + metal::Device& d, + const array& a, + const array& b, + const array& c, + array& out, + int M, + int N, + int K, + int batch_size_out, + int lda, + int ldb, + bool transpose_a, + bool transpose_b, + std::vector& copies, + float alpha = 1.0f, + float beta = 0.0f) { + using namespace mlx::steel; + + int _tm = M / 16; + int _tn = N / 16; + int _tk = K / 16; + + int bm = M < 40 ? 16 : 32; + int bn = N < 40 ? 16 : 32; + int bk = 16; + int wm = 2, wn = 2; + + int split_k_partitions = _tk < 16 ? 2 : (_tk < 32 ? 4 : (_tk < 64 ? 8 : 16)); + int split_k_partition_stride = M * N; + int gemm_k_iterations = (K / bk) / split_k_partitions; + int split_k_partition_size = gemm_k_iterations * bk; + + array C_split({split_k_partitions, M, N}, float32, nullptr, {}); + C_split.set_data(allocator::malloc(C_split.nbytes())); + copies.push_back(C_split); + + bool mn_aligned = M % bm == 0 && N % bn == 0; + bool k_aligned = K % bk == 0; + std::ostringstream kname; + + // clang-format off + kname << "steel_gemm_splitk_" + << (transpose_a ? 't' : 'n') + << (transpose_b ? 't' : 'n') + << "_" << type_to_name(a) + << "_" << type_to_name(C_split) + << "_bm" << bm << "_bn" << bn << "_bk" << bk + << "_wm" << wm << "_wn" << wn + << "_MN_" << (mn_aligned ? "t" : "n") << "aligned" + << "_K_" << (k_aligned ? "t" : "n") << "aligned"; // clang-format on + + // Encode and dispatch gemm kernel + auto& compute_encoder = d.get_command_encoder(s.index); + auto kernel = get_steel_gemm_splitk_kernel( + /* metal::Device& d = */ d, + /* const std::string& kernel_name = */ kname.str(), + /* const array& in = */ a, + /* const array& out = */ C_split, + /* bool transpose_a = */ transpose_a, + /* bool transpose_b = */ transpose_b, + /* int bm = */ bm, + /* int bn = */ bn, + /* int bk = */ bk, + /* int wm = */ wm, + /* int wn = */ wn, + /* bool mn_aligned = */ mn_aligned, + /* bool k_aligned = */ k_aligned); + + compute_encoder.set_compute_pipeline_state(kernel); + + int tn = (N + bn - 1) / bn; + int tm = (M + bm - 1) / bm; + + GEMMSpiltKParams params{ + /* const int M = */ M, + /* const int N = */ N, + /* const int K = */ K, + /* const int lda = */ lda, + /* const int ldb = */ ldb, + /* const int ldc = */ N, + /* const int tiles_n = */ tn, + /* const int tiles_m = */ tm, + /* const int split_k_partitions = */ split_k_partitions, + /* const int split_k_partition_stride = */ split_k_partition_stride, + /* const int split_k_partition_size = */ split_k_partition_size, + /* const int gemm_k_iterations_aligned = */ gemm_k_iterations}; + + MTL::Size group_dims = MTL::Size(32, wn, wm); + MTL::Size grid_dims = MTL::Size(tn, tm, split_k_partitions); + + compute_encoder.set_input_array(a, 0); + compute_encoder.set_input_array(b, 1); + compute_encoder.set_output_array(C_split, 2); + + compute_encoder.set_bytes(params, 3); + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); + + // Do accum kernel + { + const bool do_axpby = CHECK_AB && (alpha != 1.0f || beta != 0.0f); + + auto kernel_name = "steel_gemm_splitk_accum_" + type_to_name(out) + "_" + + type_to_name(C_split); + + if (do_axpby) { + kernel_name = kernel_name + "_axbpy"; + } + + auto kernel = get_steel_gemm_splitk_accum_kernel( + /* metal::Device& d = */ d, + /* const std::string& kernel_name = */ kernel_name, + /* const array& in = */ C_split, + /* const array& out = */ out, + /* bool axbpy = */ do_axpby); + compute_encoder.set_compute_pipeline_state(kernel); + + // Set the arguments for the kernel + compute_encoder.set_input_array(C_split, 0); + compute_encoder.set_output_array(out, 1); + compute_encoder.set_bytes(split_k_partitions, 2); + compute_encoder.set_bytes(split_k_partition_stride, 3); + compute_encoder.set_bytes(N, 4); + + if (do_axpby) { + int ldc = c.strides()[c.ndim() - 2]; + int fdc = c.strides()[c.ndim() - 1]; + + compute_encoder.set_input_array(c, 5); + compute_encoder.set_bytes(ldc, 6); + compute_encoder.set_bytes(fdc, 7); + compute_encoder.set_bytes(alpha, 8); + compute_encoder.set_bytes(beta, 9); + } + + // Launch enough thread groups for each output + MTL::Size grid_dims = MTL::Size(N, M, 1); + auto group_dims = get_block_dims(N, M, 1); + compute_encoder.dispatch_threads(grid_dims, group_dims); + } + + d.add_temporaries(std::move(copies), s.index); +} + +/////////////////////////////////////////////////////////////////////////////// +// Split matmul routing +/////////////////////////////////////////////////////////////////////////////// + +template +void steel_matmul_axpby( + const Stream& s, + metal::Device& d, + const array& a, + const array& b, + const array& c, + array& out, + int M, + int N, + int K, + int batch_size_out, + int lda, + int ldb, + bool transpose_a, + bool transpose_b, + std::vector& copies, + Shape batch_shape /* = {} */, + Strides A_batch_stride /* = {} */, + Strides B_batch_stride /* = {} */, + Strides C_batch_stride /* = {} */, + float alpha /* = 1.0f */, + float beta /* = 0.0f */) { + if (batch_shape.empty()) { + ///////////////////////////////////////////////////////////////////////////// + // Check and collapse batch dimensions + if constexpr (CHECK_AB) { + auto [batch_shape_, A_bstride_, B_bstride_, C_bstride_] = + collapse_batches(a, b, c); + + batch_shape = batch_shape_; + A_batch_stride = A_bstride_; + B_batch_stride = B_bstride_; + C_batch_stride = C_bstride_; + // Collapse batches into M if needed + if (batch_size_out > 1 && !transpose_a && batch_shape.size() == 1 && + a.strides()[a.ndim() - 2] == K && A_batch_stride.back() == M * K && + C_batch_stride.back() == M * c.strides()[c.ndim() - 2] && + B_batch_stride.back() == 0) { + M *= batch_shape.back(); + batch_size_out = 1; + + A_batch_stride = {0}; + B_batch_stride = {0}; + C_batch_stride = {0}; + batch_shape = {1}; + } + } else { + auto [batch_shape_, A_bstride_, B_bstride_] = collapse_batches(a, b); + + batch_shape = batch_shape_; + A_batch_stride = A_bstride_; + B_batch_stride = B_bstride_; + // Collapse batches into M if needed + if (batch_size_out > 1 && !transpose_a && batch_shape.size() == 1 && + a.strides()[a.ndim() - 2] == K && A_batch_stride.back() == M * K && + B_batch_stride.back() == 0) { + M *= batch_shape.back(); + batch_size_out = 1; + + A_batch_stride = {0}; + B_batch_stride = {0}; + batch_shape = {1}; + } + } + } + + ///////////////////////////////////////////////////////////////////////////// + // Split K specialization + + int _tm = M / 16; + int _tn = N / 16; + int _tk = K / 16; + + if (batch_size_out == 1 && (_tm * _tn) <= 32 && _tk >= 8) { + return steel_gemm_splitk_axpby( + /* const Stream& s = */ s, + /* metal::Device& d = */ d, + /* const array& a = */ a, + /* const array& b = */ b, + /* const array& c = */ c, + /* array& out = */ out, + /* int M = */ M, + /* int N = */ N, + /* int K = */ K, + /* int batch_size_out = */ batch_size_out, + /* int lda = */ lda, + /* int ldb = */ ldb, + /* bool transpose_a = */ transpose_a, + /* bool transpose_b = */ transpose_b, + /* std::vector& copies = */ copies, + /* float alpha = */ alpha, + /* float beta = */ beta); + } + + ///////////////////////////////////////////////////////////////////////////// + // Regular kernel dispatch + auto batch_strides = A_batch_stride; + batch_strides.insert( + batch_strides.end(), B_batch_stride.begin(), B_batch_stride.end()); + if (CHECK_AB && !C_batch_stride.empty()) { + batch_strides.insert( + batch_strides.end(), C_batch_stride.begin(), C_batch_stride.end()); + } + + int64_t A_batch_stride_ = A_batch_stride.empty() ? 0 : A_batch_stride.back(); + int64_t B_batch_stride_ = B_batch_stride.empty() ? 0 : B_batch_stride.back(); + int64_t C_batch_stride_ = C_batch_stride.empty() ? 0 : C_batch_stride.back(); + + return steel_matmul_regular_axpby( + /* const Stream& s = */ s, + /* metal::Device& d = */ d, + /* const array& a = */ a, + /* const array& b = */ b, + /* const array& c = */ c, + /* array& out = */ out, + /* int M = */ M, + /* int N = */ N, + /* int K = */ K, + /* int batch_size_out = */ batch_size_out, + /* int lda = */ lda, + /* int ldb = */ ldb, + /* int ldd = */ N, + /* bool transpose_a = */ transpose_a, + /* bool transpose_b = */ transpose_b, + /* std::vector& copies = */ copies, + /* Shape batch_shape = */ std::move(batch_shape), + /* Strides batch_strides = */ std::move(batch_strides), + /* int64_t A_batch_stride = */ A_batch_stride_, + /* int64_t B_batch_stride = */ B_batch_stride_, + /* int64_t matrix_stride_out = */ int64_t(M) * N, + /* int64_t C_batch_stride = */ C_batch_stride_, + /* float alpha = */ alpha, + /* float beta = */ beta); +} + +/////////////////////////////////////////////////////////////////////////////// +// GEMV dispatch +/////////////////////////////////////////////////////////////////////////////// + +template +void gemv_axbpy( + const Stream& s, + metal::Device& d, + const array& a, + const array& b, + const array& c, + array& out, + int M, + int N, + int K, + int batch_size_out, + int lda, + int ldb, + bool transpose_a, + bool transpose_b, + std::vector& copies, + Shape batch_shape = {}, + Strides A_batch_stride = {}, + Strides B_batch_stride = {}, + Strides C_batch_stride = {}, + float alpha = 1.0f, + float beta = 0.0f) { + // Collect problem info + bool is_b_matrix = N != 1; + + auto& mat = is_b_matrix ? b : a; + auto& vec = is_b_matrix ? a : b; + bool transpose_mat = is_b_matrix ? !transpose_b : transpose_a; + int in_vector_len = K; + int out_vector_len = is_b_matrix ? N : M; + + int mat_cols = transpose_mat ? out_vector_len : in_vector_len; + int mat_rows = transpose_mat ? in_vector_len : out_vector_len; + int mat_ld = is_b_matrix ? ldb : lda; + + auto batch_strides_mat = is_b_matrix ? B_batch_stride : A_batch_stride; + auto batch_strides_vec = is_b_matrix ? A_batch_stride : B_batch_stride; + + int stride_mat = batch_strides_mat.back(); + int stride_vec = batch_strides_vec.back(); + + // Determine if inputs have simple batching / broadcasting + bool contiguous_kernel = (batch_shape.size() == 1); + + int batch_ndim = batch_shape.size(); + + // Determine dispatch kernel + int tm = 4, tn = 4; + int sm = 1, sn = 32; + int bm = 1, bn = 1; + int n_out_per_tgp; + std::ostringstream kname; + + if (transpose_mat) { + if (in_vector_len >= 8192 && out_vector_len >= 2048) { + sm = 4; + sn = 8; + } else { + sm = 8; + sn = 4; + } + + if (out_vector_len >= 2048) { + bn = 16; + } else if (out_vector_len >= 512) { + bn = 4; + } else { + bn = 2; + } + + // Specialized kernel for very small outputs + tn = out_vector_len < tn ? 1 : tn; + + n_out_per_tgp = bn * sn * tn; + kname << "gemv_t_" << type_to_name(out); + + } else { + bm = out_vector_len >= 4096 ? 8 : 4; + sn = 32; + + // Specialized kernel for very small outputs + tm = out_vector_len < tm ? 1 : tm; + + n_out_per_tgp = bm * sm * tm; + kname << "gemv_" << type_to_name(out); + } + + const bool do_axpby = CHECK_AB && (alpha != 1.0f || beta != 0.0f); + + // clang-format off + kname << "_bm" << bm << "_bn" << bn + << "_sm" << sm << "_sn" << sn + << "_tm" << tm << "_tn" << tn + << "_nc" << !contiguous_kernel + << "_axpby" << do_axpby; // clang-format on + + // Encode and dispatch kernel + auto& compute_encoder = d.get_command_encoder(s.index); + auto kernel = d.get_kernel(kname.str()); + compute_encoder.set_compute_pipeline_state(kernel); + + int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp; + MTL::Size group_dims = MTL::Size(32, bn, bm); + MTL::Size grid_dims = MTL::Size(n_tgp, 1, batch_size_out); + + compute_encoder.set_input_array(mat, 0); + compute_encoder.set_input_array(vec, 1); + compute_encoder.set_output_array(out, 3); + + compute_encoder.set_bytes(in_vector_len, 4); + compute_encoder.set_bytes(out_vector_len, 5); + compute_encoder.set_bytes(mat_ld, 6); + + compute_encoder.set_bytes(batch_ndim, 9); + compute_encoder.set_vector_bytes(batch_shape, 10); + compute_encoder.set_vector_bytes(batch_strides_vec, 11); + compute_encoder.set_vector_bytes(batch_strides_mat, 12); + + if (do_axpby) { + compute_encoder.set_input_array(c, 2); + + compute_encoder.set_bytes(alpha, 7); + compute_encoder.set_bytes(beta, 8); + + compute_encoder.set_vector_bytes(C_batch_stride, 13); + + int bias_stride = c.strides()[c.ndim() - 1]; + compute_encoder.set_bytes(bias_stride, 14); + } + + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); + + d.add_temporaries(std::move(copies), s.index); +} + +inline void gemv( const Stream& s, metal::Device& d, const array& a, @@ -373,166 +771,34 @@ void steel_matmul( bool transpose_a, bool transpose_b, std::vector& copies, - Shape batch_shape /* = {} */, - Strides A_batch_stride /* = {} */, - Strides B_batch_stride /* = {} */) { - using namespace mlx::steel; - - if (batch_shape.empty()) { - ///////////////////////////////////////////////////////////////////////////// - // Check and collapse batch dimensions - auto [batch_shape_, A_bstride_, B_bstride_] = collapse_batches(a, b); - - batch_shape = batch_shape_; - A_batch_stride = A_bstride_; - B_batch_stride = B_bstride_; - // Collapse batches into M if needed - if (batch_size_out > 1 && !transpose_a && batch_shape.size() == 1 && - a.strides()[a.ndim() - 2] == K && A_batch_stride.back() == M * K && - B_batch_stride.back() == 0) { - M *= batch_shape.back(); - batch_size_out = 1; - - A_batch_stride = {0}; - B_batch_stride = {0}; - batch_shape = {1}; - } - } - - size_t matrix_stride_out = size_t(M) * N; - - ///////////////////////////////////////////////////////////////////////////// - // Split K specialization - - int _tm = M / 16; - int _tn = N / 16; - int _tk = K / 16; - - if (batch_size_out == 1 && (_tm * _tn) <= 32 && _tk >= 8) { - int bm = M < 40 ? 16 : 32; - int bn = N < 40 ? 16 : 32; - int bk = 16; - int wm = 2, wn = 2; - - int split_k_partitions = - _tk < 16 ? 2 : (_tk < 32 ? 4 : (_tk < 64 ? 8 : 16)); - int split_k_partition_stride = M * N; - int gemm_k_iterations = (K / bk) / split_k_partitions; - int split_k_partition_size = gemm_k_iterations * bk; - - array C_split({split_k_partitions, M, N}, float32, nullptr, {}); - C_split.set_data(allocator::malloc(C_split.nbytes())); - copies.push_back(C_split); - - bool mn_aligned = M % bm == 0 && N % bn == 0; - bool k_aligned = K % bk == 0; - std::ostringstream kname; - kname << "steel_gemm_splitk_" << (transpose_a ? 't' : 'n') - << (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_" - << type_to_name(C_split) << "_bm" << bm << "_bn" << bn << "_bk" << bk - << "_wm" << wm << "_wn" << wn << "_MN_" << (mn_aligned ? "t" : "n") - << "aligned" << "_K_" << (k_aligned ? "t" : "n") << "aligned"; - - // Encode and dispatch gemm kernel - auto& compute_encoder = d.get_command_encoder(s.index); - auto kernel = get_steel_gemm_splitk_kernel( - d, - kname.str(), - a, - C_split, - transpose_a, - transpose_b, - bm, - bn, - bk, - wm, - wn, - mn_aligned, - k_aligned); - compute_encoder.set_compute_pipeline_state(kernel); - - int tn = (N + bn - 1) / bn; - int tm = (M + bm - 1) / bm; - - GEMMSpiltKParams params{ - /* const int M = */ M, - /* const int N = */ N, - /* const int K = */ K, - /* const int lda = */ lda, - /* const int ldb = */ ldb, - /* const int ldc = */ N, - /* const int tiles_n = */ tn, - /* const int tiles_m = */ tm, - /* const int split_k_partitions = */ split_k_partitions, - /* const int split_k_partition_stride = */ split_k_partition_stride, - /* const int split_k_partition_size = */ split_k_partition_size, - /* const int gemm_k_iterations_aligned = */ gemm_k_iterations}; - - MTL::Size group_dims = MTL::Size(32, wn, wm); - MTL::Size grid_dims = MTL::Size(tn, tm, split_k_partitions); - - compute_encoder.set_input_array(a, 0); - compute_encoder.set_input_array(b, 1); - compute_encoder.set_output_array(C_split, 2); - - compute_encoder.set_bytes(params, 3); - compute_encoder.dispatch_threadgroups(grid_dims, group_dims); - - // Do accum kernel - { - auto kernel_name = "steel_gemm_splitk_accum_" + type_to_name(out) + "_" + - type_to_name(C_split); - - auto kernel = get_steel_gemm_splitk_accum_kernel( - d, kernel_name, C_split, out, false); - compute_encoder.set_compute_pipeline_state(kernel); - - // Set the arguments for the kernel - compute_encoder.set_input_array(C_split, 0); - compute_encoder.set_output_array(out, 1); - compute_encoder.set_bytes(split_k_partitions, 2); - compute_encoder.set_bytes(split_k_partition_stride, 3); - compute_encoder.set_bytes(N, 4); - - // Launch enough thread groups for each output - MTL::Size grid_dims = MTL::Size(N, M, 1); - auto group_dims = get_block_dims(N, M, 1); - compute_encoder.dispatch_threads(grid_dims, group_dims); - } - - d.add_temporaries(std::move(copies), s.index); - return; - } - - ///////////////////////////////////////////////////////////////////////////// - // Regular kernel dispatch - auto batch_strides = A_batch_stride; - batch_strides.insert( - batch_strides.end(), B_batch_stride.begin(), B_batch_stride.end()); - - steel_matmul_regular( - s, - d, - a, - b, - out, - M, - N, - K, - batch_size_out, - lda, - ldb, - N, - transpose_a, - transpose_b, - std::move(batch_shape), - std::move(batch_strides), - A_batch_stride.back(), - B_batch_stride.back(), - matrix_stride_out, - copies); + Shape batch_shape = {}, + Strides A_batch_stride = {}, + Strides B_batch_stride = {}) { + return gemv_axbpy( + /* const Stream& s = */ s, + /* metal::Device& d = */ d, + /* const array& a = */ a, + /* const array& b = */ b, + /* const array& c = */ b, + /* array& out = */ out, + /* int M = */ M, + /* int N = */ N, + /* int K = */ K, + /* int batch_size_out = */ batch_size_out, + /* int lda = */ lda, + /* int ldb = */ ldb, + /* bool transpose_a = */ transpose_a, + /* bool transpose_b = */ transpose_b, + /* std::vector& copies = */ copies, + /* Shape batch_shape = */ batch_shape, + /* Strides A_batch_stride = */ A_batch_stride, + /* Strides B_batch_stride = */ B_batch_stride); } +/////////////////////////////////////////////////////////////////////////////// +// Matmul implementation +/////////////////////////////////////////////////////////////////////////////// + void Matmul::eval_gpu(const std::vector& inputs, array& out) { assert(inputs.size() == 2); if (!issubdtype(out.dtype(), floating)) { @@ -591,102 +857,26 @@ void Matmul::eval_gpu(const std::vector& inputs, array& out) { // Route to gemv if needed if (std::min(M, N) == 1) { - // Collect problem info - bool is_b_matrix = N != 1; - - auto& mat = is_b_matrix ? b : a; - auto& vec = is_b_matrix ? a : b; - bool transpose_mat = is_b_matrix ? !b_transposed : a_transposed; - int in_vector_len = K; - int out_vector_len = is_b_matrix ? N : M; - - int mat_cols = transpose_mat ? out_vector_len : in_vector_len; - int mat_rows = transpose_mat ? in_vector_len : out_vector_len; - int mat_ld = is_b_matrix ? b_cols : a_cols; - - auto batch_strides_mat = is_b_matrix ? B_batch_stride : A_batch_stride; - auto batch_strides_vec = is_b_matrix ? A_batch_stride : B_batch_stride; - - int stride_mat = batch_strides_mat.back(); - int stride_vec = batch_strides_vec.back(); - - // Determine if inputs have simple batching / broadcasting - bool contiguous_kernel = (batch_shape.size() == 1); - - int batch_ndim = batch_shape.size(); - - // Determine dispatch kernel - int tm = 4, tn = 4; - int sm = 1, sn = 32; - int bm = 1, bn = 1; - int n_out_per_tgp; - std::ostringstream kname; - - if (transpose_mat) { - if (in_vector_len >= 8192 && out_vector_len >= 2048) { - sm = 4; - sn = 8; - } else { - sm = 8; - sn = 4; - } - - if (out_vector_len >= 2048) { - bn = 16; - } else if (out_vector_len >= 512) { - bn = 4; - } else { - bn = 2; - } - - // Specialized kernel for very small outputs - tn = out_vector_len < tn ? 1 : tn; - - n_out_per_tgp = bn * sn * tn; - kname << "gemv_t_" << type_to_name(out); - - } else { - bm = out_vector_len >= 4096 ? 8 : 4; - sn = 32; - - // Specialized kernel for very small outputs - tm = out_vector_len < tm ? 1 : tm; - - n_out_per_tgp = bm * sm * tm; - kname << "gemv_" << type_to_name(out); - } - - kname << "_bm" << bm << "_bn" << bn << "_sm" << sm << "_sn" << sn << "_tm" - << tm << "_tn" << tn; - kname << "_nc" << !contiguous_kernel << "_axpby0"; - - // Encode and dispatch kernel - auto& compute_encoder = d.get_command_encoder(s.index); - auto kernel = d.get_kernel(kname.str()); - compute_encoder.set_compute_pipeline_state(kernel); - - int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp; - MTL::Size group_dims = MTL::Size(32, bn, bm); - MTL::Size grid_dims = MTL::Size(n_tgp, 1, batch_size_out); - - compute_encoder.set_input_array(mat, 0); - compute_encoder.set_input_array(vec, 1); - compute_encoder.set_output_array(out, 3); - - compute_encoder.set_bytes(in_vector_len, 4); - compute_encoder.set_bytes(out_vector_len, 5); - compute_encoder.set_bytes(mat_ld, 6); - - compute_encoder.set_bytes(batch_ndim, 9); - compute_encoder.set_vector_bytes(batch_shape, 10); - compute_encoder.set_vector_bytes(batch_strides_vec, 11); - compute_encoder.set_vector_bytes(batch_strides_mat, 12); - - compute_encoder.dispatch_threadgroups(grid_dims, group_dims); - - d.add_temporaries(std::move(copies), s.index); - return; + return gemv( + /* const Stream& s = */ s, + /* metal::Device& d = */ d, + /* const array& a = */ a, + /* const array& b = */ b, + /* array& out = */ out, + /* int M = */ M, + /* int N = */ N, + /* int K = */ K, + /* int batch_size_out = */ batch_size_out, + /* int lda = */ a_cols, + /* int ldb = */ b_cols, + /* bool transpose_a = */ a_transposed, + /* bool transpose_b = */ b_transposed, + /* std::vector& copies = */ copies, + /* Shape batch_shape = */ std::move(batch_shape), + /* Strides A_batch_stride = */ std::move(A_batch_stride), + /* Strides B_batch_stride = */ std::move(B_batch_stride)); } + ///////////////////////////////////////////////////////////////////////////// // Gemm specialization @@ -704,18 +894,39 @@ void Matmul::eval_gpu(const std::vector& inputs, array& out) { /* int ldb = */ b_cols, /* bool transpose_a = */ a_transposed, /* bool transpose_b = */ b_transposed, - /* std::vector& = */ copies, - /* Shape batch_shape = */ batch_shape, - /* Strides A_batch_stride = */ A_batch_stride, - /* Strides B_batch_stride = */ B_batch_stride); + /* std::vector& copies = */ copies, + /* Shape batch_shape = */ std::move(batch_shape), + /* Strides A_batch_stride = */ std::move(A_batch_stride), + /* Strides B_batch_stride = */ std::move(B_batch_stride)); } +/////////////////////////////////////////////////////////////////////////////// +// AddMM implementation +/////////////////////////////////////////////////////////////////////////////// + void AddMM::eval_gpu(const std::vector& inputs, array& out) { assert(inputs.size() == 3); if (!issubdtype(out.dtype(), floating)) { throw std::runtime_error( "[matmul] Does not yet support non-floating point types."); } + + // Return 0s if either input is empty + if (out.size() == 0) { + out.set_data(allocator::malloc(out.nbytes())); + return; + } + + // Copy c into out and return + if (inputs[0].shape(-1) == 0) { + copy_gpu( + inputs[2], + out, + inputs[2].flags().row_contiguous ? CopyType::Vector : CopyType::General, + stream()); + return; + } + out.set_data(allocator::malloc(out.nbytes())); auto& s = stream(); auto& d = metal::device(s.device); @@ -772,346 +983,61 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { // Route to gemv if needed if (std::min(M, N) == 1) { - // Collect problem info - bool is_b_matrix = N != 1; - - auto& mat = is_b_matrix ? b : a; - auto& vec = is_b_matrix ? a : b; - bool transpose_mat = is_b_matrix ? !transpose_b : transpose_a; - int in_vector_len = K; - int out_vector_len = is_b_matrix ? N : M; - - int mat_cols = transpose_mat ? out_vector_len : in_vector_len; - int mat_rows = transpose_mat ? in_vector_len : out_vector_len; - int mat_ld = is_b_matrix ? b_cols : a_cols; - - auto batch_strides_mat = is_b_matrix ? B_batch_stride : A_batch_stride; - auto batch_strides_vec = is_b_matrix ? A_batch_stride : B_batch_stride; - - int stride_mat = batch_strides_mat.back(); - int stride_vec = batch_strides_vec.back(); - - // Determine if inputs have simple batching / broadcasting - bool contiguous_kernel = (batch_shape.size() == 1); - - int batch_ndim = batch_shape.size(); - - // Determine dispatch kernel - int tm = 4, tn = 4; - int sm = 1, sn = 32; - int bm = 1, bn = 1; - int n_out_per_tgp; - std::ostringstream kname; - - if (transpose_mat) { - if (in_vector_len >= 8192 && out_vector_len >= 2048) { - sm = 4; - sn = 8; - } else { - sm = 8; - sn = 4; - } - - if (out_vector_len >= 2048) { - bn = 16; - } else if (out_vector_len >= 512) { - bn = 4; - } else { - bn = 2; - } - - // Specialized kernel for very small outputs - tn = out_vector_len < tn ? 1 : tn; - - n_out_per_tgp = bn * sn * tn; - kname << "gemv_t_" << type_to_name(out); - - } else { - bm = out_vector_len >= 4096 ? 8 : 4; - sn = 32; - - // Specialized kernel for very small outputs - tm = out_vector_len < tm ? 1 : tm; - - n_out_per_tgp = bm * sm * tm; - kname << "gemv_" << type_to_name(out); - } - - kname << "_bm" << bm << "_bn" << bn << "_sm" << sm << "_sn" << sn << "_tm" - << tm << "_tn" << tn; - kname << "_nc" << !contiguous_kernel << "_axpby1"; - - // Encode and dispatch kernel - auto& compute_encoder = d.get_command_encoder(s.index); - auto kernel = d.get_kernel(kname.str()); - compute_encoder.set_compute_pipeline_state(kernel); - - int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp; - MTL::Size group_dims = MTL::Size(32, bn, bm); - MTL::Size grid_dims = MTL::Size(n_tgp, 1, batch_size_out); - - compute_encoder.set_input_array(mat, 0); - compute_encoder.set_input_array(vec, 1); - compute_encoder.set_input_array(c, 2); - compute_encoder.set_output_array(out, 3); - - compute_encoder.set_bytes(in_vector_len, 4); - compute_encoder.set_bytes(out_vector_len, 5); - compute_encoder.set_bytes(mat_ld, 6); - - compute_encoder.set_bytes(alpha_, 7); - compute_encoder.set_bytes(beta_, 8); - - compute_encoder.set_bytes(batch_ndim, 9); - compute_encoder.set_vector_bytes(batch_shape, 10); - compute_encoder.set_vector_bytes(batch_strides_vec, 11); - compute_encoder.set_vector_bytes(batch_strides_mat, 12); - compute_encoder.set_vector_bytes(C_batch_stride, 13); - - int bias_stride = c.strides()[c.ndim() - 1]; - compute_encoder.set_bytes(bias_stride, 14); - - compute_encoder.dispatch_threadgroups(grid_dims, group_dims); - - d.add_temporaries(std::move(copies), s.index); - return; - } - - using namespace mlx::steel; - - ///////////////////////////////////////////////////////////////////////////// - // Split K specialization - - int _tm = M / 16; - int _tn = N / 16; - int _tk = K / 16; - - if (batch_size_out == 1 && (_tm * _tn) <= 32 && _tk >= 8) { - int bm = M < 40 ? 16 : 32; - int bn = N < 40 ? 16 : 32; - int bk = 16; - int wm = 2, wn = 2; - - int split_k_partitions = - _tk < 16 ? 2 : (_tk < 32 ? 4 : (_tk < 64 ? 8 : 16)); - int split_k_partition_stride = M * N; - int gemm_k_iterations = (K / bk) / split_k_partitions; - int split_k_partition_size = gemm_k_iterations * bk; - - array C_split({split_k_partitions, M, N}, float32, nullptr, {}); - C_split.set_data(allocator::malloc(C_split.nbytes())); - copies.push_back(C_split); - - bool mn_aligned = M % bm == 0 && N % bn == 0; - bool k_aligned = K % bk == 0; - - std::ostringstream kname; - kname << "steel_gemm_splitk_" << (transpose_a ? 't' : 'n') - << (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_" - << type_to_name(C_split) << "_bm" << bm << "_bn" << bn << "_bk" << bk - << "_wm" << wm << "_wn" << wn << "_MN_" << (mn_aligned ? "t" : "n") - << "aligned" << "_K_" << (k_aligned ? "t" : "n") << "aligned"; - - // Encode and dispatch gemm kernel - auto& compute_encoder = d.get_command_encoder(s.index); - auto kernel = get_steel_gemm_splitk_kernel( - d, - kname.str(), - a, - C_split, - transpose_a, - transpose_b, - bm, - bn, - bk, - wm, - wn, - mn_aligned, - k_aligned); - - compute_encoder.set_compute_pipeline_state(kernel); - - int tn = (N + bn - 1) / bn; - int tm = (M + bm - 1) / bm; - - GEMMSpiltKParams params{ - M, - N, - K, - lda, - ldb, - N, - tn, - tm, - split_k_partitions, - split_k_partition_stride, - split_k_partition_size, - gemm_k_iterations}; - - MTL::Size group_dims = MTL::Size(32, wn, wm); - MTL::Size grid_dims = MTL::Size(tn, tm, split_k_partitions); - - compute_encoder.set_input_array(a, 0); - compute_encoder.set_input_array(b, 1); - compute_encoder.set_output_array(C_split, 2); - - compute_encoder.set_bytes(params, 3); - compute_encoder.dispatch_threadgroups(grid_dims, group_dims); - - // Do accum kernel - { - auto kernel_name = "steel_gemm_splitk_accum_" + type_to_name(out) + "_" + - type_to_name(C_split) + "_axbpy"; - auto kernel = get_steel_gemm_splitk_accum_kernel( - d, kernel_name, C_split, out, true); - - compute_encoder.set_compute_pipeline_state(kernel); - - // Set the arguments for the kernel - compute_encoder.set_input_array(C_split, 0); - compute_encoder.set_output_array(out, 1); - compute_encoder.set_bytes(split_k_partitions, 2); - compute_encoder.set_bytes(split_k_partition_stride, 3); - compute_encoder.set_bytes(N, 4); - compute_encoder.set_input_array(c, 5); - compute_encoder.set_bytes(ldc, 6); - compute_encoder.set_bytes(fdc, 7); - compute_encoder.set_bytes(alpha_, 8); - compute_encoder.set_bytes(beta_, 9); - - // Launch enough thread groups for each output - MTL::Size grid_dims = MTL::Size(N, M, 1); - auto group_dims = get_block_dims(N, M, 1); - compute_encoder.dispatch_threads(grid_dims, group_dims); - } - - d.add_temporaries(std::move(copies), s.index); - return; + return gemv_axbpy( + /* const Stream& s = */ s, + /* metal::Device& d = */ d, + /* const array& a = */ a, + /* const array& b = */ b, + /* const array& c = */ c, + /* array& out = */ out, + /* int M = */ M, + /* int N = */ N, + /* int K = */ K, + /* int batch_size_out = */ batch_size_out, + /* int lda = */ lda, + /* int ldb = */ ldb, + /* bool transpose_a = */ transpose_a, + /* bool transpose_b = */ transpose_b, + /* std::vector& copies = */ copies, + /* Shape batch_shape = */ batch_shape, + /* Strides A_batch_stride = */ A_batch_stride, + /* Strides B_batch_stride = */ B_batch_stride, + /* Strides C_batch_stride = */ C_batch_stride, + /* float alpha = */ alpha_, + /* float beta = */ beta_); } ///////////////////////////////////////////////////////////////////////////// // Regular addmm dispatch - // Determine dispatch kernel - int bm = 64, bn = 64, bk = 16; - int wm = 2, wn = 2; - - char devc = d.get_architecture().back(); - GEMM_TPARAM_MACRO(devc) - - // Prepare kernel name - std::ostringstream kname; - kname << "steel_gemm_fused_" << (transpose_a ? 't' : 'n') - << (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_" - << type_to_name(out) << "_bm" << bm << "_bn" << bn << "_bk" << bk - << "_wm" << wm << "_wn" << wn; - - std::string base_name = kname.str(); - - const bool has_batch = (batch_shape.size() > 1); - const bool use_out_source = true; - const bool do_axpby = !(alpha_ == 1. && beta_ == 1.); - const bool align_M = (M % bm) == 0; - const bool align_N = (N % bn) == 0; - const bool align_K = (K % bk) == 0; - - metal::MTLFCList func_consts = { - {&has_batch, MTL::DataType::DataTypeBool, 10}, - {&use_out_source, MTL::DataType::DataTypeBool, 100}, - {&do_axpby, MTL::DataType::DataTypeBool, 110}, - {&align_M, MTL::DataType::DataTypeBool, 200}, - {&align_N, MTL::DataType::DataTypeBool, 201}, - {&align_K, MTL::DataType::DataTypeBool, 202}, - }; - - // clang-format off - kname << "_has_batch_" << (has_batch ? 't' : 'n') - << "_use_out_source_" << (use_out_source ? 't' : 'n') - << "_do_axpby_" << (do_axpby ? 't' : 'n') - << "_align_M_" << (align_M ? 't' : 'n') - << "_align_N_" << (align_N ? 't' : 'n') - << "_align_K_" << (align_K ? 't' : 'n'); // clang-format on - - std::string hash_name = kname.str(); - - // Encode and dispatch kernel - auto& compute_encoder = d.get_command_encoder(s.index); - auto kernel = get_steel_gemm_fused_kernel( - d, - base_name, - hash_name, - func_consts, - out, - transpose_a, - transpose_b, - bm, - bn, - bk, - wm, - wn); - - compute_encoder.set_compute_pipeline_state(kernel); - - int tn = (N + bn - 1) / bn; - int tm = (M + bm - 1) / bm; - - // TODO: Explore device-based tuning for swizzle - int swizzle_log = 0; // tm >= 6 ? 3 : (tm <= 3 ? 0 : 2); - - // Prepare steel matmul params - GEMMParams gemm_params{ - /* const int M = */ M, - /* const int N = */ N, - /* const int K = */ K, - /* const int lda = */ lda, - /* const int ldb = */ ldb, - /* const int ldd = */ N, - /* const int tiles_n = */ tn, - /* const int tiles_m = */ tm, - /* const int64_t batch_stride_a = */ A_batch_stride.back(), - /* const int64_t batch_stride_b = */ B_batch_stride.back(), - /* const int64_t batch_stride_d = */ matrix_stride_out, - /* const int swizzle_log = */ swizzle_log, - /* const int gemm_k_iterations_aligned = */ (K / bk), - /* const int batch_ndim = */ int(batch_shape.size())}; - - GEMMAddMMParams params{ - /* const int ldc = */ ldc, - /* const int fdc = */ fdc, - /* const int64_t batch_stride_c = */ C_batch_stride.back(), - /* const float alpha = */ alpha_, - /* const float beta = */ beta_}; - - int tile = 1 << swizzle_log; - tm = (tm + tile - 1) / tile; - tn = tn * tile; - - MTL::Size group_dims = MTL::Size(32, wn, wm); - MTL::Size grid_dims = MTL::Size(tn, tm, batch_size_out); - - Strides batch_strides = A_batch_stride; - batch_strides.insert( - batch_strides.end(), B_batch_stride.begin(), B_batch_stride.end()); - batch_strides.insert( - batch_strides.end(), C_batch_stride.begin(), C_batch_stride.end()); - - // Launch kernel - compute_encoder.set_input_array(a, 0); - compute_encoder.set_input_array(b, 1); - compute_encoder.set_input_array(c, 2); - compute_encoder.set_output_array(out, 3); - - compute_encoder.set_bytes(gemm_params, 4); - compute_encoder.set_bytes(params, 5); - - compute_encoder.set_vector_bytes(batch_shape, 6); - compute_encoder.set_vector_bytes(batch_strides, 7); - - compute_encoder.dispatch_threadgroups(grid_dims, group_dims); - - d.add_temporaries(std::move(copies), s.index); + return steel_matmul_axpby( + /* const Stream& s = */ s, + /* metal::Device& d = */ d, + /* const array& a = */ a, + /* const array& b = */ b, + /* const array& c = */ c, + /* array& out = */ out, + /* int M = */ M, + /* int N = */ N, + /* int K = */ K, + /* int batch_size_out = */ batch_size_out, + /* int lda = */ lda, + /* int ldb = */ ldb, + /* bool transpose_a = */ transpose_a, + /* bool transpose_b = */ transpose_b, + /* std::vector& copies = */ copies, + /* Shape batch_shape = */ batch_shape, + /* Strides A_batch_stride = */ A_batch_stride, + /* Strides B_batch_stride = */ B_batch_stride, + /* Strides B_batch_stride = */ C_batch_stride, + /* float alpha = */ alpha_, + /* float beta = */ beta_); } +/////////////////////////////////////////////////////////////////////////////// +// BlockMaskedMM implementation +/////////////////////////////////////////////////////////////////////////////// + void BlockMaskedMM::eval_gpu(const std::vector& inputs, array& out) { using namespace mlx::steel; // assert(inputs.size() == 2); @@ -1500,6 +1426,10 @@ void BlockMaskedMM::eval_gpu(const std::vector& inputs, array& out) { d.add_temporaries(std::move(copies), s.index); } +/////////////////////////////////////////////////////////////////////////////// +// GatherMM implementation +/////////////////////////////////////////////////////////////////////////////// + void gather_mm_rhs( const array& a_, const array& b_, @@ -1934,4 +1864,166 @@ void GatherMM::eval_gpu(const std::vector& inputs, array& out) { gather_mm(a, b, lhs_indices, rhs_indices, out, M, N, K, d, s); } +void segmented_mm( + const array& a_, + const array& b_, + const array& segments_, + array& out, + int M, + int N, + int K, + metal::Device& d, + const Stream& s) { + auto check_segments_layout = [&d, &s](const array& x) { + // Contiguous so return early + if (x.flags().row_contiguous) { + return std::make_tuple(true, x); + } + + bool rc = true; + for (int i = 0; i < x.ndim() - 2; i++) { + rc &= + (x.strides(i + 1) * x.shape(i) == x.strides(i)) || (x.shape(i) == 1); + } + rc &= x.strides(x.ndim() - 1) == 1; + if (x.ndim() > 1) { + rc &= x.strides(x.ndim() - 2) == 1; + } + + if (rc) { + return std::make_tuple(false, x); + } + + array x_copy(x.shape(), x.dtype(), nullptr, {}); + copy_gpu(x, x_copy, CopyType::General, s); + d.add_temporary(x_copy, s.index); + return std::make_tuple(true, x_copy); + }; + + // Copy if needed + std::vector copies; + auto [transpose_a, lda, a] = check_transpose(copies, s, a_, false); + auto [transpose_b, ldb, b] = check_transpose(copies, s, b_, false); + auto [segments_contiguous, segments] = check_segments_layout(segments_); + d.add_temporaries(std::move(copies), s.index); + + // Determine dispatch kernel + int bm = 64, bn = 64, bk = 16; + int wm = 2, wn = 2; + size_t batch_size_out = out.size() / M / N; + + char devc = d.get_architecture().back(); + GEMM_TPARAM_MACRO(devc) + + const bool align_M = (M % bm) == 0; + const bool align_N = (N % bn) == 0; + + // Define the kernel name + std::string base_name; + base_name.reserve(128); + concatenate( + base_name, + "steel_segmented_mm_", + transpose_a ? 't' : 'n', + transpose_b ? 't' : 'n', + "_", + type_to_name(a), + "_", + type_to_name(out), + "_bm", + bm, + "_bn", + bn, + "_bk", + bk, + "_wm", + wm, + "_wn", + wn); + + metal::MTLFCList func_consts = { + {&segments_contiguous, MTL::DataType::DataTypeBool, 199}, + {&align_M, MTL::DataType::DataTypeBool, 200}, + {&align_N, MTL::DataType::DataTypeBool, 201}, + }; + + // And the kernel hash that includes the function constants + std::string hash_name; + hash_name.reserve(128); + concatenate( + hash_name, + base_name, + "_segments_contiguous_", + segments_contiguous ? 't' : 'n', + "_align_M_", + align_M ? 't' : 'n', + "_align_N_", + align_N ? 't' : 'n'); + + // Get and set the kernel + auto& compute_encoder = d.get_command_encoder(s.index); + auto kernel = get_steel_gemm_segmented_kernel( + d, + base_name, + hash_name, + func_consts, + out, + transpose_a, + transpose_b, + bm, + bn, + bk, + wm, + wn); + compute_encoder.set_compute_pipeline_state(kernel); + + // Prepare the matmul params + steel::GEMMParams params{ + /* const int M = */ M, + /* const int N = */ N, + /* const int K = */ K, + /* const int lda = */ static_cast(lda), + /* const int ldb = */ static_cast(ldb), + /* const int ldd = */ N, + /* const int tiles_n = */ (N + bn - 1) / bn, + /* const int tiles_m = */ (M + bm - 1) / bm, + /* const int64_t batch_stride_a = */ 0, + /* const int64_t batch_stride_b = */ 0, + /* const int64_t batch_stride_d = */ M * N, + /* const int swizzle_log = */ 0, + /* const int gemm_k_iterations_aligned = */ 0, + /* const int batch_ndim = */ 0}; + + // Prepare the grid + MTL::Size group_dims = MTL::Size(32, wn, wm); + MTL::Size grid_dims = + MTL::Size(params.tiles_n, params.tiles_m, batch_size_out); + + // Launch kernel + compute_encoder.set_input_array(a, 0); + compute_encoder.set_input_array(b, 1); + compute_encoder.set_input_array(segments, 2); + compute_encoder.set_output_array(out, 3); + compute_encoder.set_bytes(params, 4); + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); +} + +void SegmentedMM::eval_gpu(const std::vector& inputs, array& out) { + auto& s = stream(); + auto& d = metal::device(s.device); + + auto& a = inputs[0]; + auto& b = inputs[1]; + auto& segments = inputs[2]; + + out.set_data(allocator::malloc(out.nbytes())); + + // Extract shapes from inputs. + int M = a.shape(-2); + int N = b.shape(-1); + int K = a.shape(-1); + + segmented_mm(a, b, segments, out, M, N, K, d, s); +} + } // namespace mlx::core diff --git a/mlx/backend/metal/matmul.h b/mlx/backend/metal/matmul.h index 09ffe05a8..218664b1f 100644 --- a/mlx/backend/metal/matmul.h +++ b/mlx/backend/metal/matmul.h @@ -6,7 +6,34 @@ namespace mlx::core { -void steel_matmul_regular( +template +void steel_matmul_regular_axpby( + const Stream& s, + metal::Device& d, + const array& a, + const array& b, + const array& c, + array& out, + int M, + int N, + int K, + int batch_size_out, + int lda, + int ldb, + int ldd, + bool transpose_a, + bool transpose_b, + std::vector& copies, + Shape batch_shape, + Strides batch_strides, + int64_t A_batch_stride, + int64_t B_batch_stride, + int64_t matrix_stride_out, + int64_t C_batch_stride = 0, + float alpha = 1.0f, + float beta = 0.0f); + +inline void steel_matmul_regular( const Stream& s, metal::Device& d, const array& a, @@ -21,14 +48,61 @@ void steel_matmul_regular( int ldd, bool transpose_a, bool transpose_b, + std::vector& copies, Shape batch_shape, Strides batch_strides, int64_t A_batch_stride, int64_t B_batch_stride, - int64_t matrix_stride_out, - std::vector& copies); + int64_t matrix_stride_out) { + return steel_matmul_regular_axpby( + /* const Stream& s = */ s, + /* metal::Device& d = */ d, + /* const array& a = */ a, + /* const array& b = */ b, + /* const array& c = */ b, + /* array& out = */ out, + /* int M = */ M, + /* int N = */ N, + /* int K = */ K, + /* int batch_size_out = */ batch_size_out, + /* int lda = */ lda, + /* int ldb = */ ldb, + /* int ldd = */ ldd, + /* bool transpose_a = */ transpose_a, + /* bool transpose_b = */ transpose_b, + /* std::vector& copies = */ copies, + /* Shape batch_shape = */ batch_shape, + /* Strides batch_strides = */ batch_strides, + /* int64_t A_batch_stride = */ A_batch_stride, + /* int64_t B_batch_stride = */ B_batch_stride, + /* int64_t matrix_stride_out = */ matrix_stride_out); +} -void steel_matmul( +template +void steel_matmul_axpby( + const Stream& s, + metal::Device& d, + const array& a, + const array& b, + const array& c, + array& out, + int M, + int N, + int K, + int batch_size_out, + int lda, + int ldb, + bool transpose_a, + bool transpose_b, + std::vector& copies, + Shape batch_shape = {}, + Strides A_batch_stride = {}, + Strides B_batch_stride = {}, + Strides C_batch_stride = {}, + float alpha = 1.0f, + float beta = 0.0f); + +inline void steel_matmul( const Stream& s, metal::Device& d, const array& a, @@ -45,6 +119,26 @@ void steel_matmul( std::vector& copies, Shape batch_shape = {}, Strides A_batch_stride = {}, - Strides B_batch_stride = {}); + Strides B_batch_stride = {}) { + return steel_matmul_axpby( + /* const Stream& s = */ s, + /* metal::Device& d = */ d, + /* const array& a = */ a, + /* const array& b = */ b, + /* const array& c = */ b, + /* array& out = */ out, + /* int M = */ M, + /* int N = */ N, + /* int K = */ K, + /* int batch_size_out = */ batch_size_out, + /* int lda = */ lda, + /* int ldb = */ ldb, + /* bool transpose_a = */ transpose_a, + /* bool transpose_b = */ transpose_b, + /* std::vector& copies = */ copies, + /* Shape batch_shape = */ batch_shape, + /* Strides A_batch_stride = */ A_batch_stride, + /* Strides B_batch_stride = */ B_batch_stride); +} } // namespace mlx::core diff --git a/mlx/backend/metal/metal.cpp b/mlx/backend/metal/metal.cpp index a9a1bc4f6..888207322 100644 --- a/mlx/backend/metal/metal.cpp +++ b/mlx/backend/metal/metal.cpp @@ -1,11 +1,11 @@ // Copyright © 2023-2024 Apple Inc. #include +#include + #include "mlx/backend/metal/device.h" +#include "mlx/backend/metal/metal.h" #include "mlx/backend/metal/utils.h" -#include "mlx/primitives.h" -#include "mlx/scheduler.h" -#include "mlx/utils.h" namespace mlx::core::metal { @@ -13,85 +13,6 @@ bool is_available() { return true; } -inline void check_error(MTL::CommandBuffer* cbuf) { - if (cbuf->status() == MTL::CommandBufferStatusError) { - std::ostringstream msg; - msg << "[METAL] Command buffer execution failed: " - << cbuf->error()->localizedDescription()->utf8String(); - throw std::runtime_error(msg.str()); - } -} - -void eval(array& arr) { - auto pool = new_scoped_memory_pool(); - auto s = arr.primitive().stream(); - auto& d = metal::device(s.device); - auto command_buffer = d.get_command_buffer(s.index); - - auto outputs = arr.outputs(); - { - // If the array is a tracer hold a reference - // to its inputs so they don't get donated - std::vector inputs; - if (arr.is_tracer()) { - inputs = arr.inputs(); - } - - debug_set_primitive_buffer_label(command_buffer, arr.primitive()); - arr.primitive().eval_gpu(arr.inputs(), outputs); - } - std::unordered_set> buffers; - for (auto& in : arr.inputs()) { - buffers.insert(in.data_shared_ptr()); - } - for (auto& s : arr.siblings()) { - buffers.insert(s.data_shared_ptr()); - } - // Remove the output if it was donated to by an input - if (auto it = buffers.find(arr.data_shared_ptr()); it != buffers.end()) { - buffers.erase(it); - } - - if (d.command_buffer_needs_commit(s.index)) { - d.end_encoding(s.index); - scheduler::notify_new_task(s); - command_buffer->addCompletedHandler( - [s, buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) { - scheduler::notify_task_completion(s); - check_error(cbuf); - }); - d.commit_command_buffer(s.index); - d.get_command_buffer(s.index); - } else { - command_buffer->addCompletedHandler( - [s, buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) { - check_error(cbuf); - }); - } -} - -void finalize(Stream s) { - auto pool = new_scoped_memory_pool(); - auto& d = metal::device(s.device); - auto cb = d.get_command_buffer(s.index); - d.end_encoding(s.index); - cb->addCompletedHandler([s](MTL::CommandBuffer* cbuf) { check_error(cbuf); }); - d.commit_command_buffer(s.index); - d.get_command_buffer(s.index); -} - -void synchronize(Stream s) { - auto pool = new_scoped_memory_pool(); - auto& d = metal::device(s.device); - auto cb = d.get_command_buffer(s.index); - cb->retain(); - d.end_encoding(s.index); - d.commit_command_buffer(s.index); - cb->waitUntilCompleted(); - check_error(cb); - cb->release(); -} - void start_capture(std::string path, id object) { auto pool = new_scoped_memory_pool(); @@ -128,4 +49,36 @@ void stop_capture() { manager->stopCapture(); } +const std::unordered_map>& +device_info() { + auto init_device_info = []() + -> std::unordered_map> { + auto pool = new_scoped_memory_pool(); + auto raw_device = device(default_device()).mtl_device(); + auto name = std::string(raw_device->name()->utf8String()); + auto arch = std::string(raw_device->architecture()->name()->utf8String()); + + size_t memsize = 0; + size_t length = sizeof(memsize); + sysctlbyname("hw.memsize", &memsize, &length, NULL, 0); + + size_t rsrc_limit = 0; + sysctlbyname("iogpu.rsrc_limit", &rsrc_limit, &length, NULL, 0); + if (rsrc_limit == 0) { + rsrc_limit = 499000; + } + + return { + {"device_name", name}, + {"architecture", arch}, + {"max_buffer_length", raw_device->maxBufferLength()}, + {"max_recommended_working_set_size", + raw_device->recommendedMaxWorkingSetSize()}, + {"memory_size", memsize}, + {"resource_limit", rsrc_limit}}; + }; + static auto device_info_ = init_device_info(); + return device_info_; +} + } // namespace mlx::core::metal diff --git a/mlx/backend/metal/metal.h b/mlx/backend/metal/metal.h index d162007d1..af2995b63 100644 --- a/mlx/backend/metal/metal.h +++ b/mlx/backend/metal/metal.h @@ -2,11 +2,10 @@ #pragma once +#include #include #include -#include "mlx/array.h" - namespace mlx::core::metal { /* Check if the Metal backend is available. */ diff --git a/mlx/backend/metal/no_metal.cpp b/mlx/backend/metal/no_metal.cpp new file mode 100644 index 000000000..9785e07c2 --- /dev/null +++ b/mlx/backend/metal/no_metal.cpp @@ -0,0 +1,42 @@ +// Copyright © 2025 Apple Inc. + +#include + +#include "mlx/backend/metal/metal.h" +#include "mlx/fast.h" + +namespace mlx::core { + +namespace metal { + +bool is_available() { + return false; +} + +void start_capture(std::string) {} +void stop_capture() {} + +const std::unordered_map>& +device_info() { + throw std::runtime_error( + "[metal::device_info] Cannot get device info without metal backend"); +}; + +} // namespace metal + +namespace fast { + +MetalKernelFunction metal_kernel( + const std::string&, + const std::vector&, + const std::vector&, + const std::string&, + const std::string&, + bool ensure_row_contiguous, + bool atomic_outputs) { + throw std::runtime_error("[metal_kernel] No GPU back-end."); +} + +} // namespace fast + +} // namespace mlx::core diff --git a/mlx/backend/metal/nojit_kernels.cpp b/mlx/backend/metal/nojit_kernels.cpp index 8da147971..a689a793e 100644 --- a/mlx/backend/metal/nojit_kernels.cpp +++ b/mlx/backend/metal/nojit_kernels.cpp @@ -18,7 +18,7 @@ MTL::ComputePipelineState* get_unary_kernel( const std::string& kernel_name, Dtype, Dtype, - const std::string) { + const char*) { return d.get_kernel(kernel_name); } @@ -27,7 +27,7 @@ MTL::ComputePipelineState* get_binary_kernel( const std::string& kernel_name, Dtype, Dtype, - const std::string) { + const char*) { return d.get_kernel(kernel_name); } @@ -36,7 +36,7 @@ MTL::ComputePipelineState* get_binary_two_kernel( const std::string& kernel_name, Dtype, Dtype, - const std::string) { + const char*) { return d.get_kernel(kernel_name); } @@ -44,7 +44,7 @@ MTL::ComputePipelineState* get_ternary_kernel( metal::Device& d, const std::string& kernel_name, Dtype, - const std::string) { + const char*) { return d.get_kernel(kernel_name); } @@ -146,7 +146,7 @@ MTL::ComputePipelineState* get_steel_gemm_fused_kernel( int, int, int) { - return d.get_kernel(kernel_name, "mlx", hash_name, func_consts); + return d.get_kernel(kernel_name, hash_name, func_consts); } MTL::ComputePipelineState* get_steel_gemm_splitk_kernel( @@ -207,7 +207,23 @@ MTL::ComputePipelineState* get_steel_gemm_gather_kernel( int, int, bool) { - return d.get_kernel(kernel_name, "mlx", hash_name, func_consts); + return d.get_kernel(kernel_name, hash_name, func_consts); +} + +MTL::ComputePipelineState* get_steel_gemm_segmented_kernel( + metal::Device& d, + const std::string& kernel_name, + const std::string& hash_name, + const metal::MTLFCList& func_consts, + const array&, + bool, + bool, + int, + int, + int, + int, + int) { + return d.get_kernel(kernel_name, hash_name, func_consts); } MTL::ComputePipelineState* get_gemv_masked_kernel( @@ -244,13 +260,15 @@ MTL::ComputePipelineState* get_steel_conv_kernel( MTL::ComputePipelineState* get_steel_conv_general_kernel( metal::Device& d, const std::string& kernel_name, + const std::string& hash_name, + const metal::MTLFCList& func_consts, const array&, int, int, int, int, int) { - return d.get_kernel(kernel_name); + return d.get_kernel(kernel_name, hash_name, func_consts); } MTL::ComputePipelineState* get_fft_kernel( @@ -259,7 +277,7 @@ MTL::ComputePipelineState* get_fft_kernel( const std::string& hash_name, const metal::MTLFCList& func_consts, const std::string&) { - return d.get_kernel(kernel_name, "mlx", hash_name, func_consts); + return d.get_kernel(kernel_name, hash_name, func_consts); } MTL::ComputePipelineState* get_quantized_kernel( @@ -283,7 +301,7 @@ MTL::ComputePipelineState* get_gather_qmm_kernel( int, int, bool) { - return d.get_kernel(kernel_name, "mlx", hash_name, func_consts); + return d.get_kernel(kernel_name, hash_name, func_consts); } } // namespace mlx::core diff --git a/mlx/backend/metal/normalization.cpp b/mlx/backend/metal/normalization.cpp index c1d993d2a..8674eff72 100644 --- a/mlx/backend/metal/normalization.cpp +++ b/mlx/backend/metal/normalization.cpp @@ -1,7 +1,7 @@ // Copyright © 2024 Apple Inc. #include -#include "mlx/backend/metal/copy.h" +#include "mlx/backend/gpu/copy.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/kernels/defines.h" #include "mlx/backend/metal/reduce.h" @@ -10,6 +10,10 @@ namespace mlx::core::fast { +bool RMSNorm::use_fallback(Stream s) { + return s.device == Device::cpu; +} + void RMSNorm::eval_gpu( const std::vector& inputs, std::vector& outputs) { @@ -22,7 +26,7 @@ void RMSNorm::eval_gpu( bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1; if (no_copy && x.ndim() > 1) { auto s = x.strides()[x.ndim() - 2]; - no_copy &= (s == 0 || s == x.shape().back()); + no_copy &= (s == 0 || s == x.shape().back() || x.shape(-2) == 1); } if (no_copy) { if (x.is_donatable()) { @@ -168,7 +172,7 @@ void RMSNormVJP::eval_gpu( auto& compute_encoder = d.get_command_encoder(s.index); { - auto kernel = d.get_kernel(op_name, "mlx", hash_name, func_consts); + auto kernel = d.get_kernel(op_name, hash_name, func_consts); MTL::Size grid_dims, group_dims; if (axis_size <= looped_limit) { @@ -207,6 +211,10 @@ void RMSNormVJP::eval_gpu( } } +bool LayerNorm::use_fallback(Stream s) { + return s.device == Device::cpu; +} + void LayerNorm::eval_gpu( const std::vector& inputs, std::vector& outputs) { @@ -219,7 +227,7 @@ void LayerNorm::eval_gpu( bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1; if (no_copy && x.ndim() > 1) { auto s = x.strides()[x.ndim() - 2]; - no_copy &= (s == 0 || s == x.shape().back()); + no_copy &= (s == 0 || s == x.shape().back() || x.shape(-2) == 1); } if (no_copy) { if (x.is_donatable()) { @@ -247,12 +255,13 @@ void LayerNorm::eval_gpu( auto axis_size = static_cast(x.shape().back()); int n_rows = x.data_size() / axis_size; - const int simd_size = 32; - const int n_reads = RMS_N_READS; - const int looped_limit = RMS_LOOPED_LIMIT; + int simd_size = 32; + int n_reads = 8; + int looped_limit = 6656; std::string op_name = "layer_norm"; if (axis_size > looped_limit) { op_name += "_looped"; + n_reads = 4; } op_name += type_to_name(out); auto& compute_encoder = d.get_command_encoder(s.index); @@ -264,7 +273,13 @@ void LayerNorm::eval_gpu( size_t threadgroup_needed = (axis_size + n_reads - 1) / n_reads; size_t simds_needed = (threadgroup_needed + simd_size - 1) / simd_size; size_t threadgroup_size = simd_size * simds_needed; - assert(threadgroup_size <= kernel->maxTotalThreadsPerThreadgroup()); + if (threadgroup_size > kernel->maxTotalThreadsPerThreadgroup()) { + std::ostringstream msg; + msg << "[layer_norm] Threadgroup size " << threadgroup_size + << " is larger than the maximum allowed threadgroup size " + << kernel->maxTotalThreadsPerThreadgroup(); + throw std::runtime_error(msg.str()); + } size_t n_threads = n_rows * threadgroup_size; grid_dims = MTL::Size(n_threads, 1, 1); group_dims = MTL::Size(threadgroup_size, 1, 1); @@ -364,12 +379,13 @@ void LayerNormVJP::eval_gpu( g, gb, "sum", plan, {0}, compute_encoder, d, s); } - const int simd_size = 32; - const int n_reads = RMS_N_READS; - const int looped_limit = RMS_LOOPED_LIMIT; + int simd_size = 32; + int n_reads = 8; + int looped_limit = 8192; std::string op_name = "vjp_layer_norm"; if (axis_size > looped_limit) { op_name += "_looped"; + n_reads = 4; } op_name += type_to_name(gx); @@ -379,14 +395,20 @@ void LayerNormVJP::eval_gpu( }; { - auto kernel = d.get_kernel(op_name, "mlx", hash_name, func_consts); + auto kernel = d.get_kernel(op_name, hash_name, func_consts); MTL::Size grid_dims, group_dims; if (axis_size <= looped_limit) { size_t threadgroup_needed = (axis_size + n_reads - 1) / n_reads; size_t simds_needed = (threadgroup_needed + simd_size - 1) / simd_size; size_t threadgroup_size = simd_size * simds_needed; - assert(threadgroup_size <= kernel->maxTotalThreadsPerThreadgroup()); + if (threadgroup_size > kernel->maxTotalThreadsPerThreadgroup()) { + std::ostringstream msg; + msg << "[vjp_layer_norm] Threadgroup size " << threadgroup_size + << " is larger than the maximum allowed threadgroup size " + << kernel->maxTotalThreadsPerThreadgroup(); + throw std::runtime_error(msg.str()); + } size_t n_threads = n_rows * threadgroup_size; grid_dims = MTL::Size(n_threads, 1, 1); group_dims = MTL::Size(threadgroup_size, 1, 1); diff --git a/mlx/backend/metal/primitives.cpp b/mlx/backend/metal/primitives.cpp index 6946ffb9e..2ac543ad8 100644 --- a/mlx/backend/metal/primitives.cpp +++ b/mlx/backend/metal/primitives.cpp @@ -7,10 +7,10 @@ #include "mlx/backend/common/compiled.h" #include "mlx/backend/common/slicing.h" #include "mlx/backend/common/utils.h" -#include "mlx/backend/metal/copy.h" +#include "mlx/backend/gpu/copy.h" +#include "mlx/backend/gpu/slicing.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/kernels.h" -#include "mlx/backend/metal/slicing.h" #include "mlx/backend/metal/utils.h" #include "mlx/primitives.h" #include "mlx/scheduler.h" @@ -25,25 +25,6 @@ void arange_set_scalars(T start, T next, metal::CommandEncoder& enc) { enc.set_bytes(step, 1); } -void reshape(const array& in, array& out, Stream s) { - auto [copy_necessary, out_strides] = prepare_reshape(in, out); - if (copy_necessary) { - out.set_data(allocator::malloc(out.nbytes())); - copy_gpu_inplace( - in, - out, - in.shape(), - in.strides(), - make_contiguous_strides(in.shape()), - 0, - 0, - CopyType::General, - s); - } else { - shared_buffer_reshape(in, out_strides, out); - } -} - static array compute_dynamic_offset( const array& indices, const Strides& strides, @@ -201,8 +182,8 @@ void ArgReduce::eval_gpu(const std::vector& inputs, array& out) { (thread_group_size + simd_size - 1) / simd_size * simd_size; assert(thread_group_size <= kernel->maxTotalThreadsPerThreadgroup()); - size_t n_threads = out.size() * thread_group_size; - MTL::Size grid_dims = MTL::Size(n_threads, 1, 1); + auto gd = get_2d_grid_dims(out.shape(), out.strides()); + MTL::Size grid_dims = MTL::Size(thread_group_size, gd.width, gd.height); MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_input_array(in, 0); @@ -226,105 +207,10 @@ void ArgReduce::eval_gpu(const std::vector& inputs, array& out) { } } -void AsType::eval_gpu(const std::vector& inputs, array& out) { - CopyType ctype = - inputs[0].flags().contiguous ? CopyType::Vector : CopyType::General; - copy_gpu(inputs[0], out, ctype); -} - -void AsStrided::eval_gpu(const std::vector& inputs, array& out) { - eval(inputs, out); -} - -void Broadcast::eval_gpu(const std::vector& inputs, array& out) { - eval(inputs, out); -} - -void BroadcastAxes::eval_gpu(const std::vector& inputs, array& out) { - eval(inputs, out); -} - -void Concatenate::eval_gpu(const std::vector& inputs, array& out) { - concatenate_gpu(inputs, out, axis_, stream()); -} - -void Contiguous::eval_gpu(const std::vector& inputs, array& out) { - assert(inputs.size() == 1); - auto& in = inputs[0]; - constexpr size_t extra_bytes = 16384; - if (in.buffer_size() <= out.nbytes() + extra_bytes && - (in.flags().row_contiguous || - (allow_col_major_ && in.flags().col_contiguous))) { - out.copy_shared_buffer(in); - } else { - copy_gpu(in, out, CopyType::General); - } -} - -void Copy::eval_gpu(const std::vector& inputs, array& out) { - eval(inputs, out); -} - -void CustomTransforms::eval_gpu( - const std::vector& inputs, - std::vector& outputs) { - eval(inputs, outputs); -} - -void Depends::eval_gpu( - const std::vector& inputs, - std::vector& outputs) { - eval(inputs, outputs); -} - -void Full::eval_gpu(const std::vector& inputs, array& out) { - auto in = inputs[0]; - CopyType ctype; - if (in.data_size() == 1) { - ctype = CopyType::Scalar; - } else if (in.flags().contiguous) { - ctype = CopyType::Vector; - } else { - ctype = CopyType::General; - } - copy_gpu(in, out, ctype); -} - -void ExpandDims::eval_gpu(const std::vector& inputs, array& out) { - eval(inputs, out); -} - -void Flatten::eval_gpu(const std::vector& inputs, array& out) { - reshape(inputs[0], out, stream()); -} - -void Unflatten::eval_gpu(const std::vector& inputs, array& out) { - reshape(inputs[0], out, stream()); -} - void Load::eval_gpu(const std::vector& inputs, array& out) { throw std::runtime_error("[Load::eval_gpu] Not implemented."); } -void NumberOfElements::eval_gpu(const std::vector& inputs, array& out) { - eval(inputs, out); -} - -void Pad::eval_gpu(const std::vector& inputs, array& out) { - // Inputs must be base input array and scalar val array - assert(inputs.size() == 2); - auto& in = inputs[0]; - auto& val = inputs[1]; - - // Padding value must be a scalar - assert(val.size() == 1); - - // Padding value, input and output must be of the same type - assert(val.dtype() == in.dtype() && in.dtype() == out.dtype()); - - pad_gpu(in, val, out, axes_, low_pad_size_, stream()); -} - void RandomBits::eval_gpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); @@ -370,27 +256,6 @@ void RandomBits::eval_gpu(const std::vector& inputs, array& out) { compute_encoder.dispatch_threads(grid_dims, group_dims); } -void Reshape::eval_gpu(const std::vector& inputs, array& out) { - reshape(inputs[0], out, stream()); -} - -void Split::eval_gpu( - const std::vector& inputs, - std::vector& outputs) { - eval(inputs, outputs); -} - -void Slice::eval_gpu(const std::vector& inputs, array& out) { - assert(inputs.size() == 1); - if (out.size() == 0) { - out.set_data(nullptr); - return; - } - - auto& in = inputs[0]; - slice_gpu(in, out, start_indices_, strides_, stream()); -} - void DynamicSlice::eval_gpu(const std::vector& inputs, array& out) { if (out.size() == 0) { out.set_data(nullptr); @@ -457,53 +322,6 @@ void DynamicSliceUpdate::eval_gpu( /* const std::optional& dynamic_o_offset = */ out_offset); } -void SliceUpdate::eval_gpu(const std::vector& inputs, array& out) { - assert(inputs.size() == 2); - if (out.size() == 0) { - out.set_data(nullptr); - return; - } - - auto& in = inputs[0]; - auto& upd = inputs[1]; - - if (upd.size() == 0) { - out.copy_shared_buffer(in); - return; - } - - auto ctype = in.flags().contiguous && in.size() == in.data_size() - ? CopyType::Vector - : CopyType::General; - copy_gpu(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, stream()); - auto [data_offset, out_strides] = - prepare_slice(out, start_indices_, strides_); - - // Do copy - copy_gpu_inplace( - /* const array& src = */ upd, - /* array& dst = */ out, - /* const Shape& data_shape = */ upd.shape(), - /* const Strides& i_strides = */ upd.strides(), - /* const Strides& o_strides = */ out_strides, - /* int64_t i_offset = */ 0, - /* int64_t o_offset = */ data_offset, - /* CopyType ctype = */ CopyType::GeneralGeneral, - /* const Stream& s = */ stream()); -} - -void Squeeze::eval_gpu(const std::vector& inputs, array& out) { - eval(inputs, out); -} - -void StopGradient::eval_gpu(const std::vector& inputs, array& out) { - eval(inputs, out); -} - -void Transpose::eval_gpu(const std::vector& inputs, array& out) { - eval(inputs, out); -} - void QRF::eval_gpu( const std::vector& inputs, std::vector& outputs) { @@ -525,10 +343,16 @@ void Cholesky::eval_gpu(const std::vector& inputs, array& out) { "[Cholesky::eval_gpu] Metal Cholesky decomposition NYI."); } +void Eig::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + throw std::runtime_error("[Eig::eval_gpu] Metal Eig NYI."); +} + void Eigh::eval_gpu( const std::vector& inputs, std::vector& outputs) { - throw std::runtime_error("[Eigvalsh::eval_gpu] Metal Eigh NYI."); + throw std::runtime_error("[Eigh::eval_gpu] Metal Eigh NYI."); } void LUF::eval_gpu( @@ -537,35 +361,4 @@ void LUF::eval_gpu( throw std::runtime_error("[LUF::eval_gpu] Metal LU factorization NYI."); } -void View::eval_gpu(const std::vector& inputs, array& out) { - auto& in = inputs[0]; - auto ibytes = size_of(in.dtype()); - auto obytes = size_of(out.dtype()); - // Conditions for buffer copying (disjunction): - // - type size is the same - // - type size is smaller and the last axis is contiguous - // - the entire array is row contiguous - if (ibytes == obytes || (obytes < ibytes && in.strides().back() == 1) || - in.flags().row_contiguous) { - auto strides = in.strides(); - for (int i = 0; i < static_cast(strides.size()) - 1; ++i) { - strides[i] *= ibytes; - strides[i] /= obytes; - } - out.copy_shared_buffer( - in, strides, in.flags(), in.data_size() * ibytes / obytes); - } else { - auto tmp = array(in.shape(), in.dtype(), nullptr, {}); - tmp.set_data(allocator::malloc(tmp.nbytes())); - copy_gpu_inplace(in, tmp, CopyType::General, stream()); - - auto flags = out.flags(); - flags.contiguous = true; - flags.row_contiguous = true; - auto max_dim = std::max_element(out.shape().begin(), out.shape().end()); - flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim; - out.copy_shared_buffer(tmp, out.strides(), flags, out.size()); - } -} - } // namespace mlx::core diff --git a/mlx/backend/metal/quantized.cpp b/mlx/backend/metal/quantized.cpp index 6f5807543..b6dc8db30 100644 --- a/mlx/backend/metal/quantized.cpp +++ b/mlx/backend/metal/quantized.cpp @@ -4,7 +4,7 @@ #include "mlx/backend/common/broadcasting.h" #include "mlx/backend/common/compiled.h" -#include "mlx/backend/metal/copy.h" +#include "mlx/backend/gpu/copy.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/reduce.h" @@ -976,7 +976,9 @@ void fast::AffineQuantize::eval_gpu( // Treat uint32 as uint8 in kernel constexpr int uint8_per_uint32 = 4; constexpr int simd_size = 32; - int packs_per_int = bits_ == 3 ? 8 : bits_ == 6 ? 4 : 8 / bits_; + int packs_per_int = (bits_ == 3 || bits_ == 5) ? 8 + : bits_ == 6 ? 4 + : 8 / bits_; int per_thread = dequantize_ ? packs_per_int : group_size_ / simd_size; size_t nthreads = dequantize_ ? out.size() / packs_per_int : w.size() / per_thread; diff --git a/mlx/backend/metal/reduce.cpp b/mlx/backend/metal/reduce.cpp index c5650bdd7..8cb55ba58 100644 --- a/mlx/backend/metal/reduce.cpp +++ b/mlx/backend/metal/reduce.cpp @@ -3,7 +3,7 @@ #include #include -#include "mlx/backend/metal/copy.h" +#include "mlx/backend/gpu/copy.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/kernels/defines.h" diff --git a/mlx/backend/metal/resident.cpp b/mlx/backend/metal/resident.cpp index 0a9e1b861..798824c2f 100644 --- a/mlx/backend/metal/resident.cpp +++ b/mlx/backend/metal/resident.cpp @@ -1,7 +1,6 @@ // Copyright © 2024 Apple Inc. #include "mlx/backend/metal/resident.h" -#include "mlx/backend/metal/metal_impl.h" namespace mlx::core::metal { diff --git a/mlx/backend/metal/rope.cpp b/mlx/backend/metal/rope.cpp index 060758333..e141df630 100644 --- a/mlx/backend/metal/rope.cpp +++ b/mlx/backend/metal/rope.cpp @@ -1,5 +1,5 @@ // Copyright © 2023-2024 Apple Inc. -#include "mlx/backend/metal/copy.h" +#include "mlx/backend/gpu/copy.h" #include "mlx/backend/metal/utils.h" #include "mlx/fast_primitives.h" @@ -7,6 +7,10 @@ namespace mlx::core::fast { constexpr int n_per_thread = 4; +bool RoPE::use_fallback(Stream s) { + return s.device == Device::cpu; +} + void RoPE::eval_gpu( const std::vector& inputs, std::vector& outputs) { diff --git a/mlx/backend/metal/scaled_dot_product_attention.cpp b/mlx/backend/metal/scaled_dot_product_attention.cpp index 845962d01..eef279d1d 100644 --- a/mlx/backend/metal/scaled_dot_product_attention.cpp +++ b/mlx/backend/metal/scaled_dot_product_attention.cpp @@ -2,12 +2,12 @@ #include #include "mlx/backend/common/compiled.h" -#include "mlx/backend/metal/copy.h" +#include "mlx/backend/gpu/copy.h" #include "mlx/backend/metal/device.h" - #include "mlx/backend/metal/kernels/steel/attn/params.h" #include "mlx/backend/metal/utils.h" #include "mlx/fast_primitives.h" +#include "mlx/transforms_impl.h" #include "mlx/utils.h" namespace mlx::core::fast { @@ -73,7 +73,7 @@ void sdpa_full_self_attention_metal( std::string hash_name = kname.str(); auto& compute_encoder = d.get_command_encoder(s.index); - auto kernel = d.get_kernel(base_name, "mlx", hash_name, func_consts); + auto kernel = d.get_kernel(base_name, hash_name, func_consts); compute_encoder.set_compute_pipeline_state(kernel); const int NQ = (qL + bq - 1) / bq; @@ -154,9 +154,9 @@ void sdpa_vector( int gqa_factor = q.shape(1) / k.shape(1); int N = k.shape(2); int B = q.shape(0) * q.shape(1); - size_t k_head_stride = k.strides()[1]; + size_t k_head_stride = k.shape(1) == 1 ? k.strides(0) : k.strides(1); size_t k_seq_stride = k.strides()[2]; - size_t v_head_stride = v.strides()[1]; + size_t v_head_stride = v.shape(1) == 1 ? v.strides(0) : v.strides(1); size_t v_seq_stride = v.strides()[2]; MTL::Size group_dims(1024, 1, 1); @@ -180,7 +180,7 @@ void sdpa_vector( // Get the kernel auto& compute_encoder = d.get_command_encoder(s.index); - auto kernel = d.get_kernel(kname, "mlx", hash_name, func_consts); + auto kernel = d.get_kernel(kname, hash_name, func_consts); compute_encoder.set_compute_pipeline_state(kernel); // Set its arguments @@ -199,11 +199,10 @@ void sdpa_vector( if (has_mask) { auto& m = *mask; compute_encoder.set_input_array(m, 11 + float_mask); - auto nd = m.ndim(); - int32_t kv_seq_stride = - nd >= 1 && m.shape(-1) > 1 ? m.strides()[nd - 1] : 0; - int32_t q_seq_stride = nd >= 2 && m.shape(-2) > 1 ? m.strides()[nd - 2] : 0; - int32_t head_stride = nd >= 3 && m.shape(-3) > 1 ? m.strides()[nd - 3] : 0; + int32_t kv_seq_stride = m.shape(3) > 1 ? m.strides(3) : 0; + int32_t q_seq_stride = m.shape(2) > 1 ? m.strides(2) : 0; + int32_t head_stride = + m.shape(1) > 1 ? m.strides(1) : (m.shape(0) > 1 ? m.strides(0) : 0); compute_encoder.set_bytes(kv_seq_stride, 13); compute_encoder.set_bytes(q_seq_stride, 14); compute_encoder.set_bytes(head_stride, 15); @@ -238,9 +237,10 @@ void sdpa_vector_2pass( int N = k.shape(2); int blocks = 32; int B = q.shape(0) * q.shape(1); - size_t k_head_stride = k.strides()[1]; + + size_t k_head_stride = k.shape(1) == 1 ? k.strides(0) : k.strides(1); size_t k_seq_stride = k.strides()[2]; - size_t v_head_stride = v.strides()[1]; + size_t v_head_stride = v.shape(1) == 1 ? v.strides(0) : v.strides(1); size_t v_seq_stride = v.strides()[2]; MTL::Size group_dims(8 * 32, 1, 1); MTL::Size grid_dims(B, q.shape(2), blocks); @@ -281,7 +281,7 @@ void sdpa_vector_2pass( // Get the kernel auto& compute_encoder = d.get_command_encoder(s.index); - auto kernel = d.get_kernel(kname, "mlx", hash_name, func_consts); + auto kernel = d.get_kernel(kname, hash_name, func_consts); compute_encoder.set_compute_pipeline_state(kernel); @@ -302,11 +302,10 @@ void sdpa_vector_2pass( if (has_mask) { auto& m = *mask; compute_encoder.set_input_array(m, 13 + float_mask); - auto nd = m.ndim(); - int32_t kv_seq_stride = - nd >= 1 && m.shape(-1) > 1 ? m.strides()[nd - 1] : 0; - int32_t q_seq_stride = nd >= 2 && m.shape(-2) > 1 ? m.strides()[nd - 2] : 0; - int32_t head_stride = nd >= 3 && m.shape(-3) > 1 ? m.strides()[nd - 3] : 0; + int32_t kv_seq_stride = m.shape(3) > 1 ? m.strides(3) : 0; + int32_t q_seq_stride = m.shape(2) > 1 ? m.strides(2) : 0; + int32_t head_stride = + m.shape(1) > 1 ? m.strides(1) : (m.shape(0) > 1 ? m.strides(0) : 0); compute_encoder.set_bytes(kv_seq_stride, 15); compute_encoder.set_bytes(q_seq_stride, 16); compute_encoder.set_bytes(head_stride, 17); @@ -340,6 +339,46 @@ void sdpa_vector_2pass( } // namespace +bool ScaledDotProductAttention::use_fallback( + const array& q, + const array& k, + const array& v, + bool has_mask, + bool has_arr_mask, + bool do_causal, + Stream s) { + if (detail::in_grad_tracing()) { + return true; + } + if (s.device == Device::cpu) { + return true; + } + + const int value_head_dim = v.shape(-1); + const int query_head_dim = q.shape(-1); + const int query_sequence_length = q.shape(2); + const int key_sequence_length = k.shape(2); + + const bool sdpa_vector_supported_head_dim = + query_head_dim == value_head_dim && + (query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128 || + query_head_dim == 256); + const bool sdpa_full_supported_head_dim = query_head_dim == value_head_dim && + (query_head_dim == 64 || query_head_dim == 80 || query_head_dim == 128); + + const bool sdpa_full_supported_mask = !has_mask || has_arr_mask || + (query_sequence_length <= key_sequence_length && do_causal); + + const bool supports_sdpa_full = query_sequence_length > 8 && + sdpa_full_supported_mask && sdpa_full_supported_head_dim; + + const bool supports_sdpa_vector = (query_sequence_length <= 8) && + (query_sequence_length <= key_sequence_length) && + sdpa_vector_supported_head_dim; + + return !(supports_sdpa_full || supports_sdpa_vector); +} + void ScaledDotProductAttention::eval_gpu( const std::vector& inputs, array& out) { @@ -368,18 +407,6 @@ void ScaledDotProductAttention::eval_gpu( } }; - // Checks if arr is row contiguous or the sequence and head dimension are - // transposed - auto is_contiguous_or_head_seq_transposed = [](const array& arr) { - if (arr.flags().row_contiguous) { - return true; - } - auto& strides = arr.strides(); - auto& shape = arr.shape(); - return (strides[3] == 1) && (strides[2] == shape[3] * shape[1]) && - (strides[1] == shape[3]) && (strides[0] == strides[2] * shape[2]); - }; - // Checks that the headdim dimension has stride 1. auto is_matrix_contiguous = [](const array& arr) { return arr.strides(-1) == 1; @@ -387,30 +414,58 @@ void ScaledDotProductAttention::eval_gpu( // We are in vector mode ie single query if (q_pre.shape(2) <= 8) { - const auto& q = copy_unless(is_contiguous_or_head_seq_transposed, q_pre); - const auto& k = copy_unless(is_matrix_contiguous, k_pre); - const auto& v = copy_unless(is_matrix_contiguous, v_pre); + auto q_copy_unless = [](const array& arr) { + if (arr.flags().row_contiguous) { + return true; + } + auto& strides = arr.strides(); + auto& shape = arr.shape(); + if (shape[0] == 1 || shape[1] == 1) { + // If either the batch or head dimension is a singleton, the other can + // be transposed with the sequence dimension + auto bidx = shape[0] == 1 ? 1 : 0; + return (strides[3] == 1) && (strides[2] == shape[3] * shape[bidx]) && + (strides[bidx] == shape[3]); + } + return false; + }; + + auto kv_copy_unless = [](const array& arr) { + // keys and values should be copied if: + // - the last dimension is not contiguous + // - the batch and head dim are not contiguous + auto& strides = arr.strides(); + auto& shape = arr.shape(); + if (strides.back() != 1) { + return false; + } + if (shape[0] == 1 || shape[1] == 1) { + return true; + } + return (strides[0] == strides[1] * shape[1]); + }; + + const auto& q = copy_unless(q_copy_unless, q_pre); + const auto& k = copy_unless(kv_copy_unless, k_pre); + const auto& v = copy_unless(kv_copy_unless, v_pre); // Donate the query if possible - if (q.is_donatable() && (q.shape(2) == 1 || !q.flags().row_contiguous) && - q.size() == o.size()) { + if (q.is_donatable() && q.flags().row_contiguous && q.size() == o.size()) { o.copy_shared_buffer(q); } else { - if (o.shape(2) == 1) { - o.set_data(allocator::malloc(o.nbytes())); - } else { - auto strides = o.strides(); - strides[2] = o.shape(1) * o.shape(3); - strides[1] = o.shape(3); - auto flags = q.flags(); - flags.row_contiguous = q.shape(1) == 1; - o.set_data( - allocator::malloc(o.nbytes()), o.size(), std::move(strides), flags); - } + o.set_data(allocator::malloc(o.nbytes())); } - auto mask = - inputs.size() > 3 ? std::optional{inputs[3]} : std::nullopt; + auto mask_copy_unless = [&q](const array& arr) { + auto& strides = arr.strides(); + auto& shape = arr.shape(); + return arr.flags().row_contiguous || q.shape(0) == 1 || q.shape(1) == 1 || + (strides[0] == strides[1] * shape[1]); + }; + + auto mask = inputs.size() > 3 + ? std::optional{copy_unless(mask_copy_unless, inputs[3])} + : std::nullopt; // We route to the 2 pass fused attention if // - The device is large and the sequence length long diff --git a/mlx/backend/metal/scan.cpp b/mlx/backend/metal/scan.cpp index b1800fea9..3c4051105 100644 --- a/mlx/backend/metal/scan.cpp +++ b/mlx/backend/metal/scan.cpp @@ -3,7 +3,7 @@ #include #include -#include "mlx/backend/metal/copy.h" +#include "mlx/backend/gpu/copy.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/utils.h" diff --git a/mlx/backend/metal/slicing.cpp b/mlx/backend/metal/slicing.cpp index 6ab08a108..3e1a8b541 100644 --- a/mlx/backend/metal/slicing.cpp +++ b/mlx/backend/metal/slicing.cpp @@ -2,21 +2,12 @@ #include -#include "mlx/backend/common/slicing.h" -#include "mlx/backend/metal/copy.h" +#include "mlx/backend/gpu/copy.h" +#include "mlx/backend/gpu/slicing.h" #include "mlx/backend/metal/device.h" namespace mlx::core { -void slice_gpu( - const array& in, - array& out, - const Shape& start_indices, - const Shape& strides, - const Stream& s) { - slice(in, out, start_indices, strides); -} - void concatenate_gpu( const std::vector& inputs, array& out, @@ -48,30 +39,4 @@ void concatenate_gpu( } } -void pad_gpu( - const array& in, - const array& val, - array& out, - const std::vector& axes, - const Shape& low_pad_size, - const Stream& s) { - // Fill output with val - fill_gpu(val, out, s); - - // Find offset for start of input values - size_t data_offset = 0; - for (int i = 0; i < axes.size(); i++) { - auto ax = axes[i] < 0 ? out.ndim() + axes[i] : axes[i]; - data_offset += out.strides()[ax] * low_pad_size[i]; - } - - // Extract slice from output where input will be pasted - array out_slice(in.shape(), out.dtype(), nullptr, {}); - out_slice.copy_shared_buffer( - out, out.strides(), out.flags(), out_slice.size(), data_offset); - - // Copy input values into the slice - copy_gpu_inplace(in, out_slice, CopyType::GeneralGeneral, s); -} - } // namespace mlx::core diff --git a/mlx/backend/metal/softmax.cpp b/mlx/backend/metal/softmax.cpp index 224721a50..59662b05d 100644 --- a/mlx/backend/metal/softmax.cpp +++ b/mlx/backend/metal/softmax.cpp @@ -1,7 +1,7 @@ // Copyright © 2023-2024 Apple Inc. #include -#include "mlx/backend/metal/copy.h" +#include "mlx/backend/gpu/copy.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/kernels/defines.h" diff --git a/mlx/backend/metal/sort.cpp b/mlx/backend/metal/sort.cpp index 543dfd180..3c84022f2 100644 --- a/mlx/backend/metal/sort.cpp +++ b/mlx/backend/metal/sort.cpp @@ -2,7 +2,7 @@ #include -#include "mlx/backend/metal/copy.h" +#include "mlx/backend/gpu/copy.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/utils.h" diff --git a/mlx/backend/metal/ternary.cpp b/mlx/backend/metal/ternary.cpp index 36bfd3e2b..b2b9e3337 100644 --- a/mlx/backend/metal/ternary.cpp +++ b/mlx/backend/metal/ternary.cpp @@ -11,7 +11,7 @@ namespace mlx::core { void ternary_op_gpu_inplace( const std::vector& inputs, array& out, - const std::string op, + const char* op, const Stream& s) { assert(inputs.size() == 3); auto& a = inputs[0]; @@ -45,7 +45,7 @@ void ternary_op_gpu_inplace( work_per_thread = large ? 4 : 2; } else { large = out.data_size() > INT32_MAX; - work_per_thread = 1; + work_per_thread = get_work_per_thread(b.dtype(), out.data_size()); } std::string kernel_name; if (topt == TernaryOpType::General) { @@ -60,6 +60,8 @@ void ternary_op_gpu_inplace( } } else if (large) { kernel_name = "v2"; + } else if (work_per_thread > 1) { + kernel_name = "vn"; } else { kernel_name = "v"; } @@ -106,13 +108,19 @@ void ternary_op_gpu_inplace( compute_encoder.dispatch_threads(grid_dims, group_dims); } else { // Launch a 1D or 2D grid of threads - size_t nthreads = out.data_size(); + size_t nthreads = ceildiv(out.data_size(), work_per_thread); if (thread_group_size > nthreads) { thread_group_size = nthreads; } MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); - MTL::Size grid_dims = large ? get_2d_grid_dims(out.shape(), out.strides()) - : MTL::Size(nthreads, 1, 1); + MTL::Size grid_dims; + if (large) { + compute_encoder.set_bytes(out.data_size(), 4); + grid_dims = get_2d_grid_dims(out.shape(), out.strides(), work_per_thread); + } else { + compute_encoder.set_bytes(out.data_size(), 4); + grid_dims = MTL::Size(nthreads, 1, 1); + } compute_encoder.dispatch_threads(grid_dims, group_dims); } } @@ -120,7 +128,7 @@ void ternary_op_gpu_inplace( void ternary_op_gpu( const std::vector& inputs, array& out, - const std::string op, + const char* op, const Stream& s) { auto& a = inputs[0]; auto& b = inputs[1]; @@ -133,13 +141,13 @@ void ternary_op_gpu( void ternary_op_gpu( const std::vector& inputs, array& out, - const std::string op) { + const char* op) { auto& s = out.primitive().stream(); ternary_op_gpu(inputs, out, op, s); } void Select::eval_gpu(const std::vector& inputs, array& out) { - ternary_op_gpu(inputs, out, get_primitive_string(this)); + ternary_op_gpu(inputs, out, name()); } } // namespace mlx::core diff --git a/mlx/backend/metal/ternary.h b/mlx/backend/metal/ternary.h index 0834140b8..91c6fbbeb 100644 --- a/mlx/backend/metal/ternary.h +++ b/mlx/backend/metal/ternary.h @@ -9,13 +9,13 @@ namespace mlx::core { void ternary_op_gpu( const std::vector& inputs, array& out, - const std::string op, + const char* op, const Stream& s); void ternary_op_gpu_inplace( const std::vector& inputs, array& out, - const std::string op, + const char* op, const Stream& s); } // namespace mlx::core diff --git a/mlx/backend/metal/unary.cpp b/mlx/backend/metal/unary.cpp index be43c41c2..48f85635b 100644 --- a/mlx/backend/metal/unary.cpp +++ b/mlx/backend/metal/unary.cpp @@ -1,5 +1,6 @@ // Copyright © 2024 Apple Inc. -#include "mlx/backend/common/utils.h" + +#include "mlx/backend/common/unary.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/utils.h" @@ -7,7 +8,7 @@ #define UNARY_GPU(func) \ void func::eval_gpu(const std::vector& inputs, array& out) { \ - unary_op_gpu(inputs, out, get_primitive_string(this)); \ + unary_op_gpu(inputs, out, name()); \ } namespace mlx::core { @@ -15,7 +16,7 @@ namespace mlx::core { void unary_op_gpu_inplace( const std::vector& inputs, array& out, - const std::string op, + const char* op, const Stream& s) { auto& in = inputs[0]; bool contig = in.flags().contiguous; @@ -34,18 +35,19 @@ void unary_op_gpu_inplace( }; auto [shape, strides] = maybe_collapse(); int ndim = shape.size(); - size_t nthreads = contig ? in.data_size() : in.size(); bool large; if (!contig) { large = in.data_size() > INT32_MAX || out.size() > INT32_MAX; } else { large = in.data_size() > UINT32_MAX; } - int work_per_thread = !contig && large ? 4 : 1; + int work_per_thread; std::string kernel_name; if (contig) { - kernel_name = (large ? "v2" : "v"); + work_per_thread = get_work_per_thread(in.dtype(), in.data_size()); + kernel_name = (large ? "v2" : (work_per_thread > 1 ? "vn" : "v")); } else { + work_per_thread = large ? 4 : 1; kernel_name = "gn" + std::to_string(work_per_thread); if (large) { kernel_name += "large"; @@ -75,12 +77,20 @@ void unary_op_gpu_inplace( MTL::Size grid_dims = MTL::Size(dim0, dim1, rest); compute_encoder.dispatch_threads(grid_dims, group_dims); } else { + size_t nthreads = ceildiv(in.data_size(), work_per_thread); if (thread_group_size > nthreads) { thread_group_size = nthreads; } + MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); - MTL::Size grid_dims = large ? get_2d_grid_dims(out.shape(), out.strides()) - : MTL::Size(nthreads, 1, 1); + MTL::Size grid_dims; + if (large) { + compute_encoder.set_bytes(in.data_size(), 2); + grid_dims = get_2d_grid_dims(out.shape(), out.strides(), work_per_thread); + } else { + compute_encoder.set_bytes(in.data_size(), 2); + grid_dims = MTL::Size(nthreads, 1, 1); + } compute_encoder.dispatch_threads(grid_dims, group_dims); } } @@ -88,30 +98,16 @@ void unary_op_gpu_inplace( void unary_op_gpu( const std::vector& inputs, array& out, - const std::string op, + const char* op, const Stream& s) { - auto& in = inputs[0]; - bool contig = in.flags().contiguous; - if (contig) { - if (in.is_donatable() && in.itemsize() == out.itemsize()) { - out.copy_shared_buffer(in); - } else { - out.set_data( - allocator::malloc(in.data_size() * out.itemsize()), - in.data_size(), - in.strides(), - in.flags()); - } - } else { - out.set_data(allocator::malloc(out.nbytes())); - } + set_unary_output_data(inputs[0], out); unary_op_gpu_inplace(inputs, out, op, s); } void unary_op_gpu( const std::vector& inputs, array& out, - const std::string op) { + const char* op) { auto& s = out.primitive().stream(); unary_op_gpu(inputs, out, op, s); } @@ -150,13 +146,13 @@ UNARY_GPU(Tanh) void Log::eval_gpu(const std::vector& inputs, array& out) { switch (base_) { case Base::e: - unary_op_gpu(inputs, out, get_primitive_string(this)); + unary_op_gpu(inputs, out, name()); break; case Base::two: - unary_op_gpu(inputs, out, get_primitive_string(this)); + unary_op_gpu(inputs, out, name()); break; case Base::ten: - unary_op_gpu(inputs, out, get_primitive_string(this)); + unary_op_gpu(inputs, out, name()); break; } } @@ -165,7 +161,7 @@ void Round::eval_gpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; if (issubdtype(in.dtype(), inexact)) { - unary_op_gpu(inputs, out, get_primitive_string(this)); + unary_op_gpu(inputs, out, name()); } else { // No-op integer types out.copy_shared_buffer(in); diff --git a/mlx/backend/metal/unary.h b/mlx/backend/metal/unary.h index 19057076b..1d6ecf027 100644 --- a/mlx/backend/metal/unary.h +++ b/mlx/backend/metal/unary.h @@ -9,13 +9,13 @@ namespace mlx::core { void unary_op_gpu( const std::vector& inputs, array& out, - const std::string op, + const char* op, const Stream& s); void unary_op_gpu_inplace( const std::vector& inputs, array& out, - const std::string op, + const char* op, const Stream& s); } // namespace mlx::core diff --git a/mlx/backend/metal/utils.cpp b/mlx/backend/metal/utils.cpp index 329d250df..978501835 100644 --- a/mlx/backend/metal/utils.cpp +++ b/mlx/backend/metal/utils.cpp @@ -1,8 +1,7 @@ // Copyright © 2023-2024 Apple Inc. #include "mlx/backend/metal/utils.h" - -using namespace mlx; +#include "mlx/backend/common/utils.h" namespace mlx::core { @@ -59,109 +58,20 @@ std::string type_to_name(const array& a) { return type_to_name(a.dtype()); } -MTL::Size get_block_dims(int dim0, int dim1, int dim2, int pow2 /* = 10 */) { - int pows[3] = {0, 0, 0}; - int sum = 0; - while (true) { - int presum = sum; - // Check all the pows - if (dim0 >= (1 << (pows[0] + 1))) { - pows[0]++; - sum++; - } - if (sum == 10) { - break; - } - if (dim1 >= (1 << (pows[1] + 1))) { - pows[1]++; - sum++; - } - if (sum == 10) { - break; - } - if (dim2 >= (1 << (pows[2] + 1))) { - pows[2]++; - sum++; - } - if (sum == presum || sum == pow2) { - break; - } - } - return MTL::Size{1ul << pows[0], 1ul << pows[1], 1ul << pows[2]}; +MTL::Size get_block_dims(int dim0, int dim1, int dim2, int pow2) { + Dims dims = get_block_dims_common(dim0, dim1, dim2, pow2); + return MTL::Size(std::get<0>(dims), std::get<1>(dims), std::get<2>(dims)); } MTL::Size get_2d_grid_dims(const Shape& shape, const Strides& strides) { - // Dims with strides of 0 are ignored as they - // correspond to broadcasted dimensions - size_t grid_x = 1; - size_t grid_y = 1; - for (int i = 0; i < shape.size(); ++i) { - if (strides[i] == 0) { - continue; - } - if (grid_x * shape[i] < UINT32_MAX) { - grid_x *= shape[i]; - } else { - grid_y *= shape[i]; - } - } - if (grid_y > UINT32_MAX || grid_x > UINT32_MAX) { - throw std::runtime_error("Unable to safely factor shape."); - } - if (grid_y > grid_x) { - std::swap(grid_x, grid_y); - } - return MTL::Size( - static_cast(grid_x), static_cast(grid_y), 1); + Dims dims = get_2d_grid_dims_common(shape, strides); + return MTL::Size(std::get<0>(dims), std::get<1>(dims), std::get<2>(dims)); } MTL::Size get_2d_grid_dims(const Shape& shape, const Strides& strides, size_t divisor) { - // Compute the 2d grid dimensions such that the total size of the grid is - // divided by divisor. - size_t grid_x = 1; - size_t grid_y = 1; - for (int i = 0; i < shape.size(); ++i) { - if (strides[i] == 0) { - continue; - } - - // No need to add this shape we can just remove it from the divisor. - if (divisor % shape[i] == 0) { - divisor /= shape[i]; - continue; - } - - if (grid_x * shape[i] < UINT32_MAX) { - grid_x *= shape[i]; - } else { - grid_y *= shape[i]; - } - - if (divisor > 1) { - if (grid_x % divisor == 0) { - grid_x /= divisor; - divisor = 1; - } else if (grid_y % divisor == 0) { - grid_y /= divisor; - divisor = 1; - } - } - } - if (grid_y > UINT32_MAX || grid_x > UINT32_MAX || divisor > 1) { - throw std::runtime_error("Unable to safely factor shape."); - } - if (grid_y > grid_x) { - std::swap(grid_x, grid_y); - } - return MTL::Size( - static_cast(grid_x), static_cast(grid_y), 1); -} - -std::string get_primitive_string(Primitive* primitive) { - std::ostringstream op_t; - primitive->print(op_t); - return op_t.str(); + Dims dims = get_2d_grid_dims_common(shape, strides, divisor); + return MTL::Size(std::get<0>(dims), std::get<1>(dims), std::get<2>(dims)); } } // namespace mlx::core diff --git a/mlx/backend/metal/utils.h b/mlx/backend/metal/utils.h index 079d15f17..e7784e599 100644 --- a/mlx/backend/metal/utils.h +++ b/mlx/backend/metal/utils.h @@ -13,22 +13,9 @@ namespace mlx::core { std::string type_to_name(const Dtype& t); std::string type_to_name(const array& a); -// Compute the thread block dimensions which fit the given -// input dimensions. -// - The thread block dimensions will be powers of two -// - The thread block size will be less than 2^pow2 +// Compute the grid and block dimensions, check backend/common/utils.h for docs. MTL::Size get_block_dims(int dim0, int dim1, int dim2, int pow2 = 10); - -// Computes a 2D grid where each element is < UINT_MAX -// Assumes: -// - overall size (product of non-broadcasted dimensions) is < UINT_MAX^2 -// - shape and strides correspond to a contiguous (no holes) but -// possibly broadcasted array MTL::Size get_2d_grid_dims(const Shape& shape, const Strides& strides); - -// Same as above but we do an implicit division with divisor. -// Basically, equivalent to factorizing -// Prod(s \forall s in shape if strides[s] > 0) / divisor. MTL::Size get_2d_grid_dims(const Shape& shape, const Strides& strides, size_t divisor); @@ -53,13 +40,11 @@ inline void debug_set_primitive_buffer_label( if (auto cbuf_label = command_buffer->label(); cbuf_label) { label << cbuf_label->utf8String(); } - primitive.print(label); + label << primitive.name(); command_buffer->setLabel(make_string(label)); #endif } -std::string get_primitive_string(Primitive* primitive); - template constexpr bool is_numeric_except_char = std::is_arithmetic_v && !std::is_same_v && !std::is_same_v && @@ -84,4 +69,16 @@ void concatenate(std::string& acc, T first, Args... args) { concatenate(acc, args...); } +inline int get_work_per_thread(Dtype dtype) { + return std::max(1, 8 / dtype.size()); +} +inline int get_work_per_thread(Dtype dtype, size_t size) { + constexpr size_t wpt_threshold = 1 << 16; + return size < wpt_threshold ? 1 : std::max(1, 8 / dtype.size()); +} + +inline size_t ceildiv(size_t n, size_t m) { + return (n + m - 1) / m; +} + } // namespace mlx::core diff --git a/mlx/backend/no_cpu/CMakeLists.txt b/mlx/backend/no_cpu/CMakeLists.txt index e1524ec63..2e6960829 100644 --- a/mlx/backend/no_cpu/CMakeLists.txt +++ b/mlx/backend/no_cpu/CMakeLists.txt @@ -1,6 +1,7 @@ target_sources( mlx - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/available.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../cpu/eval.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../cpu/encoder.cpp ${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp) diff --git a/mlx/backend/no_cpu/available.cpp b/mlx/backend/no_cpu/available.cpp new file mode 100644 index 000000000..04c1bac8e --- /dev/null +++ b/mlx/backend/no_cpu/available.cpp @@ -0,0 +1,11 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cpu/available.h" + +namespace mlx::core::cpu { + +bool is_available() { + return false; +} + +} // namespace mlx::core::cpu diff --git a/mlx/backend/no_cpu/compiled.cpp b/mlx/backend/no_cpu/compiled.cpp index c1c42c735..2eeddab47 100644 --- a/mlx/backend/no_cpu/compiled.cpp +++ b/mlx/backend/no_cpu/compiled.cpp @@ -18,7 +18,7 @@ void Compiled::eval_cpu( const std::vector& inputs, std::vector& outputs) { throw std::runtime_error( - "[Compiled::eval_cpu] CPU compialtion not supported on the platform."); + "[Compiled::eval_cpu] CPU compilation not supported on the platform."); } } // namespace mlx::core diff --git a/mlx/backend/no_cpu/primitives.cpp b/mlx/backend/no_cpu/primitives.cpp index 84372b096..09e6c4ef3 100644 --- a/mlx/backend/no_cpu/primitives.cpp +++ b/mlx/backend/no_cpu/primitives.cpp @@ -55,6 +55,7 @@ NO_CPU(DynamicSlice) NO_CPU(DynamicSliceUpdate) NO_CPU(NumberOfElements) NO_CPU(Remainder) +NO_CPU_MULTI(Eig) NO_CPU_MULTI(Eigh) NO_CPU(Equal) NO_CPU(Erf) @@ -104,6 +105,7 @@ NO_CPU(Scan) NO_CPU(Scatter) NO_CPU(ScatterAxis) NO_CPU(Select) +NO_CPU(SegmentedMM) NO_CPU(Sigmoid) NO_CPU(Sign) NO_CPU(Sin) diff --git a/mlx/backend/no_metal/CMakeLists.txt b/mlx/backend/no_gpu/CMakeLists.txt similarity index 82% rename from mlx/backend/no_metal/CMakeLists.txt rename to mlx/backend/no_gpu/CMakeLists.txt index 962ceecb7..78e15ac69 100644 --- a/mlx/backend/no_metal/CMakeLists.txt +++ b/mlx/backend/no_gpu/CMakeLists.txt @@ -3,5 +3,5 @@ target_sources( PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp ${CMAKE_CURRENT_SOURCE_DIR}/event.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp) diff --git a/mlx/backend/no_metal/allocator.cpp b/mlx/backend/no_gpu/allocator.cpp similarity index 96% rename from mlx/backend/no_metal/allocator.cpp rename to mlx/backend/no_gpu/allocator.cpp index a8b260b6b..320d1a267 100644 --- a/mlx/backend/no_metal/allocator.cpp +++ b/mlx/backend/no_gpu/allocator.cpp @@ -6,9 +6,9 @@ #include "mlx/allocator.h" #ifdef __APPLE__ -#include "mlx/backend/no_metal/apple_memory.h" +#include "mlx/backend/no_gpu/apple_memory.h" #elif defined(__linux__) -#include "mlx/backend/no_metal/linux_memory.h" +#include "mlx/backend/no_gpu/linux_memory.h" #else size_t get_memory_size() { return 0; diff --git a/mlx/backend/no_metal/apple_memory.h b/mlx/backend/no_gpu/apple_memory.h similarity index 100% rename from mlx/backend/no_metal/apple_memory.h rename to mlx/backend/no_gpu/apple_memory.h diff --git a/mlx/backend/no_gpu/eval.cpp b/mlx/backend/no_gpu/eval.cpp new file mode 100644 index 000000000..8bff86a98 --- /dev/null +++ b/mlx/backend/no_gpu/eval.cpp @@ -0,0 +1,28 @@ +// Copyright © 2025 Apple Inc. + +#include + +#include "mlx/backend/gpu/available.h" +#include "mlx/backend/gpu/eval.h" + +namespace mlx::core::gpu { + +bool is_available() { + return false; +} + +void new_stream(Stream) {} + +void eval(array&) { + throw std::runtime_error("[gpu::eval] GPU backend is not available"); +} + +void finalize(Stream) { + throw std::runtime_error("[gpu::finalize] GPU backend is not available"); +} + +void synchronize(Stream) { + throw std::runtime_error("[gpu::synchronize] GPU backend is not available"); +} + +} // namespace mlx::core::gpu diff --git a/mlx/backend/no_metal/event.cpp b/mlx/backend/no_gpu/event.cpp similarity index 100% rename from mlx/backend/no_metal/event.cpp rename to mlx/backend/no_gpu/event.cpp diff --git a/mlx/backend/no_metal/fence.cpp b/mlx/backend/no_gpu/fence.cpp similarity index 100% rename from mlx/backend/no_metal/fence.cpp rename to mlx/backend/no_gpu/fence.cpp diff --git a/mlx/backend/no_metal/linux_memory.h b/mlx/backend/no_gpu/linux_memory.h similarity index 100% rename from mlx/backend/no_metal/linux_memory.h rename to mlx/backend/no_gpu/linux_memory.h diff --git a/mlx/backend/no_metal/primitives.cpp b/mlx/backend/no_gpu/primitives.cpp similarity index 84% rename from mlx/backend/no_metal/primitives.cpp rename to mlx/backend/no_gpu/primitives.cpp index 6826c97f6..dfe5b57f1 100644 --- a/mlx/backend/no_metal/primitives.cpp +++ b/mlx/backend/no_gpu/primitives.cpp @@ -10,6 +10,12 @@ throw std::runtime_error(#func " has no GPU implementation."); \ } +#define NO_GPU_USE_FALLBACK(func) \ + bool func::use_fallback(Stream s) { \ + return true; \ + } \ + NO_GPU_MULTI(func) + #define NO_GPU(func) \ void func::eval_gpu(const std::vector& inputs, array& out) { \ throw std::runtime_error(#func " has no GPU implementation."); \ @@ -17,6 +23,17 @@ namespace mlx::core { +bool fast::ScaledDotProductAttention::use_fallback( + const array& q, + const array& k, + const array& v, + bool has_mask, + bool has_arr_mask, + bool do_causal, + Stream s) { + return true; +} + NO_GPU(Abs) NO_GPU(Add) NO_GPU(AddMM) @@ -104,6 +121,7 @@ NO_GPU(Scan) NO_GPU(Scatter) NO_GPU(ScatterAxis) NO_GPU(Select) +NO_GPU(SegmentedMM) NO_GPU(Sigmoid) NO_GPU(Sign) NO_GPU(Sin) @@ -126,14 +144,15 @@ NO_GPU(Unflatten) NO_GPU(Inverse) NO_GPU(Cholesky) NO_GPU_MULTI(Eigh) +NO_GPU_MULTI(Eig) NO_GPU(View) namespace fast { -NO_GPU_MULTI(LayerNorm) +NO_GPU_USE_FALLBACK(LayerNorm) NO_GPU_MULTI(LayerNormVJP) -NO_GPU_MULTI(RMSNorm) +NO_GPU_USE_FALLBACK(RMSNorm) NO_GPU_MULTI(RMSNormVJP) -NO_GPU_MULTI(RoPE) +NO_GPU_USE_FALLBACK(RoPE) NO_GPU(ScaledDotProductAttention) NO_GPU_MULTI(AffineQuantize) NO_GPU_MULTI(CustomKernel) diff --git a/mlx/backend/no_metal/metal.cpp b/mlx/backend/no_metal/metal.cpp deleted file mode 100644 index ef9af8800..000000000 --- a/mlx/backend/no_metal/metal.cpp +++ /dev/null @@ -1,43 +0,0 @@ -// Copyright © 2023-2024 Apple Inc. - -#include - -#include "mlx/backend/metal/metal.h" -#include "mlx/backend/metal/metal_impl.h" -namespace mlx::core::metal { - -bool is_available() { - return false; -} - -void new_stream(Stream) {} - -std::unique_ptr> new_scoped_memory_pool() { - return nullptr; -} - -void eval(array&) { - throw std::runtime_error( - "[metal::eval] Cannot eval on GPU without metal backend"); -} - -void finalize(Stream) { - throw std::runtime_error( - "[metal::finalize] Cannot finalize GPU without metal backend"); -} - -void synchronize(Stream) { - throw std::runtime_error( - "[metal::synchronize] Cannot synchronize GPU without metal backend"); -} - -void start_capture(std::string) {} -void stop_capture() {} - -const std::unordered_map>& -device_info() { - throw std::runtime_error( - "[metal::device_info] Cannot get device info without metal backend"); -}; - -} // namespace mlx::core::metal diff --git a/mlx/compile.cpp b/mlx/compile.cpp index 7ff5c8f9e..91743ec04 100644 --- a/mlx/compile.cpp +++ b/mlx/compile.cpp @@ -1,16 +1,20 @@ // Copyright © 2023-2024 Apple Inc. #include #include +#include #include #include #include "mlx/allocator.h" +#include "mlx/backend/common/compiled.h" #include "mlx/compile.h" #include "mlx/compile_impl.h" #include "mlx/fast_primitives.h" +#include "mlx/graph_utils.h" #include "mlx/primitives.h" #include "mlx/transforms.h" #include "mlx/transforms_impl.h" +#include "mlx/utils.h" namespace mlx::core { @@ -82,7 +86,54 @@ Compiled::Compiled( inputs_(std::move(inputs)), outputs_(std::move(outputs)), tape_(std::move(tape)), - constant_ids_(std::move(constant_ids)) {} + constant_ids_(std::move(constant_ids)), + is_constant_([this](size_t i) { + return constant_ids_.find(inputs_[i].id()) != constant_ids_.end(); + }) { + // Build the kernel name. + NodeNamer namer; + std::ostringstream os; + std::ostringstream constant_hasher; + + // Fill the input names. This is not really necessary, I just like having A, + // B, C, ... as the inputs. + for (const auto& x : inputs_) { + namer.get_name(x); + } + + // The primitives describing the tape. For unary and binary primitives this + // must be enough to describe the full computation. + for (const auto& a : tape_) { + // name and type of output + os << namer.get_name(a) << kindof(a.dtype()) << a.itemsize(); + // computation performed + os << a.primitive().name(); + // name of inputs to the function + for (auto& inp : a.inputs()) { + os << namer.get_name(inp); + } + } + os << "_"; + + for (const auto& x : inputs_) { + if (constant_ids_.find(x.id()) != constant_ids_.end()) { + os << "C"; + print_constant(constant_hasher, x); + } else { + os << (is_scalar(x) ? "S" : "V"); + } + } + os << "_"; + for (const auto& x : inputs) { + if (constant_ids.find(x.id()) != constant_ids.end()) { + continue; + } + os << kindof(x.dtype()) << x.itemsize(); + } + os << "_" << std::hash{}(constant_hasher.str()); + + kernel_lib_ = os.str(); +} std::vector Compiled::vjp( const std::vector&, @@ -119,11 +170,16 @@ bool Compiled::is_equivalent(const Primitive& other) const { }); } -void Compiled::print(std::ostream& os) { - os << "Compiled"; - for (auto& a : tape_) { - a.primitive().print(os); +const char* Compiled::name() const { + if (name_.empty()) { + std::ostringstream os; + os << "Compiled"; + for (auto& a : tape_) { + os << a.primitive().name(); + } + name_ = os.str(); } + return name_.c_str(); } std::vector Compiled::output_shapes(const std::vector& inputs) { @@ -168,6 +224,15 @@ void merge_one(array& dst, array& src, ParentsMap& parents_map) { parent.first.inputs()[parent.second] = dst; pairs.push_back(parent); } + + // If src is a parent of dst, remove it from dst's parents + for (auto it = pairs.begin(); it != pairs.end();) { + if (it->first.id() == src.id()) { + it = pairs.erase(it); + } else { + it++; + } + } // Remove the source from the map to avoid fusing with it again parents_map.erase(src_parents); } @@ -185,6 +250,30 @@ void merge(array& dst, array& src, ParentsMap& parents_map) { } } +// Any parent in the divider will continue to refer to `x` but any parent not +// in the divider will refer to a copy of the operation. +array split_one( + const array& x, + ParentsMap& parents_map, + const std::unordered_set& divider) { + array y(x.shape(), x.dtype(), x.primitive_ptr(), x.inputs()); + + auto& x_parents = parents_map[x.id()]; + auto& y_parents = parents_map[y.id()]; + + for (auto it = x_parents.begin(); it != x_parents.end();) { + if (divider.find(it->first.id()) != divider.end()) { + it->first.inputs()[it->second] = y; + y_parents.emplace_back(std::move(*it)); + it = x_parents.erase(it); + } else { + it++; + } + } + + return std::move(y); +} + template std::uintptr_t get_function_address(const std::function& fun) { using FunType = T (*)(U...); @@ -609,10 +698,16 @@ void compile_fuse( } // Arrays with a mix of parents outside the compilable section - // are not fusable + // are not fusable except for broadcast which we can split to avoid + // stopping fusion if (!all_parents_in) { - // Possible input - input_set.insert(a.id()); + if (a.has_primitive() && is_broadcast(a.primitive())) { + array b = split_one(a, parents_map, cache); + recurse(b, depth, s, shape); + } else { + // Possible input + input_set.insert(a.id()); + } return; } diff --git a/mlx/device.cpp b/mlx/device.cpp index e635782e2..ec17a509a 100644 --- a/mlx/device.cpp +++ b/mlx/device.cpp @@ -1,23 +1,28 @@ // Copyright © 2023 Apple Inc. +#include + +#include "mlx/backend/cpu/available.h" +#include "mlx/backend/gpu/available.h" #include "mlx/device.h" -#include "mlx/backend/metal/metal.h" namespace mlx::core { -static Device default_device_{ - metal::is_available() ? Device::gpu : Device::cpu}; +Device& mutable_default_device() { + static Device default_device{gpu::is_available() ? Device::gpu : Device::cpu}; + return default_device; +} const Device& default_device() { - return default_device_; + return mutable_default_device(); } void set_default_device(const Device& d) { - if (!metal::is_available() && d == Device::gpu) { + if (!gpu::is_available() && d == Device::gpu) { throw std::invalid_argument( "[set_default_device] Cannot set gpu device without gpu backend."); } - default_device_ = d; + mutable_default_device() = d; } bool operator==(const Device& lhs, const Device& rhs) { @@ -28,4 +33,15 @@ bool operator!=(const Device& lhs, const Device& rhs) { return !(lhs == rhs); } +bool is_available(const Device& d) { + switch (d.type) { + case Device::cpu: + return cpu::is_available(); + case Device::gpu: + return gpu::is_available(); + } + // appease compiler + return false; +} + } // namespace mlx::core diff --git a/mlx/device.h b/mlx/device.h index a11e40e9d..80c624c1c 100644 --- a/mlx/device.h +++ b/mlx/device.h @@ -26,4 +26,6 @@ void set_default_device(const Device& d); bool operator==(const Device& lhs, const Device& rhs); bool operator!=(const Device& lhs, const Device& rhs); +bool is_available(const Device& d); + } // namespace mlx::core diff --git a/mlx/distributed/mpi/mpi.cpp b/mlx/distributed/mpi/mpi.cpp index e80a1759f..6a440c319 100644 --- a/mlx/distributed/mpi/mpi.cpp +++ b/mlx/distributed/mpi/mpi.cpp @@ -225,6 +225,8 @@ struct MPIWrapper { return mpi_bfloat16_; case float64: return mpi_double_; + default: + throw std::runtime_error("Invalid type"); } } diff --git a/mlx/distributed/primitives.h b/mlx/distributed/primitives.h index 7320e6cb6..7ad00a0d6 100644 --- a/mlx/distributed/primitives.h +++ b/mlx/distributed/primitives.h @@ -45,27 +45,22 @@ class AllReduce : public DistPrimitive { const std::vector& argnums, const std::vector& outputs) override; - void print(std::ostream& os) override { + const char* name() const override { switch (reduce_type_) { case And: - os << "And"; + return "And AllReduce"; case Or: - os << "And"; - break; + return "Or AllReduce"; case Sum: - os << "Sum"; - break; + return "Sum AllReduce"; case Prod: - os << "Prod"; - break; + return "Prod AllReduce"; case Min: - os << "Min"; - break; + return "Min AllReduce"; case Max: - os << "Max"; - break; + return "Max AllReduce"; } - os << " AllReduce"; + return ""; } private: @@ -94,7 +89,7 @@ class AllGather : public DistPrimitive { const std::vector& argnums, const std::vector& outputs) override; - DEFINE_PRINT(AllGather); + DEFINE_NAME(AllGather); }; class Send : public DistPrimitive { @@ -110,7 +105,7 @@ class Send : public DistPrimitive { const std::vector& inputs, const std::vector& axes) override; - DEFINE_PRINT(Send); + DEFINE_NAME(Send); private: int dst_; @@ -126,7 +121,7 @@ class Recv : public DistPrimitive { void eval_gpu(const std::vector& inputs, std::vector& outputs) override; - DEFINE_PRINT(Recv); + DEFINE_NAME(Recv); private: int src_; diff --git a/mlx/dtype_utils.cpp b/mlx/dtype_utils.cpp index a4448536d..9f10e6a9a 100644 --- a/mlx/dtype_utils.cpp +++ b/mlx/dtype_utils.cpp @@ -5,16 +5,38 @@ namespace mlx::core { const char* dtype_to_string(Dtype arg) { - if (arg == bool_) { - return "bool"; + switch (arg) { + case bool_: + return "bool"; + case int8: + return "int8"; + case int16: + return "int16"; + case int32: + return "int32"; + case int64: + return "int64"; + case uint8: + return "uint8"; + case uint16: + return "uint16"; + case uint32: + return "uint32"; + case uint64: + return "uint64"; + case float16: + return "float16"; + case bfloat16: + return "bfloat16"; + case float32: + return "float32"; + case float64: + return "float64"; + case complex64: + return "complex64"; + default: + return "unknown"; } -#define SPECIALIZE_DtypeToString(CPP_TYPE, DTYPE) \ - if (DTYPE == arg) { \ - return #DTYPE; \ - } - MLX_FORALL_DTYPES(SPECIALIZE_DtypeToString) -#undef SPECIALIZE_DtypeToString - return "(unknown)"; } } // namespace mlx::core diff --git a/mlx/dtype_utils.h b/mlx/dtype_utils.h index 55de890f2..27fe432f6 100644 --- a/mlx/dtype_utils.h +++ b/mlx/dtype_utils.h @@ -1,207 +1,106 @@ // Copyright © 2025 Apple Inc. -// Copyright © Meta Platforms, Inc. and affiliates. -// -// This source code is licensed under the BSD-style license found in -// https://github.com/pytorch/executorch/blob/main/LICENSE -// -// Forked from -// https://github.com/pytorch/executorch/blob/main/runtime/core/exec_aten/util/scalar_type_util.h #pragma once -#include "mlx/dtype.h" +#include -#include +#include "mlx/dtype.h" +#include "mlx/utils.h" namespace mlx::core { // Return string representation of dtype. const char* dtype_to_string(Dtype arg); -// Macros that iterate across different subsets of Dtypes. -// -// For all of these macros, the final `_` parameter is the name of another macro -// that takes two parameters: the name of a C type, and the name of the -// corresponding Dtype enumerator. -// -// Note that these macros should use fully-qualified namespaces (starting with -// `::`) to ensure that they can be called safely in any arbitrary namespace. -#define MLX_FORALL_INT_TYPES(_) \ - _(uint8_t, uint8) \ - _(uint16_t, uint16) \ - _(uint32_t, uint32) \ - _(uint64_t, uint64) \ - _(int8_t, int8) \ - _(int16_t, int16) \ - _(int32_t, int32) \ - _(int64_t, int64) +#define MLX_INTERNAL_DTYPE_SWITCH_CASE(DTYPE, TYPE) \ + case DTYPE: \ + f(type_identity{}); \ + break -#define MLX_FORALL_FLOAT_TYPES(_) \ - _(float16_t, float16) \ - _(float, float32) \ - _(double, float64) \ - _(bfloat16_t, bfloat16) +#define MLX_INTERNAL_DTYPE_SWITCH_INTS() \ + MLX_INTERNAL_DTYPE_SWITCH_CASE(int8, int8_t); \ + MLX_INTERNAL_DTYPE_SWITCH_CASE(int16, int16_t); \ + MLX_INTERNAL_DTYPE_SWITCH_CASE(int32, int32_t); \ + MLX_INTERNAL_DTYPE_SWITCH_CASE(int64, int64_t); \ + MLX_INTERNAL_DTYPE_SWITCH_CASE(uint8, uint8_t); \ + MLX_INTERNAL_DTYPE_SWITCH_CASE(uint16, uint16_t); \ + MLX_INTERNAL_DTYPE_SWITCH_CASE(uint32, uint32_t); \ + MLX_INTERNAL_DTYPE_SWITCH_CASE(uint64, uint64_t) -// Calls the provided macro on every Dtype, providing the C type and the -// Dtype name to each call. -// -// @param _ A macro that takes two parameters: the name of a C type, and the -// name of the corresponding Dtype enumerator. -#define MLX_FORALL_DTYPES(_) \ - MLX_FORALL_INT_TYPES(_) \ - MLX_FORALL_FLOAT_TYPES(_) \ - _(bool, bool_) \ - _(complex64_t, complex64) +#define MLX_INTERNAL_DTYPE_SWITCH_FLOATS() \ + MLX_INTERNAL_DTYPE_SWITCH_CASE(float16, float16_t); \ + MLX_INTERNAL_DTYPE_SWITCH_CASE(bfloat16, bfloat16_t); \ + MLX_INTERNAL_DTYPE_SWITCH_CASE(float32, float); \ + MLX_INTERNAL_DTYPE_SWITCH_CASE(float64, double) -// Maps Dtypes to C++ types. -template -struct DtypeToCppType; - -#define SPECIALIZE_DtypeToCppType(CPP_TYPE, DTYPE) \ - template <> \ - struct DtypeToCppType { \ - using type = CPP_TYPE; \ - }; - -MLX_FORALL_DTYPES(SPECIALIZE_DtypeToCppType) - -#undef SPECIALIZE_DtypeToCppType - -// Maps C++ types to Dtypes. +// This already exists in C++20 but in C++20 we can also just use templated +// lambdas which will make this so much nicer. template -struct CppTypeToDtype; +struct type_identity { + using type = T; +}; -#define SPECIALIZE_CppTypeToDtype(CPP_TYPE, DTYPE) \ - template <> \ - struct CppTypeToDtype \ - : std::integral_constant {}; +#define MLX_GET_TYPE(x) typename decltype(x)::type +#define MLX_GET_VALUE(x) decltype(x)::value -MLX_FORALL_DTYPES(SPECIALIZE_CppTypeToDtype) - -#undef SPECIALIZE_CppTypeToDtype - -// Helper macros for switch case macros (see below) -// -// These macros are not meant to be used directly. They provide an easy way to -// generate a switch statement that can handle subsets of Dtypes supported. - -#define MLX_INTERNAL_SWITCH_CASE(enum_type, CTYPE_ALIAS, ...) \ - case enum_type: { \ - using CTYPE_ALIAS = ::mlx::core::DtypeToCppType::type; \ - __VA_ARGS__; \ - break; \ +template +void dispatch_all_types(Dtype dt, F&& f) { + switch (dt) { + MLX_INTERNAL_DTYPE_SWITCH_CASE(bool_, bool); + MLX_INTERNAL_DTYPE_SWITCH_INTS(); + MLX_INTERNAL_DTYPE_SWITCH_FLOATS(); + MLX_INTERNAL_DTYPE_SWITCH_CASE(complex64, complex64_t); } +} -#define MLX_INTERNAL_SWITCH_CHECKED(TYPE, NAME, ...) \ - switch (TYPE) { \ - __VA_ARGS__ \ - default: \ - throw std::invalid_argument(fmt::format( \ - "Unhandled dtype %s for %s", dtype_to_string(TYPE), NAME)); \ +template +void dispatch_int_types(Dtype dt, std::string_view tag, F&& f) { + switch (dt) { + MLX_INTERNAL_DTYPE_SWITCH_INTS(); + default: + std::ostringstream msg; + msg << tag << " Only integer types supported but " << dt + << " was provided"; + throw std::invalid_argument(msg.str()); } +} -#define MLX_INTERNAL_SWITCH_CASE_INT_TYPES(CTYPE_ALIAS, ...) \ - MLX_INTERNAL_SWITCH_CASE( \ - ::mlx::core::Dtype::Val::uint8, CTYPE_ALIAS, __VA_ARGS__) \ - MLX_INTERNAL_SWITCH_CASE( \ - ::mlx::core::Dtype::Val::uint16, CTYPE_ALIAS, __VA_ARGS__) \ - MLX_INTERNAL_SWITCH_CASE( \ - ::mlx::core::Dtype::Val::uint32, CTYPE_ALIAS, __VA_ARGS__) \ - MLX_INTERNAL_SWITCH_CASE( \ - ::mlx::core::Dtype::Val::uint64, CTYPE_ALIAS, __VA_ARGS__) \ - MLX_INTERNAL_SWITCH_CASE( \ - ::mlx::core::Dtype::Val::int8, CTYPE_ALIAS, __VA_ARGS__) \ - MLX_INTERNAL_SWITCH_CASE( \ - ::mlx::core::Dtype::Val::int16, CTYPE_ALIAS, __VA_ARGS__) \ - MLX_INTERNAL_SWITCH_CASE( \ - ::mlx::core::Dtype::Val::int32, CTYPE_ALIAS, __VA_ARGS__) \ - MLX_INTERNAL_SWITCH_CASE( \ - ::mlx::core::Dtype::Val::int64, CTYPE_ALIAS, __VA_ARGS__) +template +void dispatch_float_types(Dtype dt, std::string_view tag, F&& f) { + switch (dt) { + MLX_INTERNAL_DTYPE_SWITCH_FLOATS(); + default: + std::ostringstream msg; + msg << tag << " Only float types supported but " << dt << " was provided"; + throw std::invalid_argument(msg.str()); + } +} -#define MLX_INTERNAL_SWITCH_CASE_FLOAT_TYPES(CTYPE_ALIAS, ...) \ - MLX_INTERNAL_SWITCH_CASE( \ - ::mlx::core::Dtype::Val::float16, CTYPE_ALIAS, __VA_ARGS__) \ - MLX_INTERNAL_SWITCH_CASE( \ - ::mlx::core::Dtype::Val::float32, CTYPE_ALIAS, __VA_ARGS__) \ - MLX_INTERNAL_SWITCH_CASE( \ - ::mlx::core::Dtype::Val::float64, CTYPE_ALIAS, __VA_ARGS__) \ - MLX_INTERNAL_SWITCH_CASE( \ - ::mlx::core::Dtype::Val::bfloat16, CTYPE_ALIAS, __VA_ARGS__) +template +void dispatch_int_float_types(Dtype dt, std::string_view tag, F&& f) { + switch (dt) { + MLX_INTERNAL_DTYPE_SWITCH_INTS(); + MLX_INTERNAL_DTYPE_SWITCH_FLOATS(); + default: + std::ostringstream msg; + msg << tag << " Only integer and float types supported but " << dt + << " was provided"; + throw std::invalid_argument(msg.str()); + } +} -#define MLX_INTERNAL_SWITCH_CASE_INT_FLOAT_TYPES(CTYPE_ALIAS, ...) \ - MLX_INTERNAL_SWITCH_CASE_INT_TYPES(CTYPE_ALIAS, __VA_ARGS__) \ - MLX_INTERNAL_SWITCH_CASE_FLOAT_TYPES(CTYPE_ALIAS, __VA_ARGS__) - -#define MLX_INTERNAL_SWITCH_CASE_REAL_TYPES(CTYPE_ALIAS, ...) \ - MLX_INTERNAL_SWITCH_CASE_INT_FLOAT_TYPES(CTYPE_ALIAS, __VA_ARGS__) \ - MLX_INTERNAL_SWITCH_CASE( \ - ::mlx::core::Dtype::Val::bool_, CTYPE_ALIAS, __VA_ARGS__) - -#define MLX_INTERNAL_SWITCH_CASE_COMPLEX_TYPES(CTYPE_ALIAS, ...) \ - MLX_INTERNAL_SWITCH_CASE( \ - ::mlx::core::Dtype::Val::complex64, CTYPE_ALIAS, __VA_ARGS__) - -#define MLX_INTERNAL_SWITCH_CASE_ALL_TYPES(CTYPE_ALIAS, ...) \ - MLX_INTERNAL_SWITCH_CASE_REAL_TYPES(CTYPE_ALIAS, __VA_ARGS__) \ - MLX_INTERNAL_SWITCH_CASE_COMPLEX_TYPES(CTYPE_ALIAS, __VA_ARGS__) - -// Switch case macros -// -// These macros provide an easy way to generate switch statements that apply a -// common lambda function to subsets of Dtypes supported by MLX. -// The lambda function can type specialize to the ctype associated with the -// Dtype being handled through an alias passed as the CTYPE_ALIAS argument. -// -// Arguments: -// - ADDITIONAL: Additional Dtype case to add -// - TYPE: The Dtype to handle through the switch statement -// - NAME: A name for this operation which will be used in error messages -// - CTYPE_ALIAS: A typedef for the ctype associated with the Dtype. -// - ...: A statement to be applied to each Dtype case -// -// An example usage is: -// -// MLX_SWITCH_ALL_TYPES(input.dtype(), CTYPE, { -// output.data[0] = input.data[0]; -// }); -// -// Note that these can be nested as well: -// -// MLX_SWITCH_ALL_TYPES(input.dtype(), CTYPE_IN, { -// MLX_SWITCH_ALL_TYPES(output.dtype(), CTYPE_OUT, { -// output.data[0] = input.data[0]; -// }); -// }); -// -// These macros are adapted from Dispatch.h in the ATen library. The primary -// difference is that the CTYPE_ALIAS argument is exposed to users, which is -// used to alias the ctype associated with the Dtype that is being handled. - -#define MLX_SWITCH_ALL_TYPES(TYPE, CTYPE_ALIAS, ...) \ - switch (TYPE) { MLX_INTERNAL_SWITCH_CASE_ALL_TYPES(CTYPE_ALIAS, __VA_ARGS__) } - -#define MLX_SWITCH_INT_TYPES_CHECKED(TYPE, NAME, CTYPE_ALIAS, ...) \ - MLX_INTERNAL_SWITCH_CHECKED( \ - TYPE, \ - NAME, \ - MLX_INTERNAL_SWITCH_CASE_INT_TYPES(CTYPE_ALIAS, __VA_ARGS__)) - -#define MLX_SWITCH_FLOAT_TYPES_CHECKED(TYPE, NAME, CTYPE_ALIAS, ...) \ - MLX_INTERNAL_SWITCH_CHECKED( \ - TYPE, \ - NAME, \ - MLX_INTERNAL_SWITCH_CASE_FLOAT_TYPES(CTYPE_ALIAS, __VA_ARGS__)) - -#define MLX_SWITCH_INT_FLOAT_TYPES_CHECKED(TYPE, NAME, CTYPE_ALIAS, ...) \ - MLX_INTERNAL_SWITCH_CHECKED( \ - TYPE, \ - NAME, \ - MLX_INTERNAL_SWITCH_CASE_INT_FLOAT_TYPES(CTYPE_ALIAS, __VA_ARGS__)) - -#define MLX_SWITCH_REAL_TYPES_CHECKED(TYPE, NAME, CTYPE_ALIAS, ...) \ - MLX_INTERNAL_SWITCH_CHECKED( \ - TYPE, \ - NAME, \ - MLX_INTERNAL_SWITCH_CASE_REAL_TYPES(CTYPE_ALIAS, __VA_ARGS__)) +template +void dispatch_real_types(Dtype dt, std::string_view tag, F&& f) { + switch (dt) { + MLX_INTERNAL_DTYPE_SWITCH_CASE(bool_, bool); + MLX_INTERNAL_DTYPE_SWITCH_INTS(); + MLX_INTERNAL_DTYPE_SWITCH_FLOATS(); + default: + std::ostringstream msg; + msg << tag << " Only real numbers supported but " << dt + << " was provided"; + throw std::invalid_argument(msg.str()); + } +} } // namespace mlx::core diff --git a/mlx/export.cpp b/mlx/export.cpp index effc7a0c1..8eb385bb1 100644 --- a/mlx/export.cpp +++ b/mlx/export.cpp @@ -266,6 +266,7 @@ struct PrimitiveFactory { SERIALIZE_PRIMITIVE(Floor), SERIALIZE_PRIMITIVE(Full), SERIALIZE_PRIMITIVE(Gather), + SERIALIZE_PRIMITIVE(GatherAxis), SERIALIZE_PRIMITIVE(GatherMM), SERIALIZE_PRIMITIVE(Greater), SERIALIZE_PRIMITIVE(GreaterEqual), @@ -307,6 +308,7 @@ struct PrimitiveFactory { "CumMax", "CumLogaddexp"), SERIALIZE_PRIMITIVE(Scatter), + SERIALIZE_PRIMITIVE(ScatterAxis), SERIALIZE_PRIMITIVE(Select), SERIALIZE_PRIMITIVE(Sigmoid), SERIALIZE_PRIMITIVE(Sign), @@ -331,6 +333,7 @@ struct PrimitiveFactory { SERIALIZE_PRIMITIVE(SVD), SERIALIZE_PRIMITIVE(Inverse), SERIALIZE_PRIMITIVE(Cholesky), + SERIALIZE_PRIMITIVE(Eig), SERIALIZE_PRIMITIVE(Eigh), SERIALIZE_PRIMITIVE(AffineQuantize), SERIALIZE_PRIMITIVE(RMSNorm), @@ -351,9 +354,7 @@ struct PrimitiveFactory { void save(Writer& os, const std::shared_ptr& p) { serialize(os, p->stream()); - std::ostringstream pout; - p->print(pout); - auto name = pout.str(); + std::string name = p->name(); name = name.substr(0, name.find(' ')); if (auto it = name_remap.find(name); it != name_remap.end()) { name = it->second; @@ -470,6 +471,9 @@ bool FunctionTable::match( if (x.dtype() != y.dtype()) { return false; } + if (x.ndim() != y.ndim()) { + return false; + } if (!shapeless && x.shape() != y.shape()) { return false; } diff --git a/mlx/fast.cpp b/mlx/fast.cpp index 77210f713..210c7f729 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -1,15 +1,11 @@ // Copyright © 2023-2024 Apple Inc. #include -#include #include -#include -#include "mlx/backend/common/compiled.h" #include "mlx/fast.h" #include "mlx/fast_primitives.h" #include "mlx/ops.h" #include "mlx/transforms.h" -#include "mlx/transforms_impl.h" namespace mlx::core::fast { @@ -112,7 +108,8 @@ array rms_norm( auto passed_weight = (has_weight) ? astype(*weight, out_type, s) : array(1, out_type); - if (s.device == Device::gpu) { + + if (!RMSNorm::use_fallback(s)) { return array( x.shape(), out_type, @@ -231,13 +228,11 @@ array layer_norm( const std::vector& inputs) { auto x = astype(inputs[0], float32, s); - // Should I not be smart here and leave the double mean to simplify()? auto mu = mean(x, /* axis= */ -1, /* keepdims= */ true, s); - auto mu2 = square(mu, s); - auto x2 = mean(square(x, s), /* axis= */ -1, /* keepdims= */ true, s); - auto v = subtract(x2, mu2, s); + auto xc = subtract(x, mu, s); + auto v = mean(square(xc, s), /* axis= */ -1, /* keepdims= */ true, s); - x = multiply(subtract(x, mu, s), rsqrt(add(v, array(eps, float32), s), s)); + x = multiply(xc, rsqrt(add(v, array(eps, float32), s), s)); x = astype(x, out_type, s); // If the LN is affine then transform x according to the weight and bias @@ -256,7 +251,7 @@ array layer_norm( auto passed_bias = (has_bias) ? astype(*bias, out_type, s) : array(0, out_type); - if (s.device == Device::gpu) { + if (!LayerNorm::use_fallback(s)) { return array( x.shape(), out_type, @@ -470,7 +465,7 @@ array rope( } }; auto stream = to_stream(s); - if (stream.device == Device::gpu) { + if (!RoPE::use_fallback(stream)) { return array( x.shape(), x.dtype(), @@ -727,31 +722,6 @@ array scaled_dot_product_attention( }; auto stream = to_stream(s); - const int value_head_dim = v.shape(-1); - const int query_head_dim = q.shape(-1); - const int query_sequence_length = q.shape(2); - const int key_sequence_length = k.shape(2); - - const bool sdpa_vector_supported_head_dim = - query_head_dim == value_head_dim && - (query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128 || - query_head_dim == 256); - const bool sdpa_full_supported_head_dim = query_head_dim == value_head_dim && - (query_head_dim == 64 || query_head_dim == 80 || query_head_dim == 128); - - const bool sdpa_full_supported_mask = !has_mask || has_arr_mask || - (query_sequence_length <= key_sequence_length && do_causal); - - const bool supports_sdpa_full = sdpa_full_supported_mask && - sdpa_full_supported_head_dim && stream.device == Device::gpu; - - const bool supports_sdpa_vector = (query_sequence_length <= 8) && - (query_sequence_length <= key_sequence_length) && - sdpa_vector_supported_head_dim && stream.device == Device::gpu; - - const bool implementation_supports_use_case = - supports_sdpa_full || supports_sdpa_vector; - std::vector inputs = {q, k, v}; if (has_arr_mask) { // Check type @@ -770,7 +740,8 @@ array scaled_dot_product_attention( mask_shape.back() = keys.shape(-2); inputs.push_back(broadcast_to(mask_arr, mask_shape, stream)); } - if (!detail::in_grad_tracing() && implementation_supports_use_case) { + if (!ScaledDotProductAttention::use_fallback( + q, k, v, has_mask, has_arr_mask, do_causal, stream)) { auto out_shape = Shape{q.shape(0), q.shape(1), q.shape(2), v.shape(-1)}; return array( std::move(out_shape), @@ -779,7 +750,7 @@ array scaled_dot_product_attention( stream, fallback, scale, do_causal), std::move(inputs)); } - return fallback(inputs)[0]; + return fallback(std::move(inputs))[0]; } bool ScaledDotProductAttention::is_equivalent(const Primitive& other) const { @@ -839,14 +810,14 @@ affine_quantize(const array& w, int group_size, int bits, StreamOrDevice s_) { if (group_size != 32 && group_size != 64 && group_size != 128) { std::ostringstream msg; msg << "[quantize] The requested group size " << group_size - << " is not supported. The supported group sizes are 64 and 128."; + << " is not supported. The supported group sizes are 32, 64, and 128."; throw std::invalid_argument(msg.str()); } - if (bits != 2 && bits != 3 && bits != 4 && bits != 6 && bits != 8) { + if (bits < 2 || bits > 8 || bits == 7) { std::ostringstream msg; msg << "[quantize] The requested number of bits " << bits - << " is not supported. The supported bits are 2, 3, 4, 6 and 8."; + << " is not supported. The supported bits are 2, 3, 4, 5, 6 and 8."; throw std::invalid_argument(msg.str()); } @@ -1053,303 +1024,4 @@ std::vector AffineQuantize::output_shapes( } } -std::string write_signature( - std::string func_name, - const std::string& header, - const std::string& source, - const std::vector& input_names, - const std::vector& inputs, - const std::vector& output_names, - const std::vector& output_dtypes, - const std::vector>& template_args, - const std::vector& attributes, - const std::vector& shape_infos, - bool atomic_outputs) { - std::string kernel_source; - kernel_source.reserve(header.size() + source.size() + 16384); - kernel_source += header; - // Auto-generate a function signature based on `template_args` - // and the dtype/shape of the arrays passed as `inputs`. - if (!template_args.empty()) { - kernel_source += "template <"; - int i = 0; - for (const auto& [name, arg] : template_args) { - std::string param_type; - if (std::holds_alternative(arg)) { - param_type = "int"; - } else if (std::holds_alternative(arg)) { - param_type = "bool"; - } else if (std::holds_alternative(arg)) { - param_type = "typename"; - } - if (i > 0) { - kernel_source += ", "; - } - kernel_source += param_type; - kernel_source += " "; - kernel_source += name; - i++; - } - kernel_source += ">\n"; - } - kernel_source += "[[kernel]] void "; - kernel_source += func_name; - kernel_source += "(\n"; - - int index = 0; - constexpr int max_constant_array_size = 8; - // Add inputs - for (int i = 0; i < inputs.size(); ++i) { - const auto& name = input_names[i]; - const auto& arr = inputs[i]; - auto dtype = get_type_string(arr.dtype()); - std::string location = - arr.size() < max_constant_array_size ? "constant" : "device"; - std::string ref = arr.ndim() == 0 ? "&" : "*"; - kernel_source += " const "; - kernel_source += location; - kernel_source += " "; - kernel_source += dtype; - kernel_source += ref; - kernel_source += " "; - kernel_source += name; - kernel_source += " [[buffer("; - kernel_source += std::to_string(index); - kernel_source += ")]],\n"; - index++; - // Add input shape, strides and ndim if present in the source - if (arr.ndim() > 0) { - if (shape_infos[i].shape) { - kernel_source += - (" const constant int* " + name + "_shape [[buffer(" + - std::to_string(index) + ")]],\n"); - index++; - } - if (shape_infos[i].strides) { - kernel_source += - (" const constant int64_t* " + name + "_strides [[buffer(" + - std::to_string(index) + ")]],\n"); - index++; - } - if (shape_infos[i].ndim) { - kernel_source += - (" const constant int& " + name + "_ndim [[buffer(" + - std::to_string(index) + ")]],\n"); - index++; - } - } - } - // Add outputs - for (int i = 0; i < output_names.size(); ++i) { - const auto& name = output_names[i]; - const auto& dtype = output_dtypes[i]; - kernel_source += " device "; - auto type_string = get_type_string(dtype); - if (atomic_outputs) { - kernel_source += "atomic<"; - } - kernel_source += type_string; - if (atomic_outputs) { - kernel_source += ">"; - } - kernel_source += "* "; - kernel_source += name; - kernel_source += " [[buffer("; - kernel_source += std::to_string(index); - kernel_source += ")]]"; - if (index < inputs.size() + output_names.size() - 1 || - attributes.size() > 0) { - kernel_source += ",\n"; - } else { - kernel_source += ") {\n"; - } - index++; - } - - index = 0; - for (const auto& attr : attributes) { - kernel_source += attr; - if (index < attributes.size() - 1) { - kernel_source += ",\n"; - } else { - kernel_source += ") {\n"; - } - index++; - } - kernel_source += source; - kernel_source += "\n}\n"; - return kernel_source; -} - -std::string write_template( - const std::vector>& template_args) { - std::ostringstream template_def; - template_def << "<"; - int i = 0; - for (const auto& [name, arg] : template_args) { - if (i > 0) { - template_def << ", "; - } - if (std::holds_alternative(arg)) { - template_def << std::get(arg); - } else if (std::holds_alternative(arg)) { - template_def << std::get(arg); - } else if (std::holds_alternative(arg)) { - template_def << get_type_string(std::get(arg)); - } - i++; - } - template_def << ">"; - return template_def.str(); -} - -MetalKernelFunction metal_kernel( - const std::string& name, - const std::vector& input_names, - const std::vector& output_names, - const std::string& source, - const std::string& header /* = "" */, - bool ensure_row_contiguous /* = true */, - bool atomic_outputs /* = false */) { - if (output_names.empty()) { - throw std::invalid_argument( - "[metal_kernel] Must specify at least one output."); - } - std::vector shape_infos; - for (auto& n : input_names) { - CustomKernelShapeInfo shape_info; - shape_info.shape = source.find(n + "_shape") != std::string::npos; - shape_info.strides = source.find(n + "_strides") != std::string::npos; - shape_info.ndim = source.find(n + "_ndim") != std::string::npos; - shape_infos.push_back(shape_info); - } - const std::vector> metal_attributes = { - {"dispatch_quadgroups_per_threadgroup", "uint"}, - {"dispatch_simdgroups_per_threadgroup", "uint"}, - {"dispatch_threads_per_threadgroup", "uint3"}, - {"grid_origin", "uint3"}, - {"grid_size", "uint3"}, - {"quadgroup_index_in_threadgroup", "uint"}, - {"quadgroups_per_threadgroup", "uint"}, - {"simdgroup_index_in_threadgroup", "uint"}, - {"simdgroups_per_threadgroup", "uint"}, - {"thread_execution_width", "uint"}, - {"thread_index_in_quadgroup", "uint"}, - {"thread_index_in_simdgroup", "uint"}, - {"thread_index_in_threadgroup", "uint"}, - {"thread_position_in_grid", "uint3"}, - {"thread_position_in_threadgroup", "uint3"}, - {"threadgroup_position_in_grid", "uint3"}, - {"threadgroups_per_grid", "uint3"}, - {"threads_per_grid", "uint3"}, - {"threads_per_simdgroup", "uint"}, - {"threads_per_threadgroup", "uint3"}, - }; - - std::vector attributes; - for (const auto& [attr, dtype] : metal_attributes) { - if (source.find(attr) != std::string::npos) { - attributes.push_back(" " + dtype + " " + attr + " [[" + attr + "]]"); - } - } - - return [=, - shape_infos = std::move(shape_infos), - attributes = std::move(attributes)]( - const std::vector& inputs, - const std::vector& output_shapes, - const std::vector& output_dtypes, - std::tuple grid, - std::tuple threadgroup, - const std::vector>& - template_args = {}, - std::optional init_value = std::nullopt, - bool verbose = false, - StreamOrDevice s_ = {}) { - if (inputs.size() != input_names.size()) { - std::ostringstream msg; - msg << "[metal_kernel] Expected `inputs` to have size " - << input_names.size() << " but got size " << inputs.size() << "." - << std::endl; - throw std::invalid_argument(msg.str()); - } - if (output_shapes.size() != output_names.size()) { - std::ostringstream msg; - msg << "[metal_kernel] Expected `output_shapes` to have size " - << output_names.size() << " but got size " << output_shapes.size() - << "." << std::endl; - throw std::invalid_argument(msg.str()); - } - if (output_dtypes.size() != output_names.size()) { - std::ostringstream msg; - msg << "[metal_kernel] Expected `output_dtypes` to have size " - << output_names.size() << " but got size " << output_dtypes.size() - << "." << std::endl; - throw std::invalid_argument(msg.str()); - } - - auto s = to_stream(s_); - if (s.device != Device::gpu) { - throw std::invalid_argument("[metal_kernel] Only supports the GPU."); - } - - std::ostringstream func_name; - std::string template_def = ""; - std::string hash_key = ""; - if (!template_args.empty()) { - std::regex disallowed_chars("\\<|\\>|(, )"); - template_def = write_template(template_args); - hash_key = std::regex_replace(template_def, disallowed_chars, "_"); - hash_key.pop_back(); - } - func_name << "custom_kernel_" << name << hash_key; - std::string kernel_name = func_name.str(); - - std::string kernel_source = write_signature( - kernel_name, - header, - source, - input_names, - inputs, - output_names, - output_dtypes, - template_args, - attributes, - shape_infos, - atomic_outputs); - - if (!template_args.empty()) { - template_def = kernel_name + template_def; - kernel_source += "\ntemplate [[host_name(\""; - kernel_source += kernel_name; - kernel_source += "\")]] [[kernel]] decltype("; - kernel_source += template_def; - kernel_source += ") "; - kernel_source += template_def; - kernel_source += ";\n"; - } - - if (verbose) { - std::cout << "Generated source code for `" << name << "`:" << std::endl - << "```" << std::endl - << kernel_source << std::endl - << "```" << std::endl; - } - - return array::make_arrays( - std::move(output_shapes), - std::move(output_dtypes), - std::make_shared( - s, - std::move(kernel_name), - std::move(kernel_source), - grid, - threadgroup, - shape_infos, - ensure_row_contiguous, - init_value), - std::move(inputs)); - }; -} - } // namespace mlx::core::fast diff --git a/mlx/fast_primitives.h b/mlx/fast_primitives.h index 4d9e505ee..52135adad 100644 --- a/mlx/fast_primitives.h +++ b/mlx/fast_primitives.h @@ -43,6 +43,8 @@ class RMSNorm : public Custom { float eps) : Custom(stream, fallback), eps_(eps) {} + static bool use_fallback(Stream stream); + void eval_cpu(const std::vector& inputs, std::vector& outputs) override { throw std::runtime_error("NYI"); @@ -56,7 +58,7 @@ class RMSNorm : public Custom { const std::vector& argnums, const std::vector& outputs) override; - DEFINE_PRINT(RMSNorm) + DEFINE_NAME(RMSNorm) bool is_equivalent(const Primitive& other) const override; DEFINE_INPUT_OUTPUT_SHAPE() @@ -65,7 +67,6 @@ class RMSNorm : public Custom { } private: - std::function(std::vector)> fallback_; float eps_; }; @@ -84,14 +85,13 @@ class RMSNormVJP : public Custom { void eval_gpu(const std::vector& inputs, std::vector& outputs) override; - DEFINE_PRINT(RMSNormVJP) + DEFINE_NAME(RMSNormVJP) bool is_equivalent(const Primitive& other) const override; auto state() const { return std::make_pair(nullptr, eps_); } private: - std::function(std::vector)> fallback_; float eps_; }; @@ -103,6 +103,8 @@ class LayerNorm : public Custom { float eps) : Custom(stream, fallback), eps_(eps) {} + static bool use_fallback(Stream s); + void eval_cpu(const std::vector& inputs, std::vector& outputs) override { throw std::runtime_error("NYI"); @@ -116,7 +118,7 @@ class LayerNorm : public Custom { const std::vector& argnums, const std::vector& outputs) override; - DEFINE_PRINT(LayerNorm) + DEFINE_NAME(LayerNorm) bool is_equivalent(const Primitive& other) const override; DEFINE_INPUT_OUTPUT_SHAPE() auto state() const { @@ -124,7 +126,6 @@ class LayerNorm : public Custom { } private: - std::function(std::vector)> fallback_; float eps_; }; @@ -143,14 +144,13 @@ class LayerNormVJP : public Custom { void eval_gpu(const std::vector& inputs, std::vector& outputs) override; - DEFINE_PRINT(LayerNormVJP) + DEFINE_NAME(LayerNormVJP) bool is_equivalent(const Primitive& other) const override; auto state() const { return std::make_pair(nullptr, eps_); } private: - std::function(std::vector)> fallback_; float eps_; }; @@ -171,6 +171,8 @@ class RoPE : public Custom { scale_(scale), forward_(forward) {} + static bool use_fallback(Stream s); + void eval_cpu(const std::vector& inputs, std::vector& outputs) override { throw std::runtime_error("NYI"); @@ -184,7 +186,7 @@ class RoPE : public Custom { const std::vector& argnums, const std::vector& outputs) override; - DEFINE_PRINT(RoPE) + DEFINE_NAME(RoPE) bool is_equivalent(const Primitive& other) const override; DEFINE_INPUT_OUTPUT_SHAPE() auto state() const { @@ -193,7 +195,6 @@ class RoPE : public Custom { } private: - std::function(std::vector)> fallback_; int dims_; bool traditional_; float base_; @@ -210,6 +211,15 @@ class ScaledDotProductAttention : public Custom { const bool do_causal) : Custom(stream, fallback), scale_(scale), do_causal_(do_causal) {} + static bool use_fallback( + const array& q, + const array& k, + const array& v, + bool has_mask, + bool has_arr_mask, + bool do_causal, + Stream s); + void eval_cpu(const std::vector& inputs, std::vector& outputs) override { throw std::runtime_error("NYI"); @@ -223,14 +233,13 @@ class ScaledDotProductAttention : public Custom { void eval_gpu(const std::vector& inputs, array& out); bool is_equivalent(const Primitive& other) const override; - DEFINE_PRINT(ScaledDotProductAttention); + DEFINE_NAME(ScaledDotProductAttention); DEFINE_INPUT_OUTPUT_SHAPE() auto state() const { return std::make_tuple(nullptr, scale_, do_causal_); } private: - std::function(std::vector)> fallback_; float scale_; bool do_causal_; }; @@ -254,7 +263,7 @@ class AffineQuantize : public Custom { void eval_gpu(const std::vector& inputs, std::vector& outputs) override; - DEFINE_PRINT(AffineQuantize); + DEFINE_NAME(AffineQuantize); bool is_equivalent(const Primitive& other) const override; std::vector output_shapes(const std::vector& inputs) override; @@ -263,7 +272,6 @@ class AffineQuantize : public Custom { } private: - std::function(std::vector)> fallback_; int group_size_; int bits_; bool dequantize_; @@ -303,7 +311,7 @@ class CustomKernel : public Primitive { void eval_gpu(const std::vector& inputs, std::vector& outputs) override; - DEFINE_PRINT(CustomKernel); + DEFINE_NAME(CustomKernel); private: std::string source_; diff --git a/mlx/fft.cpp b/mlx/fft.cpp index f0d41bf0f..6510faec1 100644 --- a/mlx/fft.cpp +++ b/mlx/fft.cpp @@ -1,5 +1,4 @@ // Copyright © 2023 Apple Inc. - #include #include @@ -109,7 +108,7 @@ array fft_impl( for (auto ax : axes) { n.push_back(a.shape(ax)); } - if (real && inverse) { + if (real && inverse && a.ndim() > 0) { n.back() = (n.back() - 1) * 2; } return fft_impl(a, n, axes, real, inverse, s); @@ -185,8 +184,79 @@ array irfftn( StreamOrDevice s /* = {} */) { return fft_impl(a, axes, true, true, s); } + array irfftn(const array& a, StreamOrDevice s /* = {} */) { return fft_impl(a, true, true, s); } +array fftshift( + const array& a, + const std::vector& axes, + StreamOrDevice s /* = {} */) { + if (axes.empty()) { + return a; + } + + Shape shifts; + for (int ax : axes) { + // Convert negative axes to positive + int axis = ax < 0 ? ax + a.ndim() : ax; + if (axis < 0 || axis >= a.ndim()) { + std::ostringstream msg; + msg << "[fftshift] Invalid axis " << ax << " for array with " << a.ndim() + << " dimensions."; + throw std::invalid_argument(msg.str()); + } + // Match NumPy's implementation + shifts.push_back(a.shape(axis) / 2); + } + + return roll(a, shifts, axes, s); +} + +array ifftshift( + const array& a, + const std::vector& axes, + StreamOrDevice s /* = {} */) { + if (axes.empty()) { + return a; + } + + Shape shifts; + for (int ax : axes) { + // Convert negative axes to positive + int axis = ax < 0 ? ax + a.ndim() : ax; + if (axis < 0 || axis >= a.ndim()) { + std::ostringstream msg; + msg << "[ifftshift] Invalid axis " << ax << " for array with " << a.ndim() + << " dimensions."; + throw std::invalid_argument(msg.str()); + } + // Match NumPy's implementation + int size = a.shape(axis); + shifts.push_back(-(size / 2)); + } + + return roll(a, shifts, axes, s); +} + +// Default versions that operate on all axes +array fftshift(const array& a, StreamOrDevice s /* = {} */) { + if (a.ndim() < 1) { + return a; + } + std::vector axes(a.ndim()); + std::iota(axes.begin(), axes.end(), 0); + return fftshift(a, axes, s); +} + +array ifftshift(const array& a, StreamOrDevice s /* = {} */) { + if (a.ndim() < 1) { + return a; + } + std::vector axes(a.ndim()); + std::iota(axes.begin(), axes.end(), 0); + return ifftshift(a, axes, s); +} + } // namespace mlx::core::fft diff --git a/mlx/fft.h b/mlx/fft.h index 2f02da73b..163e06b80 100644 --- a/mlx/fft.h +++ b/mlx/fft.h @@ -145,5 +145,23 @@ inline array irfft2( StreamOrDevice s = {}) { return irfftn(a, axes, s); } +/** Shift the zero-frequency component to the center of the spectrum. */ +array fftshift(const array& a, StreamOrDevice s = {}); + +/** Shift the zero-frequency component to the center of the spectrum along + * specified axes. */ +array fftshift( + const array& a, + const std::vector& axes, + StreamOrDevice s = {}); + +/** The inverse of fftshift. */ +array ifftshift(const array& a, StreamOrDevice s = {}); + +/** The inverse of fftshift along specified axes. */ +array ifftshift( + const array& a, + const std::vector& axes, + StreamOrDevice s = {}); } // namespace mlx::core::fft diff --git a/mlx/graph_utils.cpp b/mlx/graph_utils.cpp index 29373f266..854881bc9 100644 --- a/mlx/graph_utils.cpp +++ b/mlx/graph_utils.cpp @@ -93,7 +93,7 @@ void print_graph( os << "\n"; for (auto& arr : tape) { - arr.primitive().print(os); + os << arr.primitive().name(); os << " "; print_arrs(arr.inputs()); os << " -> "; @@ -143,7 +143,7 @@ void export_to_dot( os << "{ "; os << x.primitive_id(); os << " [label =\""; - x.primitive().print(os); + os << x.primitive().name(); os << "\", shape=rectangle]"; os << "; }" << std::endl; // Arrows to primitive's inputs diff --git a/mlx/io/load.cpp b/mlx/io/load.cpp index 59d91c007..2f9053f4d 100644 --- a/mlx/io/load.cpp +++ b/mlx/io/load.cpp @@ -335,7 +335,10 @@ ThreadPool& thread_pool() { return pool_; } -ThreadPool ParallelFileReader::thread_pool_{4}; +ThreadPool& ParallelFileReader::thread_pool() { + static ThreadPool thread_pool{4}; + return thread_pool; +} void ParallelFileReader::read(char* data, size_t n) { while (n != 0) { @@ -371,7 +374,8 @@ void ParallelFileReader::read(char* data, size_t n, size_t offset) { break; } else { size_t m = batch_size_; - futs.emplace_back(thread_pool_.enqueue(readfn, offset, m, data)); + futs.emplace_back( + ParallelFileReader::thread_pool().enqueue(readfn, offset, m, data)); data += m; n -= m; offset += m; diff --git a/mlx/io/load.h b/mlx/io/load.h index 138098e82..8b5dd95b6 100644 --- a/mlx/io/load.h +++ b/mlx/io/load.h @@ -101,7 +101,7 @@ class ParallelFileReader : public Reader { private: static constexpr size_t batch_size_ = 1 << 25; - static ThreadPool thread_pool_; + static ThreadPool& thread_pool(); int fd_; std::string label_; }; diff --git a/mlx/linalg.cpp b/mlx/linalg.cpp index 5b9b51ad3..e8a9e430e 100644 --- a/mlx/linalg.cpp +++ b/mlx/linalg.cpp @@ -27,6 +27,15 @@ void check_float(Dtype dtype, const std::string& prefix) { } } +void check_float_or_complex(Dtype dtype, const std::string& prefix) { + if (dtype != float32 && dtype != float64 && dtype != complex64) { + std::ostringstream msg; + msg << prefix << " Arrays must have type float32, float64 or complex64. " + << "Received array with type " << dtype << "."; + throw std::invalid_argument(msg.str()); + } +} + Dtype at_least_float(const Dtype& d) { return issubdtype(d, inexact) ? d : promote_types(d, float32); } @@ -379,7 +388,12 @@ array pinv(const array& a, StreamOrDevice s /* = {} */) { // Prepare S S = expand_dims(S, -2, s); - return matmul(divide(V, S, s), U); + auto rcond = 10. * std::max(m, n) * finfo(a.dtype()).eps; + auto cutoff = multiply(array(rcond, a.dtype()), max(S, -1, true, s), s); + auto rS = + where(greater(S, cutoff, s), reciprocal(S, s), array(0.0f, a.dtype()), s); + + return matmul(multiply(V, rS, s), U, s); } array cholesky_inv( @@ -483,12 +497,12 @@ array cross( return concatenate(outputs, axis, s); } -void validate_eigh( +void validate_eig( const array& a, const StreamOrDevice& stream, - const std::string fname) { + const std::string& fname) { check_cpu_stream(stream, fname); - check_float(a.dtype(), fname); + check_float_or_complex(a.dtype(), fname); if (a.ndim() < 2) { std::ostringstream msg; @@ -506,11 +520,12 @@ array eigvalsh( const array& a, std::string UPLO /* = "L" */, StreamOrDevice s /* = {} */) { - validate_eigh(a, s, "[linalg::eigvalsh]"); + validate_eig(a, s, "[linalg::eigvalsh]"); Shape out_shape(a.shape().begin(), a.shape().end() - 1); + Dtype eigval_type = a.dtype() == complex64 ? float32 : a.dtype(); return array( std::move(out_shape), - a.dtype(), + eigval_type, std::make_shared(to_stream(s), UPLO, false), {a}); } @@ -519,15 +534,36 @@ std::pair eigh( const array& a, std::string UPLO /* = "L" */, StreamOrDevice s /* = {} */) { - validate_eigh(a, s, "[linalg::eigh]"); + validate_eig(a, s, "[linalg::eigh]"); + Dtype eigval_type = a.dtype() == complex64 ? float32 : a.dtype(); auto out = array::make_arrays( {Shape(a.shape().begin(), a.shape().end() - 1), a.shape()}, - {a.dtype(), a.dtype()}, + {eigval_type, a.dtype()}, std::make_shared(to_stream(s), UPLO, true), {a}); return std::make_pair(out[0], out[1]); } +array eigvals(const array& a, StreamOrDevice s /* = {} */) { + validate_eig(a, s, "[linalg::eigvals]"); + Shape out_shape(a.shape().begin(), a.shape().end() - 1); + return array( + std::move(out_shape), + complex64, + std::make_shared(to_stream(s), false), + {a}); +} + +std::pair eig(const array& a, StreamOrDevice s /* = {} */) { + validate_eig(a, s, "[linalg::eig]"); + auto out = array::make_arrays( + {Shape(a.shape().begin(), a.shape().end() - 1), a.shape()}, + {complex64, complex64}, + std::make_shared(to_stream(s), true), + {a}); + return std::make_pair(out[0], out[1]); +} + void validate_lu( const array& a, const StreamOrDevice& stream, @@ -652,7 +688,7 @@ array solve(const array& a, const array& b, StreamOrDevice s /* = {} */) { perm = expand_dims(perm, -1, s); take_axis -= 1; } - auto pb = take_along_axis(b, perm, take_axis); + auto pb = take_along_axis(b, perm, take_axis, s); auto y = solve_triangular(luf[1], pb, /* upper = */ false, s); return solve_triangular(luf[2], y, /* upper = */ true, s); } diff --git a/mlx/linalg.h b/mlx/linalg.h index 8c3a2070a..0690fba95 100644 --- a/mlx/linalg.h +++ b/mlx/linalg.h @@ -99,6 +99,10 @@ array cross( int axis = -1, StreamOrDevice s = {}); +std::pair eig(const array& a, StreamOrDevice s = {}); + +array eigvals(const array& a, StreamOrDevice s = {}); + array eigvalsh(const array& a, std::string UPLO = "L", StreamOrDevice s = {}); std::pair diff --git a/mlx/mlx.h b/mlx/mlx.h index cef8d806d..de3ee392a 100644 --- a/mlx/mlx.h +++ b/mlx/mlx.h @@ -3,6 +3,7 @@ #pragma once #include "mlx/array.h" +#include "mlx/backend/cuda/cuda.h" #include "mlx/backend/metal/metal.h" #include "mlx/compile.h" #include "mlx/device.h" diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 54ac62fef..7161a39b2 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -472,9 +472,24 @@ array hadamard_transform( const array& a, std::optional scale_ /* = std::nullopt */, StreamOrDevice s /* = {} */) { + if (a.size() == 0) { + throw std::invalid_argument( + "[hadamard_transform] Does not support empty arrays."); + } // Default to an orthonormal Hadamard matrix scaled by 1/sqrt(N) - float scale = scale_.has_value() ? *scale_ : 1.0f / std::sqrt(a.shape(-1)); + int n = a.ndim() > 0 ? a.shape(-1) : 1; + float scale = scale_.has_value() ? *scale_ : 1.0f / std::sqrt(n); auto dtype = issubdtype(a.dtype(), floating) ? a.dtype() : float32; + + // Nothing to do for a scalar + if (n == 1) { + if (scale == 1) { + return a; + } + + return multiply(a, array(scale, dtype), s); + } + return array( a.shape(), dtype, @@ -2832,6 +2847,27 @@ array matmul( "[matmul] Got 0 dimension input. Inputs must " "have at least one dimension."); } + + // complex matmul using Karatsuba's Algorithm + if (a.dtype() == complex64 || b.dtype() == complex64) { + // Extract real and imaginary parts + auto a_real = real(a, s); + auto a_imag = imag(a, s); + auto b_real = real(b, s); + auto b_imag = imag(b, s); + + // Compute real and imaginary components of the result + auto m1 = matmul(a_real, b_real, s); + auto m2 = matmul(a_imag, b_imag, s); + auto m3 = matmul(add(a_real, a_imag, s), add(b_real, b_imag, s), s); + + auto c_real = subtract(m1, m2, s); + auto c_imag = subtract(m3, add(m1, m2, s), s); + + return add( + c_real, multiply(array(complex64_t{0, 1}, complex64), c_imag, s), s); + } + if (a.ndim() == 1) { // Insert a singleton dim in the beginning a = expand_dims(a, 0, s); @@ -2847,20 +2883,9 @@ array matmul( << " second input with shape " << b.shape() << "."; throw std::invalid_argument(msg.str()); } + // Type promotion auto out_type = promote_types(a.dtype(), b.dtype()); - // Complex matmul in terms of real matmuls - if (out_type == complex64) { - auto a_real = real(a, s); - auto b_real = real(b, s); - auto a_imag = imag(a, s); - auto b_imag = imag(b, s); - auto c_real = - subtract(matmul(a_real, b_real, s), matmul(a_imag, b_imag, s), s); - auto c_imag = add(matmul(a_real, b_imag, s), matmul(a_imag, b_real, s), s); - return add( - c_real, multiply(array(complex64_t{0, 1}, complex64), c_imag, s), s); - } if (!issubdtype(out_type, floating)) { std::ostringstream msg; @@ -3160,6 +3185,10 @@ array scatter_axis( throw std::invalid_argument(msg.str()); } + if (a.size() == 0) { + return a; + } + auto upd = astype(values, a.dtype(), s); // Squeeze leading singletons out of update @@ -3565,21 +3594,21 @@ Shape conv_out_shape( if (pads_lo.size() != spatial_dims || pads_hi.size() != spatial_dims) { std::ostringstream msg; - msg << "[conv] Invalid padding " << pads_lo << " | " << pads_hi << "for " + msg << "[conv] Invalid padding " << pads_lo << " | " << pads_hi << " for " << spatial_dims << "D convolution."; throw std::invalid_argument(msg.str()); } if (kernel_dilation.size() != spatial_dims) { std::ostringstream msg; - msg << "[conv] Invalid kernel dilation " << kernel_dilation << "for " + msg << "[conv] Invalid kernel dilation " << kernel_dilation << " for " << spatial_dims << "D convolution."; throw std::invalid_argument(msg.str()); } if (input_dilation.size() != spatial_dims) { std::ostringstream msg; - msg << "[conv] Invalid input dilation " << input_dilation << "for " + msg << "[conv] Invalid input dilation " << input_dilation << " for " << spatial_dims << "D convolution."; throw std::invalid_argument(msg.str()); } @@ -3769,6 +3798,7 @@ array conv_transpose_general( std::vector stride, std::vector padding, std::vector dilation, + std::vector output_padding, int groups, StreamOrDevice s) { std::vector padding_lo(padding.size()); @@ -3782,7 +3812,8 @@ array conv_transpose_general( int in_size = 1 + (conv_output_shape - 1); int out_size = 1 + stride[i] * (input.shape(1 + i) - 1); - padding_hi[i] = in_size - out_size + padding[i]; + padding_hi[i] = in_size - out_size + padding[i] + + output_padding[i]; // Adjust with output_padding } return conv_general( @@ -3805,10 +3836,11 @@ array conv_transpose1d( int stride /* = 1 */, int padding /* = 0 */, int dilation /* = 1 */, + int output_padding /* = 0 */, int groups /* = 1 */, StreamOrDevice s /* = {} */) { return conv_transpose_general( - in_, wt_, {stride}, {padding}, {dilation}, groups, s); + in_, wt_, {stride}, {padding}, {dilation}, {output_padding}, groups, s); } /** 2D transposed convolution with a filter */ @@ -3818,6 +3850,7 @@ array conv_transpose2d( const std::pair& stride /* = {1, 1} */, const std::pair& padding /* = {0, 0} */, const std::pair& dilation /* = {1, 1} */, + const std::pair& output_padding /* = {0, 0} */, int groups /* = 1 */, StreamOrDevice s /* = {} */) { return conv_transpose_general( @@ -3826,6 +3859,7 @@ array conv_transpose2d( {stride.first, stride.second}, {padding.first, padding.second}, {dilation.first, dilation.second}, + {output_padding.first, output_padding.second}, groups, s); } @@ -3837,6 +3871,7 @@ array conv_transpose3d( const std::tuple& stride /* = {1, 1, 1} */, const std::tuple& padding /* = {0, 0, 0} */, const std::tuple& dilation /* = {1, 1, 1} */, + const std::tuple& output_padding /* = {0, 0, 0} */, int groups /* = 1 */, StreamOrDevice s /* = {} */) { return conv_transpose_general( @@ -3845,6 +3880,9 @@ array conv_transpose3d( {std::get<0>(stride), std::get<1>(stride), std::get<2>(stride)}, {std::get<0>(padding), std::get<1>(padding), std::get<2>(padding)}, {std::get<0>(dilation), std::get<1>(dilation), std::get<2>(dilation)}, + {std::get<0>(output_padding), + std::get<1>(output_padding), + std::get<2>(output_padding)}, groups, s); } @@ -3954,6 +3992,7 @@ array conv_general( to_stream(s), stride, padding_lo, + padding_hi, kernel_dilation, input_dilation, groups, @@ -4202,6 +4241,16 @@ array addmm( "have at least one dimension."); } + // Type promotion + auto out_type = result_type(a, b, c); + + if (out_type == complex64) { + return add( + multiply(matmul(a, b, s), array(alpha), s), + multiply(array(beta), c, s), + s); + } + if (a.ndim() == 1) { // Insert a singleton dim in the beginning a = expand_dims(a, 0, s); @@ -4219,16 +4268,6 @@ array addmm( throw std::invalid_argument(msg.str()); } - // Type promotion - auto out_type = result_type(a, b, c); - - if (out_type == complex64) { - return add( - multiply(matmul(a, b, s), array(alpha), s), - multiply(array(beta), c, s), - s); - } - if (!issubdtype(out_type, floating)) { std::ostringstream msg; msg << "[addmm] Only real floating point types are supported but " @@ -4305,6 +4344,10 @@ array addmm( c = reshape(c, c_reshape, s); } + if (c.shape() != out_shape) { + throw std::invalid_argument( + "[addmm] input c must broadcast to the output shape"); + } auto out = array( std::move(out_shape), @@ -4606,6 +4649,54 @@ array gather_mm( return axes.empty() ? out : squeeze(out, axes, s); } +array segmented_mm( + array a, + array b, + array segments, + StreamOrDevice s /* = {} */) { + if (a.ndim() != 2 || b.ndim() != 2) { + throw std::invalid_argument("[segmented_mm] Batched matmul not supported"); + } + + if (segments.ndim() < 1 || segments.shape().back() != 2) { + std::ostringstream msg; + msg << "[segmented_mm] The segments should have shape (..., 2) but " + << segments.shape() << " was provided."; + throw std::invalid_argument(msg.str()); + } + + // Type promotion + auto out_type = result_type(a, b); + if (!issubdtype(out_type, floating)) { + std::ostringstream msg; + msg << "[segmented_mm] Only real floating point types are supported but " + << a.dtype() << " and " << b.dtype() + << " were provided which results in " << out_type + << ", which is not a real floating point type."; + throw std::invalid_argument(msg.str()); + } + + if (!issubdtype(segments.dtype(), integer)) { + throw std::invalid_argument( + "[segmented_mm] Got segments with invalid dtype. Segments must be integral."); + } + + a = astype(a, out_type, s); + b = astype(b, out_type, s); + segments = astype(segments, uint32, s); + + Shape out_shape = segments.shape(); + out_shape.pop_back(); + out_shape.push_back(a.shape(0)); + out_shape.push_back(b.shape(1)); + + return array( + std::move(out_shape), + out_type, + std::make_shared(to_stream(s)), + {std::move(a), std::move(b), std::move(segments)}); +} + array diagonal( const array& a, int offset /* = 0 */, @@ -4873,8 +4964,9 @@ array bitwise_impl( const array& b, BitwiseBinary::Op op, const std::string& op_name, - const StreamOrDevice& s) { - auto out_type = promote_types(a.dtype(), b.dtype()); + const StreamOrDevice& s, + std::optional out_type_ = std::nullopt) { + auto out_type = out_type_ ? *out_type_ : promote_types(a.dtype(), b.dtype()); if (!(issubdtype(out_type, integer) || out_type == bool_)) { std::ostringstream msg; msg << "[" << op_name @@ -4919,12 +5011,7 @@ array left_shift(const array& a, const array& b, StreamOrDevice s /* = {} */) { if (t == bool_) { t = uint8; } - return bitwise_impl( - astype(a, t, s), - astype(b, t, s), - BitwiseBinary::Op::LeftShift, - "left_shift", - s); + return bitwise_impl(a, b, BitwiseBinary::Op::LeftShift, "left_shift", s, t); } array operator<<(const array& a, const array& b) { return left_shift(a, b); @@ -4940,7 +5027,8 @@ array right_shift(const array& a, const array& b, StreamOrDevice s /* = {} */) { astype(b, t, s), BitwiseBinary::Op::RightShift, "right_shift", - s); + s, + t); } array operator>>(const array& a, const array& b) { return right_shift(a, b); @@ -5019,8 +5107,11 @@ array roll( } auto sh = shift[i]; - auto split_index = - (sh < 0) ? (-sh) % a.shape(ax) : a.shape(ax) - sh % a.shape(ax); + auto size = a.shape(ax); + if (size == 0) { + continue; // skip rolling this axis if it has size 0 + } + auto split_index = (sh < 0) ? (-sh) % size : size - sh % size; auto parts = split(result, Shape{split_index}, ax, s); std::swap(parts[0], parts[1]); diff --git a/mlx/ops.h b/mlx/ops.h index e79ea235d..596d6d287 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -569,7 +569,7 @@ inline array std(const array& a, StreamOrDevice s = {}) { return std(a, false, 0, to_stream(s)); } -/** Computes the standard deviatoin of the elements of an array along the given +/** Computes the standard deviation of the elements of an array along the given * axes */ array std( const array& a, @@ -1291,6 +1291,7 @@ array conv_transpose1d( int stride = 1, int padding = 0, int dilation = 1, + int output_padding = 0, int groups = 1, StreamOrDevice s = {}); @@ -1301,6 +1302,7 @@ array conv_transpose2d( const std::pair& stride = {1, 1}, const std::pair& padding = {0, 0}, const std::pair& dilation = {1, 1}, + const std::pair& output_padding = {0, 0}, int groups = 1, StreamOrDevice s = {}); @@ -1311,6 +1313,7 @@ array conv_transpose3d( const std::tuple& stride = {1, 1, 1}, const std::tuple& padding = {0, 0, 0}, const std::tuple& dilation = {1, 1, 1}, + const std::tuple& output_padding = {0, 0, 0}, int groups = 1, StreamOrDevice s = {}); @@ -1403,6 +1406,12 @@ array gather_mm( bool sorted_indices = false, StreamOrDevice s = {}); +/** + * Compute a matrix product but segment the inner dimension and write the + * result separately for each segment. + */ +array segmented_mm(array a, array b, array segments, StreamOrDevice s = {}); + /** Extract a diagonal or construct a diagonal array */ array diagonal( const array& a, diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 590af60f6..cf0e6ef0d 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -109,6 +109,70 @@ std::tuple vmap_ternary_op( return {a, b, c, to_ax}; } +// Calculate the gradient wrt to the weights of the following calculation +// +// y = gather_mm(x, w.T, lhs_indices, rhs_indices, sorted) +// +// Note the transpose above. This function returns the gradient for w.T so if w +// was used instead then one needs to transpose the returned gradient. +// +// We define it as a separate function to reuse it for gather_mm and +// gather_qmm. +array gather_mm_grad( + const array& x, + const array& dy, + const array& lhs_indices, + const array& rhs_indices, + bool sorted, + Shape batch_shape, + const Stream& s) { + int M = x.shape(-2); + int K = x.shape(-1); + int N = dy.shape(-1); + int num_segments = std::accumulate( + batch_shape.begin(), batch_shape.end(), 1, std::multiplies()); + batch_shape.push_back(N); + batch_shape.push_back(K); + + // If the indices are sorted then it means that we can do the whole gradient + // computation via a segmented matmul. We just need to calculate the segments + // using the indices. + if (sorted) { + auto segments = zeros({num_segments}, uint32, s); + segments = scatter_add_axis(segments, rhs_indices, array(M, uint32), 0, s); + segments = cumsum(segments, 0, false, true, s); + segments = concatenate({array({0}, {1}, uint32), segments}, 0, s); + segments = as_strided(segments, {num_segments, 2}, {1, 1}, 0, s); + + return reshape( + segmented_mm( + swapaxes(flatten(dy, 0, -2, s), 0, 1, s), + flatten(x, 0, -2, s), + segments, + s), + std::move(batch_shape), + s); + } + + // Otherwise we need to gather matmul the dy and then scatter add it to the + // correct locations. + else { + // TODO: If the lhs indices wasn't provided, this is always a sorted matmul + // so we should add that check. + auto dw = gather_mm( + swapaxes(dy, -1, -2, s), x, std::nullopt, lhs_indices, false, s); + return reshape( + scatter_add( + zeros({num_segments, N, K}, dw.dtype(), s), + rhs_indices, + expand_dims(dw, -3, s), + 0, + s), + std::move(batch_shape), + s); + } +} + } // namespace std::vector Primitive::jvp( @@ -117,7 +181,7 @@ std::vector Primitive::jvp( const std::vector&) { std::ostringstream msg; msg << "[Primitive::jvp] Not implemented for "; - print(msg); + msg << name(); msg << "."; throw std::invalid_argument(msg.str()); } @@ -129,7 +193,7 @@ std::vector Primitive::vjp( const std::vector&) { std::ostringstream msg; msg << "[Primitive::vjp] Not implemented for "; - print(msg); + msg << name(); msg << "."; throw std::invalid_argument(msg.str()); } @@ -139,7 +203,7 @@ std::pair, std::vector> Primitive::vmap( const std::vector&) { std::ostringstream msg; msg << "[Primitive::vmap] Not implemented for "; - print(msg); + msg << name(); msg << "."; throw std::invalid_argument(msg.str()); } @@ -147,7 +211,7 @@ std::pair, std::vector> Primitive::vmap( std::vector Primitive::output_shapes(const std::vector&) { std::ostringstream msg; msg << "[Primitive::output_shapes] "; - this->print(msg); + msg << name(); msg << " cannot infer output shapes."; throw std::invalid_argument(msg.str()); } @@ -556,10 +620,11 @@ std::vector ArgReduce::vjp( } std::vector ArgReduce::jvp( + const std::vector& primals, const std::vector&, - const std::vector& tangents, const std::vector&) { - return {zeros_like(tangents[0], stream())}; + auto shape = output_shapes(primals)[0]; + return {zeros(shape, uint32, stream())}; } std::pair, std::vector> ArgSort::vmap( @@ -583,6 +648,21 @@ bool ArgSort::is_equivalent(const Primitive& other) const { return axis_ == r_other.axis_; } +std::vector ArgSort::vjp( + const std::vector& primals, + const std::vector&, + const std::vector&, + const std::vector&) { + return {zeros_like(primals[0], stream())}; +} + +std::vector ArgSort::jvp( + const std::vector& primals, + const std::vector&, + const std::vector&) { + return {zeros(primals[0].shape(), uint32, stream())}; +} + std::vector AsType::vjp( const std::vector& primals, const std::vector& cotangents, @@ -663,26 +743,6 @@ bool BitwiseBinary::is_equivalent(const Primitive& other) const { return op_ == a_other.op_; } -void BitwiseBinary::print(std::ostream& os) { - switch (op_) { - case BitwiseBinary::And: - os << "BitwiseAnd"; - break; - case BitwiseBinary::Or: - os << "BitwiseOr"; - break; - case BitwiseBinary::Xor: - os << "BitwiseXor"; - break; - case BitwiseBinary::LeftShift: - os << "LeftShift"; - break; - case BitwiseBinary::RightShift: - os << "RightShift"; - break; - } -} - std::pair, std::vector> BitwiseBinary::vmap( const std::vector& inputs, const std::vector& axes) { @@ -875,6 +935,43 @@ std::pair, std::vector> Cholesky::vmap( return {{linalg::cholesky(a, upper_, stream())}, {ax}}; } +std::pair, std::vector> Eig::vmap( + const std::vector& inputs, + const std::vector& axes) { + assert(inputs.size() == 1); + assert(axes.size() == 1); + + bool needs_move = axes[0] >= (inputs[0].ndim() - 2); + auto a = needs_move ? moveaxis(inputs[0], axes[0], 0, stream()) : inputs[0]; + auto ax = needs_move ? 0 : axes[0]; + + std::vector outputs; + if (compute_eigenvectors_) { + auto [values, vectors] = linalg::eig(a, stream()); + outputs = {values, vectors}; + } else { + outputs = {linalg::eigvals(a, stream())}; + } + + return {outputs, std::vector(outputs.size(), ax)}; +} + +std::vector Eig::output_shapes(const std::vector& inputs) { + auto shape = inputs[0].shape(); + shape.pop_back(); // Remove last dimension for eigenvalues + if (compute_eigenvectors_) { + return { + std::move(shape), inputs[0].shape()}; // Eigenvalues and eigenvectors + } else { + return {std::move(shape)}; // Only eigenvalues + } +} + +bool Eig::is_equivalent(const Primitive& other) const { + auto& e_other = static_cast(other); + return compute_eigenvectors_ == e_other.compute_eigenvectors_; +} + std::pair, std::vector> Eigh::vmap( const std::vector& inputs, const std::vector& axes) { @@ -1055,7 +1152,8 @@ array conv_weight_backward_patches( const array& wt, const array& cotan, const std::vector& kernel_strides, - const std::vector& padding, + const std::vector& padding_lo, + const std::vector& padding_hi, StreamOrDevice s) { // Resolve Padded input shapes and strides Shape padding_starts(in.ndim(), 0); @@ -1064,9 +1162,9 @@ array conv_weight_backward_patches( // padded shape for (int i = 1; i < in.ndim() - 1; i++) { - in_padded_shape[i] += 2 * padding[i - 1]; - padding_ends[i] += padding[i - 1]; - padding_starts[i] += padding[i - 1]; + in_padded_shape[i] += padding_lo[i - 1] + padding_hi[i - 1]; + padding_ends[i] += padding_lo[i - 1]; + padding_starts[i] += padding_lo[i - 1]; } // padded strides (contiguous) @@ -1078,9 +1176,14 @@ array conv_weight_backward_patches( // Pad input std::vector padded_axes(in.ndim() - 2, 0); std::iota(padded_axes.begin(), padded_axes.end(), 1); - Shape padding_(padding.begin(), padding.end()); - auto in_padded = pad( - in, padded_axes, padding_, padding_, array(0, in.dtype()), "constant", s); + auto in_padded = + pad(in, + padded_axes, + Shape(padding_lo), + Shape(padding_hi), + array(0, in.dtype()), + "constant", + s); // Resolve strided patches @@ -1147,16 +1250,16 @@ std::vector Convolution::vjp( for (int a : argnums) { // Grads for input if (a == 0) { - std::vector padding_lo = padding_; - std::vector padding_hi = padding_; + std::vector padding_lo = padding_lo_; + std::vector padding_hi = padding_hi_; for (int i = 0; i < padding_lo.size(); ++i) { int wt_size = 1 + kernel_dilation_[i] * (wt.shape(1 + i) - 1); - padding_lo[i] = wt_size - padding_[i] - 1; + padding_lo[i] = wt_size - padding_lo_[i] - 1; int in_size = 1 + input_dilation_[i] * (in.shape(1 + i) - 1); int out_size = 1 + kernel_strides_[i] * (cotan.shape(1 + i) - 1); - padding_hi[i] = in_size - out_size + padding_[i]; + padding_hi[i] = in_size - out_size + padding_hi_[i]; } // Check for negative padding @@ -1226,18 +1329,18 @@ std::vector Convolution::vjp( if (no_dilation && !flip_ && groups_ == 1) { auto grad = conv_weight_backward_patches( - in, wt, cotan, kernel_strides_, padding_, stream()); + in, wt, cotan, kernel_strides_, padding_lo_, padding_hi_, stream()); grads.push_back(grad); } else { - std::vector padding_lo = padding_; - std::vector padding_hi = padding_; + auto padding_hi = padding_lo_; for (int i = 0; i < padding_hi.size(); ++i) { int in_size = 1 + input_dilation_[i] * (in.shape(1 + i) - 1); int out_size = 1 + kernel_strides_[i] * (cotan.shape(1 + i) - 1); int wt_size = 1 + kernel_dilation_[i] * (wt.shape(1 + i) - 1); - padding_hi[i] = out_size - in_size + wt_size - padding_[i] - 1; + padding_hi[i] = out_size - in_size + wt_size - padding_hi[i] - 1; } + auto cotan_trans = swapaxes(cotan, 0, -1, stream()); auto in_trans = group_transpose(in, -1, 0, -1); @@ -1245,7 +1348,7 @@ std::vector Convolution::vjp( /* const array& input = */ in_trans, /* const array& weight = */ cotan_trans, /* std::vector stride = */ kernel_dilation_, - /* std::vector padding_lo = */ padding_lo, + /* std::vector padding_lo = */ padding_lo_, /* std::vector padding_hi = */ padding_hi, /* std::vector kernel_dilation = */ kernel_strides_, /* std::vector input_dilation = */ input_dilation_, @@ -1275,9 +1378,66 @@ std::vector Convolution::vjp( return grads; } +std::pair, std::vector> Convolution::vmap( + const std::vector& inputs, + const std::vector& axes) { + auto do_conv = [&](const array& in, const array& w, int groups) { + return conv_general( + in, + w, + kernel_strides_, + padding_lo_, + padding_hi_, + kernel_dilation_, + input_dilation_, + groups, + flip_, + stream()); + }; + bool in_vmap = axes[0] >= 0; + bool w_vmap = axes[1] >= 0; + auto in = inputs[0]; + auto w = inputs[1]; + if (in_vmap && !w_vmap) { + // flatten / unflatten the batch dimension + // of the input / output + if (axes[0] > 0) { + in = moveaxis(in, axes[0], 0, stream()); + } + auto out = do_conv(flatten(in, 0, 1, stream()), w, groups_); + out = unflatten(out, 0, {in.shape(0), in.shape(1)}, stream()); + return {{out}, {0}}; + } else if (!in_vmap && w_vmap) { + // flatten into the output channels of w + // unflatten the channels of the output + if (axes[1] > 0) { + w = moveaxis(w, axes[1], 0, stream()); + } + auto out = do_conv(in, flatten(w, 0, 1, stream()), groups_); + out = unflatten(out, -1, {w.shape(0), w.shape(1)}, stream()); + return {{out}, {static_cast(out.ndim() - 2)}}; + } else if (in_vmap && w_vmap) { + // use a group convolution when both inputs are vmapped + auto b = in.shape(axes[0]); + in = moveaxis(in, axes[0], -2, stream()); + in = flatten(in, -2, -1, stream()); + if (axes[1] > 0) { + w = moveaxis(w, axes[1], 0, stream()); + } + auto c_out = w.shape(1); + w = flatten(w, 0, 1, stream()); + auto out = do_conv(in, w, groups_ * b); + out = unflatten(out, -1, {b, c_out}, stream()); + return {{out}, {static_cast(out.ndim() - 2)}}; + } else { + return {{do_conv(in, w, groups_)}, {-1}}; + } +} + bool Convolution::is_equivalent(const Primitive& other) const { const Convolution& c_other = static_cast(other); - return padding_ == c_other.padding_ && + return padding_lo_ == c_other.padding_lo_ && + padding_hi_ == c_other.padding_hi_ && kernel_strides_ == c_other.kernel_strides_ && kernel_dilation_ == c_other.kernel_dilation_ && input_dilation_ == c_other.input_dilation_ && @@ -1429,14 +1589,16 @@ std::vector Divide::vjp( const std::vector& argnums, const std::vector&) { std::vector vjps; + array denominator_bar = conjugate(primals[1], stream()); for (auto arg : argnums) { if (arg == 0) { - vjps.push_back(divide(cotangents[0], primals[1], stream())); + vjps.push_back(divide(cotangents[0], denominator_bar, stream())); } else { vjps.push_back(negative( divide( - multiply(cotangents[0], primals[0], stream()), - square(primals[1], stream()), + multiply( + cotangents[0], conjugate(primals[0], stream()), stream()), + square(denominator_bar, stream()), stream()), stream())); } @@ -1891,30 +2053,74 @@ std::vector FFT::vjp( assert(argnums.size() == 1); auto& in = primals[0]; std::vector axes(axes_.begin(), axes_.end()); + + // TODO: Add it as an option to do an unnormalized or scaled fft so that this + // isn't part of the graph. + double n_elements = 1; + for (auto ax : axes) { + n_elements *= inverse_ ? cotangents[0].shape(ax) : primals[0].shape(ax); + } + if (real_ && inverse_) { - auto out = fft::fftn(cotangents[0], axes, stream()); - auto start = Shape(out.ndim(), 0); - auto stop = in.shape(); - out = slice(out, start, stop, stream()); - auto mask_shape = out.shape(); - mask_shape[axes_.back()] -= 2; - auto mask = full(mask_shape, 2.0f, stream()); - auto pad_shape = out.shape(); - pad_shape[axes_.back()] = 1; - auto pad = full(pad_shape, 1.0f, stream()); - mask = concatenate({pad, mask, pad}, axes_.back(), stream()); - return {multiply(mask, out, stream())}; + // Make a mask to account for the double use in the forward pass. + // Everything except the DC and nyquist frequencies gets doubled. + int N = in.shape(axes_.back()); + bool odd = cotangents[0].shape(axes_.back()) % 2; + Shape c(in.ndim(), 1); + c[axes_.back()] = N; + array indices = reshape(arange(N, stream()), std::move(c), stream()); + array first(0, indices.dtype()); + array last(N - 1 + odd, indices.dtype()); + array one(1 / n_elements, in.dtype()); + array two(2 / n_elements, in.dtype()); + array mask = where( + logical_and( + greater(indices, first, stream()), + less(indices, last, stream()), + stream()), + two, + one, + stream()); + return { + multiply(fft::rfftn(cotangents[0], axes, stream()), mask, stream())}; } else if (real_) { Shape n; for (auto ax : axes_) { - n.push_back(in.shape()[ax]); + n.push_back(in.shape(ax)); } - return {astype( - fft::fftn(cotangents[0], n, axes, stream()), in.dtype(), stream())}; + // Make a mask to account for the double use in the forward pass. + // Everything except the DC and nyquist frequencies gets halved. + int N = cotangents[0].shape(axes_.back()); + bool odd = in.shape(axes_.back()) % 2; + Shape c(in.ndim(), 1); + c[axes_.back()] = N; + array indices = reshape(arange(N, stream()), std::move(c), stream()); + array first(0, indices.dtype()); + array last(N - 1 + odd, indices.dtype()); + array one(1, complex64); + array half(0.5, complex64); + array mask = where( + logical_and( + greater(indices, first, stream()), + less(indices, last, stream()), + stream()), + half, + one, + stream()); + return {multiply( + fft::irfftn(multiply(cotangents[0], mask, stream()), n, axes, stream()), + array(n_elements, in.dtype()), + stream())}; } else if (inverse_) { - return {fft::ifftn(cotangents[0], axes, stream())}; + return {multiply( + fft::fftn(cotangents[0], axes, stream()), + array(1 / n_elements, complex64), + stream())}; } else { - return {fft::fftn(cotangents[0], axes, stream())}; + return {multiply( + fft::ifftn(cotangents[0], axes, stream()), + array(n_elements, complex64), + stream())}; } } @@ -2233,7 +2439,7 @@ std::vector Imag::vjp( assert(primals.size() == 1); assert(argnums.size() == 1); return {multiply( - array(complex64_t{0.0f, -1.0f}, primals[0].dtype()), + array(complex64_t{0.0f, 1.0f}, primals[0].dtype()), cotangents[0], stream())}; } @@ -2562,15 +2768,19 @@ std::vector Matmul::vjp( std::vector reorder(cotan.ndim()); std::iota(reorder.begin(), reorder.end(), 0); std::iter_swap(reorder.end() - 1, reorder.end() - 2); + auto& s = stream(); + + auto complex_transpose = [&](const array& x) { + return transpose(conjugate(x, s), reorder, s); + }; + for (auto arg : argnums) { if (arg == 0) { // M X N * (K X N).T -> M X K - vjps.push_back( - matmul(cotan, transpose(primals[1], reorder, stream()), stream())); + vjps.push_back(matmul(cotan, complex_transpose(primals[1]), s)); } else { // (M X K).T * M X N -> K X N - vjps.push_back( - matmul(transpose(primals[0], reorder, stream()), cotan, stream())); + vjps.push_back(matmul(complex_transpose(primals[0]), cotan, s)); } } return vjps; @@ -2717,7 +2927,8 @@ std::vector Multiply::vjp( const std::vector&) { std::vector vjps; for (auto arg : argnums) { - vjps.push_back(multiply(primals[1 - arg], cotangents[0], stream())); + vjps.push_back(multiply( + conjugate(primals[1 - arg], stream()), cotangents[0], stream())); } return vjps; } @@ -3001,6 +3212,7 @@ std::vector QuantizedMatmul::vjp( std::vector vjps; // We rely on the fact that w is always 2D so transpose is simple + std::optional dsb = std::nullopt; for (auto arg : argnums) { // gradient wrt to x if (arg == 0) { @@ -3016,9 +3228,34 @@ std::vector QuantizedMatmul::vjp( } // gradient wrt to w_q, scales or biases - else { + else if (arg == 1) { throw std::runtime_error( - "[QuantizedMatmul::vjp] no gradient wrt the quantized matrix yet."); + "[QuantizedMatmul::vjp] no gradient wrt the quantized weights."); + } else { + if (!dsb) { + int ndim = primals[1].ndim(); + auto fc = flatten(cotangents[0], 0, -ndim, stream()); + auto fx = flatten(primals[0], 0, -ndim, stream()); + auto dw = transpose_ + ? matmul(swapaxes(fc, -1, -2, stream()), fx, stream()) + : matmul(swapaxes(fx, -1, -2, stream()), fc, stream()); + dsb = unflatten(dw, -1, {-1, group_size_}, stream()); + } + if (arg == 3) { + // biases + vjps.push_back(sum(*dsb, -1, false, stream())); + } else { + // scales + auto wq = dequantize( + primals[1], + ones_like(primals[2], stream()), + zeros_like(primals[3], stream()), + group_size_, + bits_, + stream()); + wq = unflatten(wq, -1, {-1, group_size_}, stream()); + vjps.push_back(sum(multiply(*dsb, wq, stream()), -1, false, stream())); + } } } return vjps; @@ -3080,34 +3317,42 @@ std::vector GatherQMM::vjp( auto& lhs_indices = primals[4]; auto& rhs_indices = primals[5]; + int M = cotan.shape(-2); + int N = cotan.shape(-1); + int K = x.shape(-1); + bool sorted = left_sorted_ || right_sorted_; + bool no_broadcast = rhs_indices.size() * M * K == x.size(); + std::optional dsb = std::nullopt; for (auto arg : argnums) { // gradient wrt to x if (arg == 0) { - vjps.push_back(reshape( - scatter_add( - flatten(zeros_like(x, stream()), 0, -3, stream()), - lhs_indices, - expand_dims( - gather_qmm( - cotan, - w, - scales, - biases, - std::nullopt, - rhs_indices, - !transpose_, - group_size_, - bits_, - sorted, - stream()), - -3, - stream()), - 0, - stream()), - x.shape(), - stream())); + auto g = gather_qmm( + cotan, + w, + scales, + biases, + std::nullopt, + rhs_indices, + !transpose_, + group_size_, + bits_, + sorted, + stream()); + if (sorted && no_broadcast) { + vjps.push_back(g); + } else { + vjps.push_back(reshape( + scatter_add( + flatten(zeros_like(x, stream()), 0, -3, stream()), + lhs_indices, + expand_dims(g, -3, stream()), + 0, + stream()), + x.shape(), + stream())); + } } // gradient wrt to the indices is undefined @@ -3117,9 +3362,49 @@ std::vector GatherQMM::vjp( } // gradient wrt to w_q, scales or biases - else { + else if (arg == 1) { throw std::runtime_error( - "GatherQMM::vjp no gradient wrt the quantized matrix yet."); + "GatherQMM::vjp no gradient wrt the quantized weights."); + } else { + if (!dsb) { + auto shape = w.shape(); + shape.pop_back(); + shape.pop_back(); + dsb = unflatten( + gather_mm_grad( + x, + cotan, + lhs_indices, + rhs_indices, + sorted, + std::move(shape), + stream()), + -1, + {-1, group_size_}, + stream()); + } + if (arg == 3) { + vjps.push_back(sum(*dsb, -1, false, stream())); + } else { + vjps.push_back( + sum(multiply( + *dsb, + unflatten( + dequantize( + w, + ones_like(scales, stream()), + zeros_like(biases, stream()), + group_size_, + bits_, + stream()), + -1, + {-1, group_size_}, + stream()), + stream()), + -1, + false, + stream())); + } } } return vjps; @@ -3375,7 +3660,7 @@ std::vector Reduce::vjp( } else { - throw std::runtime_error("Reduce type VJP not yet implemented."); + return {zeros_like(in, stream())}; } } @@ -4891,6 +5176,8 @@ std::vector GatherMM::vjp( std::vector vjps; auto& cotan = cotangents[0]; + auto& a = primals[0]; + auto& b = primals[1]; auto& lhs_indices = primals[2]; auto& rhs_indices = primals[3]; @@ -4899,39 +5186,46 @@ std::vector GatherMM::vjp( int K = primals[0].shape(-1); bool sorted = left_sorted_ || right_sorted_; + bool no_broadcast = rhs_indices.size() * M * K == primals[0].size(); for (auto arg : argnums) { if (arg == 0) { - // M X N * (K X N).T -> M X K - auto base = zeros_like(primals[0], stream()); - auto bt = swapaxes(primals[1], -1, -2, stream()); - - auto base_shape = base.shape(); - base = reshape(base, {-1, M, K}, stream()); - - // g : (out_batch_shape) + (M, K) - auto g = - gather_mm(cotan, bt, std::nullopt, rhs_indices, sorted, stream()); - g = expand_dims(g, -3, stream()); - auto gacc = scatter_add(base, lhs_indices, g, 0, stream()); - - vjps.push_back(reshape(gacc, base_shape, stream())); - + auto g = gather_mm( + cotan, + swapaxes(b, -1, -2, stream()), + std::nullopt, + rhs_indices, + sorted, + stream()); + if (sorted && no_broadcast) { + vjps.push_back(g); + } else { + vjps.push_back(reshape( + scatter_add( + flatten(zeros_like(a, stream()), 0, -3, stream()), + lhs_indices, + expand_dims(g, -3, stream()), + 0, + stream()), + a.shape(), + stream())); + } } else if (arg == 1) { - // (M X K).T * M X N -> K X N - auto base = zeros_like(primals[1], stream()); - auto at = swapaxes(primals[0], -1, -2, stream()); - - auto base_shape = base.shape(); - base = reshape(base, {-1, K, N}, stream()); - - // g : (out_batch_shape) + (K, N) - auto g = - gather_mm(at, cotan, lhs_indices, std::nullopt, sorted, stream()); - g = expand_dims(g, -3, stream()); - auto gacc = scatter_add(base, rhs_indices, g, 0, stream()); - - vjps.push_back(reshape(gacc, base_shape, stream())); + auto shape = b.shape(); + shape.pop_back(); + shape.pop_back(); + vjps.push_back(swapaxes( + gather_mm_grad( + a, + cotan, + lhs_indices, + rhs_indices, + sorted, + std::move(shape), + stream()), + -1, + -2, + stream())); } else { throw std::invalid_argument( "[GatherMM] Cannot calculate VJP with respect to indices."); @@ -5061,8 +5355,13 @@ std::pair, std::vector> View::vmap( return {{view(inputs[0], dtype_, stream())}, axes}; } -void View::print(std::ostream& os) { - os << "View " << dtype_; +const char* View::name() const { + if (name_.empty()) { + std::ostringstream os; + os << "View " << dtype_; + name_ = os.str(); + } + return name_.c_str(); } bool View::is_equivalent(const Primitive& other) const { diff --git a/mlx/primitives.h b/mlx/primitives.h index 997931f30..d482a1bf9 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -26,9 +26,9 @@ const std::vector& argnums, \ const std::vector& outputs) override; -#define DEFINE_PRINT(PRIMITIVE) \ - void print(std::ostream& os) override { \ - os << #PRIMITIVE; \ +#define DEFINE_NAME(PRIMITIVE) \ + const char* name() const override { \ + return #PRIMITIVE; \ } #define DEFINE_DEFAULT_IS_EQUIVALENT() \ @@ -100,8 +100,8 @@ class Primitive { const std::vector& inputs, const std::vector& axes); - /** Print the primitive. */ - virtual void print(std::ostream& os) = 0; + /** Get the name of primitive. */ + virtual const char* name() const = 0; /** Equivalence check defaults to false unless overridden by the primitive */ virtual bool is_equivalent(const Primitive& other) const { @@ -160,7 +160,7 @@ class Abs : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Abs) + DEFINE_NAME(Abs) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -174,7 +174,7 @@ class Add : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Add) + DEFINE_NAME(Add) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -189,7 +189,7 @@ class AddMM : public UnaryPrimitive { DEFINE_GRADS() DEFINE_VMAP() - DEFINE_PRINT(AddMM) + DEFINE_NAME(AddMM) bool is_equivalent(const Primitive& other) const override; std::pair state() const { @@ -209,7 +209,7 @@ class Arange : public UnaryPrimitive { void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; - DEFINE_PRINT(Arange) + DEFINE_NAME(Arange) bool is_equivalent(const Primitive& other) const override; std::vector output_shapes(const std::vector& inputs) override; std::tuple state() const { @@ -231,7 +231,7 @@ class ArcCos : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(ArcCos) + DEFINE_NAME(ArcCos) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -245,7 +245,7 @@ class ArcCosh : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(ArcCosh) + DEFINE_NAME(ArcCosh) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -259,7 +259,7 @@ class ArcSin : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(ArcSin) + DEFINE_NAME(ArcSin) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -273,7 +273,7 @@ class ArcSinh : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(ArcSinh) + DEFINE_NAME(ArcSinh) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -287,7 +287,7 @@ class ArcTan : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(ArcTan) + DEFINE_NAME(ArcTan) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -301,7 +301,7 @@ class ArcTan2 : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(ArcTan2) + DEFINE_NAME(ArcTan2) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -315,7 +315,7 @@ class ArcTanh : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(ArcTanh) + DEFINE_NAME(ArcTanh) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -330,7 +330,7 @@ class ArgPartition : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(ArgPartition) + DEFINE_NAME(ArgPartition) DEFINE_INPUT_OUTPUT_SHAPE() bool is_equivalent(const Primitive& other) const override; std::pair state() const { @@ -357,7 +357,7 @@ class ArgReduce : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(ArgReduce) + DEFINE_NAME(ArgReduce) bool is_equivalent(const Primitive& other) const override; std::vector output_shapes(const std::vector& inputs) override; std::pair state() const { @@ -378,7 +378,8 @@ class ArgSort : public UnaryPrimitive { void eval_gpu(const std::vector& inputs, array& out) override; DEFINE_VMAP() - DEFINE_PRINT(ArgSort) + DEFINE_GRADS() + DEFINE_NAME(ArgSort) DEFINE_INPUT_OUTPUT_SHAPE() bool is_equivalent(const Primitive& other) const override; int state() const { @@ -399,7 +400,7 @@ class AsType : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(AsType) + DEFINE_NAME(AsType) DEFINE_INPUT_OUTPUT_SHAPE() bool is_equivalent(const Primitive& other) const override; Dtype state() const { @@ -422,7 +423,7 @@ class AsStrided : public UnaryPrimitive { void eval_gpu(const std::vector& inputs, array& out) override; DEFINE_GRADS() - DEFINE_PRINT(AsStrided) + DEFINE_NAME(AsStrided) bool is_equivalent(const Primitive& other) const override; auto state() const { return std::make_tuple(shape_, strides_, offset_); @@ -448,8 +449,24 @@ class BitwiseBinary : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() + + const char* name() const override { + switch (op_) { + case BitwiseBinary::And: + return "BitwiseAnd"; + case BitwiseBinary::Or: + return "BitwiseOr"; + case BitwiseBinary::Xor: + return "BitwiseXor"; + case BitwiseBinary::LeftShift: + return "LeftShift"; + case BitwiseBinary::RightShift: + return "RightShift"; + } + return ""; + } + bool is_equivalent(const Primitive& other) const override; - void print(std::ostream& os) override; DEFINE_INPUT_OUTPUT_SHAPE() auto state() const { return op_; @@ -467,7 +484,7 @@ class BitwiseInvert : public UnaryPrimitive { void eval_gpu(const std::vector& inputs, array& out) override; DEFINE_VMAP() - DEFINE_PRINT(BitwiseInvert) + DEFINE_NAME(BitwiseInvert) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -486,7 +503,7 @@ class BlockMaskedMM : public UnaryPrimitive { const std::vector& argnums, const std::vector& outputs) override; - DEFINE_PRINT(BlockMaskedMM) + DEFINE_NAME(BlockMaskedMM) bool is_equivalent(const Primitive& other) const override; auto state() const { return block_size_; @@ -515,7 +532,7 @@ class GatherMM : public UnaryPrimitive { const std::vector& argnums, const std::vector& outputs) override; - DEFINE_PRINT(GatherMM) + DEFINE_NAME(GatherMM) bool is_equivalent(const Primitive& other) const override; auto state() const { return std::make_pair(left_sorted_, right_sorted_); @@ -526,6 +543,16 @@ class GatherMM : public UnaryPrimitive { bool right_sorted_; }; +class SegmentedMM : public UnaryPrimitive { + public: + explicit SegmentedMM(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_NAME(SegmentedMM) +}; + class BroadcastAxes : public UnaryPrimitive { public: explicit BroadcastAxes(Stream stream, std::vector ignore_axes = {}) @@ -536,7 +563,7 @@ class BroadcastAxes : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(BroadcastAxes) + DEFINE_NAME(BroadcastAxes) bool is_equivalent(const Primitive& other) const override; static Shape output_shape( const std::vector& inputs, @@ -561,7 +588,7 @@ class Broadcast : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Broadcast) + DEFINE_NAME(Broadcast) static Shape output_shape(const std::vector& inputs); std::vector output_shapes(const std::vector& inputs) override; bool is_equivalent(const Primitive& other) const override; @@ -584,7 +611,7 @@ class Ceil : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Ceil) + DEFINE_NAME(Ceil) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -614,8 +641,8 @@ class Compiled : public Primitive { DEFINE_VMAP() DEFINE_GRADS() + const char* name() const override; std::vector output_shapes(const std::vector& inputs) override; - void print(std::ostream& os) override; bool is_equivalent(const Primitive& other) const override; std::string lib_name() const { @@ -627,7 +654,9 @@ class Compiled : public Primitive { const std::vector outputs_; const std::vector tape_; const std::unordered_set constant_ids_; + const std::function is_constant_; + mutable std::string name_; std::string kernel_lib_; }; @@ -641,7 +670,7 @@ class Concatenate : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Concatenate) + DEFINE_NAME(Concatenate) bool is_equivalent(const Primitive& other) const override; std::vector output_shapes(const std::vector& inputs) override; auto state() const { @@ -660,7 +689,7 @@ class Conjugate : public UnaryPrimitive { void eval_gpu(const std::vector& inputs, array& out) override; DEFINE_VMAP() - DEFINE_PRINT(Conjugate) + DEFINE_NAME(Conjugate) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -675,7 +704,7 @@ class Contiguous : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Contiguous) + DEFINE_NAME(Contiguous) DEFINE_INPUT_OUTPUT_SHAPE() bool is_equivalent(const Primitive& other) const override; @@ -689,13 +718,15 @@ class Convolution : public UnaryPrimitive { explicit Convolution( Stream stream, const std::vector& kernel_strides, - const std::vector& padding, + const std::vector& padding_lo, + const std::vector& padding_hi, const std::vector& kernel_dilation, const std::vector& input_dilation, const int groups = 1, const bool flip = false) : UnaryPrimitive(stream), - padding_(padding), + padding_lo_(padding_lo), + padding_hi_(padding_hi), kernel_strides_(kernel_strides), kernel_dilation_(kernel_dilation), input_dilation_(input_dilation), @@ -711,12 +742,14 @@ class Convolution : public UnaryPrimitive { const std::vector& argnums, const std::vector& outputs) override; - DEFINE_PRINT(Convolution) + DEFINE_VMAP() + DEFINE_NAME(Convolution) bool is_equivalent(const Primitive& other) const override; auto state() const { return std::make_tuple( - padding_, kernel_strides_, + padding_lo_, + padding_hi_, kernel_dilation_, input_dilation_, groups_, @@ -724,7 +757,8 @@ class Convolution : public UnaryPrimitive { } private: - std::vector padding_; + std::vector padding_lo_; + std::vector padding_hi_; std::vector kernel_strides_; std::vector kernel_dilation_; std::vector input_dilation_; @@ -741,7 +775,7 @@ class Copy : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Copy) + DEFINE_NAME(Copy) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() @@ -758,7 +792,7 @@ class Cos : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Cos) + DEFINE_NAME(Cos) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -772,7 +806,7 @@ class Cosh : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Cosh) + DEFINE_NAME(Cosh) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -806,7 +840,7 @@ class CustomTransforms : public Primitive { DEFINE_GRADS(); DEFINE_VMAP(); - DEFINE_PRINT(CustomTransforms); + DEFINE_NAME(CustomTransforms); private: void eval(const std::vector& inputs, std::vector& outputs); @@ -844,7 +878,7 @@ class Depends : public Primitive { const std::vector& argnums, const std::vector& outputs) override; - DEFINE_PRINT(Depends); + DEFINE_NAME(Depends); private: void eval(const std::vector& inputs, std::vector& outputs); @@ -859,7 +893,7 @@ class Divide : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Divide) + DEFINE_NAME(Divide) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -875,7 +909,7 @@ class DivMod : public Primitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(DivMod) + DEFINE_NAME(DivMod) DEFINE_DEFAULT_IS_EQUIVALENT() std::vector output_shapes(const std::vector& inputs) override { return std::vector{inputs[0].shape(), inputs[0].shape()}; @@ -891,7 +925,7 @@ class Select : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Select) + DEFINE_NAME(Select) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -905,7 +939,7 @@ class Remainder : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Remainder) + DEFINE_NAME(Remainder) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -923,11 +957,11 @@ class Equal : public UnaryPrimitive { DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() - void print(std::ostream& os) override { + const char* name() const override { if (equal_nan_) { - os << "NaNEqual"; + return "NaNEqual"; } else { - os << "Equal"; + return "Equal"; } } auto state() const { @@ -947,7 +981,7 @@ class Erf : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Erf) + DEFINE_NAME(Erf) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -961,7 +995,7 @@ class ErfInv : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(ErfInv) + DEFINE_NAME(ErfInv) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -975,7 +1009,7 @@ class Exp : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Exp) + DEFINE_NAME(Exp) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -989,7 +1023,7 @@ class Expm1 : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Expm1) + DEFINE_NAME(Expm1) DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -1003,7 +1037,7 @@ class ExpandDims : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(ExpandDims) + DEFINE_NAME(ExpandDims) std::vector output_shapes(const std::vector& inputs) override; bool is_equivalent(const Primitive& other) const override; @@ -1032,7 +1066,7 @@ class FFT : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(FFT) + DEFINE_NAME(FFT) bool is_equivalent(const Primitive& other) const override; auto state() const { @@ -1055,7 +1089,7 @@ class Flatten : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Flatten) + DEFINE_NAME(Flatten) std::vector output_shapes(const std::vector& inputs) override; bool is_equivalent(const Primitive& other) const override; @@ -1079,7 +1113,7 @@ class Floor : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Floor) + DEFINE_NAME(Floor) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -1093,7 +1127,7 @@ class Full : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Full) + DEFINE_NAME(Full) DEFINE_DEFAULT_IS_EQUIVALENT() }; @@ -1109,7 +1143,7 @@ class Gather : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Gather) + DEFINE_NAME(Gather) bool is_equivalent(const Primitive& other) const override; std::vector output_shapes(const std::vector& inputs) override; std::pair, std::vector> state() const { @@ -1131,7 +1165,7 @@ class GatherAxis : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(GatherAxis) + DEFINE_NAME(GatherAxis) bool is_equivalent(const Primitive& other) const override; std::vector output_shapes(const std::vector& inputs) override; auto state() const { @@ -1151,7 +1185,7 @@ class Greater : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Greater) + DEFINE_NAME(Greater) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -1165,7 +1199,7 @@ class GreaterEqual : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(GreaterEqual) + DEFINE_NAME(GreaterEqual) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -1180,7 +1214,7 @@ class Hadamard : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Hadamard) + DEFINE_NAME(Hadamard) DEFINE_INPUT_OUTPUT_SHAPE() bool is_equivalent(const Primitive& other) const override; @@ -1201,7 +1235,7 @@ class Imag : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Imag) + DEFINE_NAME(Imag) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -1215,7 +1249,7 @@ class Less : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Less) + DEFINE_NAME(Less) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -1229,7 +1263,7 @@ class LessEqual : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(LessEqual) + DEFINE_NAME(LessEqual) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -1249,7 +1283,7 @@ class Load : public UnaryPrimitive { void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; - DEFINE_PRINT(Load) + DEFINE_NAME(Load) private: std::shared_ptr reader_; @@ -1276,18 +1310,16 @@ class Log : public UnaryPrimitive { return base_; }; - void print(std::ostream& os) override { + const char* name() const override { switch (base_) { case e: - os << "Log"; - break; + return "Log"; case two: - os << "Log2"; - break; + return "Log2"; case ten: - os << "Log10"; - break; + return "Log10"; } + return ""; } private: @@ -1303,7 +1335,7 @@ class Log1p : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Log1p) + DEFINE_NAME(Log1p) DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -1316,7 +1348,7 @@ class LogicalNot : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(LogicalNot) + DEFINE_NAME(LogicalNot) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -1330,7 +1362,7 @@ class LogicalAnd : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(LogicalAnd) + DEFINE_NAME(LogicalAnd) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -1344,7 +1376,7 @@ class LogicalOr : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(LogicalOr) + DEFINE_NAME(LogicalOr) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -1358,7 +1390,7 @@ class LogAddExp : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(LogAddExp) + DEFINE_NAME(LogAddExp) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -1372,7 +1404,7 @@ class LogSumExp : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(LogSumExp) + DEFINE_NAME(LogSumExp) DEFINE_DEFAULT_IS_EQUIVALENT() std::vector output_shapes(const std::vector& inputs) override; }; @@ -1386,7 +1418,7 @@ class Matmul : public UnaryPrimitive { DEFINE_GRADS() DEFINE_VMAP() - DEFINE_PRINT(Matmul) + DEFINE_NAME(Matmul) DEFINE_DEFAULT_IS_EQUIVALENT() std::vector output_shapes(const std::vector& inputs) override; }; @@ -1400,7 +1432,7 @@ class Maximum : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Maximum) + DEFINE_NAME(Maximum) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -1414,7 +1446,7 @@ class Minimum : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Minimum) + DEFINE_NAME(Minimum) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -1428,7 +1460,7 @@ class Multiply : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Multiply) + DEFINE_NAME(Multiply) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -1442,7 +1474,7 @@ class Negative : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Negative) + DEFINE_NAME(Negative) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -1456,7 +1488,7 @@ class NotEqual : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(NotEqual) + DEFINE_NAME(NotEqual) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -1477,7 +1509,7 @@ class NumberOfElements : public UnaryPrimitive { void eval_gpu(const std::vector& inputs, array& out) override; DEFINE_VMAP() - DEFINE_PRINT(NumberOfElements) + DEFINE_NAME(NumberOfElements) bool is_equivalent(const Primitive& other) const override; std::vector output_shapes(const std::vector& inputs) override { return {{}}; @@ -1511,7 +1543,7 @@ class Pad : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Pad) + DEFINE_NAME(Pad) bool is_equivalent(const Primitive& other) const override; auto state() const { return std::make_tuple(axes_, low_pad_size_, high_pad_size_); @@ -1533,7 +1565,7 @@ class Partition : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Partition) + DEFINE_NAME(Partition) DEFINE_INPUT_OUTPUT_SHAPE() bool is_equivalent(const Primitive& other) const override; auto state() const { @@ -1554,7 +1586,7 @@ class Power : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Power) + DEFINE_NAME(Power) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -1576,7 +1608,7 @@ class QuantizedMatmul : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(QuantizedMatmul) + DEFINE_NAME(QuantizedMatmul) bool is_equivalent(const Primitive& other) const override; std::vector output_shapes(const std::vector& inputs) override; auto state() const { @@ -1610,7 +1642,7 @@ class GatherQMM : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(GatherQMM) + DEFINE_NAME(GatherQMM) bool is_equivalent(const Primitive& other) const override; auto state() const { return std::make_tuple( @@ -1634,7 +1666,7 @@ class RandomBits : public UnaryPrimitive { void eval_gpu(const std::vector& inputs, array& out) override; DEFINE_VMAP() - DEFINE_PRINT(RandomBits) + DEFINE_NAME(RandomBits) bool is_equivalent(const Primitive& other) const override; std::pair, int> state() const { return {shape_, width_}; @@ -1654,7 +1686,7 @@ class Real : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Real) + DEFINE_NAME(Real) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -1669,7 +1701,7 @@ class Reshape : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Reshape) + DEFINE_NAME(Reshape) bool is_equivalent(const Primitive& other) const override; std::vector state() const { return shape_; @@ -1704,28 +1736,24 @@ class Reduce : public UnaryPrimitive { std::vector output_shapes(const std::vector& inputs) override; - void print(std::ostream& os) override { + const char* name() const override { switch (reduce_type_) { case And: - os << "And"; - break; + return "And"; case Or: - os << "Or"; - break; + return "Or"; case Sum: - os << "Sum"; - break; + return "Sum"; case Prod: - os << "Prod"; - break; + return "Prod"; case Min: - os << "Min"; - break; + return "Min"; case Max: - os << "Max"; - break; + return "Max"; } + return ""; } + bool is_equivalent(const Primitive& other) const override; std::pair> state() const { return {reduce_type_, axes_}; @@ -1745,7 +1773,7 @@ class Round : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Round) + DEFINE_NAME(Round) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -1772,26 +1800,22 @@ class Scan : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS(); - void print(std::ostream& os) override { - os << "Cum"; + const char* name() const override { switch (reduce_type_) { case Sum: - os << "Sum"; - break; + return "CumSum"; case Prod: - os << "Prod"; - break; + return "CumProd"; case Min: - os << "Min"; - break; + return "CumMin"; case Max: - os << "Max"; - break; + return "CumMax"; case LogAddExp: - os << "Logaddexp"; - break; + return "CumLogAddExp"; } + return ""; } + bool is_equivalent(const Primitive& other) const override; auto state() const { return std::make_tuple(reduce_type_, axis_, reverse_, inclusive_); @@ -1820,25 +1844,22 @@ class Scatter : public UnaryPrimitive { DEFINE_VMAP(); DEFINE_GRADS(); - void print(std::ostream& os) override { - os << "Scatter"; + const char* name() const override { switch (reduce_type_) { case Sum: - os << " Sum"; - break; + return "ScatterSum"; case Prod: - os << " Prod"; - break; + return "ScatterProd"; case Min: - os << " Min"; - break; + return "ScatterMin"; case Max: - os << " Max"; - break; + return "ScatterMax"; case None: - break; + return "Scatter"; } + return ""; } + bool is_equivalent(const Primitive& other) const override; std::pair> state() const { return {reduce_type_, axes_}; @@ -1862,15 +1883,14 @@ class ScatterAxis : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - void print(std::ostream& os) override { - os << "ScatterAxis"; + const char* name() const override { switch (reduce_type_) { case Sum: - os << " Sum"; - break; + return "ScatterAxisSum"; case None: - break; + return "ScatterAxis"; } + return ""; } bool is_equivalent(const Primitive& other) const override; @@ -1893,7 +1913,7 @@ class Sigmoid : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Sigmoid) + DEFINE_NAME(Sigmoid) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -1907,7 +1927,7 @@ class Sign : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Sign) + DEFINE_NAME(Sign) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -1921,7 +1941,7 @@ class Sin : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Sin) + DEFINE_NAME(Sin) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -1935,7 +1955,7 @@ class Sinh : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Sinh) + DEFINE_NAME(Sinh) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -1957,7 +1977,7 @@ class Slice : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Slice) + DEFINE_NAME(Slice) bool is_equivalent(const Primitive& other) const override; auto state() const { return std::make_tuple(start_indices_, end_indices_, strides_); @@ -1986,7 +2006,7 @@ class SliceUpdate : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(SliceUpdate) + DEFINE_NAME(SliceUpdate) bool is_equivalent(const Primitive& other) const override; DEFINE_INPUT_OUTPUT_SHAPE() auto state() const { @@ -2011,7 +2031,7 @@ class DynamicSlice : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(DynamicSlice) + DEFINE_NAME(DynamicSlice) bool is_equivalent(const Primitive& other) const override; std::vector output_shapes(const std::vector& inputs) override; auto state() const { @@ -2033,7 +2053,7 @@ class DynamicSliceUpdate : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(DynamicSliceUpdate) + DEFINE_NAME(DynamicSliceUpdate) bool is_equivalent(const Primitive& other) const override; DEFINE_INPUT_OUTPUT_SHAPE() auto state() const { @@ -2054,7 +2074,7 @@ class Softmax : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Softmax) + DEFINE_NAME(Softmax) DEFINE_INPUT_OUTPUT_SHAPE() bool is_equivalent(const Primitive& other) const override; @@ -2076,7 +2096,7 @@ class Sort : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Sort) + DEFINE_NAME(Sort) DEFINE_INPUT_OUTPUT_SHAPE() bool is_equivalent(const Primitive& other) const override; auto state() const { @@ -2099,7 +2119,7 @@ class Split : public Primitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Split) + DEFINE_NAME(Split) bool is_equivalent(const Primitive& other) const override; std::pair, int> state() const { return {indices_, axis_}; @@ -2121,7 +2141,7 @@ class Square : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Square) + DEFINE_NAME(Square) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -2142,11 +2162,11 @@ class Sqrt : public UnaryPrimitive { return recip_; } - void print(std::ostream& os) override { + const char* name() const override { if (recip_) { - os << "Rsqrt"; + return "Rsqrt"; } else { - os << "Sqrt"; + return "Sqrt"; } } @@ -2162,7 +2182,7 @@ class StopGradient : public UnaryPrimitive { void eval_gpu(const std::vector& inputs, array& out) override; DEFINE_VMAP() - DEFINE_PRINT(StopGradient) + DEFINE_NAME(StopGradient) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() @@ -2179,7 +2199,7 @@ class Subtract : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Subtract) + DEFINE_NAME(Subtract) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -2194,7 +2214,7 @@ class Squeeze : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Squeeze) + DEFINE_NAME(Squeeze) std::vector output_shapes(const std::vector& inputs) override; bool is_equivalent(const Primitive& other) const override; @@ -2218,7 +2238,7 @@ class Tan : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Tan) + DEFINE_NAME(Tan) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -2232,7 +2252,7 @@ class Tanh : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Tanh) + DEFINE_NAME(Tanh) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -2247,7 +2267,7 @@ class Unflatten : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Unflatten) + DEFINE_NAME(Unflatten) std::vector output_shapes(const std::vector& inputs) override; bool is_equivalent(const Primitive& other) const override; @@ -2271,7 +2291,7 @@ class View : public UnaryPrimitive { void eval_gpu(const std::vector& inputs, array& out) override; DEFINE_VMAP() - void print(std::ostream& os) override; + const char* name() const override; bool is_equivalent(const Primitive& other) const override; auto state() const { return dtype_; @@ -2279,6 +2299,7 @@ class View : public UnaryPrimitive { private: Dtype dtype_; + mutable std::string name_; }; class Transpose : public UnaryPrimitive { @@ -2291,7 +2312,7 @@ class Transpose : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Transpose) + DEFINE_NAME(Transpose) bool is_equivalent(const Primitive& other) const override; std::vector output_shapes(const std::vector& inputs) override; std::vector state() const { @@ -2314,7 +2335,7 @@ class QRF : public Primitive { void eval_gpu(const std::vector& inputs, std::vector& outputs) override; - DEFINE_PRINT(QRF) + DEFINE_NAME(QRF) }; /* SVD primitive. */ @@ -2329,7 +2350,7 @@ class SVD : public Primitive { override; DEFINE_VMAP() - DEFINE_PRINT(SVD) + DEFINE_NAME(SVD) auto state() const { return compute_uv_; } @@ -2348,7 +2369,7 @@ class Inverse : public UnaryPrimitive { void eval_gpu(const std::vector& inputs, array& output) override; DEFINE_VMAP() - DEFINE_PRINT(Inverse) + DEFINE_NAME(Inverse) auto state() const { return std::make_pair(tri_, upper_); } @@ -2370,12 +2391,35 @@ class Cholesky : public UnaryPrimitive { } DEFINE_VMAP() - DEFINE_PRINT(Cholesky) + DEFINE_NAME(Cholesky) private: bool upper_; }; +class Eig : public Primitive { + public: + explicit Eig(Stream stream, bool compute_eigenvectors) + : Primitive(stream), compute_eigenvectors_(compute_eigenvectors) {} + void eval_cpu(const std::vector& inputs, std::vector& outputs) + override; + void eval_gpu(const std::vector& inputs, std::vector& outputs) + override; + + DEFINE_VMAP() + DEFINE_NAME(Eig) + + std::vector output_shapes(const std::vector& inputs) override; + + bool is_equivalent(const Primitive& other) const override; + auto state() const { + return compute_eigenvectors_; + } + + private: + bool compute_eigenvectors_; +}; + class Eigh : public Primitive { public: explicit Eigh(Stream stream, std::string uplo, bool compute_eigenvectors) @@ -2388,7 +2432,7 @@ class Eigh : public Primitive { override; DEFINE_VMAP() - DEFINE_PRINT(Eigh) + DEFINE_NAME(Eigh) std::vector output_shapes(const std::vector& inputs) override; @@ -2411,7 +2455,7 @@ class LUF : public Primitive { void eval_gpu(const std::vector& inputs, std::vector& outputs) override; - DEFINE_PRINT(LUF) + DEFINE_NAME(LUF) }; } // namespace mlx::core diff --git a/mlx/random.cpp b/mlx/random.cpp index d6ce5bb0e..def3169cb 100644 --- a/mlx/random.cpp +++ b/mlx/random.cpp @@ -92,29 +92,6 @@ T below_one() { return f; } -// Get the next representable value above -1.0 for half precision -// floating point types (fp16, bf16) -template -T above_minus_one() { - T f = T(-1.0); - uint16_t* m = (uint16_t*)&f; - *m -= 1; - return f; -} - -// Get the next representable value above -1.0 for half precision -// use std::nextafter as default case. -array above_minus_one_with_default(Dtype dtype) { - switch (dtype) { - case float16: - return array(above_minus_one(), dtype); - case bfloat16: - return array(above_minus_one(), dtype); - default: - return array(std::nextafter(-1.0f, 0.0f), dtype); - } -} - array uniform( const array& low, const array& high, @@ -139,31 +116,27 @@ array uniform( << " from broadcasted shape " << out_shape << "."; throw std::invalid_argument(msg.str()); } - // Get random values between [0, nextafter(maxval, 0.0f)] since samples must + + // Get random values between [0, nextafter(1.0, 0.0)] since samples must // be in [low, high) - auto get_limits = [&dtype]() { + auto get_upper = [&dtype]() { switch (dtype) { case float32: - return std::make_pair( - array(std::nextafter(1.0f, 0.0f), float32), - array(std::numeric_limits::max(), float32)); + return array(std::nextafter(1.0f, 0.0f), float32); case float16: - return std::make_pair( - array(below_one(), float16), - array(std::numeric_limits::max(), float32)); + return array(below_one(), float32); case bfloat16: - return std::make_pair( - array(below_one(), bfloat16), - array(std::numeric_limits::max(), float32)); + return array(below_one(), float32); default: throw std::runtime_error("[uniform] Unsupported type."); } }; - auto [upper, maxval] = get_limits(); - auto out = bits(shape, size_of(dtype), key, stream); - out = astype(divide(out, maxval, stream), dtype, stream); - out = minimum(out, upper, stream); + auto upper = get_upper(); + auto maxval = array(std::numeric_limits::max(), float32); + auto out = bits(shape, size_of(float32), key, stream); + out = divide(out, maxval, stream); + out = astype(minimum(out, upper, stream), dtype, stream); return add(multiply(range, out, stream), lo, stream); } @@ -176,24 +149,56 @@ array uniform( array(0.0, dtype), array(1.0, dtype), shape, dtype, key, to_stream(s)); } +inline array complex_normal( + Shape shape, + const std::optional& loc, + const std::optional& scale, + const std::optional& key, + StreamOrDevice s) { + auto stream = to_stream(s); + auto low = array(std::nextafter(-1.0f, 0.0f), float32); + auto high = array(1.0f, float32); + shape.push_back(2); + auto samples = + erfinv(uniform(low, high, shape, float32, key, stream), stream); + samples = squeeze(view(samples, complex64, stream), -1, stream); + if (scale.has_value()) { + samples = multiply(*scale, samples, stream); + } + if (loc.has_value()) { + samples = add(*loc, samples, stream); + } + return samples; +} + array normal( const Shape& shape, Dtype dtype, - const float loc /* = 0.0 */, - const float scale /* = 1.0 */, - const std::optional& key /*= nullopt */, + const std::optional& loc, + const std::optional& scale, + const std::optional& key, StreamOrDevice s /* = {} */) { - auto stream = to_stream(s); - auto low = above_minus_one_with_default(dtype); - auto high = array(1.0f, dtype); - auto samples = uniform(low, high, shape, dtype, key, stream); - samples = - multiply(array(std::sqrt(2.0), dtype), erfinv(samples, stream), stream); - if (scale != 1.0) { - samples = multiply(array(scale, dtype), samples, stream); + if (dtype == complex64) { + return complex_normal(shape, loc, scale, key, s); + } else if (!issubdtype(dtype, floating)) { + throw std::invalid_argument( + "[normal] Can only generate uniform numbers with " + "floating point type."); } - if (loc != 0.0) { - samples = add(array(loc, dtype), samples, stream); + + auto stream = to_stream(s); + auto low = array(std::nextafter(-1.0f, 0.0f), float32); + auto high = array(1.0f, float32); + auto samples = uniform(low, high, shape, float32, key, stream); + auto applied_scale = array(std::sqrt(2.0), dtype); + if (scale.has_value()) { + applied_scale = + multiply(applied_scale, astype(*scale, dtype, stream), stream); + } + samples = astype(erfinv(samples, stream), dtype, stream); + samples = multiply(applied_scale, samples, stream); + if (loc.has_value()) { + samples = add(astype(*loc, dtype, stream), samples, stream); } return samples; } @@ -223,7 +228,7 @@ array multivariate_normal( auto n = mean.shape(-1); - // Check shapes comatibility of mean and cov + // Check shapes compatibility of mean and cov if (cov.shape(-1) != cov.shape(-2)) { throw std::invalid_argument( "[multivariate_normal] last two dimensions of cov must be equal."); @@ -402,7 +407,7 @@ array categorical( if (broadcast_shapes(shape, reduced_shape) != shape) { std::ostringstream msg; msg << "[categorical] Requested shape " << shape - << " is not broadcast compatable with reduced logits shape" + << " is not broadcast compatible with reduced logits shape" << reduced_shape << "."; throw std::invalid_argument(msg.str()); } @@ -442,16 +447,23 @@ array laplace( const float scale /* = 1.0 */, const std::optional& key /*= nullopt */, StreamOrDevice s /* = {} */) { + if (!issubdtype(dtype, floating)) { + throw std::invalid_argument( + "[laplace] Can only generate uniform numbers with real" + "floating point type."); + } + auto stream = to_stream(s); - auto low = above_minus_one_with_default(dtype); - auto high = array(1.0f, dtype); - auto samples = uniform(low, high, shape, dtype, key, stream); + auto low = array(std::nextafter(-1.0f, 0.0f), float32); + auto high = array(1.0f, float32); + auto samples = uniform(low, high, shape, float32, key, stream); // Use inverse CDF to generate Laplacian noise samples = multiply( sign(samples, stream), log1p( multiply(array(-1.0f, dtype), abs(samples, stream), stream), stream), stream); + samples = astype(samples, dtype, stream); if (scale != 1.0) { samples = multiply(array(scale, dtype), samples, stream); diff --git a/mlx/random.h b/mlx/random.h index b2c821736..0dfdab7a1 100644 --- a/mlx/random.h +++ b/mlx/random.h @@ -94,12 +94,24 @@ inline array uniform( /** Generate samples from the standard normal distribution. */ array normal( + const Shape& shape, + Dtype dtype, + const std::optional& loc, + const std::optional& scale, + const std::optional& key, + StreamOrDevice s = {}); +inline array normal( const Shape& shape, Dtype dtype, const float loc, const float scale, const std::optional& key = std::nullopt, - StreamOrDevice s = {}); + StreamOrDevice s = {}) { + auto loc_ = loc == 0 ? std::nullopt : std::make_optional(array(loc, dtype)); + auto scale_ = + scale == 1 ? std::nullopt : std::make_optional(array(scale, dtype)); + return normal(shape, dtype, loc_, scale_, key, s); +} inline array normal( const Shape& shape, const float loc, @@ -113,13 +125,13 @@ inline array normal( const Dtype dtype, const std::optional& key = std::nullopt, StreamOrDevice s = {}) { - return normal(shape, dtype, 0.0, 1.0, key, s); + return normal(shape, dtype, std::nullopt, std::nullopt, key, s); } inline array normal( const Shape& shape, const std::optional& key = std::nullopt, StreamOrDevice s = {}) { - return normal(shape, float32, 0.0, 1.0, key, s); + return normal(shape, float32, std::nullopt, std::nullopt, key, s); } /** Generate samples from a multivariate normal distribution. **/ diff --git a/mlx/scheduler.cpp b/mlx/scheduler.cpp index 7bd128c10..b19f6434a 100644 --- a/mlx/scheduler.cpp +++ b/mlx/scheduler.cpp @@ -1,12 +1,13 @@ // Copyright © 2023 Apple Inc. #include "mlx/scheduler.h" -#include "mlx/backend/metal/metal.h" +#include "mlx/backend/gpu/available.h" +#include "mlx/backend/gpu/eval.h" namespace mlx::core { Stream default_stream(Device d) { - if (!metal::is_available() && d == Device::gpu) { + if (!gpu::is_available() && d == Device::gpu) { throw std::invalid_argument( "[default_stream] Cannot get gpu stream without gpu backend."); } @@ -14,7 +15,7 @@ Stream default_stream(Device d) { } void set_default_stream(Stream s) { - if (!metal::is_available() && s.device == Device::gpu) { + if (!gpu::is_available() && s.device == Device::gpu) { throw std::invalid_argument( "[set_default_stream] Cannot set gpu stream without gpu backend."); } @@ -26,7 +27,7 @@ Stream get_stream(int index) { } Stream new_stream(Device d) { - if (!metal::is_available() && d == Device::gpu) { + if (!gpu::is_available() && d == Device::gpu) { throw std::invalid_argument( "[new_stream] Cannot make gpu stream without gpu backend."); } @@ -44,7 +45,7 @@ void synchronize(Stream s) { scheduler::enqueue(s, [p = std::move(p)]() { p->set_value(); }); f.wait(); } else { - metal::synchronize(s); + gpu::synchronize(s); } } diff --git a/mlx/scheduler.h b/mlx/scheduler.h index b2c6b842b..877fdd5f6 100644 --- a/mlx/scheduler.h +++ b/mlx/scheduler.h @@ -8,8 +8,7 @@ #include #include -#include "mlx/backend/metal/metal.h" -#include "mlx/backend/metal/metal_impl.h" +#include "mlx/backend/gpu/eval.h" #include "mlx/device.h" #include "mlx/stream.h" @@ -67,7 +66,7 @@ struct StreamThread { class Scheduler { public: Scheduler() : n_active_tasks_(0) { - if (metal::is_available()) { + if (is_available(Device::gpu)) { default_streams_.insert({Device::gpu, new_stream(Device::gpu)}); } default_streams_.insert({Device::cpu, new_stream(Device::cpu)}); @@ -83,7 +82,7 @@ class Scheduler { streams_.emplace_back(streams_.size(), d); if (d == Device::gpu) { threads_.push_back(nullptr); - metal::new_stream(streams_.back()); + gpu::new_stream(streams_.back()); } else { threads_.push_back(new StreamThread{}); } diff --git a/mlx/transforms.cpp b/mlx/transforms.cpp index b305257f0..d9e227ea3 100644 --- a/mlx/transforms.cpp +++ b/mlx/transforms.cpp @@ -10,7 +10,7 @@ #include #include "mlx/backend/cpu/eval.h" -#include "mlx/backend/metal/metal_impl.h" +#include "mlx/backend/gpu/eval.h" #include "mlx/fence.h" #include "mlx/memory.h" #include "mlx/ops.h" @@ -33,7 +33,7 @@ class Synchronizer : public Primitive { void eval_cpu(const std::vector&, std::vector&) override {} void eval_gpu(const std::vector&, std::vector&) override {} - DEFINE_PRINT(Synchronize); + DEFINE_NAME(Synchronize); }; // Initialize the static tracing members from transforms_impl.h @@ -42,7 +42,10 @@ class Synchronizer : public Primitive { // are currently under a function transformation and the retain_graph() // function which returns true if we are forced to retain the graph during // evaluation. -std::vector> detail::InTracing::trace_stack{}; +std::vector>& detail::InTracing::trace_stack() { + static std::vector> trace_stack_; + return trace_stack_; +} int detail::InTracing::grad_counter{0}; int detail::RetainGraph::tracing_counter{0}; @@ -215,7 +218,7 @@ array eval_impl(std::vector outputs, bool async) { } if (arr.primitive().device() == Device::gpu) { - metal::eval(arr); + gpu::eval(arr); } else { cpu::eval(arr); } @@ -226,7 +229,7 @@ array eval_impl(std::vector outputs, bool async) { // Commit any open streams for (auto& [_, e] : events) { if (e.stream().device == Device::gpu) { - metal::finalize(e.stream()); + gpu::finalize(e.stream()); } } scheduler::wait_for_one(); @@ -264,7 +267,7 @@ array eval_impl(std::vector outputs, bool async) { auto s = e.stream(); e.signal(s); if (s.device == Device::gpu) { - metal::finalize(s); + gpu::finalize(s); } } diff --git a/mlx/transforms_impl.h b/mlx/transforms_impl.h index 7f62c406b..46851fa3d 100644 --- a/mlx/transforms_impl.h +++ b/mlx/transforms_impl.h @@ -22,19 +22,19 @@ std::vector vmap_replace( struct InTracing { explicit InTracing(bool dynamic = false, bool grad = false) { grad_counter += grad; - trace_stack.push_back({dynamic, grad}); + trace_stack().push_back({dynamic, grad}); } ~InTracing() { - grad_counter -= trace_stack.back().second; - trace_stack.pop_back(); + grad_counter -= trace_stack().back().second; + trace_stack().pop_back(); } static bool in_tracing() { - return !trace_stack.empty(); + return !trace_stack().empty(); } static bool in_dynamic_tracing() { // compile is always and only the outer-most transform - return in_tracing() && trace_stack.front().first; + return in_tracing() && trace_stack().front().first; } static bool in_grad_tracing() { @@ -43,7 +43,7 @@ struct InTracing { private: static int grad_counter; - static std::vector> trace_stack; + static std::vector>& trace_stack(); }; struct RetainGraph { diff --git a/mlx/types/limits.h b/mlx/types/limits.h index 7e0de15bc..5f2b1e9e0 100644 --- a/mlx/types/limits.h +++ b/mlx/types/limits.h @@ -33,6 +33,9 @@ struct numeric_limits { static constexpr float16_t max() { return bits_to_half(0x7BFF); } + static constexpr float16_t epsilon() { + return bits_to_half(0x1400); + } static constexpr float16_t infinity() { return bits_to_half(0x7C00); } @@ -56,6 +59,9 @@ struct numeric_limits { static constexpr bfloat16_t max() { return bits_to_bfloat(0x7F7F); } + static constexpr bfloat16_t epsilon() { + return bits_to_bfloat(0x3C00); + } static constexpr bfloat16_t infinity() { return bits_to_bfloat(0x7F80); } diff --git a/mlx/utils.cpp b/mlx/utils.cpp index 188584174..e53a7a97f 100644 --- a/mlx/utils.cpp +++ b/mlx/utils.cpp @@ -69,7 +69,12 @@ inline void PrintFormatter::print(std::ostream& os, double val) { os << val; } inline void PrintFormatter::print(std::ostream& os, complex64_t val) { - os << val; + os << val.real(); + if (val.imag() >= 0 || std::isnan(val.imag())) { + os << "+" << val.imag() << "j"; + } else { + os << "-" << -val.imag() << "j"; + } } PrintFormatter& get_global_formatter() { @@ -248,7 +253,9 @@ std::ostream& operator<<(std::ostream& os, const Dtype::Kind& k) { std::ostream& operator<<(std::ostream& os, array a) { a.eval(); - MLX_SWITCH_ALL_TYPES(a.dtype(), CTYPE, print_array(os, a)); + dispatch_all_types(a.dtype(), [&](auto type_tag) { + print_array(os, a); + }); return os; } @@ -283,9 +290,10 @@ int get_var(const char* name, int default_value) { } // namespace env template -void set_finfo_limits(double& min, double& max) { +void set_finfo_limits(double& min, double& max, double& eps) { min = numeric_limits::lowest(); max = numeric_limits::max(); + eps = numeric_limits::epsilon(); } finfo::finfo(Dtype dtype) : dtype(dtype) { @@ -295,16 +303,16 @@ finfo::finfo(Dtype dtype) : dtype(dtype) { throw std::invalid_argument(msg.str()); } if (dtype == float32) { - set_finfo_limits(min, max); + set_finfo_limits(min, max, eps); } else if (dtype == float16) { - set_finfo_limits(min, max); + set_finfo_limits(min, max, eps); } else if (dtype == bfloat16) { - set_finfo_limits(min, max); + set_finfo_limits(min, max, eps); } else if (dtype == float64) { - set_finfo_limits(min, max); + set_finfo_limits(min, max, eps); } else if (dtype == complex64) { this->dtype = float32; - set_finfo_limits(min, max); + set_finfo_limits(min, max, eps); } } @@ -315,8 +323,9 @@ void set_iinfo_limits(int64_t& min, uint64_t& max) { } iinfo::iinfo(Dtype dtype) : dtype(dtype) { - MLX_SWITCH_INT_TYPES_CHECKED( - dtype, "[iinfo]", CTYPE, set_iinfo_limits(min, max)); + dispatch_int_types(dtype, "[iinfo]", [&](auto type_tag) { + set_iinfo_limits(min, max); + }); } } // namespace mlx::core diff --git a/mlx/utils.h b/mlx/utils.h index 19241e4c6..f16bf0468 100644 --- a/mlx/utils.h +++ b/mlx/utils.h @@ -65,6 +65,7 @@ struct finfo { Dtype dtype; double min; double max; + double eps; }; /** Holds information about integral types. */ @@ -148,6 +149,11 @@ inline bool metal_fast_synch() { return metal_fast_synch; } +inline bool enable_tf32() { + static bool enable_tf32_ = get_var("MLX_ENABLE_TF32", 1); + return enable_tf32_; +} + } // namespace env } // namespace mlx::core diff --git a/mlx/version.h b/mlx/version.h index fe47d96cc..c01135177 100644 --- a/mlx/version.h +++ b/mlx/version.h @@ -3,8 +3,8 @@ #pragma once #define MLX_VERSION_MAJOR 0 -#define MLX_VERSION_MINOR 25 -#define MLX_VERSION_PATCH 0 +#define MLX_VERSION_MINOR 26 +#define MLX_VERSION_PATCH 3 #define MLX_VERSION_NUMERIC \ (100000 * MLX_VERSION_MAJOR + 1000 * MLX_VERSION_MINOR + MLX_VERSION_PATCH) diff --git a/pyproject.toml b/pyproject.toml index ad0d2e328..6fcd5d16c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [build-system] requires = [ - "setuptools>=42", + "setuptools>=80", "nanobind==2.4.0", "cmake>=3.25", ] diff --git a/python/mlx/distributed_run.py b/python/mlx/distributed_run.py index 9c946005b..404ecc349 100644 --- a/python/mlx/distributed_run.py +++ b/python/mlx/distributed_run.py @@ -270,9 +270,11 @@ def launch_ring(parser, hosts, args, command): # Repeat the stdout and stderr to the local machine to_read = [p.stdout.fileno(), p.stderr.fileno()] - to_write = [p.stdin.fileno()] + to_write = [p.stdin.fileno(), sys.stdout.fileno(), sys.stderr.fileno()] pidfile = "" stdin_buffer = b"" + stdout_buffer = b"" + stderr_buffer = b"" while p.poll() is None: try: stdin_buffer += input_queue.get_nowait() @@ -280,8 +282,6 @@ def launch_ring(parser, hosts, args, command): pass rlist, wlist, _ = select(to_read, to_write, [], 1.0) for fd in rlist: - is_stdout = fd == p.stdout.fileno() - outfile = sys.stdout if is_stdout else sys.stderr msg = os.read(fd, 8192).decode(errors="ignore") # Fetch the PID file first if we haven't already @@ -289,12 +289,21 @@ def launch_ring(parser, hosts, args, command): pidfile, *msg = msg.split("\n", maxsplit=1) msg = msg[0] if msg else "" - outfile.write(msg) - outfile.flush() + is_stdout = fd == p.stdout.fileno() + if is_stdout: + stdout_buffer += msg.encode() + else: + stderr_buffer += msg.encode() for fd in wlist: - if len(stdin_buffer) > 0: + if fd == p.stdin.fileno() and len(stdin_buffer) > 0: n = os.write(fd, stdin_buffer) stdin_buffer = stdin_buffer[n:] + elif fd == sys.stdout.fileno() and len(stdout_buffer) > 0: + n = os.write(fd, stdout_buffer) + stdout_buffer = stdout_buffer[n:] + elif fd == sys.stderr.fileno() and len(stderr_buffer) > 0: + n = os.write(fd, stderr_buffer) + stderr_buffer = stderr_buffer[n:] if stop: p.terminate() break diff --git a/python/mlx/extension.py b/python/mlx/extension.py index 8c0d60655..c426d5953 100644 --- a/python/mlx/extension.py +++ b/python/mlx/extension.py @@ -53,11 +53,7 @@ class CMakeBuild(build_ext): # Set CMAKE_BUILD_PARALLEL_LEVEL to control the parallel build level # across all generators. if "CMAKE_BUILD_PARALLEL_LEVEL" not in os.environ: - # self.parallel is a Python 3 only way to set parallel jobs by hand - # using -j in the build_ext call, not supported by pip or PyPA-build. - if hasattr(self, "parallel") and self.parallel: - # CMake 3.12+ only. - build_args += [f"-j{self.parallel}"] + build_args += [f"-j{os.cpu_count()}"] build_temp = Path(self.build_temp) / ext.name if not build_temp.exists(): diff --git a/python/mlx/nn/layers/activations.py b/python/mlx/nn/layers/activations.py index 8eafd75d3..21994c0e6 100644 --- a/python/mlx/nn/layers/activations.py +++ b/python/mlx/nn/layers/activations.py @@ -546,7 +546,7 @@ class GELU(Module): See :func:`gelu`, :func:`gelu_approx` and :func:`gelu_fast_approx` for the functional equivalents and information regarding error bounds. - + Args: approx ('none' | 'precise' | 'fast'): Which approximation to gelu to use if any. @@ -554,20 +554,19 @@ class GELU(Module): def __init__(self, approx="none"): super().__init__() - - if approx == "none": - self._act = gelu - elif approx == "precise" or approx == "tanh": - self._act = gelu_approx - elif approx == "fast": - self._act = gelu_fast_approx - else: + self._approx = approx + allowed = ["none", "precise", "tanh", "fast"] + if approx not in allowed: raise ValueError( - f"The approximation should be in ['none', 'precise', 'tanh', 'fast'] but '{approx}' was given" + f"The approximation should be in {allowed} but '{approx}' was given" ) def __call__(self, x): - return self._act(x) + if self._approx == "none": + return gelu(x) + elif self._approx in ["precise", "tanh"]: + return gelu_approx(x) + return gelu_fast_approx(x) @_make_activation_module(tanh) diff --git a/python/mlx/nn/layers/base.py b/python/mlx/nn/layers/base.py index b35c58478..e99943834 100644 --- a/python/mlx/nn/layers/base.py +++ b/python/mlx/nn/layers/base.py @@ -114,6 +114,12 @@ class Module(dict): super(Module, self).__setattr__(key, val) self.pop(key, None) + def __delattr__(self, name): + if (val := self.get(name, None)) is not None: + del self[name] + else: + super().__delattr__(name) + def load_weights( self, file_or_weights: Union[str, List[Tuple[str, mx.array]]], @@ -174,11 +180,15 @@ class Module(dict): new_weights = dict(weights) curr_weights = dict(tree_flatten(self.parameters())) if extras := (new_weights.keys() - curr_weights.keys()): - extras = " ".join(extras) - raise ValueError(f"Received parameters not in model: {extras}.") + num_extra = len(extras) + extras = ",\n".join(sorted(extras)) + raise ValueError( + f"Received {num_extra} parameters not in model: \n{extras}." + ) if missing := (curr_weights.keys() - new_weights.keys()): - missing = " ".join(missing) - raise ValueError(f"Missing parameters: {missing}.") + num_missing = len(missing) + missing = ",\n".join(sorted(missing)) + raise ValueError(f"Missing {num_missing} parameters: \n{missing}.") for k, v in curr_weights.items(): v_new = new_weights[k] if not isinstance(v_new, mx.array): @@ -193,7 +203,7 @@ class Module(dict): ) if len(weights) != 0: - self.update(tree_unflatten(weights)) + self.update(tree_unflatten(weights), strict=False) return self def save_weights(self, file: str): @@ -291,7 +301,7 @@ class Module(dict): return self.filter_and_map(self.valid_child_filter, is_leaf_fn=_is_leaf_module) - def update(self, parameters: dict) -> Module: + def update(self, parameters: dict, strict: bool = True) -> Module: """Replace the parameters of this Module with the provided ones in the dict of dicts and lists. @@ -305,7 +315,9 @@ class Module(dict): Args: parameters (dict): A complete or partial dictionary of the modules - parameters. + parameters. + strict (bool): If ``True`` checks that ``parameters`` is a + subset of the module's parameters. Default: ``True``. Returns: The module instance after updating the parameters. """ @@ -317,21 +329,29 @@ class Module(dict): current_value = dst[k] new_value = parameters[k] if isinstance(current_value, mx.array): + if strict and not isinstance(new_value, mx.array): + raise ValueError( + f"Received invalid type: {type(new_value).__name__}." + ) dst[k] = new_value - elif isinstance(current_value, Module): - current_value.update(new_value) - elif isinstance(current_value, (dict, list)): + else: apply(current_value, new_value) + elif strict: + raise ValueError(f'Module does not have parameter named "{k}".') elif isinstance(parameters, list): for i in range(len(parameters)): current_value = dst[i] new_value = parameters[i] if isinstance(current_value, mx.array): + if strict and not isinstance(new_value, mx.array): + raise ValueError( + f"Received invalid type: {type(new_value).__name__}." + ) dst[i] = new_value - elif isinstance(current_value, Module): - current_value.update(new_value) - elif isinstance(current_value, (dict, list)): + else: apply(current_value, new_value) + elif strict: + raise ValueError(f"Received invalid type: {type(parameters).__name__}.") apply(self, parameters) return self @@ -359,7 +379,7 @@ class Module(dict): self.update(self.filter_and_map(filter_fn, map_fn)) return self - def update_modules(self, modules: dict) -> Module: + def update_modules(self, modules: dict, strict: bool = True) -> Module: """Replace the child modules of this :class:`Module` instance with the provided ones in the dict of dicts and lists. @@ -368,12 +388,14 @@ class Module(dict): programmatically swapping layers. The passed in parameters dictionary need not be a full dictionary - similar to :meth:`parameters`. Only the provided locations will be + similar to :meth:`modules`. Only the provided locations will be updated. Args: - modules (dict): A complete or partial dictionary of the modules + modules (dict): A complete or partial dictionary of the module's submodules. + strict (bool): If ``True`` checks that ``modules`` is a + subset of the child modules of this instance. Default: ``True``. Returns: The module instance after updating the submodules. """ @@ -388,14 +410,28 @@ class Module(dict): dst[k] = new_value elif isinstance(current_value, (dict, list)): apply(current_value, new_value) + elif strict and new_value != {}: + raise ValueError( + f"Received invalid type: {type(new_value).__name__}." + ) + elif strict: + raise ValueError( + f'Module does not have sub-module named "{k}".' + ) elif isinstance(modules, list): - for i in range(len(dst)): + for i in range(len(modules)): current_value = dst[i] new_value = modules[i] if self.is_module(current_value) and self.is_module(new_value): dst[i] = new_value elif isinstance(current_value, (dict, list)): apply(current_value, new_value) + elif strict and new_value != {}: + raise ValueError( + f"Received invalid type: {type(new_value).__name__}." + ) + elif strict: + raise ValueError(f"Received invalid type: {type(modules).__name__}.") apply(self, modules) return self diff --git a/python/mlx/nn/layers/convolution_transpose.py b/python/mlx/nn/layers/convolution_transpose.py index edacab061..a11c4cb40 100644 --- a/python/mlx/nn/layers/convolution_transpose.py +++ b/python/mlx/nn/layers/convolution_transpose.py @@ -25,6 +25,8 @@ class ConvTranspose1d(Module): padding (int, optional): How many positions to 0-pad the input with. Default: ``0``. dilation (int, optional): The dilation of the convolution. + output_padding(int, optional): Additional size added to one side of the + output shape. Default: ``0``. bias (bool, optional): If ``True`` add a learnable bias to the output. Default: ``True`` """ @@ -37,6 +39,7 @@ class ConvTranspose1d(Module): stride: int = 1, padding: int = 0, dilation: int = 1, + output_padding: int = 0, bias: bool = True, ): super().__init__() @@ -53,18 +56,25 @@ class ConvTranspose1d(Module): self.padding = padding self.dilation = dilation self.stride = stride + self.output_padding = output_padding def _extra_repr(self): return ( f"{self.weight.shape[-1]}, {self.weight.shape[0]}, " f"kernel_size={self.weight.shape[1]}, stride={self.stride}, " f"padding={self.padding}, dilation={self.dilation}, " + f"output_padding={self.output_padding}, " f"bias={'bias' in self}" ) def __call__(self, x): y = mx.conv_transpose1d( - x, self.weight, self.stride, self.padding, self.dilation + x, + self.weight, + self.stride, + self.padding, + self.dilation, + self.output_padding, ) if "bias" in self: y = y + self.bias @@ -90,6 +100,8 @@ class ConvTranspose2d(Module): padding (int or tuple, optional): How many positions to 0-pad the input with. Default: ``0``. dilation (int or tuple, optional): The dilation of the convolution. + output_padding(int or tuple, optional): Additional size added to one + side of the output shape. Default: ``0``. bias (bool, optional): If ``True`` add a learnable bias to the output. Default: ``True`` """ @@ -102,13 +114,14 @@ class ConvTranspose2d(Module): stride: Union[int, tuple] = 1, padding: Union[int, tuple] = 0, dilation: Union[int, tuple] = 1, + output_padding: Union[int, tuple] = 0, bias: bool = True, ): super().__init__() - kernel_size, stride, padding = map( + kernel_size, stride, padding, output_padding = map( lambda x: (x, x) if isinstance(x, int) else x, - (kernel_size, stride, padding), + (kernel_size, stride, padding, output_padding), ) scale = math.sqrt(1 / (in_channels * kernel_size[0] * kernel_size[1])) self.weight = mx.random.uniform( @@ -122,18 +135,25 @@ class ConvTranspose2d(Module): self.padding = padding self.stride = stride self.dilation = dilation + self.output_padding = output_padding def _extra_repr(self): return ( f"{self.weight.shape[-1]}, {self.weight.shape[0]}, " f"kernel_size={self.weight.shape[1:2]}, stride={self.stride}, " f"padding={self.padding}, dilation={self.dilation}, " + f"output_padding={self.output_padding}, " f"bias={'bias' in self}" ) def __call__(self, x): y = mx.conv_transpose2d( - x, self.weight, self.stride, self.padding, self.dilation + x, + self.weight, + self.stride, + self.padding, + self.dilation, + self.output_padding, ) if "bias" in self: y = y + self.bias @@ -160,6 +180,8 @@ class ConvTranspose3d(Module): padding (int or tuple, optional): How many positions to 0-pad the input with. Default: ``0``. dilation (int or tuple, optional): The dilation of the convolution. + output_padding(int or tuple, optional): Additional size added to one + side of the output shape. Default: ``0``. bias (bool, optional): If ``True`` add a learnable bias to the output. Default: ``True`` """ @@ -172,13 +194,14 @@ class ConvTranspose3d(Module): stride: Union[int, tuple] = 1, padding: Union[int, tuple] = 0, dilation: Union[int, tuple] = 1, + output_padding: Union[int, tuple] = 0, bias: bool = True, ): super().__init__() - kernel_size, stride, padding = map( + kernel_size, stride, padding, output_padding = map( lambda x: (x, x, x) if isinstance(x, int) else x, - (kernel_size, stride, padding), + (kernel_size, stride, padding, output_padding), ) scale = math.sqrt( 1 / (in_channels * kernel_size[0] * kernel_size[1] * kernel_size[2]) @@ -194,18 +217,25 @@ class ConvTranspose3d(Module): self.padding = padding self.stride = stride self.dilation = dilation + self.output_padding = output_padding def _extra_repr(self): return ( f"{self.weight.shape[-1]}, {self.weight.shape[0]}, " f"kernel_size={self.weight.shape[1:3]}, stride={self.stride}, " f"padding={self.padding}, dilation={self.dilation}, " + f"output_padding={self.output_padding}, " f"bias={'bias' in self}" ) def __call__(self, x): y = mx.conv_transpose3d( - x, self.weight, self.stride, self.padding, self.dilation + x, + self.weight, + self.stride, + self.padding, + self.dilation, + self.output_padding, ) if "bias" in self: y = y + self.bias diff --git a/python/mlx/nn/layers/quantized.py b/python/mlx/nn/layers/quantized.py index 823a0084f..2d6dc0882 100644 --- a/python/mlx/nn/layers/quantized.py +++ b/python/mlx/nn/layers/quantized.py @@ -193,12 +193,6 @@ class QuantizedLinear(Module): # Freeze this model's parameters self.freeze() - def unfreeze(self, *args, **kwargs): - """Wrap unfreeze so that we unfreeze any layers we might contain but - our parameters will remain frozen.""" - super().unfreeze(*args, **kwargs) - self.freeze(recurse=False) - def _extra_repr(self): out_dims, in_dims = self.weight.shape in_dims *= 32 // self.bits diff --git a/python/mlx/nn/layers/upsample.py b/python/mlx/nn/layers/upsample.py index 1f2ffd3da..e6bd282af 100644 --- a/python/mlx/nn/layers/upsample.py +++ b/python/mlx/nn/layers/upsample.py @@ -25,7 +25,16 @@ def _scaled_indices(N, scale, align_corners, dim, ndims): def _nearest_indices(N, scale, dim, ndims): - return _scaled_indices(N, scale, True, dim, ndims).astype(mx.uint32) + M = int(scale * N) + indices = mx.arange(M, dtype=mx.float32) + if M > N: + indices = (indices + 0.5) * (N / M) - 0.5 + indices = indices.round() + else: + indices = indices * (N / M) + shape = [1] * ndims + shape[dim] = -1 + return indices.astype(mx.uint32).reshape(shape) def _linear_indices(N, scale, align_corners, dim, ndims): diff --git a/python/mlx/optimizers/optimizers.py b/python/mlx/optimizers/optimizers.py index 7931c74fa..9f78aa912 100644 --- a/python/mlx/optimizers/optimizers.py +++ b/python/mlx/optimizers/optimizers.py @@ -526,8 +526,10 @@ class Adam(Optimizer): state["v"] = v if bias_correction: - numerator = lr / (1 - b1**step) * m - denominator = mx.sqrt(v) / mx.sqrt(1 - b2**step) + eps + c1 = (lr / (1 - b1**step)).astype(gradient.dtype) + c2 = mx.rsqrt(1 - b2**step).astype(gradient.dtype) + numerator = c1 * m + denominator = mx.sqrt(v) * c2 + eps return parameter - numerator / denominator else: return parameter - lr * m / (mx.sqrt(v) + eps) diff --git a/python/scripts/repair_cuda.sh b/python/scripts/repair_cuda.sh new file mode 100644 index 000000000..ec0a89930 --- /dev/null +++ b/python/scripts/repair_cuda.sh @@ -0,0 +1,23 @@ +#!/bin/bash + +auditwheel repair dist/* \ + --plat manylinux_2_39_x86_64 \ + --exclude libcublas* \ + --exclude libnvrtc* \ + -w wheel_tmp + + +mkdir wheelhouse +cd wheel_tmp +repaired_wheel=$(find . -name "*.whl" -print -quit) +unzip -q "${repaired_wheel}" +rm "${repaired_wheel}" +mlx_so="mlx/lib/libmlx.so" +rpath=$(patchelf --print-rpath "${mlx_so}") +base="\$ORIGIN/../../nvidia" +rpath=$rpath:${base}/cublas/lib:${base}/cuda_nvrtc/lib +patchelf --force-rpath --set-rpath "$rpath" "$mlx_so" +python ../python/scripts/repair_record.py ${mlx_so} + +# Re-zip the repaired wheel +zip -r -q "../wheelhouse/${repaired_wheel}" . diff --git a/python/scripts/repair_linux.sh b/python/scripts/repair_linux.sh new file mode 100644 index 000000000..82cf49060 --- /dev/null +++ b/python/scripts/repair_linux.sh @@ -0,0 +1,19 @@ +#!/bin/bash + +auditwheel repair dist/* \ + --plat manylinux_2_35_x86_64 \ + --exclude libmlx* \ + -w wheel_tmp + +mkdir wheelhouse +cd wheel_tmp +repaired_wheel=$(find . -name "*.whl" -print -quit) +unzip -q "${repaired_wheel}" +rm "${repaired_wheel}" +core_so=$(find mlx -name "core*.so" -print -quit) +rpath="\$ORIGIN/lib" +patchelf --force-rpath --set-rpath "$rpath" "$core_so" +python ../python/scripts/repair_record.py ${core_so} + +# Re-zip the repaired wheel +zip -r -q "../wheelhouse/${repaired_wheel}" . diff --git a/python/scripts/repair_record.py b/python/scripts/repair_record.py new file mode 100644 index 000000000..1738fd5ad --- /dev/null +++ b/python/scripts/repair_record.py @@ -0,0 +1,33 @@ +import base64 +import glob +import hashlib +import sys + +filename = sys.argv[1] + + +# Compute the new hash and size +def urlsafe_b64encode(data: bytes) -> bytes: + return base64.urlsafe_b64encode(data).rstrip(b"=") + + +hasher = hashlib.sha256() +with open(filename, "rb") as f: + data = f.read() + hasher.update(data) +hash_str = urlsafe_b64encode(hasher.digest()).decode("ascii") +size = len(data) + +# Update the record file +record_file = glob.glob("*/RECORD")[0] +with open(record_file, "r") as f: + lines = [l.split(",") for l in f.readlines()] + +for l in lines: + if filename == l[0]: + l[1] = hash_str + l[2] = f"{size}\n" + +with open(record_file, "w") as f: + for l in lines: + f.write(",".join(l)) diff --git a/python/src/CMakeLists.txt b/python/src/CMakeLists.txt index 7ea302cf9..29beca859 100644 --- a/python/src/CMakeLists.txt +++ b/python/src/CMakeLists.txt @@ -54,5 +54,9 @@ target_link_libraries(core PRIVATE mlx) target_compile_definitions(core PRIVATE _VERSION_=${MLX_VERSION}) if(BUILD_SHARED_LIBS) - target_link_options(core PRIVATE -Wl,-rpath,@loader_path/lib) + if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin") + target_link_options(core PRIVATE -Wl,-rpath,@loader_path/lib) + else() + target_link_options(core PRIVATE -Wl,-rpath,\$ORIGIN/lib) + endif() endif() diff --git a/python/src/array.cpp b/python/src/array.cpp index 467bd0fa5..25889d775 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -17,10 +17,7 @@ #include "python/src/indexing.h" #include "python/src/utils.h" -#include "mlx/device.h" -#include "mlx/ops.h" -#include "mlx/transforms.h" -#include "mlx/utils.h" +#include "mlx/mlx.h" namespace mx = mlx::core; namespace nb = nanobind; @@ -197,6 +194,13 @@ void init_array(nb::module_& m) { "max", &mx::finfo::max, R"pbdoc(The largest representable number.)pbdoc") + .def_ro( + "eps", + &mx::finfo::eps, + R"pbdoc( + The difference between 1.0 and the next smallest + representable number larger than 1.0. + )pbdoc") .def_ro("dtype", &mx::finfo::dtype, R"pbdoc(The :obj:`Dtype`.)pbdoc") .def("__repr__", [](const mx::finfo& f) { std::ostringstream os; @@ -312,6 +316,18 @@ void init_array(nb::module_& m) { R"pbdoc( The array's :class:`Dtype`. )pbdoc") + .def_prop_ro( + "real", + [](const mx::array& a) { return mx::real(a); }, + R"pbdoc( + The real part of a complex array. + )pbdoc") + .def_prop_ro( + "imag", + [](const mx::array& a) { return mx::imag(a); }, + R"pbdoc( + The imaginary part of a complex array. + )pbdoc") .def( "item", &to_scalar, @@ -442,9 +458,12 @@ void init_array(nb::module_& m) { .def( "__dlpack_device__", [](const mx::array& a) { + // See + // https://github.com/dmlc/dlpack/blob/5c210da409e7f1e51ddf445134a4376fdbd70d7d/include/dlpack/dlpack.h#L74 if (mx::metal::is_available()) { - // Metal device is available return nb::make_tuple(8, 0); + } else if (mx::cu::is_available()) { + return nb::make_tuple(13, 0); } else { // CPU device return nb::make_tuple(1, 0); diff --git a/python/src/convert.cpp b/python/src/convert.cpp index 00f8395fc..1340b663a 100644 --- a/python/src/convert.cpp +++ b/python/src/convert.cpp @@ -205,6 +205,8 @@ nb::object to_scalar(mx::array& a) { return nb::cast(static_cast(a.item())); case mx::complex64: return nb::cast(a.item>()); + case mx::float64: + return nb::cast(a.item()); default: throw nb::type_error("type cannot be converted to Python scalar."); } diff --git a/python/src/device.cpp b/python/src/device.cpp index 85b15dd4d..006a05dc0 100644 --- a/python/src/device.cpp +++ b/python/src/device.cpp @@ -58,4 +58,9 @@ void init_device(nb::module_& m) { &mx::set_default_device, "device"_a, R"pbdoc(Set the default device.)pbdoc"); + m.def( + "is_available", + &mx::is_available, + "device"_a, + R"pbdoc(Check if a back-end is available for the given device.)pbdoc"); } diff --git a/python/src/fast.cpp b/python/src/fast.cpp index c94f99e1a..8adba2a25 100644 --- a/python/src/fast.cpp +++ b/python/src/fast.cpp @@ -175,11 +175,12 @@ void init_fast(nb::module_& parent_module) { * `Grouped Query Attention `_ * `Multi-Query Attention `_ - Note: The softmax operation is performed in ``float32`` regardless of - the input precision. + .. note:: - Note: For Grouped Query Attention and Multi-Query Attention, the ``k`` - and ``v`` inputs should not be pre-tiled to match ``q``. + * The softmax operation is performed in ``float32`` regardless of + the input precision. + * For Grouped Query Attention and Multi-Query Attention, the ``k`` + and ``v`` inputs should not be pre-tiled to match ``q``. In the following the dimensions are given by: @@ -195,13 +196,30 @@ void init_fast(nb::module_& parent_module) { k (array): Keys with shape ``[B, N_kv, T_kv, D]``. v (array): Values with shape ``[B, N_kv, T_kv, D]``. scale (float): Scale for queries (typically ``1.0 / sqrt(q.shape(-1)``) - mask (Union[None, str, array], optional): A causal, boolean or additive - mask to apply to the query-key scores. The mask can have at most 4 - dimensions and must be broadcast-compatible with the shape - ``[B, N, T_q, T_kv]``. If an additive mask is given its type must - promote to the promoted type of ``q``, ``k``, and ``v``. + mask (Union[None, str, array], optional): The mask to apply to the + query-key scores. The mask can be an array or a string indicating + the mask type. The only supported string type is ``"causal"``. If + the mask is an array it can be a boolean or additive mask. The mask + can have at most 4 dimensions and must be broadcast-compatible with + the shape ``[B, N, T_q, T_kv]``. If an additive mask is given its + type must promote to the promoted type of ``q``, ``k``, and ``v``. Returns: array: The output array. + + Example: + + .. code-block:: python + + B = 2 + N_q = N_kv = 32 + T_q = T_kv = 1000 + D = 128 + + q = mx.random.normal(shape=(B, N_q, T_q, D)) + k = mx.random.normal(shape=(B, N_kv, T_kv, D)) + v = mx.random.normal(shape=(B, N_kv, T_kv, D)) + scale = D ** -0.5 + out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask="causal") )pbdoc"); m.def( diff --git a/python/src/fft.cpp b/python/src/fft.cpp index 5ad4702e2..026f8139d 100644 --- a/python/src/fft.cpp +++ b/python/src/fft.cpp @@ -459,4 +459,55 @@ void init_fft(nb::module_& parent_module) { Returns: array: The real array containing the inverse of :func:`rfftn`. )pbdoc"); + m.def( + "fftshift", + [](const mx::array& a, + const std::optional>& axes, + mx::StreamOrDevice s) { + if (axes.has_value()) { + return mx::fft::fftshift(a, axes.value(), s); + } else { + return mx::fft::fftshift(a, s); + } + }, + "a"_a, + "axes"_a = nb::none(), + "stream"_a = nb::none(), + R"pbdoc( + Shift the zero-frequency component to the center of the spectrum. + + Args: + a (array): The input array. + axes (list(int), optional): Axes over which to perform the shift. + If ``None``, shift all axes. + + Returns: + array: The shifted array with the same shape as the input. + )pbdoc"); + m.def( + "ifftshift", + [](const mx::array& a, + const std::optional>& axes, + mx::StreamOrDevice s) { + if (axes.has_value()) { + return mx::fft::ifftshift(a, axes.value(), s); + } else { + return mx::fft::ifftshift(a, s); + } + }, + "a"_a, + "axes"_a = nb::none(), + "stream"_a = nb::none(), + R"pbdoc( + The inverse of :func:`fftshift`. While identical to :func:`fftshift` for even-length axes, + the behavior differs for odd-length axes. + + Args: + a (array): The input array. + axes (list(int), optional): Axes over which to perform the inverse shift. + If ``None``, shift all axes. + + Returns: + array: The inverse-shifted array with the same shape as the input. + )pbdoc"); } diff --git a/python/src/linalg.cpp b/python/src/linalg.cpp index 3bc0e5b1b..634abaef4 100644 --- a/python/src/linalg.cpp +++ b/python/src/linalg.cpp @@ -236,7 +236,7 @@ void init_linalg(nb::module_& parent_module) { Returns: Union[tuple(array, ...), array]: - If compute_uv is ``True`` returns the ``U``, ``S``, and ``Vt`` matrices, such that + If compute_uv is ``True`` returns the ``U``, ``S``, and ``Vt`` matrices, such that ``A = U @ diag(S) @ Vt``. If compute_uv is ``False`` returns singular values array ``S``. )pbdoc"); m.def( @@ -407,6 +407,76 @@ void init_linalg(nb::module_& parent_module) { Returns: array: The cross product of ``a`` and ``b`` along the specified axis. )pbdoc"); + m.def( + "eigvals", + &mx::linalg::eigvals, + "a"_a, + nb::kw_only(), + "stream"_a = nb::none(), + R"pbdoc( + Compute the eigenvalues of a square matrix. + + This function differs from :func:`numpy.linalg.eigvals` in that the + return type is always complex even if the eigenvalues are all real. + + This function supports arrays with at least 2 dimensions. When the + input has more than two dimensions, the eigenvalues are computed for + each matrix in the last two dimensions. + + Args: + a (array): The input array. + stream (Stream, optional): Stream or device. Defaults to ``None`` + in which case the default stream of the default device is used. + + Returns: + array: The eigenvalues (not necessarily in order). + + Example: + >>> A = mx.array([[1., -2.], [-2., 1.]]) + >>> eigenvalues = mx.linalg.eigvals(A, stream=mx.cpu) + >>> eigenvalues + array([3+0j, -1+0j], dtype=complex64) + )pbdoc"); + m.def( + "eig", + [](const mx::array& a, mx::StreamOrDevice s) { + auto result = mx::linalg::eig(a, s); + return nb::make_tuple(result.first, result.second); + }, + "a"_a, + nb::kw_only(), + "stream"_a = nb::none(), + R"pbdoc( + Compute the eigenvalues and eigenvectors of a square matrix. + + This function differs from :func:`numpy.linalg.eig` in that the + return type is always complex even if the eigenvalues are all real. + + This function supports arrays with at least 2 dimensions. When the input + has more than two dimensions, the eigenvalues and eigenvectors are + computed for each matrix in the last two dimensions. + + Args: + a (array): The input array. + stream (Stream, optional): Stream or device. Defaults to ``None`` + in which case the default stream of the default device is used. + + Returns: + Tuple[array, array]: + A tuple containing the eigenvalues and the normalized right + eigenvectors. The column ``v[:, i]`` is the eigenvector + corresponding to the i-th eigenvalue. + + Example: + >>> A = mx.array([[1., -2.], [-2., 1.]]) + >>> w, v = mx.linalg.eig(A, stream=mx.cpu) + >>> w + array([3+0j, -1+0j], dtype=complex64) + >>> v + array([[0.707107+0j, 0.707107+0j], + [-0.707107+0j, 0.707107+0j]], dtype=complex64) + )pbdoc"); + m.def( "eigvalsh", &mx::linalg::eigvalsh, @@ -444,7 +514,7 @@ void init_linalg(nb::module_& parent_module) { )pbdoc"); m.def( "eigh", - [](const mx::array& a, const std::string UPLO, mx::StreamOrDevice s) { + [](const mx::array& a, const std::string& UPLO, mx::StreamOrDevice s) { auto result = mx::linalg::eigh(a, UPLO, s); return nb::make_tuple(result.first, result.second); }, diff --git a/python/src/metal.cpp b/python/src/metal.cpp index a13dd2a03..3b2f4a53a 100644 --- a/python/src/metal.cpp +++ b/python/src/metal.cpp @@ -14,7 +14,7 @@ namespace mx = mlx::core; namespace nb = nanobind; using namespace nb::literals; -bool DEPRECATE(const std::string& old_fn, const std::string new_fn) { +bool DEPRECATE(const char* old_fn, const char* new_fn) { std::cerr << old_fn << " is deprecated and will be removed in a future " << "version. Use " << new_fn << " instead." << std::endl; return true; @@ -49,21 +49,21 @@ void init_metal(nb::module_& m) { metal.def( "set_memory_limit", [](size_t limit) { - DEPRECATE("mx.metal.set_memory_limt", "mx.set_memory_limit"); + DEPRECATE("mx.metal.set_memory_limit", "mx.set_memory_limit"); return mx::set_memory_limit(limit); }, "limit"_a); metal.def( "set_cache_limit", [](size_t limit) { - DEPRECATE("mx.metal.set_cache_limt", "mx.set_cache_limit"); + DEPRECATE("mx.metal.set_cache_limit", "mx.set_cache_limit"); return mx::set_cache_limit(limit); }, "limit"_a); metal.def( "set_wired_limit", [](size_t limit) { - DEPRECATE("mx.metal.set_wired_limt", "mx.set_wired_limit"); + DEPRECATE("mx.metal.set_wired_limit", "mx.set_wired_limit"); return mx::set_wired_limit(limit); }, "limit"_a); diff --git a/python/src/ops.cpp b/python/src/ops.cpp index f98aa80aa..9703bbd2d 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -3076,7 +3076,7 @@ void init_ops(nb::module_& m) { std::tuple, std::pair, std::vector>>& pad_width, - const std::string mode, + const std::string& mode, const ScalarOrArray& constant_value, mx::StreamOrDevice s) { if (auto pv = std::get_if(&pad_width); pv) { @@ -3455,8 +3455,8 @@ void init_ops(nb::module_& m) { 1D convolution over an input with several channels Args: - input (array): Input array of shape ``(N, H, C_in)``. - weight (array): Weight array of shape ``(C_out, H, C_in)``. + input (array): Input array of shape ``(N, L, C_in)``. + weight (array): Weight array of shape ``(C_out, K, C_in)``. stride (int, optional): Kernel stride. Default: ``1``. padding (int, optional): Input padding. Default: ``0``. dilation (int, optional): Kernel dilation. Default: ``1``. @@ -3514,7 +3514,7 @@ void init_ops(nb::module_& m) { Args: input (array): Input array of shape ``(N, H, W, C_in)``. - weight (array): Weight array of shape ``(C_out, H, W, C_in)``. + weight (array): Weight array of shape ``(C_out, KH, KW, C_in)``. stride (int or tuple(int), optional): :obj:`tuple` of size 2 with kernel strides. All spatial dimensions get the same stride if only one number is specified. Default: ``1``. @@ -3586,7 +3586,7 @@ void init_ops(nb::module_& m) { Args: input (array): Input array of shape ``(N, D, H, W, C_in)``. - weight (array): Weight array of shape ``(C_out, D, H, W, C_in)``. + weight (array): Weight array of shape ``(C_out, KD, KH, KW, C_in)``. stride (int or tuple(int), optional): :obj:`tuple` of size 3 with kernel strides. All spatial dimensions get the same stride if only one number is specified. Default: ``1``. @@ -3609,20 +3609,22 @@ void init_ops(nb::module_& m) { "stride"_a = 1, "padding"_a = 0, "dilation"_a = 1, + "output_padding"_a = 0, "groups"_a = 1, nb::kw_only(), "stream"_a = nb::none(), nb::sig( - "def conv_transpose1d(input: array, weight: array, /, stride: int = 1, padding: int = 0, dilation: int = 1, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array"), + "def conv_transpose1d(input: array, weight: array, /, stride: int = 1, padding: int = 0, dilation: int = 1, output_padding: int = 0, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( 1D transposed convolution over an input with several channels Args: - input (array): Input array of shape ``(N, H, C_in)``. - weight (array): Weight array of shape ``(C_out, H, C_in)``. + input (array): Input array of shape ``(N, L, C_in)``. + weight (array): Weight array of shape ``(C_out, K, C_in)``. stride (int, optional): Kernel stride. Default: ``1``. padding (int, optional): Input padding. Default: ``0``. dilation (int, optional): Kernel dilation. Default: ``1``. + output_padding (int, optional): Output padding. Default: ``0``. groups (int, optional): Input feature groups. Default: ``1``. Returns: @@ -3635,11 +3637,13 @@ void init_ops(nb::module_& m) { const std::variant>& stride, const std::variant>& padding, const std::variant>& dilation, + const std::variant>& output_padding, int groups, mx::StreamOrDevice s) { std::pair stride_pair{1, 1}; std::pair padding_pair{0, 0}; std::pair dilation_pair{1, 1}; + std::pair output_padding_pair{0, 0}; if (auto pv = std::get_if(&stride); pv) { stride_pair = std::pair{*pv, *pv}; @@ -3659,19 +3663,33 @@ void init_ops(nb::module_& m) { dilation_pair = std::get>(dilation); } + if (auto pv = std::get_if(&output_padding); pv) { + output_padding_pair = std::pair{*pv, *pv}; + } else { + output_padding_pair = std::get>(output_padding); + } + return mx::conv_transpose2d( - input, weight, stride_pair, padding_pair, dilation_pair, groups, s); + input, + weight, + stride_pair, + padding_pair, + dilation_pair, + output_padding_pair, + groups, + s); }, nb::arg(), nb::arg(), "stride"_a = 1, "padding"_a = 0, "dilation"_a = 1, + "output_padding"_a = 0, "groups"_a = 1, nb::kw_only(), "stream"_a = nb::none(), nb::sig( - "def conv_transpose2d(input: array, weight: array, /, stride: Union[int, Tuple[int, int]] = 1, padding: Union[int, Tuple[int, int]] = 0, dilation: Union[int, Tuple[int, int]] = 1, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array"), + "def conv_transpose2d(input: array, weight: array, /, stride: Union[int, Tuple[int, int]] = 1, padding: Union[int, Tuple[int, int]] = 0, dilation: Union[int, Tuple[int, int]] = 1, output_padding: Union[int, Tuple[int, int]] = 0, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( 2D transposed convolution over an input with several channels @@ -3679,7 +3697,7 @@ void init_ops(nb::module_& m) { Args: input (array): Input array of shape ``(N, H, W, C_in)``. - weight (array): Weight array of shape ``(C_out, H, W, C_in)``. + weight (array): Weight array of shape ``(C_out, KH, KW, C_in)``. stride (int or tuple(int), optional): :obj:`tuple` of size 2 with kernel strides. All spatial dimensions get the same stride if only one number is specified. Default: ``1``. @@ -3689,6 +3707,9 @@ void init_ops(nb::module_& m) { dilation (int or tuple(int), optional): :obj:`tuple` of size 2 with kernel dilation. All spatial dimensions get the same dilation if only one number is specified. Default: ``1`` + output_padding (int or tuple(int), optional): :obj:`tuple` of size 2 with + output padding. All spatial dimensions get the same output + padding if only one number is specified. Default: ``0``. groups (int, optional): input feature groups. Default: ``1``. Returns: @@ -3701,11 +3722,13 @@ void init_ops(nb::module_& m) { const std::variant>& stride, const std::variant>& padding, const std::variant>& dilation, + const std::variant>& output_padding, int groups, mx::StreamOrDevice s) { std::tuple stride_tuple{1, 1, 1}; std::tuple padding_tuple{0, 0, 0}; std::tuple dilation_tuple{1, 1, 1}; + std::tuple output_padding_tuple{0, 0, 0}; if (auto pv = std::get_if(&stride); pv) { stride_tuple = std::tuple{*pv, *pv, *pv}; @@ -3725,12 +3748,20 @@ void init_ops(nb::module_& m) { dilation_tuple = std::get>(dilation); } + if (auto pv = std::get_if(&output_padding); pv) { + output_padding_tuple = std::tuple{*pv, *pv, *pv}; + } else { + output_padding_tuple = + std::get>(output_padding); + } + return mx::conv_transpose3d( input, weight, stride_tuple, padding_tuple, dilation_tuple, + output_padding_tuple, groups, s); }, @@ -3739,11 +3770,12 @@ void init_ops(nb::module_& m) { "stride"_a = 1, "padding"_a = 0, "dilation"_a = 1, + "output_padding"_a = 0, "groups"_a = 1, nb::kw_only(), "stream"_a = nb::none(), nb::sig( - "def conv_transpose3d(input: array, weight: array, /, stride: Union[int, Tuple[int, int, int]] = 1, padding: Union[int, Tuple[int, int, int]] = 0, dilation: Union[int, Tuple[int, int, int]] = 1, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array"), + "def conv_transpose3d(input: array, weight: array, /, stride: Union[int, Tuple[int, int, int]] = 1, padding: Union[int, Tuple[int, int, int]] = 0, dilation: Union[int, Tuple[int, int, int]] = 1, output_padding: Union[int, Tuple[int, int, int]] = 0, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( 3D transposed convolution over an input with several channels @@ -3751,7 +3783,7 @@ void init_ops(nb::module_& m) { Args: input (array): Input array of shape ``(N, D, H, W, C_in)``. - weight (array): Weight array of shape ``(C_out, D, H, W, C_in)``. + weight (array): Weight array of shape ``(C_out, KD, KH, KW, C_in)``. stride (int or tuple(int), optional): :obj:`tuple` of size 3 with kernel strides. All spatial dimensions get the same stride if only one number is specified. Default: ``1``. @@ -3761,6 +3793,9 @@ void init_ops(nb::module_& m) { dilation (int or tuple(int), optional): :obj:`tuple` of size 3 with kernel dilation. All spatial dimensions get the same dilation if only one number is specified. Default: ``1`` + output_padding (int or tuple(int), optional): :obj:`tuple` of size 3 with + output padding. All spatial dimensions get the same output + padding if only one number is specified. Default: ``0``. groups (int, optional): input feature groups. Default: ``1``. Returns: @@ -4286,6 +4321,28 @@ void init_ops(nb::module_& m) { array: The result of the multiplication of ``x`` with ``w`` after gathering using ``lhs_indices`` and ``rhs_indices``. )pbdoc"); + m.def( + "segmented_mm", + &mx::segmented_mm, + nb::arg(), + nb::arg(), + "segments"_a, + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def segmented_mm(a: array, b: array, /, segments: array, *, stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + Perform a matrix multiplication but segment the inner dimension and + save the result for each segment separately. + + Args: + a (array): Input array of shape ``MxK``. + b (array): Input array of shape ``KxN``. + segments (array): The offsets into the inner dimension for each segment. + + Returns: + array: The result per segment of shape ``MxN``. + )pbdoc"); m.def( "tensordot", [](const mx::array& a, @@ -5189,4 +5246,46 @@ void init_ops(nb::module_& m) { Returns: array: The row or col contiguous output. )pbdoc"); + m.def( + "broadcast_shapes", + [](const nb::args& shapes) { + if (shapes.size() == 0) + throw std::invalid_argument( + "[broadcast_shapes] Must provide at least one shape."); + + mx::Shape result = nb::cast(shapes[0]); + for (size_t i = 1; i < shapes.size(); ++i) { + if (!nb::isinstance(shapes[i]) && + !nb::isinstance(shapes[i])) + throw std::invalid_argument( + "[broadcast_shapes] Expects a sequence of shapes (tuple or list of ints)."); + result = mx::broadcast_shapes(result, nb::cast(shapes[i])); + } + + return nb::tuple(nb::cast(result)); + }, + nb::sig("def broadcast_shapes(*shapes: Sequence[int]) -> Tuple[int]"), + R"pbdoc( + Broadcast shapes. + + Returns the shape that results from broadcasting the supplied array shapes + against each other. + + Args: + *shapes (Sequence[int]): The shapes to broadcast. + + Returns: + tuple: The broadcasted shape. + + Raises: + ValueError: If the shapes cannot be broadcast. + + Example: + >>> mx.broadcast_shapes((1,), (3, 1)) + (3, 1) + >>> mx.broadcast_shapes((6, 7), (5, 6, 1), (7,)) + (5, 6, 7) + >>> mx.broadcast_shapes((5, 1, 4), (1, 3, 1)) + (5, 3, 4) + )pbdoc"); } diff --git a/python/src/random.cpp b/python/src/random.cpp index e9c0a87fc..837f91616 100644 --- a/python/src/random.cpp +++ b/python/src/random.cpp @@ -152,31 +152,42 @@ void init_random(nb::module_& parent_module) { "normal", [](const mx::Shape& shape, std::optional type, - float loc, - float scale, + const std::optional& loc_, + const std::optional& scale_, const std::optional& key_, mx::StreamOrDevice s) { + auto dtype = type.value_or(mx::float32); auto key = key_ ? key_.value() : default_key().next(); - return mx::random::normal( - shape, type.value_or(mx::float32), loc, scale, key, s); + auto loc = + loc_ ? std::make_optional(to_array(*loc_, dtype)) : std::nullopt; + auto scale = scale_ ? std::make_optional(to_array(*scale_, dtype)) + : std::nullopt; + return mx::random::normal(shape, dtype, loc, scale, key, s); }, "shape"_a = mx::Shape{}, "dtype"_a.none() = mx::float32, - "loc"_a = 0.0, - "scale"_a = 1.0, + "loc"_a = nb::none(), + "scale"_a = nb::none(), "key"_a = nb::none(), "stream"_a = nb::none(), nb::sig( - "def normal(shape: Sequence[int] = [], dtype: Optional[Dtype] = float32, loc: float = 0.0, scale: float = 1.0, key: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"), + "def normal(shape: Sequence[int] = [], dtype: Optional[Dtype] = float32, loc: Optional[scalar, array] = None, scale: Optional[scalar, array] = None, key: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( Generate normally distributed random numbers. + If ``loc`` and ``scale`` are not provided the "standard" normal + distribution is used. That means $x \sim \mathcal{N}(0, 1)$ for + real numbers and $\text{Re}(x),\text{Im}(x) \sim \mathcal{N}(0, + \frac{1}{2})$ for complex numbers. + Args: - shape (list(int), optional): Shape of the output. Default is ``()``. - dtype (Dtype, optional): Type of the output. Default is ``float32``. - loc (float, optional): Mean of the distribution. Default is ``0.0``. - scale (float, optional): Standard deviation of the distribution. Default is ``1.0``. - key (array, optional): A PRNG key. Default: None. + shape (list(int), optional): Shape of the output. Default: ``()``. + dtype (Dtype, optional): Type of the output. Default: ``float32``. + loc (scalar or array, optional): Mean of the distribution. + Default: ``None``. + scale (scalar or array, optional): Standard deviation of the + distribution. Default: ``None``. + key (array, optional): A PRNG key. Default: ``None``. Returns: array: The output array of random values. @@ -422,7 +433,7 @@ void init_random(nb::module_& parent_module) { axis (int, optional): The axis which specifies the distribution. Default: ``-1``. shape (list(int), optional): The shape of the output. This must - be broadcast compatable with ``logits.shape`` with the ``axis`` + be broadcast compatible with ``logits.shape`` with the ``axis`` dimension removed. Default: ``None`` num_samples (int, optional): The number of samples to draw from each of the categorical distributions in ``logits``. The output will have diff --git a/python/tests/__main__.py b/python/tests/__main__.py new file mode 100644 index 000000000..5230bd535 --- /dev/null +++ b/python/tests/__main__.py @@ -0,0 +1,5 @@ +from . import mlx_tests + +__unittest = True + +mlx_tests.MLXTestRunner(module=None) diff --git a/python/tests/cuda_skip.py b/python/tests/cuda_skip.py new file mode 100644 index 000000000..50cb8dcbe --- /dev/null +++ b/python/tests/cuda_skip.py @@ -0,0 +1,88 @@ +cuda_skip = { + "TestLoad.test_load_f8_e4m3", + "TestLayers.test_quantized_embedding", + "TestOps.test_dynamic_slicing", + # Block masked matmul NYI + "TestBlas.test_block_masked_matmul", + # Gather matmul NYI + "TestBlas.test_gather_matmul", + "TestBlas.test_gather_matmul_grad", + "TestBlas.test_gather_mm_sorted", + # Segmented matmul NYI + "TestBlas.test_segmented_mm", + # Hadamard NYI + "TestOps.test_hadamard", + "TestOps.test_hadamard_grad_vmap", + # Convolutions NYI + "TestConv.test_1d_conv_with_2d", + "TestConv.test_asymmetric_padding", + "TestConv.test_basic_grad_shapes", + "TestConv.test_conv2d_unaligned_channels", + "TestConv.test_conv_1d_groups_flipped", + "TestConv.test_conv_general_flip_grad", + "TestConv.test_conv_groups_grad", + "TestConv.test_numpy_conv", + "TestConv.test_repeated_conv", + "TestConv.test_torch_conv_1D", + "TestConv.test_torch_conv_1D_grad", + "TestConv.test_torch_conv_2D", + "TestConv.test_torch_conv_2D_grad", + "TestConv.test_torch_conv_3D", + "TestConv.test_torch_conv_3D_grad", + "TestConv.test_torch_conv_depthwise", + "TestConv.test_torch_conv_general", + "TestConvTranspose.test_torch_conv_tranpose_1d_output_padding", + "TestConvTranspose.test_torch_conv_transpose_1D", + "TestConvTranspose.test_torch_conv_transpose_1D_grad", + "TestConvTranspose.test_torch_conv_transpose_2D", + "TestConvTranspose.test_torch_conv_transpose_2D_grad", + "TestConvTranspose.test_torch_conv_transpose_2d_output_padding", + "TestConvTranspose.test_torch_conv_transpose_3D", + "TestConvTranspose.test_torch_conv_transpose_3D_grad", + "TestConvTranspose.test_torch_conv_transpose_3d_output_padding", + "TestExportImport.test_export_conv", + "TestLayers.test_conv1d", + "TestLayers.test_conv2d", + "TestVmap.test_vmap_conv", + # FFTs NYI + "TestFFT.test_fft", + "TestFFT.test_fft_big_powers_of_two", + "TestFFT.test_fft_contiguity", + "TestFFT.test_fft_exhaustive", + "TestFFT.test_fft_grads", + "TestFFT.test_fft_into_ifft", + "TestFFT.test_fft_large_numbers", + "TestFFT.test_fft_shared_mem", + "TestFFT.test_fftn", + # Lapack ops NYI + "TestLinalg.test_cholesky", + "TestLinalg.test_cholesky_inv", + "TestLinalg.test_eig", + "TestLinalg.test_eigh", + "TestLinalg.test_inverse", + "TestVmap.test_vmap_inverse", + "TestLinalg.test_lu", + "TestLinalg.test_lu_factor", + "TestLinalg.test_pseudo_inverse", + "TestLinalg.test_qr_factorization", + "TestInit.test_orthogonal", + "TestLinalg.test_svd_decomposition", + "TestVmap.test_vmap_svd", + "TestLinalg.test_tri_inverse", + # Quantization NYI + "TestQuantized.test_gather_matmul_grad", + "TestQuantized.test_gather_qmm", + "TestQuantized.test_gather_qmm_sorted", + "TestQuantized.test_gather_qmm_grad", + "TestQuantized.test_non_multiples", + "TestQuantized.test_qmm", + "TestQuantized.test_qmm_jvp", + "TestQuantized.test_qmm_shapes", + "TestQuantized.test_qmm_vjp", + "TestQuantized.test_qmv", + "TestQuantized.test_qvm", + "TestQuantized.test_qvm_splitk", + "TestQuantized.test_small_matrix", + "TestQuantized.test_throw", + "TestQuantized.test_vjp_scales_biases", +} diff --git a/python/tests/mlx_tests.py b/python/tests/mlx_tests.py index f446b5e67..bc197b673 100644 --- a/python/tests/mlx_tests.py +++ b/python/tests/mlx_tests.py @@ -1,6 +1,10 @@ # Copyright © 2023 Apple Inc. import os + +# Use regular fp32 precision for tests +os.environ["MLX_ENABLE_TF32"] = "0" + import platform import unittest from typing import Any, Callable, List, Tuple, Union @@ -9,6 +13,42 @@ import mlx.core as mx import numpy as np +class MLXTestRunner(unittest.TestProgram): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def createTests(self, *args, **kwargs): + super().createTests(*args, **kwargs) + + # Asume CUDA backend in this case + device = os.getenv("DEVICE", None) + if device is not None: + device = getattr(mx, device) + else: + device = mx.default_device() + + if not (device == mx.gpu and not mx.metal.is_available()): + return + + from cuda_skip import cuda_skip + + filtered_suite = unittest.TestSuite() + + def filter_and_add(t): + if isinstance(t, unittest.TestSuite): + for sub_t in t: + filter_and_add(sub_t) + else: + t_id = ".".join(t.id().split(".")[-2:]) + if t_id in cuda_skip: + print(f"Skipping {t_id}") + else: + filtered_suite.addTest(t) + + filter_and_add(self.test) + self.test = filtered_suite + + class MLXTestCase(unittest.TestCase): @property def is_apple_silicon(self): diff --git a/python/tests/ring_test_distributed.py b/python/tests/ring_test_distributed.py index 77d45dbad..213f85274 100644 --- a/python/tests/ring_test_distributed.py +++ b/python/tests/ring_test_distributed.py @@ -130,4 +130,4 @@ class TestRingDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase): if __name__ == "__main__": - unittest.main() + mlx_tests.MLXTestRunner() diff --git a/python/tests/test_array.py b/python/tests/test_array.py index fa5784ea9..3ab41bef7 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -103,10 +103,12 @@ class TestDtypes(mlx_tests.MLXTestCase): self.assertEqual(mx.finfo(mx.float32).min, np.finfo(np.float32).min) self.assertEqual(mx.finfo(mx.float32).max, np.finfo(np.float32).max) + self.assertEqual(mx.finfo(mx.float32).eps, np.finfo(np.float32).eps) self.assertEqual(mx.finfo(mx.float32).dtype, mx.float32) self.assertEqual(mx.finfo(mx.float16).min, np.finfo(np.float16).min) self.assertEqual(mx.finfo(mx.float16).max, np.finfo(np.float16).max) + self.assertEqual(mx.finfo(mx.float16).eps, np.finfo(np.float16).eps) self.assertEqual(mx.finfo(mx.float16).dtype, mx.float16) def test_iinfo(self): @@ -196,7 +198,7 @@ class TestInequality(mlx_tests.MLXTestCase): def test_dlx_device_type(self): a = mx.array([1, 2, 3]) device_type, device_id = a.__dlpack_device__() - self.assertIn(device_type, [1, 8]) + self.assertIn(device_type, [1, 8, 13]) self.assertEqual(device_id, 0) if device_type == 8: @@ -1185,7 +1187,7 @@ class TestArray(mlx_tests.MLXTestCase): check_slices(np.zeros((3, 2)), np.array([[3, 3], [4, 4]]), np.array([0, 1])) check_slices(np.zeros((3, 2)), np.array([[3, 3], [4, 4]]), np.array([0, 1])) check_slices( - np.zeros((3, 2)), np.array([[3, 3], [4, 4], [5, 5]]), np.array([0, 0, 1]) + np.zeros((3, 2)), np.array([[3, 3], [4, 4], [5, 5]]), np.array([0, 2, 1]) ) # Multiple slices @@ -2020,6 +2022,15 @@ class TestArray(mlx_tests.MLXTestCase): with self.assertRaises(ValueError): mx.add(y, x) + def test_real_imag(self): + x = mx.array([1.0]) + self.assertEqual(x.real.item(), 1.0) + self.assertEqual(x.imag.item(), 0.0) + + x = mx.array([1.0 + 1.0j]) + self.assertEqual(x.imag.item(), 1.0) + self.assertEqual(x.real.item(), 1.0) + if __name__ == "__main__": - unittest.main() + mlx_tests.MLXTestRunner() diff --git a/python/tests/test_autograd.py b/python/tests/test_autograd.py index ec9d957ea..5722071f6 100644 --- a/python/tests/test_autograd.py +++ b/python/tests/test_autograd.py @@ -606,7 +606,7 @@ class TestAutograd(mlx_tests.MLXTestCase): x = mx.array([0.0 + 1j, 1.0 + 0.0j, 0.5 + 0.5j]) dfdx = mx.grad(fun)(x) - self.assertTrue(mx.allclose(dfdx, -2j * mx.ones_like(x))) + self.assertTrue(mx.allclose(dfdx, 2j * mx.ones_like(x))) def test_flatten_unflatten_vjps(self): def fun(x): @@ -799,4 +799,4 @@ class TestAutograd(mlx_tests.MLXTestCase): if __name__ == "__main__": - unittest.main() + mlx_tests.MLXTestRunner() diff --git a/python/tests/test_bf16.py b/python/tests/test_bf16.py index 0b4b49919..2e4e2e0c3 100644 --- a/python/tests/test_bf16.py +++ b/python/tests/test_bf16.py @@ -193,4 +193,4 @@ class TestBF16(mlx_tests.MLXTestCase): if __name__ == "__main__": - unittest.main() + mlx_tests.MLXTestRunner() diff --git a/python/tests/test_blas.py b/python/tests/test_blas.py index 6fca4885b..5e096d9c5 100644 --- a/python/tests/test_blas.py +++ b/python/tests/test_blas.py @@ -589,6 +589,10 @@ class TestBlas(mlx_tests.MLXTestCase): alpha = 0.5 beta = 2.0 + # c must broadcast to the output shape + with self.assertRaises(ValueError): + mx.addmm(mx.zeros((2, 2, 2)), mx.zeros((2, 2)), mx.zeros((2, 2))) + # Regular batched case a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32) b_npy = np.random.normal(0.0, 1.0 / 128, (32, 16, 16)).astype(np.float32) @@ -745,6 +749,19 @@ class TestBlas(mlx_tests.MLXTestCase): mx.eval(c) self.assertEqual(c.shape, (0, 0)) + c = mx.array(1.0, dtype=mx.float32) + a = mx.array([], dtype=mx.float32) + b = mx.array([], dtype=mx.float32) + out = mx.addmm(c, a, b) + self.assertEqual(out.item(), 1.0) + self.assertEqual(out.shape, ()) + + a = mx.zeros(shape=(5, 0)) + b = mx.zeros(shape=(0, 5)) + c = mx.random.uniform(shape=(5, 5)) + out = mx.addmm(c, a, b) + self.assertTrue(mx.allclose(out, c)) + def test_block_masked_matmul(self): def ref_block_masked_mm( a, b, block_size, out_mask=None, lhs_mask=None, rhs_mask=None @@ -1146,6 +1163,99 @@ class TestBlas(mlx_tests.MLXTestCase): self.assertEqual(r.shape, t.shape) self.assertTrue(mx.allclose(r, t, atol=1e-4).item()) + def test_gather_mm_sorted(self): + def gather_mm_ref(a, b, rhs): + b = b[rhs] + return a @ b + + def gather_mm_test(a, b, rhs): + return mx.gather_mm(a, b, rhs_indices=rhs, sorted_indices=True) + + a = mx.random.normal((100, 1, 100)) + b = mx.random.normal((8, 100, 100)) + rhs = mx.sort(mx.random.randint(0, 8, shape=(100,))) + + c1 = gather_mm_ref(a, b, rhs) + c2 = gather_mm_test(a, b, rhs) + self.assertTrue(mx.allclose(c1, c2, atol=1e-4)) + + cotan = mx.random.normal(c1.shape) + c1, dc1 = mx.vjp( + lambda a, b: gather_mm_ref(a, b, rhs), + [a, b], + [cotan], + ) + c2, dc2 = mx.vjp( + lambda a, b: gather_mm_test(a, b, rhs), + [a, b], + [cotan], + ) + self.assertTrue(mx.allclose(c1[0], c2[0], atol=1e-4)) + self.assertTrue(mx.allclose(dc1[0], dc2[0], atol=1e-4)) + self.assertTrue(mx.allclose(dc1[1], dc2[1], atol=1e-4)) + + def test_segmented_mm(self): + def segmented_mm_ref(a, b, s): + s = s.tolist() + c = [] + for s1, s2 in s: + c.append(a[:, s1:s2] @ b[s1:s2, :]) + return mx.stack(c, axis=0) + + shapes = [ + (10, 10, 10), + (10, 10, 1000), + (1000, 1000, 1000), + ] + all_segments = [[0, 0, 1.0], [0, 0.5, 1.0], [r / 9 for r in range(10)]] + + for M, N, K in shapes: + for s in all_segments: + segments = [] + for i in range(len(s) - 1): + segments.append([s[i], s[i + 1]]) + segments = mx.array(segments) + segments = mx.minimum(K - 1, (K * segments).astype(mx.uint32)) + a = mx.random.normal((M, K)) + b = mx.random.normal((K, N)) + c1 = segmented_mm_ref(a, b, segments) + c2 = mx.segmented_mm(a, b, segments) + self.assertTrue(mx.allclose(c1, c2, atol=1e-4)) + + a = mx.random.normal((K, M)) + b = mx.random.normal((K, N)) + c1 = segmented_mm_ref(a.T, b, segments) + c2 = mx.segmented_mm(a.T, b, segments) + self.assertTrue(mx.allclose(c1, c2, atol=1e-4)) + + a = mx.random.normal((M, K)) + b = mx.random.normal((N, K)) + c1 = segmented_mm_ref(a, b.T, segments) + c2 = mx.segmented_mm(a, b.T, segments) + self.assertTrue(mx.allclose(c1, c2, atol=1e-4)) + + a = mx.random.normal((K, M)) + b = mx.random.normal((N, K)) + c1 = segmented_mm_ref(a.T, b.T, segments) + c2 = mx.segmented_mm(a.T, b.T, segments) + self.assertTrue(mx.allclose(c1, c2, atol=1e-4)) + + with self.assertRaises(ValueError): + a = mx.ones((2, 10, 10)) + s = mx.array([[0, 5], [5, 10]]).astype(mx.uint32) + mx.segmented_mm(a, a, s) + + a = mx.ones((10, 1000)) + s = mx.random.randint(0, 16, shape=(1000,)) + s = mx.zeros(16, dtype=s.dtype).at[s].add(1) + s = mx.sort(s) + s = mx.cumsum(s) + s = mx.concatenate([mx.array([0]), s]) + s = mx.as_strided(s, (16, 2), (1, 1)) + s = mx.reshape(s, (2, 2, 4, 2)) + c = mx.segmented_mm(a, a.T, s) + self.assertEqual(c.shape, (2, 2, 4, 10, 10)) + def test_gemv_gemm_same_precision(self): mx.random.seed(0) N = 256 @@ -1178,6 +1288,16 @@ class TestBlas(mlx_tests.MLXTestCase): c_np = np.matmul(np.array(a).T, b) self.assertTrue(np.allclose(c, c_np)) + # Check shapes + a = mx.random.normal((2, 3)).astype(mx.complex64) + b = mx.random.normal((3,)) + self.assertEqual((a @ b).shape, (2,)) + + a = mx.random.normal((2, 3)).astype(mx.complex64) + b = mx.random.normal((3,)) + c = mx.random.normal((2,)) + self.assertEqual(mx.addmm(c, a, b).shape, (2,)) + def test_complex_gemm(self): M = 16 K = 50 @@ -1193,13 +1313,6 @@ class TestBlas(mlx_tests.MLXTestCase): self.assertTrue(np.allclose(c, c_np)) # Test addmm - M = 16 - K = 50 - N = 32 - - def rand(shape): - return mx.random.uniform(shape=shape) + 1j * mx.random.uniform(shape=shape) - a = rand((M, K)) b = rand((K, N)) c = rand((M, N)) @@ -1207,6 +1320,13 @@ class TestBlas(mlx_tests.MLXTestCase): out_np = 2.0 * np.matmul(a, b) + 2.0 * c self.assertTrue(np.allclose(out, out_np)) + # complex with real + a = rand((M, K)).real + b = rand((K, N)) + c = mx.matmul(a, b) + c_np = np.matmul(a, b) + self.assertTrue(np.allclose(out, out_np)) + if __name__ == "__main__": - unittest.main() + mlx_tests.MLXTestRunner() diff --git a/python/tests/test_compile.py b/python/tests/test_compile.py index f5ce496cd..ada2b1484 100644 --- a/python/tests/test_compile.py +++ b/python/tests/test_compile.py @@ -2,8 +2,10 @@ import gc import io +import math import unittest from functools import partial +from io import StringIO import mlx.core as mx import mlx_tests @@ -979,6 +981,39 @@ class TestCompile(mlx_tests.MLXTestCase): self.assertEqual(mem_pre, mem_post) + def test_double_constant(self): + with mx.stream(mx.cpu): + x = mx.array(1.0, dtype=mx.float64) + + def fun(x): + return (x + math.pi) * 2.0 + + y = fun(x).item() + y_compiled = mx.compile(fun)(x).item() + self.assertEqual(y, y_compiled) + + def test_shared_broadcast(self): + def fun(x, y, z): + yy = mx.broadcast_to(y, z.shape) + return (x + yy * z), yy.sum() + + a = mx.random.normal((10, 10)) + b = mx.array(0.1) + c = mx.random.normal((10, 10)) + mx.eval(a, b, c) + fc = mx.compile(fun) + d = fc(a, b, c) + + s = StringIO() + mx.export_to_dot(s, a=a, b=b, c=c, d1=d[0], d2=d[1]) + s.seek(0) + s = s.read() + + self.assertTrue("CompiledBroadcastMultiplyAdd" in s) + d_hat = fun(a, b, c) + self.assertTrue(mx.allclose(d[0], d_hat[0])) + self.assertTrue(mx.allclose(d[1], d_hat[1])) + if __name__ == "__main__": - unittest.main() + mlx_tests.MLXTestRunner() diff --git a/python/tests/test_constants.py b/python/tests/test_constants.py index 104e7522d..cfd971fbe 100644 --- a/python/tests/test_constants.py +++ b/python/tests/test_constants.py @@ -38,4 +38,4 @@ class TestConstants(mlx_tests.MLXTestCase): if __name__ == "__main__": - unittest.main() + mlx_tests.MLXTestRunner() diff --git a/python/tests/test_conv.py b/python/tests/test_conv.py index 671c86a32..9be22e01b 100644 --- a/python/tests/test_conv.py +++ b/python/tests/test_conv.py @@ -1088,6 +1088,104 @@ class TestConv(mlx_tests.MLXTestCase): atol=2e-5 if dtype == np.float32 else 5e-4, ) + @unittest.skipIf(not has_torch, "requires Torch") + def test_asymmetric_padding(self): + inputs = np.random.normal(size=(2, 8, 8, 8, 3)).astype(np.float32) + kernel = np.random.normal(size=(2, 3, 3, 3, 3)).astype(np.float32) + strides = (2, 2, 2) + + pt_out = torch.conv3d( + torch.permute(torch.tensor(inputs), (0, 4, 1, 2, 3)), + torch.permute(torch.tensor(kernel), (0, 4, 1, 2, 3)), + stride=strides, + padding=2, + ) + pt_out = torch.permute(pt_out, (0, 2, 3, 4, 1))[:, 1:, 1:, 1:, :].numpy() + + mx_out = mx.conv_general( + mx.array(inputs), + mx.array(kernel), + stride=strides, + padding=([0, 0, 0], [1, 1, 1]), + ) + + self.assertTrue(mx.allclose(mx_out, mx.array(pt_out), atol=1e-3, rtol=1e-3)) + + inputs = np.random.normal(size=(2, 10, 10, 3)).astype(np.float32) + kernel = np.random.normal(size=(2, 2, 2, 3)).astype(np.float32) + + pt_out = torch.conv2d( + torch.permute(torch.tensor(inputs), (0, 3, 1, 2)), + torch.permute(torch.tensor(kernel), (0, 3, 1, 2)), + stride=1, + padding=(1, 0), + ) + pt_out = torch.permute(pt_out, (0, 2, 3, 1))[:, 1:].numpy() + + mx_out = mx.conv_general( + mx.array(inputs), + mx.array(kernel), + stride=1, + padding=([0, 0], [1, 0]), + ) + self.assertTrue(mx.allclose(mx_out, mx.array(pt_out), atol=1e-3, rtol=1e-3)) + + def test_basic_grad_shapes(self): + def loss_fn(kernel, inputs, strides, groups): + return mx.sum( + mx.conv_general( + inputs, + kernel, + stride=strides, + groups=groups, + ) + ) + + for in_shape, k_shape, strides, groups in [ + ((3, 5, 4), (6, 2, 2), (2,), 2), + ((3, 5, 4), (24, 2, 1), (2,), 4), + ((3, 5, 5, 4), (6, 2, 2, 2), (2, 1), 2), + ((3, 5, 5, 4), (24, 2, 2, 1), (2, 2), 4), + ]: + grads = mx.grad(loss_fn)( + mx.zeros(k_shape), mx.zeros(in_shape), strides, groups + ) + self.assertEqual(grads.shape, k_shape) + + def test_1d_conv_with_2d(self): + x = mx.random.uniform(shape=(2, 10, 16)) + y = mx.random.normal(shape=(16, 3, 16)) + + out = mx.conv1d(x, y, padding=1) + out_2d = mx.conv2d( + mx.expand_dims(x, axis=2), mx.expand_dims(y, axis=2), padding=(1, 0) + ) + + self.assertTrue(mx.allclose(out, out_2d.squeeze(2))) + + x = mx.random.uniform(shape=(2, 10, 4)) + y = mx.random.normal(shape=(4, 3, 4)) + + out = mx.conv1d(x, y, padding=1) + out_2d = mx.conv2d( + mx.expand_dims(x, axis=2), mx.expand_dims(y, axis=2), padding=(1, 0) + ) + + self.assertTrue(mx.allclose(out, out_2d.squeeze(2))) + + def test_conv2d_unaligned_channels(self): + x = mx.random.uniform(shape=(2, 16, 16, 21)) + w = mx.random.uniform(shape=(32, 3, 3, 21)) + y = mx.conv2d(x, w, stream=mx.cpu) + y_hat = mx.conv2d(x, w) + self.assertTrue(mx.allclose(y, y_hat)) + + x = mx.random.uniform(shape=(2, 16, 16, 21)) + w = mx.random.uniform(shape=(21, 3, 3, 21)) + y = mx.conv2d(x, w, stream=mx.cpu) + y_hat = mx.conv2d(x, w) + self.assertTrue(mx.allclose(y, y_hat)) + if __name__ == "__main__": - unittest.main() + mlx_tests.MLXTestRunner() diff --git a/python/tests/test_conv_transpose.py b/python/tests/test_conv_transpose.py index 1ac20cbb1..7289955ed 100644 --- a/python/tests/test_conv_transpose.py +++ b/python/tests/test_conv_transpose.py @@ -596,6 +596,215 @@ class TestConvTranspose(mlx_tests.MLXTestCase): N, C, O, idim, kdim, stride, padding, dilation, dtype=dtype ) + @unittest.skipIf(not has_torch, "requires Torch") + def test_torch_conv_tranpose_1d_output_padding(self): + def run_conv_transpose_1d_output_padding( + N, C, O, iH, kH, stride, padding, output_padding, dtype="float32", atol=1e-5 + ): + with self.subTest( + dtype=dtype, + N=N, + C=C, + O=O, + iH=iH, + kH=kH, + stride=stride, + padding=padding, + output_padding=output_padding, + ): + np_dtype = getattr(np, dtype) + np.random.seed(0) + in_np = np.random.normal(0, 1.0 / C, (N, iH, C)).astype(np_dtype) + wt_np = np.random.normal(0, 1.0 / C, (O, kH, C)).astype(np_dtype) + + in_mx, wt_mx = map(mx.array, (in_np, wt_np)) + in_pt = torch.from_numpy(in_np.transpose(0, 2, 1)) + wt_pt = torch.from_numpy(wt_np.transpose(2, 0, 1)) + + out_mx = mx.conv_transpose1d( + in_mx, + wt_mx, + stride=stride, + padding=padding, + output_padding=output_padding, + ) + + out_pt = torch.conv_transpose1d( + in_pt, + wt_pt, + stride=stride, + padding=padding, + output_padding=output_padding, + ) + out_pt = torch.transpose(out_pt, 2, 1) + + self.assertEqual(out_pt.shape, out_mx.shape) + self.assertTrue(np.allclose(out_pt.numpy(), out_mx, atol=atol)) + + for dtype in ("float32",): + for N, C, O in ((1, 1, 1), (1, 6, 1), (4, 32, 64)): + for iH, kH, stride, padding, output_padding in ( + (3, 2, 2, 0, 1), + (5, 3, 2, 1, 0), + (7, 4, 3, 1, 2), + ): + run_conv_transpose_1d_output_padding( + N, C, O, iH, kH, stride, padding, output_padding, dtype=dtype + ) + + @unittest.skipIf(not has_torch, "requires Torch") + def test_torch_conv_transpose_2d_output_padding(self): + def run_conv_transpose_2d_output_padding( + N, + C, + O, + idim, + kdim, + stride, + padding, + output_padding, + dtype="float32", + atol=1e-5, + ): + with self.subTest( + dtype=dtype, + N=N, + C=C, + O=O, + idim=idim, + kdim=kdim, + stride=stride, + padding=padding, + output_padding=output_padding, + ): + np_dtype = getattr(np, dtype) + np.random.seed(0) + iH, iW = idim + kH, kW = kdim + in_np = np.random.normal(0, 1.0 / C, (N, iH, iW, C)).astype(np_dtype) + wt_np = np.random.normal(0, 1.0 / C, (O, kH, kW, C)).astype(np_dtype) + + in_mx, wt_mx = map(mx.array, (in_np, wt_np)) + in_pt = torch.from_numpy(in_np.transpose(0, 3, 1, 2)) + wt_pt = torch.from_numpy(wt_np.transpose(3, 0, 1, 2)) + + out_mx = mx.conv_transpose2d( + in_mx, + wt_mx, + stride=stride, + padding=padding, + output_padding=output_padding, + ) + + out_pt = torch.conv_transpose2d( + in_pt, + wt_pt, + stride=stride, + padding=padding, + output_padding=output_padding, + ) + out_pt = torch.permute(out_pt, (0, 2, 3, 1)).numpy(force=True) + + self.assertEqual(out_pt.shape, out_mx.shape) + self.assertTrue(np.allclose(out_pt, out_mx, atol=atol)) + + for dtype in ("float32",): + for N, C, O in ((1, 1, 1), (1, 6, 1), (4, 32, 64)): + for idim, kdim, stride, padding, output_padding in ( + ((3, 3), (2, 2), (2, 2), (0, 0), (1, 1)), + ((5, 5), (3, 3), (2, 2), (1, 1), (0, 0)), + ((7, 7), (4, 4), (3, 3), (1, 1), (2, 2)), + ): + run_conv_transpose_2d_output_padding( + N, + C, + O, + idim, + kdim, + stride, + padding, + output_padding, + dtype=dtype, + ) + + @unittest.skipIf(not has_torch, "requires Torch") + def test_torch_conv_transpose_3d_output_padding(self): + def run_conv_transpose_3d_output_padding( + N, + C, + O, + idim, + kdim, + stride, + padding, + output_padding, + dtype="float32", + atol=1e-5, + ): + with self.subTest( + dtype=dtype, + N=N, + C=C, + O=O, + idim=idim, + kdim=kdim, + stride=stride, + padding=padding, + output_padding=output_padding, + ): + np_dtype = getattr(np, dtype) + np.random.seed(0) + iD, iH, iW = idim + kD, kH, kW = kdim + in_np = np.random.normal(0, 1.0 / C, (N, iD, iH, iW, C)).astype( + np_dtype + ) + wt_np = np.random.normal(0, 1.0 / C, (O, kD, kH, kW, C)).astype( + np_dtype + ) + + in_mx, wt_mx = map(mx.array, (in_np, wt_np)) + in_pt = torch.from_numpy(in_np.transpose(0, 4, 1, 2, 3)) + wt_pt = torch.from_numpy(wt_np.transpose(4, 0, 1, 2, 3)) + + out_mx = mx.conv_transpose3d( + in_mx, + wt_mx, + stride=stride, + padding=padding, + output_padding=output_padding, + ) + out_pt = torch.conv_transpose3d( + in_pt, + wt_pt, + stride=stride, + padding=padding, + output_padding=output_padding, + ) + out_pt = torch.permute(out_pt, (0, 2, 3, 4, 1)).numpy(force=True) + + self.assertEqual(out_pt.shape, out_mx.shape) + self.assertTrue(np.allclose(out_pt, out_mx, atol=atol)) + + for dtype in ("float32",): + for N, C, O in ((1, 1, 1), (1, 6, 1), (4, 32, 64)): + for idim, kdim, stride, padding, output_padding in ( + ((3, 3, 3), (2, 2, 2), (2, 2, 2), (0, 0, 0), (1, 1, 1)), + ((5, 5, 5), (3, 3, 3), (2, 2, 2), (1, 1, 1), (0, 0, 0)), + ((7, 7, 7), (4, 4, 4), (3, 3, 3), (1, 1, 1), (2, 2, 2)), + ): + run_conv_transpose_3d_output_padding( + N, + C, + O, + idim, + kdim, + stride, + padding, + output_padding, + dtype=dtype, + ) + if __name__ == "__main__": - unittest.main() + mlx_tests.MLXTestRunner() diff --git a/python/tests/test_device.py b/python/tests/test_device.py index 53826cad7..d51028def 100644 --- a/python/tests/test_device.py +++ b/python/tests/test_device.py @@ -10,7 +10,7 @@ import mlx_tests class TestDefaultDevice(unittest.TestCase): def test_mlx_default_device(self): device = mx.default_device() - if mx.metal.is_available(): + if mx.is_available(mx.gpu): self.assertEqual(device, mx.Device(mx.gpu)) self.assertEqual(str(device), "Device(gpu, 0)") self.assertEqual(device, mx.gpu) @@ -38,7 +38,7 @@ class TestDevice(mlx_tests.MLXTestCase): # Restore device mx.set_default_device(device) - @unittest.skipIf(not mx.metal.is_available(), "Metal is not available") + @unittest.skipIf(not mx.is_available(mx.gpu), "GPU is not available") def test_device_context(self): default = mx.default_device() diff = mx.cpu if default == mx.gpu else mx.gpu @@ -73,7 +73,7 @@ class TestStream(mlx_tests.MLXTestCase): self.assertEqual(s2.device, mx.default_device()) self.assertNotEqual(s1, s2) - if mx.metal.is_available(): + if mx.is_available(mx.gpu): s_gpu = mx.default_stream(mx.gpu) self.assertEqual(s_gpu.device, mx.gpu) else: @@ -86,7 +86,7 @@ class TestStream(mlx_tests.MLXTestCase): s_cpu = mx.new_stream(mx.cpu) self.assertEqual(s_cpu.device, mx.cpu) - if mx.metal.is_available(): + if mx.is_available(mx.gpu): s_gpu = mx.new_stream(mx.gpu) self.assertEqual(s_gpu.device, mx.gpu) else: @@ -99,7 +99,7 @@ class TestStream(mlx_tests.MLXTestCase): a = mx.add(x, y, stream=mx.default_stream(mx.default_device())) - if mx.metal.is_available(): + if mx.is_available(mx.gpu): b = mx.add(x, y, stream=mx.default_stream(mx.gpu)) self.assertEqual(a.item(), b.item()) s_gpu = mx.new_stream(mx.gpu) @@ -114,4 +114,4 @@ class TestStream(mlx_tests.MLXTestCase): if __name__ == "__main__": - unittest.main() + mlx_tests.MLXTestRunner() diff --git a/python/tests/test_double.py b/python/tests/test_double.py index 10fce0db1..fccf3628f 100644 --- a/python/tests/test_double.py +++ b/python/tests/test_double.py @@ -294,4 +294,4 @@ class TestDouble(mlx_tests.MLXTestCase): if __name__ == "__main__": - unittest.main() + mlx_tests.MLXTestRunner() diff --git a/python/tests/test_einsum.py b/python/tests/test_einsum.py index 19ea8178e..a73ea3818 100644 --- a/python/tests/test_einsum.py +++ b/python/tests/test_einsum.py @@ -360,4 +360,4 @@ class TestEinsum(mlx_tests.MLXTestCase): if __name__ == "__main__": - unittest.main() + mlx_tests.MLXTestRunner() diff --git a/python/tests/test_eval.py b/python/tests/test_eval.py index fcd424343..5d6daaec2 100644 --- a/python/tests/test_eval.py +++ b/python/tests/test_eval.py @@ -172,7 +172,7 @@ class TestEval(mlx_tests.MLXTestCase): post = mx.get_peak_memory() self.assertEqual(pre, post) - @unittest.skipIf(not mx.metal.is_available(), "Metal is not available") + @unittest.skipIf(not mx.is_available(mx.gpu), "GPU is not available") def test_multistream_deadlock(self): s1 = mx.default_stream(mx.gpu) s2 = mx.new_stream(mx.gpu) @@ -197,4 +197,4 @@ class TestEval(mlx_tests.MLXTestCase): if __name__ == "__main__": - unittest.main() + mlx_tests.MLXTestRunner() diff --git a/python/tests/test_export_import.py b/python/tests/test_export_import.py index 2b4b425ca..099be0cc0 100644 --- a/python/tests/test_export_import.py +++ b/python/tests/test_export_import.py @@ -6,6 +6,7 @@ import tempfile import unittest import mlx.core as mx +import mlx.nn as nn import mlx_tests @@ -242,6 +243,7 @@ class TestExportImport(mlx_tests.MLXTestCase): def test_leaks(self): path = os.path.join(self.test_dir, "fn.mlxfn") + mx.synchronize() if mx.metal.is_available(): mem_pre = mx.get_active_memory() else: @@ -267,6 +269,83 @@ class TestExportImport(mlx_tests.MLXTestCase): self.assertEqual(mem_pre, mem_post) + def test_export_import_shapeless(self): + path = os.path.join(self.test_dir, "fn.mlxfn") + + def fun(*args): + return sum(args) + + with mx.exporter(path, fun, shapeless=True) as exporter: + exporter(mx.array(1)) + exporter(mx.array(1), mx.array(2)) + exporter(mx.array(1), mx.array(2), mx.array(3)) + + f2 = mx.import_function(path) + self.assertEqual(f2(mx.array(1))[0].item(), 1) + self.assertEqual(f2(mx.array(1), mx.array(1))[0].item(), 2) + self.assertEqual(f2(mx.array(1), mx.array(1), mx.array(1))[0].item(), 3) + with self.assertRaises(ValueError): + f2(mx.array(10), mx.array([5, 10, 20])) + + def test_export_scatter_gather(self): + path = os.path.join(self.test_dir, "fn.mlxfn") + + def fun(a, b): + return mx.take_along_axis(a, b, axis=0) + + x = mx.random.uniform(shape=(4, 4)) + y = mx.array([[0, 1, 2, 3], [1, 2, 0, 3]]) + mx.export_function(path, fun, (x, y)) + imported_fun = mx.import_function(path) + expected = fun(x, y) + out = imported_fun(x, y)[0] + self.assertTrue(mx.array_equal(expected, out)) + + def fun(a, b, c): + return mx.put_along_axis(a, b, c, axis=0) + + x = mx.random.uniform(shape=(4, 4)) + y = mx.array([[0, 1, 2, 3], [1, 2, 0, 3]]) + z = mx.random.uniform(shape=(2, 4)) + mx.export_function(path, fun, (x, y, z)) + imported_fun = mx.import_function(path) + expected = fun(x, y, z) + out = imported_fun(x, y, z)[0] + self.assertTrue(mx.array_equal(expected, out)) + + def test_export_conv(self): + path = os.path.join(self.test_dir, "fn.mlxfn") + + class Model(nn.Module): + def __init__(self): + super().__init__() + self.c1 = nn.Conv2d( + 3, 16, kernel_size=3, stride=1, padding=1, bias=False + ) + self.c2 = nn.Conv2d( + 16, 16, kernel_size=3, stride=2, padding=1, bias=False + ) + self.c3 = nn.Conv2d( + 16, 16, kernel_size=3, stride=1, padding=2, bias=False + ) + + def __call__(self, x): + return self.c3(self.c2(self.c1(x))) + + model = Model() + mx.eval(model.parameters()) + + def forward(x): + return model(x) + + input_data = mx.random.normal(shape=(4, 32, 32, 3)) + mx.export_function(path, forward, input_data) + + imported_fn = mx.import_function(path) + out = imported_fn(input_data)[0] + expected = forward(input_data) + self.assertTrue(mx.allclose(expected, out)) + if __name__ == "__main__": - unittest.main() + mlx_tests.MLXTestRunner() diff --git a/python/tests/test_fast.py b/python/tests/test_fast.py index 2c90a3755..13c65de99 100644 --- a/python/tests/test_fast.py +++ b/python/tests/test_fast.py @@ -735,6 +735,41 @@ class TestFast(mlx_tests.MLXTestCase): )[0] self.assertEqual(out.item(), 2) + @unittest.skipIf(not mx.metal.is_available(), "Metal is not available") + def test_custom_kernel_caching(self): + def call_kernel(a: mx.array, source): + kernel = mx.fast.metal_kernel( + name="my_kernel", + input_names=["inp"], + output_names=["out"], + source=source, + ) + return kernel( + inputs=[a], + grid=(a.size, 1, 1), + threadgroup=(a.size, 1, 1), + output_shapes=[a.shape], + output_dtypes=[a.dtype], + stream=mx.gpu, + )[0] + + a = mx.random.normal(shape=(32,)) + + source = """ + uint elem = thread_position_in_grid.x; + out[elem] = 0.0; + """ + + out = call_kernel(a, source) + self.assertTrue(mx.array_equal(out, mx.zeros_like(out))) + + source = """ + uint elem = thread_position_in_grid.x; + out[elem] = 1.0; + """ + out = call_kernel(a, source) + self.assertTrue(mx.array_equal(out, mx.ones_like(out))) + if __name__ == "__main__": - unittest.main() + mlx_tests.MLXTestRunner() diff --git a/python/tests/test_fast_sdpa.py b/python/tests/test_fast_sdpa.py index d35a2b1da..a929e91cf 100644 --- a/python/tests/test_fast_sdpa.py +++ b/python/tests/test_fast_sdpa.py @@ -473,6 +473,46 @@ class TestFastSDPA(mlx_tests.MLXTestCase): out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale) self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4)) + def test_sdpa_vector_batched(self): + D = 64 + q = mx.random.normal(shape=(2, 1, 3, D)) + k = mx.random.normal(shape=(2, 1, 3, D)) + v = mx.random.normal(shape=(2, 1, 3, D)) + + out = mx.fast.scaled_dot_product_attention(q, k, v, mask=None, scale=1.0) + ref = mlx_ref_attn(q, k, v) + self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4)) + + q = mx.random.normal(shape=(2, 4, 3, D)) + out = mx.fast.scaled_dot_product_attention(q, k, v, mask=None, scale=1.0) + ref = mlx_ref_attn(q, k, v) + self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4)) + + q = mx.random.normal(shape=(2, 3, 4, D)).swapaxes(1, 2) + out = mx.fast.scaled_dot_product_attention(q, k, v, mask=None, scale=1.0) + ref = mlx_ref_attn(q, k, v) + self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4)) + + k = mx.random.normal(shape=(2, 3, 1, D)).swapaxes(1, 2) + out = mx.fast.scaled_dot_product_attention(q, k, v, mask=None, scale=1.0) + ref = mlx_ref_attn(q, k, v) + self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4)) + + q = mx.random.normal(shape=(2, 4, 3, D)) + k = mx.random.normal(shape=(2, 3, 2, D)).swapaxes(1, 2) + v = mx.random.normal(shape=(2, 2, 3, D)) + out = mx.fast.scaled_dot_product_attention(q, k, v, mask=None, scale=1.0) + ref = mlx_ref_attn(q, k, v) + self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4)) + + q = mx.random.normal(shape=(2, 4, 3, D)) + k = mx.random.normal(shape=(2, 1, 3, D)) + v = mx.random.normal(shape=(2, 1, 3, D)) + mask = 10 * mx.random.normal(shape=(1, 2, 3, 3)).swapaxes(0, 1) + out = mx.fast.scaled_dot_product_attention(q, k, v, mask=mask, scale=1.0) + ref = mlx_ref_attn(q, k, v, mask=mask) + self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4)) + class TestSDPA(mlx_tests.MLXTestCase): @property @@ -567,7 +607,7 @@ class TestSDPA(mlx_tests.MLXTestCase): out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=mask) self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4)) - def test_sdpa_prommote_mask(self): + def test_sdpa_promote_mask(self): mask = mx.array(2.0, mx.bfloat16) D = 64 Nq = 4 @@ -613,4 +653,4 @@ class TestSDPA(mlx_tests.MLXTestCase): if __name__ == "__main__": - unittest.main(failfast=True) + mlx_tests.MLXTestRunner(failfast=True) diff --git a/python/tests/test_fft.py b/python/tests/test_fft.py index ec9a48f00..07ab62672 100644 --- a/python/tests/test_fft.py +++ b/python/tests/test_fft.py @@ -7,6 +7,13 @@ import mlx.core as mx import mlx_tests import numpy as np +try: + import torch + + has_torch = True +except ImportError as e: + has_torch = False + class TestFFT(mlx_tests.MLXTestCase): def check_mx_np(self, op_mx, op_np, a_np, atol=1e-5, rtol=1e-6, **kwargs): @@ -194,6 +201,123 @@ class TestFFT(mlx_tests.MLXTestCase): r_np = np.fft.ifft(segment, n=n_fft) self.assertTrue(np.allclose(r, r_np, atol=1e-5, rtol=1e-5)) + def test_fft_throws(self): + x = mx.array(3.0) + with self.assertRaises(ValueError): + mx.fft.irfftn(x) + + def test_fftshift(self): + # Test 1D arrays + r = np.random.rand(100).astype(np.float32) + self.check_mx_np(mx.fft.fftshift, np.fft.fftshift, r) + + # Test with specific axis + r = np.random.rand(4, 6).astype(np.float32) + self.check_mx_np(mx.fft.fftshift, np.fft.fftshift, r, axes=[0]) + self.check_mx_np(mx.fft.fftshift, np.fft.fftshift, r, axes=[1]) + self.check_mx_np(mx.fft.fftshift, np.fft.fftshift, r, axes=[0, 1]) + + # Test with negative axes + self.check_mx_np(mx.fft.fftshift, np.fft.fftshift, r, axes=[-1]) + + # Test with odd lengths + r = np.random.rand(5, 7).astype(np.float32) + self.check_mx_np(mx.fft.fftshift, np.fft.fftshift, r) + self.check_mx_np(mx.fft.fftshift, np.fft.fftshift, r, axes=[0]) + + # Test with complex input + r = np.random.rand(8, 8).astype(np.float32) + i = np.random.rand(8, 8).astype(np.float32) + c = r + 1j * i + self.check_mx_np(mx.fft.fftshift, np.fft.fftshift, c) + + def test_ifftshift(self): + # Test 1D arrays + r = np.random.rand(100).astype(np.float32) + self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r) + + # Test with specific axis + r = np.random.rand(4, 6).astype(np.float32) + self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r, axes=[0]) + self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r, axes=[1]) + self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r, axes=[0, 1]) + + # Test with negative axes + self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r, axes=[-1]) + + # Test with odd lengths + r = np.random.rand(5, 7).astype(np.float32) + self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r) + self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r, axes=[0]) + + # Test with complex input + r = np.random.rand(8, 8).astype(np.float32) + i = np.random.rand(8, 8).astype(np.float32) + c = r + 1j * i + self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, c) + + def test_fftshift_errors(self): + # Test invalid axes + x = mx.array(np.random.rand(4, 4).astype(np.float32)) + with self.assertRaises(ValueError): + mx.fft.fftshift(x, axes=[2]) + with self.assertRaises(ValueError): + mx.fft.fftshift(x, axes=[-3]) + + # Test empty array + x = mx.array([]) + self.assertTrue(mx.array_equal(mx.fft.fftshift(x), x)) + + @unittest.skipIf(not has_torch, "requires PyTorch") + def test_fft_grads(self): + real = [True, False] + inverse = [True, False] + axes = [ + (-1,), + (-2, -1), + ] + shapes = [ + (4, 4), + (2, 4), + (2, 7), + (7, 7), + ] + + mxffts = { + (True, True): mx.fft.irfftn, + (True, False): mx.fft.rfftn, + (False, True): mx.fft.ifftn, + (False, False): mx.fft.fftn, + } + tffts = { + (True, True): torch.fft.irfftn, + (True, False): torch.fft.rfftn, + (False, True): torch.fft.ifftn, + (False, False): torch.fft.fftn, + } + + for r, i, ax, sh in itertools.product(real, inverse, axes, shapes): + + def f(x): + y = mxffts[r, i](x) + return (mx.abs(y) ** 2).sum() + + def g(x): + y = tffts[r, i](x) + return (torch.abs(y) ** 2).sum() + + if r and not i: + x = mx.random.normal(sh) + else: + x = mx.random.normal((*sh, 2)).view(mx.complex64).squeeze() + fx = f(x) + gx = g(torch.tensor(x)) + self.assertLess((fx - gx).abs().max() / gx.abs().mean(), 1e-4) + + dfdx = mx.grad(f)(x) + dgdx = torch.func.grad(g)(torch.tensor(x)) + self.assertLess((dfdx - dgdx).abs().max() / dgdx.abs().mean(), 1e-4) + if __name__ == "__main__": - unittest.main() + mlx_tests.MLXTestRunner() diff --git a/python/tests/test_graph.py b/python/tests/test_graph.py index 4b8f6d86a..7c6a11412 100644 --- a/python/tests/test_graph.py +++ b/python/tests/test_graph.py @@ -34,4 +34,4 @@ class TestGraph(mlx_tests.MLXTestCase): if __name__ == "__main__": - unittest.main() + mlx_tests.MLXTestRunner() diff --git a/python/tests/test_init.py b/python/tests/test_init.py index 4b209736f..046a6e836 100644 --- a/python/tests/test_init.py +++ b/python/tests/test_init.py @@ -136,4 +136,4 @@ class TestInit(mlx_tests.MLXTestCase): if __name__ == "__main__": - unittest.main() + mlx_tests.MLXTestRunner() diff --git a/python/tests/test_linalg.py b/python/tests/test_linalg.py index ffa355c10..81a43ed7f 100644 --- a/python/tests/test_linalg.py +++ b/python/tests/test_linalg.py @@ -232,6 +232,11 @@ class TestLinalg(mlx_tests.MLXTestCase): for M, M_plus in zip(AB, pinvs): self.assertTrue(mx.allclose(M @ M_plus @ M, M, rtol=0, atol=1e-3)) + # Test singular matrix + A = mx.array([[4.0, 1.0], [4.0, 1.0]]) + A_plus = mx.linalg.pinv(A, stream=mx.cpu) + self.assertTrue(mx.allclose(A @ A_plus @ A, A)) + def test_cholesky_inv(self): mx.random.seed(7) @@ -307,6 +312,53 @@ class TestLinalg(mlx_tests.MLXTestCase): with self.assertRaises(ValueError): mx.linalg.cross(a, b) + def test_eig(self): + tols = {"atol": 1e-5, "rtol": 1e-5} + + def check_eigs_and_vecs(A_np, kwargs={}): + A = mx.array(A_np) + eig_vals, eig_vecs = mx.linalg.eig(A, stream=mx.cpu, **kwargs) + self.assertTrue( + mx.allclose(A @ eig_vecs, eig_vals[..., None, :] * eig_vecs, **tols) + ) + eig_vals_only = mx.linalg.eigvals(A, stream=mx.cpu, **kwargs) + self.assertTrue(mx.allclose(eig_vals, eig_vals_only, **tols)) + + # Test a simple 2x2 matrix + A_np = np.array([[1.0, 1.0], [3.0, 4.0]], dtype=np.float32) + check_eigs_and_vecs(A_np) + + # Test complex eigenvalues + A_np = np.array([[1.0, -1.0], [1.0, 1.0]], dtype=np.float32) + check_eigs_and_vecs(A_np) + + # Test a larger random symmetric matrix + n = 5 + np.random.seed(1) + A_np = np.random.randn(n, n).astype(np.float32) + check_eigs_and_vecs(A_np) + + # Test with batched input + A_np = np.random.randn(3, n, n).astype(np.float32) + check_eigs_and_vecs(A_np) + + # Test error cases + with self.assertRaises(ValueError): + mx.linalg.eig(mx.array([1.0, 2.0])) # 1D array + + with self.assertRaises(ValueError): + mx.linalg.eig( + mx.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) + ) # Non-square matrix + + with self.assertRaises(ValueError): + mx.linalg.eigvals(mx.array([1.0, 2.0])) # 1D array + + with self.assertRaises(ValueError): + mx.linalg.eigvals( + mx.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) + ) # Non-square matrix + def test_eigh(self): tols = {"atol": 1e-5, "rtol": 1e-5} @@ -341,6 +393,13 @@ class TestLinalg(mlx_tests.MLXTestCase): A_np = (A_np + np.transpose(A_np, (0, 2, 1))) / 2 check_eigs_and_vecs(A_np) + # Test with complex inputs + A_np = ( + np.random.randn(8, 8, 2).astype(np.float32).view(np.complex64).squeeze(-1) + ) + A_np = A_np + A_np.T.conj() + check_eigs_and_vecs(A_np) + # Test error cases with self.assertRaises(ValueError): mx.linalg.eigh(mx.array([1.0, 2.0])) # 1D array @@ -486,4 +545,4 @@ class TestLinalg(mlx_tests.MLXTestCase): if __name__ == "__main__": - unittest.main() + mlx_tests.MLXTestRunner() diff --git a/python/tests/test_load.py b/python/tests/test_load.py index 341564dae..840d3b471 100644 --- a/python/tests/test_load.py +++ b/python/tests/test_load.py @@ -391,13 +391,15 @@ class TestLoad(mlx_tests.MLXTestCase): scale = mx.array(2.0) y = mx.load(save_file) mx.eval(y) + mx.synchronize() load_only = mx.get_peak_memory() y = mx.load(save_file) * scale mx.eval(y) + mx.synchronize() load_with_binary = mx.get_peak_memory() self.assertEqual(load_only, load_with_binary) if __name__ == "__main__": - unittest.main() + mlx_tests.MLXTestRunner() diff --git a/python/tests/test_losses.py b/python/tests/test_losses.py index 102ec857d..2ef1fa36c 100644 --- a/python/tests/test_losses.py +++ b/python/tests/test_losses.py @@ -83,14 +83,14 @@ class TestLosses(mlx_tests.MLXTestCase): logits, targets, reduction="mean" ) expected_mean = mx.mean(expected_none) - self.assertEqual(losses_mean, expected_mean) + self.assertTrue(mx.allclose(losses_mean, expected_mean)) # Test with reduction 'sum' losses_sum = nn.losses.binary_cross_entropy( logits, targets, reduction="sum" ) expected_sum = mx.sum(expected_none) - self.assertEqual(losses_sum, expected_sum) + self.assertTrue(mx.allclose(losses_sum, expected_sum)) # With weights, no label smoothing weights = mx.array([1.0, 2.0, 1.0, 2.0]) @@ -414,4 +414,4 @@ class TestLosses(mlx_tests.MLXTestCase): if __name__ == "__main__": - unittest.main() + mlx_tests.MLXTestRunner() diff --git a/python/tests/test_memory.py b/python/tests/test_memory.py index 7343bdc91..08da7ccc6 100644 --- a/python/tests/test_memory.py +++ b/python/tests/test_memory.py @@ -60,4 +60,4 @@ class TestMemory(mlx_tests.MLXTestCase): if __name__ == "__main__": - unittest.main() + mlx_tests.MLXTestRunner() diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index 9cfa25dae..ae3fae4da 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -8,7 +8,7 @@ import mlx.core as mx import mlx.nn as nn import mlx_tests import numpy as np -from mlx.utils import tree_flatten, tree_map +from mlx.utils import tree_flatten, tree_map, tree_reduce class TestBase(mlx_tests.MLXTestCase): @@ -198,6 +198,13 @@ class TestBase(mlx_tests.MLXTestCase): self.assertTrue(isinstance(m.layers[1], nn.ReLU)) self.assertTrue(isinstance(m.layers[2], nn.QuantizedLinear)) + def test_quantize_freeze(self): + lin = nn.Linear(512, 512) + qlin = lin.to_quantized() + qlin.unfreeze(keys=["scales"]) + size = tree_reduce(lambda acc, p: acc + p.size, qlin.trainable_parameters(), 0) + self.assertTrue(size > 0) + def test_grad_of_module(self): class Model(nn.Module): def __init__(self): @@ -212,6 +219,66 @@ class TestBase(mlx_tests.MLXTestCase): x = mx.zeros((3,)) mx.grad(loss_fn)(model) + def test_update(self): + m = nn.Sequential(nn.Linear(3, 3), nn.Linear(3, 3)) + + # Updating non-existent parameters + with self.assertRaises(ValueError): + updates = {"layers": [{"value": 0}]} + m.update(updates) + + with self.assertRaises(ValueError): + updates = {"layers": ["hello"]} + m.update(updates) + + # Wronge type + with self.assertRaises(ValueError): + updates = {"layers": [{"weight": "hi"}]} + m.update(updates) + + def test_update_modules(self): + m = nn.Sequential(nn.Linear(3, 3), nn.Linear(3, 3)) + + # Updating non-existent modules should not be allowed by default + with self.assertRaises(ValueError): + m = m.update_modules({"values": [0, 1]}) + + # Update wrong types + with self.assertRaises(ValueError): + m = m.update_modules({"layers": [0, 1]}) + + class MyModule(nn.Module): + def __init__(self): + super().__init__() + self.test = mx.array(1.0) + self.list = [mx.array(1.0), mx.array(2.0)] + + m = MyModule() + with self.assertRaises(ValueError): + m = m.update_modules({"test": "hi"}) + with self.assertRaises(ValueError): + m = m.update_modules({"list": ["hi"]}) + + # Allow updating a strict subset + m = nn.Sequential(nn.Linear(3, 3), nn.Linear(3, 3)) + m.update_modules({"layers": [{}, nn.Linear(3, 4)]}) + self.assertEqual(m.layers[1].weight.shape, (4, 3)) + + # Using leaf_modules in the update should always work + class MyModel(nn.Module): + def __init__(self): + super().__init__() + self.stuff = [nn.Linear(2, 2), 0, nn.Linear(2, 2)] + self.more_stuff = {"hi": nn.Linear(2, 2), "bye": 0} + + m = MyModel() + m.update_modules(m.leaf_modules()) + + def test_parameter_deletion(self): + m = nn.Linear(32, 32) + del m.weight + self.assertFalse(hasattr(m, "weight")) + class TestLayers(mlx_tests.MLXTestCase): def test_identity(self): @@ -1860,4 +1927,4 @@ class TestLayers(mlx_tests.MLXTestCase): if __name__ == "__main__": - unittest.main() + mlx_tests.MLXTestRunner() diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 4fcb31f18..bbea9ad8e 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -10,6 +10,47 @@ import mlx_tests import numpy as np +def np_wrap_between(x, a): + """Wraps `x` between `[-a, a]`.""" + two_a = 2 * a + zero = 0 + rem = np.remainder(np.add(x, a), two_a) + if isinstance(rem, np.ndarray): + rem = np.select(rem < zero, np.add(rem, two_a), rem) + else: + rem = np.add(rem, two_a) if rem < zero else rem + return np.subtract(rem, a) + + +def np_logaddexp(x1: np.ndarray, x2: np.ndarray): + amax = np.maximum(x1, x2) + if np.issubdtype(x1.dtype, np.floating): + delta = np.subtract(x1, x2) + if isinstance(delta, np.ndarray): + return np.select( + np.isnan(delta), + np.add(x1, x2), + np.add(amax, np.log1p(np.exp(np.negative(np.abs(delta))))), + ) + else: + return ( + np.add(x1, x2) + if np.isnan(delta) + else np.add(amax, np.log1p(np.exp(np.negative(np.abs(delta))))) + ) + else: + delta = np.subtract(np.add(x1, x2), np.multiply(amax, 2)) + out = np.add(amax, np.log1p(np.exp(delta))) + return np.real(out) + 1j * np_wrap_between(np.imag(out), np.pi) + + +def np_cumlogaddexp(x1: np.ndarray, axis: int = -1): + out = x1.copy() + for i in range(1, out.shape[axis]): + out[i] = np_logaddexp(out[i], out[i - 1]) + return out + + class TestOps(mlx_tests.MLXTestCase): def test_full_ones_zeros(self): x = mx.full(2, 3.0) @@ -853,6 +894,16 @@ class TestOps(mlx_tests.MLXTestCase): self.assertTrue(np.allclose(result, expected)) + # Complex test + + a = mx.array([0, 1, 2, 9.0]) + 1j + b = mx.array([1, 0, 4, 2.5]) + 1j + + result = mx.logaddexp(a, b) + expected = np_logaddexp(np.array(a), np.array(b)) + + self.assertTrue(np.allclose(result, expected)) + a = mx.array([float("nan")]) b = mx.array([0.0]) self.assertTrue(math.isnan(mx.logaddexp(a, b).item())) @@ -977,6 +1028,13 @@ class TestOps(mlx_tests.MLXTestCase): self.assertTrue(np.allclose(result, expected)) + # Complex test + a = mx.array([1, 0.5, 10, 100]) + 1j + result = mx.log1p(a) + expected = np.log1p(a, dtype=np.complex64) + + self.assertTrue(np.allclose(result, expected)) + def test_sigmoid(self): a = mx.array([0.0, 1.0, -1.0, 5.0, -5.0]) result = mx.sigmoid(a) @@ -1197,6 +1255,12 @@ class TestOps(mlx_tests.MLXTestCase): np.put_along_axis(out_np, np.array(indices), np.array(update), axis=-2) self.assertTrue(np.array_equal(out_np, np.array(out_mlx))) + a = mx.array([], mx.float32) + b = mx.put_along_axis(a, a, a, axis=None) + mx.eval(b) + self.assertEqual(b.size, 0) + self.assertEqual(b.shape, a.shape) + def test_split(self): a = mx.array([1, 2, 3]) splits = mx.split(a, 3) @@ -1414,7 +1478,7 @@ class TestOps(mlx_tests.MLXTestCase): r_mlx = mlxop(y) mx.eval(r_mlx) - self.assertTrue(np.allclose(r_np, r_mlx, atol=atol)) + self.assertTrue(np.allclose(r_np, r_mlx, atol=atol, equal_nan=True)) x = np.random.rand(9, 12, 18) xi = np.random.rand(9, 12, 18) @@ -1881,10 +1945,31 @@ class TestOps(mlx_tests.MLXTestCase): c_mlx = mxop(a_mlx, axis=0) self.assertTrue(np.allclose(c_npy, c_mlx, rtol=1e-3, atol=1e-3)) + # Complex tests + + a_npy = np.array([1, 2, 3]).astype(np.float32) + 1j + a_mlx = mx.array(a_npy) + c_npy = np_cumlogaddexp(a_npy, axis=-1) + c_mlx = mxop(a_mlx, axis=-1) + self.assertTrue(np.allclose(c_npy, c_mlx, rtol=1e-3, atol=1e-3)) + def test_scans(self): a_npy = np.random.randn(32, 32, 32).astype(np.float32) a_mlx = mx.array(a_npy) + for op in ["cumsum", "cumprod"]: + npop = getattr(np, op) + mxop = getattr(mx, op) + for axis in (None, 0, 1, 2): + c_npy = npop(a_npy, axis=axis) + c_mlx = mxop(a_mlx, axis=axis) + self.assertTrue(np.allclose(c_npy, c_mlx, rtol=1e-3, atol=1e-3)) + + # Complex test + + a_npy = np.random.randn(32, 32, 32).astype(np.float32) + 0.5j + a_mlx = mx.array(a_npy) + for op in ["cumsum", "cumprod"]: npop = getattr(np, op) mxop = getattr(mx, op) @@ -2501,17 +2586,6 @@ class TestOps(mlx_tests.MLXTestCase): self.assertEqualArray(result, mx.array(expected)) def test_atleast_1d(self): - def compare_nested_lists(x, y): - if isinstance(x, list) and isinstance(y, list): - if len(x) != len(y): - return False - for i in range(len(x)): - if not compare_nested_lists(x[i], y[i]): - return False - return True - else: - return x == y - # Test 1D input arrays = [ [1], @@ -2529,23 +2603,11 @@ class TestOps(mlx_tests.MLXTestCase): for i, array in enumerate(arrays): mx_res = mx.atleast_1d(mx.array(array)) np_res = np.atleast_1d(np.array(array)) - self.assertTrue(compare_nested_lists(mx_res.tolist(), np_res.tolist())) self.assertEqual(mx_res.shape, np_res.shape) self.assertEqual(mx_res.ndim, np_res.ndim) - self.assertTrue(mx.all(mx.equal(mx_res, atleast_arrays[i]))) + self.assertTrue(mx.array_equal(mx_res, atleast_arrays[i])) def test_atleast_2d(self): - def compare_nested_lists(x, y): - if isinstance(x, list) and isinstance(y, list): - if len(x) != len(y): - return False - for i in range(len(x)): - if not compare_nested_lists(x[i], y[i]): - return False - return True - else: - return x == y - # Test 1D input arrays = [ [1], @@ -2563,23 +2625,11 @@ class TestOps(mlx_tests.MLXTestCase): for i, array in enumerate(arrays): mx_res = mx.atleast_2d(mx.array(array)) np_res = np.atleast_2d(np.array(array)) - self.assertTrue(compare_nested_lists(mx_res.tolist(), np_res.tolist())) self.assertEqual(mx_res.shape, np_res.shape) self.assertEqual(mx_res.ndim, np_res.ndim) - self.assertTrue(mx.all(mx.equal(mx_res, atleast_arrays[i]))) + self.assertTrue(mx.array_equal(mx_res, atleast_arrays[i])) def test_atleast_3d(self): - def compare_nested_lists(x, y): - if isinstance(x, list) and isinstance(y, list): - if len(x) != len(y): - return False - for i in range(len(x)): - if not compare_nested_lists(x[i], y[i]): - return False - return True - else: - return x == y - # Test 1D input arrays = [ [1], @@ -2597,10 +2647,9 @@ class TestOps(mlx_tests.MLXTestCase): for i, array in enumerate(arrays): mx_res = mx.atleast_3d(mx.array(array)) np_res = np.atleast_3d(np.array(array)) - self.assertTrue(compare_nested_lists(mx_res.tolist(), np_res.tolist())) self.assertEqual(mx_res.shape, np_res.shape) self.assertEqual(mx_res.ndim, np_res.ndim) - self.assertTrue(mx.all(mx.equal(mx_res, atleast_arrays[i]))) + self.assertTrue(mx.array_equal(mx_res, atleast_arrays[i])) def test_issubdtype(self): self.assertTrue(mx.issubdtype(mx.bfloat16, mx.inexact)) @@ -2751,6 +2800,9 @@ class TestOps(mlx_tests.MLXTestCase): return H def test_hadamard(self): + with self.assertRaises(ValueError): + mx.hadamard_transform(mx.array([])) + h28_str = """ +------++----++-+--+-+--++-- -+-----+++-----+-+--+-+--++- @@ -2789,11 +2841,33 @@ class TestOps(mlx_tests.MLXTestCase): h28 = parse_h_string(h28_str) + x = mx.array(5) + y = mx.hadamard_transform(x) + self.assertEqual(y.item(), 5) + + x = mx.array(5) + y = mx.hadamard_transform(x, scale=0.2) + self.assertEqual(y.item(), 1) + + x = mx.random.normal((8, 8, 1)) + y = mx.hadamard_transform(x) + self.assertTrue(mx.all(y == x).item()) + + # Too slow to compare to numpy so let's compare CPU to GPU + if mx.default_device() == mx.gpu: + rk = mx.random.key(42) + for k in range(14, 17): + for m in [1, 3, 5, 7]: + x = mx.random.normal((4, m * 2**k), key=rk) + y1 = mx.hadamard_transform(x, stream=mx.cpu) + y2 = mx.hadamard_transform(x, stream=mx.gpu) + self.assertLess(mx.abs(y1 - y2).max().item(), 5e-6) + np.random.seed(7) - tests = product([np.float32, np.float16, np.int32], [1, 28], range(1, 15)) + tests = product([np.float32, np.float16, np.int32], [1, 28], range(1, 14)) for dtype, m, k in tests: # skip large m=28 cases because they're very slow in NumPy - if (m > 1 and k > 8) or (dtype != np.float16 and k == 14): + if m > 1 and k > 8: continue with self.subTest(dtype=dtype, m=m, k=k): n = m * 2**k @@ -2882,6 +2956,11 @@ class TestOps(mlx_tests.MLXTestCase): y2 = mx.roll(x, s, a) self.assertTrue(mx.array_equal(y1, y2).item()) + def test_roll_errors(self): + x = mx.array([]) + result = mx.roll(x, [0], [0]) + self.assertTrue(mx.array_equal(result, x)) + def test_real_imag(self): x = mx.random.uniform(shape=(4, 4)) out = mx.real(x) @@ -2934,6 +3013,82 @@ class TestOps(mlx_tests.MLXTestCase): out = a[::-1] self.assertTrue(mx.array_equal(out[-1, :], a[0, :])) + def test_complex_ops(self): + x = mx.array( + [ + 3.0 + 4.0j, + -5.0 + 12.0j, + -8.0 + 0.0j, + 0.0 + 9.0j, + 0.0 + 0.0j, + ] + ) + + ops = ["arccos", "arcsin", "arctan", "square", "sqrt"] + for op in ops: + with self.subTest(op=op): + np_op = getattr(np, op) + mx_op = getattr(mx, op) + self.assertTrue(np.allclose(mx_op(x), np_op(x))) + + x = mx.array( + [ + 3.0 + 4.0j, + -5.0 + 12.0j, + -8.0 + 0.0j, + 0.0 + 9.0j, + 9.0 + 1.0j, + ] + ) + self.assertTrue(np.allclose(mx.rsqrt(x), 1.0 / np.sqrt(x))) + + def test_complex_power(self): + out = mx.power(mx.array(0j), 2) + self.assertEqual(out.item(), 0j) + + out = mx.power(mx.array(0j), float("nan")) + self.assertTrue(mx.isnan(out)) + + +class TestBroadcast(mlx_tests.MLXTestCase): + def test_broadcast_shapes(self): + # Basic broadcasting + self.assertEqual(mx.broadcast_shapes((1, 2, 3), (3,)), (1, 2, 3)) + self.assertEqual(mx.broadcast_shapes((4, 1, 6), (5, 6)), (4, 5, 6)) + self.assertEqual(mx.broadcast_shapes((5, 1, 4), (1, 3, 4)), (5, 3, 4)) + + # Multiple arguments + self.assertEqual(mx.broadcast_shapes((1, 1), (1, 8), (7, 1)), (7, 8)) + self.assertEqual( + mx.broadcast_shapes((6, 1, 5), (1, 7, 1), (6, 7, 5)), (6, 7, 5) + ) + + # Same shapes + self.assertEqual(mx.broadcast_shapes((3, 4, 5), (3, 4, 5)), (3, 4, 5)) + + # Single argument + self.assertEqual(mx.broadcast_shapes((2, 3)), (2, 3)) + + # Empty shapes + self.assertEqual(mx.broadcast_shapes((), ()), ()) + self.assertEqual(mx.broadcast_shapes((), (1,)), (1,)) + self.assertEqual(mx.broadcast_shapes((1,), ()), (1,)) + + # Broadcasting with zeroes + self.assertEqual(mx.broadcast_shapes((0,), (0,)), (0,)) + self.assertEqual(mx.broadcast_shapes((1, 0, 5), (3, 1, 5)), (3, 0, 5)) + self.assertEqual(mx.broadcast_shapes((5, 0), (0, 5, 0)), (0, 5, 0)) + + # Error cases + with self.assertRaises(ValueError): + mx.broadcast_shapes((3, 4), (4, 3)) + + with self.assertRaises(ValueError): + mx.broadcast_shapes((2, 3, 4), (2, 5, 4)) + + with self.assertRaises(ValueError): + mx.broadcast_shapes() + if __name__ == "__main__": - unittest.main() + mlx_tests.MLXTestRunner() diff --git a/python/tests/test_optimizers.py b/python/tests/test_optimizers.py index ebfe97d80..8f9e33679 100644 --- a/python/tests/test_optimizers.py +++ b/python/tests/test_optimizers.py @@ -196,6 +196,13 @@ class TestOptimizers(mlx_tests.MLXTestCase): ) ) + # Test for correct gradient type propagation + params = tree_map(lambda x: x.astype(mx.float16), params) + grads = tree_map(lambda x: x.astype(mx.float16), grads) + optim = opt.Adam(1e-2, bias_correction=True) + new_params = optim.apply_gradients(grads, params) + self.assertTrue(tree_equal(lambda p: p.dtype == mx.float16, new_params)) + @unittest.skipIf(not has_torch, "requires Torch") def test_adamw_matches_pytorch(self): mx.random.seed(0) @@ -353,7 +360,7 @@ class TestOptimizers(mlx_tests.MLXTestCase): self.assertTrue(mx.allclose(result["w"], mx.full((5, 5), 3.0))) -class TestSchedulers(unittest.TestCase): +class TestSchedulers(mlx_tests.MLXTestCase): def test_decay_lr(self): for optim_class in optimizers_dict.values(): lr_schedule = opt.step_decay(1e-1, 0.9, 1) @@ -527,4 +534,4 @@ class TestSchedulers(unittest.TestCase): if __name__ == "__main__": - unittest.main() + mlx_tests.MLXTestRunner() diff --git a/python/tests/test_quantized.py b/python/tests/test_quantized.py index eeefcd94f..2c62c6307 100644 --- a/python/tests/test_quantized.py +++ b/python/tests/test_quantized.py @@ -11,7 +11,7 @@ class TestQuantized(mlx_tests.MLXTestCase): def test_quantize_dequantize(self): w = mx.random.normal(shape=(128, 512)) for gs in [32, 64, 128]: - for b in [2, 3, 6, 4, 8]: + for b in [2, 3, 5, 6, 4, 8]: with self.subTest(gs=gs, b=b): w_q, scales, biases = mx.quantize(w, group_size=gs, bits=b) w_hat = mx.dequantize(w_q, scales, biases, gs, b) @@ -22,7 +22,7 @@ class TestQuantized(mlx_tests.MLXTestCase): # test quantize/dequantize 0s a = mx.zeros((256, 512)) for gs in [32, 64, 128]: - for b in [2, 3, 4, 6, 8]: + for b in [2, 3, 4, 5, 6, 8]: w_q, scales, biases = mx.quantize(a, gs, b) a_hat = mx.dequantize(w_q, scales, biases, gs, b) self.assertTrue(mx.all(a_hat == 0)) @@ -146,7 +146,7 @@ class TestQuantized(mlx_tests.MLXTestCase): k1, k2 = mx.random.split(key) tests = product( [128, 64, 32], # group_size - [2, 3, 4, 6, 8], # bits + [2, 3, 4, 5, 6, 8], # bits [256, 512, 67], # M [64, 128], # N [0, 1, 3, 8], # B @@ -173,7 +173,7 @@ class TestQuantized(mlx_tests.MLXTestCase): k1, k2 = mx.random.split(key) tests = product( [128, 64, 32], # group_size - [2, 3, 4, 6, 8], # bits + [2, 3, 4, 5, 6, 8], # bits [32, 128, 256], # M [128, 256, 67], # N [0, 1, 3, 8], # B @@ -549,6 +549,74 @@ class TestQuantized(mlx_tests.MLXTestCase): self.assertTrue(mx.allclose(y1, y3, atol=1e-5)) self.assertTrue(mx.allclose(y1, y4, atol=1e-5)) + def test_gather_qmm_grad(self): + def gather_qmm_ref(x, w, s, b, lhs, rhs, trans, sort): + if lhs is not None: + x = x[lhs] + if rhs is not None: + w = w[rhs] + s = s[rhs] + b = b[rhs] + return mx.quantized_matmul(x, w, s, b, transpose=trans) + + def gather_qmm(x, w, s, b, lhs, rhs, trans, sort): + return mx.gather_qmm( + x, + w, + s, + b, + transpose=trans, + lhs_indices=lhs, + rhs_indices=rhs, + sorted_indices=sort, + ) + + x = mx.random.normal((16, 1, 256)) + w, s, b = mx.quantize(mx.random.normal((4, 256, 256))) + indices = mx.sort(mx.random.randint(0, 4, shape=(16,))) + cotan = mx.random.normal((16, 1, 256)) + + (o1,), (dx1, ds1, db1) = mx.vjp( + lambda x, s, b: gather_qmm_ref(x, w, s, b, None, indices, True, True), + [x, s, b], + [cotan], + ) + (o2,), (dx2, ds2, db2) = mx.vjp( + lambda x, s, b: gather_qmm(x, w, s, b, None, indices, True, True), + [x, s, b], + [cotan], + ) + + self.assertTrue(mx.allclose(o1, o2, atol=1e-4)) + self.assertTrue(mx.allclose(dx1, dx2, atol=1e-4)) + self.assertTrue(mx.allclose(ds1, ds2, atol=1e-3)) + self.assertTrue(mx.allclose(db1, db2, atol=1e-3)) + + def test_vjp_scales_biases(self): + mx.random.seed(0) + x = mx.random.normal(shape=(2, 2, 512)) + w = mx.random.normal(shape=(512, 512)) + wq, s, b = mx.quantize(w, bits=4, group_size=64) + + def mm(sb, x, wq): + return mx.quantized_matmul(x, wq, *sb, bits=4, group_size=64).sum() + + params = (s, b) + dparams = mx.grad(mm)((s, b), x, wq) + + eps = 8e-3 + # numerical grad check with a few indices + indices = [(0, 0), (11, 4), (22, 7)] + for idx in indices: + for p in [0, 1]: + params[p][idx] += eps + out_up = mm(params, x, wq) + params[p][idx] -= 2 * eps + out_down = mm(params, x, wq) + params[p][idx] += eps + num_ds = (out_up - out_down) / (2 * eps) + self.assertAlmostEqual(dparams[p][idx], num_ds, delta=2e-2) + if __name__ == "__main__": - unittest.main() + mlx_tests.MLXTestRunner() diff --git a/python/tests/test_random.py b/python/tests/test_random.py index 9efbfb5f6..551c32993 100644 --- a/python/tests/test_random.py +++ b/python/tests/test_random.py @@ -352,6 +352,41 @@ class TestRandom(mlx_tests.MLXTestCase): x = mx.random.permutation(mx.array([[1]])) self.assertEqual(x.shape, (1, 1)) + def test_complex_normal(self): + sample = mx.random.normal(tuple(), dtype=mx.complex64) + self.assertEqual(sample.shape, tuple()) + self.assertEqual(sample.dtype, mx.complex64) + + sample = mx.random.normal((1, 2, 3, 4), dtype=mx.complex64) + self.assertEqual(sample.shape, (1, 2, 3, 4)) + self.assertEqual(sample.dtype, mx.complex64) + + sample = mx.random.normal((1, 2, 3, 4), dtype=mx.complex64, scale=2.0, loc=3.0) + self.assertEqual(sample.shape, (1, 2, 3, 4)) + self.assertEqual(sample.dtype, mx.complex64) + + sample = mx.random.normal( + (1, 2, 3, 4), dtype=mx.complex64, scale=2.0, loc=3.0 + 1j + ) + self.assertEqual(sample.shape, (1, 2, 3, 4)) + self.assertEqual(sample.dtype, mx.complex64) + + def test_broadcastable_scale_loc(self): + b = mx.random.normal((10, 2)) + sample = mx.random.normal((2, 10, 2), loc=b, scale=b) + mx.eval(sample) + self.assertEqual(sample.shape, (2, 10, 2)) + + with self.assertRaises(ValueError): + b = mx.random.normal((10,)) + sample = mx.random.normal((2, 10, 2), loc=b, scale=b) + + b = mx.random.normal((3, 1, 2)) + sample = mx.random.normal((3, 4, 2), dtype=mx.float16, loc=b, scale=b) + mx.eval(sample) + self.assertEqual(sample.shape, (3, 4, 2)) + self.assertEqual(sample.dtype, mx.float16) + if __name__ == "__main__": - unittest.main() + mlx_tests.MLXTestRunner() diff --git a/python/tests/test_reduce.py b/python/tests/test_reduce.py index 9012216ba..d6ddf353b 100644 --- a/python/tests/test_reduce.py +++ b/python/tests/test_reduce.py @@ -153,6 +153,63 @@ class TestReduce(mlx_tests.MLXTestCase): x = x.transpose(1, 0, 2, 3, 4, 5, 6, 7, 8, 9) check(x, (1, 3, 5, 7, 9)) + def test_nan_propagation(self): + dtypes = [ + "uint8", + "uint16", + "uint32", + "int8", + "int16", + "int32", + "float16", + "float32", + ] + + for dtype in dtypes: + with self.subTest(dtype=dtype): + x = (mx.random.normal((4, 4)) * 10).astype(getattr(mx, dtype)) + indices = mx.random.randint(0, 4, shape=(6,)).reshape(3, 2) + for idx in indices: + x[idx[0], idx[1]] = mx.nan + x_np = np.array(x) + + for op in ["max", "min"]: + for axis in [0, 1]: + out = getattr(mx, op)(x, axis=axis) + ref = getattr(np, op)(x_np, axis=axis) + self.assertTrue(np.array_equal(out, ref, equal_nan=True)) + + def test_nan_propagation_complex64(self): + complex_array_1 = mx.array( + [1 + 1j, 2 + 2j, 3 + 3j, mx.nan + 4j], dtype=mx.complex64 + ).reshape(2, 2) + complex_array_2 = mx.array( + [1 + 1j, 2 + 2j, 3 + mx.nan * 1j, 4 + 4j], dtype=mx.complex64 + ).reshape(2, 2) + complex_array_3 = mx.array( + [1 + 1j, 2 + mx.nan * 1j, 3 + 3j, 4 + 4j], dtype=mx.complex64 + ).reshape(2, 2) + complex_array_4 = mx.array( + [mx.nan + 1j, 2 + 2j, 3 + 3j, 4 + 4j], dtype=mx.complex64 + ).reshape(2, 2) + + np_arrays = [ + np.array(complex_array_1), + np.array(complex_array_2), + np.array(complex_array_3), + np.array(complex_array_4), + ] + + for mx_arr, np_arr in zip( + [complex_array_1, complex_array_2, complex_array_3, complex_array_4], + np_arrays, + ): + for axis in [0, 1]: + for op in ["max", "min"]: + out = getattr(mx, op)(mx_arr, axis=axis) + ref = getattr(np, op)(np_arr, axis=axis) + self.assertTrue(np.array_equal(out, ref, equal_nan=True)) + if __name__ == "__main__": - unittest.main(failfast=True) + mlx_tests.MLXTestRunner(failfast=True) diff --git a/python/tests/test_tree.py b/python/tests/test_tree.py index 63018fdae..bacf6e71d 100644 --- a/python/tests/test_tree.py +++ b/python/tests/test_tree.py @@ -48,4 +48,4 @@ class TestTreeUtils(mlx_tests.MLXTestCase): if __name__ == "__main__": - unittest.main() + mlx_tests.MLXTestRunner() diff --git a/python/tests/test_upsample.py b/python/tests/test_upsample.py index 402c7b0ca..631853cce 100644 --- a/python/tests/test_upsample.py +++ b/python/tests/test_upsample.py @@ -51,6 +51,7 @@ class TestUpsample(mlx_tests.MLXTestCase): align_corners=align_corner, )(in_mx) mode_pt = { + "nearest": "nearest", "linear": "bilinear", "cubic": "bicubic", }[mode] @@ -58,7 +59,7 @@ class TestUpsample(mlx_tests.MLXTestCase): in_pt, scale_factor=scale_factor, mode=mode_pt, - align_corners=align_corner, + align_corners=align_corner if mode != "nearest" else None, ) out_pt = torch.permute(out_pt, (0, 2, 3, 1)).numpy(force=True) self.assertEqual(out_pt.shape, out_mx.shape) @@ -76,14 +77,14 @@ class TestUpsample(mlx_tests.MLXTestCase): ((4, 4), (0.5, 0.5)), ((7, 7), (2.0, 2.0)), ((10, 10), (0.2, 0.2)), + ((10, 10), (0.3, 0.3)), ((11, 21), (3.0, 3.0)), ((11, 21), (3.0, 2.0)), ): - # only test linear and cubic interpolation - # there will be numerical difference in nearest - # due to different indices selection. - for mode in ("cubic", "linear"): + for mode in ("cubic", "linear", "nearest"): for align_corner in (False, True): + if mode == "nearest" and align_corner: + continue run_upsample( N, C, @@ -96,4 +97,4 @@ class TestUpsample(mlx_tests.MLXTestCase): if __name__ == "__main__": - unittest.main() + mlx_tests.MLXTestRunner() diff --git a/python/tests/test_vmap.py b/python/tests/test_vmap.py index 1a1ba23b3..a88e59585 100644 --- a/python/tests/test_vmap.py +++ b/python/tests/test_vmap.py @@ -634,6 +634,8 @@ class TestVmap(mlx_tests.MLXTestCase): self.assertEqual(fy.shape, (4, 5, 6, 7)) def test_leaks(self): + gc.collect() + mx.synchronize() if mx.metal.is_available(): mem_pre = mx.get_active_memory() else: @@ -652,6 +654,7 @@ class TestVmap(mlx_tests.MLXTestCase): outer() gc.collect() + mx.synchronize() if mx.metal.is_available(): mem_post = mx.get_active_memory() else: @@ -669,6 +672,57 @@ class TestVmap(mlx_tests.MLXTestCase): self.assertEqual(mx.vmap(fun, in_axes=(1,))(x).shape, (3, 8)) self.assertEqual(mx.vmap(fun, in_axes=(2,))(x).shape, (4, 6)) + def test_vmap_conv(self): + # vmap input only + x = mx.random.uniform(shape=(2, 2, 5, 4)) + w = mx.random.uniform(shape=(8, 3, 4)) + + expected = mx.stack([mx.conv1d(xi, w) for xi in x]) + out = mx.vmap(mx.conv1d, in_axes=(0, None))(x, w) + self.assertTrue(mx.allclose(expected, out)) + + x = mx.moveaxis(x, 0, 2) + out = mx.vmap(mx.conv1d, in_axes=(2, None))(x, w) + self.assertTrue(mx.allclose(expected, out)) + + # vmap weights only + x = mx.random.uniform(shape=(2, 5, 4)) + w = mx.random.uniform(shape=(3, 8, 3, 4)) + + expected = mx.stack([mx.conv1d(x, wi) for wi in w]) + out = mx.vmap(mx.conv1d, in_axes=(None, 0))(x, w) + self.assertTrue(mx.allclose(expected, out)) + + w = mx.moveaxis(w, 0, 1) + out = mx.vmap(mx.conv1d, in_axes=(None, 1))(x, w) + self.assertTrue(mx.allclose(expected, out)) + + # vmap weights and input + x = mx.random.uniform(shape=(3, 2, 5, 4)) + w = mx.random.uniform(shape=(3, 8, 3, 4)) + + expected = mx.stack([mx.conv1d(xi, wi) for xi, wi in zip(x, w)]) + out = mx.vmap(mx.conv1d, in_axes=(0, 0))(x, w) + self.assertTrue(mx.allclose(expected, out)) + + x = mx.random.uniform(shape=(2, 3, 5, 4)) + w = mx.random.uniform(shape=(8, 3, 4, 3)) + + expected = mx.stack([mx.conv1d(x[:, i], w[..., i]) for i in range(3)]) + out = mx.vmap(mx.conv1d, in_axes=(1, 3))(x, w) + self.assertTrue(mx.allclose(expected, out)) + + # Test with groups + x = mx.random.uniform(shape=(3, 2, 5, 8)) + w = mx.random.uniform(shape=(3, 2, 3, 4)) + + def gconv(x, w): + return mx.conv1d(x, w, groups=2) + + expected = mx.stack([gconv(xi, wi) for xi, wi in zip(x, w)]) + out = mx.vmap(gconv, in_axes=(0, 0))(x, w) + self.assertTrue(mx.allclose(expected, out)) + if __name__ == "__main__": - unittest.main() + mlx_tests.MLXTestRunner() diff --git a/setup.py b/setup.py index d742e6595..6cc4015c3 100644 --- a/setup.py +++ b/setup.py @@ -5,10 +5,12 @@ import os import platform import re import subprocess +from functools import partial from pathlib import Path from subprocess import run -from setuptools import Command, Extension, find_namespace_packages, setup +from setuptools import Command, Extension, setup +from setuptools.command.bdist_wheel import bdist_wheel from setuptools.command.build_ext import build_ext @@ -41,6 +43,9 @@ def get_version(): return version +build_stage = int(os.environ.get("MLX_BUILD_STAGE", 0)) + + # A CMakeExtension needs a sourcedir instead of a file list. # The name must be the _single_ output extension from the CMake build. # If you need multiple extensions, see scikit-build. @@ -59,13 +64,22 @@ class CMakeBuild(build_ext): debug = int(os.environ.get("DEBUG", 0)) if self.debug is None else self.debug cfg = "Debug" if debug else "Release" - # Set Python_EXECUTABLE instead if you use PYBIND11_FINDPYTHON - # EXAMPLE_VERSION_INFO shows you how to pass a value into the C++ code - # from Python. + build_temp = Path(self.build_temp) / ext.name + if not build_temp.exists(): + build_temp.mkdir(parents=True) + + build_python = "ON" + install_prefix = f"{extdir}{os.sep}" + if build_stage == 1: + # Don't include MLX libraries in the wheel + install_prefix = f"{build_temp}" + elif build_stage == 2: + # Don't include Python bindings in the wheel + build_python = "OFF" cmake_args = [ - f"-DCMAKE_INSTALL_PREFIX={extdir}{os.sep}", + f"-DCMAKE_INSTALL_PREFIX={install_prefix}", f"-DCMAKE_BUILD_TYPE={cfg}", - "-DMLX_BUILD_PYTHON_BINDINGS=ON", + f"-DMLX_BUILD_PYTHON_BINDINGS={build_python}", "-DMLX_BUILD_TESTS=OFF", "-DMLX_BUILD_BENCHMARKS=OFF", "-DMLX_BUILD_EXAMPLES=OFF", @@ -97,15 +111,7 @@ class CMakeBuild(build_ext): # Set CMAKE_BUILD_PARALLEL_LEVEL to control the parallel build level # across all generators. if "CMAKE_BUILD_PARALLEL_LEVEL" not in os.environ: - # self.parallel is a Python 3 only way to set parallel jobs by hand - # using -j in the build_ext call, not supported by pip or PyPA-build. - if hasattr(self, "parallel") and self.parallel: - # CMake 3.12+ only. - build_args += [f"-j{self.parallel}"] - - build_temp = Path(self.build_temp) / ext.name - if not build_temp.exists(): - build_temp.mkdir(parents=True) + build_args += [f"-j{os.cpu_count()}"] subprocess.run( ["cmake", ext.sourcedir, *cmake_args], cwd=build_temp, check=True @@ -162,50 +168,118 @@ class GenerateStubs(Command): subprocess.run(stub_cmd + ["-o", f"{out_path}/__init__.pyi"]) +class MLXBdistWheel(bdist_wheel): + def get_tag(self) -> tuple[str, str, str]: + impl, abi, plat_name = super().get_tag() + if build_stage == 2: + impl = self.python_tag + abi = "none" + return (impl, abi, plat_name) + + # Read the content of README.md with open(Path(__file__).parent / "README.md", encoding="utf-8") as f: long_description = f.read() -# The information here can also be placed in setup.cfg - better separation of -# logic and declaration, and simpler if you include description/version in a file. -if __name__ == "__main__": - packages = find_namespace_packages( - where="python", exclude=["src", "tests", "tests.*"] - ) - package_dir = {"": "python"} - package_data = {"mlx": ["lib/*", "include/*", "share/*"], "mlx.core": ["*.pyi"]} - setup( - name="mlx", - version=get_version(), +if __name__ == "__main__": + package_dir = {"": "python"} + packages = [ + "mlx", + "mlx.nn", + "mlx.nn.layers", + "mlx.optimizers", + ] + + build_macos = platform.system() == "Darwin" + build_cuda = "MLX_BUILD_CUDA=ON" in os.environ.get("CMAKE_ARGS", "") + + install_requires = [] + if build_cuda: + install_requires = ["nvidia-cublas-cu12", "nvidia-cuda-nvrtc-cu12"] + version = get_version() + + _setup = partial( + setup, + version=version, author="MLX Contributors", author_email="mlx@group.apple.com", description="A framework for machine learning on Apple silicon.", long_description=long_description, long_description_content_type="text/markdown", + license="MIT", url="https://github.com/ml-explore/mlx", - packages=packages, - package_dir=package_dir, - package_data=package_data, include_package_data=True, - extras_require={ - "dev": [ - "nanobind==2.4.0", - "numpy", - "pre-commit", - "setuptools>=42", - "torch", - "typing_extensions", - ], - }, - entry_points={ - "console_scripts": [ - "mlx.launch = mlx.distributed_run:main", - "mlx.distributed_config = mlx.distributed_run:distributed_config", - ] - }, - ext_modules=[CMakeExtension("mlx.core")], - cmdclass={"build_ext": CMakeBuild, "generate_stubs": GenerateStubs}, + package_dir=package_dir, zip_safe=False, python_requires=">=3.9", + ext_modules=[CMakeExtension("mlx.core")], + cmdclass={ + "build_ext": CMakeBuild, + "generate_stubs": GenerateStubs, + "bdist_wheel": MLXBdistWheel, + }, ) + + package_data = {"mlx": ["lib/*", "include/*", "share/*"], "mlx.core": ["*.pyi"]} + + extras = { + "dev": [ + "nanobind==2.4.0", + "numpy", + "pre-commit", + "setuptools>=80", + "torch", + "typing_extensions", + ], + } + entry_points = { + "console_scripts": [ + "mlx.launch = mlx.distributed_run:main", + "mlx.distributed_config = mlx.distributed_run:distributed_config", + ] + } + + # Release builds for PyPi are in two stages. + # Each stage should be run from a clean build: + # python setup.py clean --all + # + # Stage 1: + # - Triggered with `MLX_BUILD_STAGE=1` + # - Include everything except backend-specific binaries (e.g. libmlx.so, mlx.metallib, etc) + # - Wheel has Python ABI and platform tags + # - Wheel should be built for the cross-product of python version and platforms + # - Package name is mlx and it depends on subpackage in stage 2 (e.g. mlx-metal) + # Stage 2: + # - Triggered with `MLX_BUILD_STAGE=2` + # - Includes only backend-specific binaries (e.g. libmlx.so, mlx.metallib, etc) + # - Wheel has only platform tags + # - Wheel should be built only for different platforms + # - Package name is back-end specific, e.g mlx-metal + if build_stage != 2: + if build_stage == 1: + if build_macos: + install_requires += [f"mlx-metal=={version}"] + else: + extras["cuda"] = [f"mlx-cuda=={version}"] + extras["cpu"] = [f"mlx-cpu=={version}"] + + _setup( + name="mlx", + packages=packages, + extras_require=extras, + entry_points=entry_points, + install_requires=install_requires, + package_data=package_data, + ) + else: + if build_macos: + name = "mlx-metal" + elif build_cuda: + name = "mlx-cuda" + else: + name = "mlx-cpu" + _setup( + name=name, + packages=["mlx"], + ) diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index be4479e70..cb174865d 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -9,8 +9,8 @@ FetchContent_MakeAvailable(doctest) add_executable(tests ${PROJECT_SOURCE_DIR}/tests/tests.cpp) -if(MLX_BUILD_METAL) - set(METAL_TEST_SOURCES metal_tests.cpp) +if(MLX_BUILD_METAL OR MLX_BUILD_CUDA) + set(METAL_TEST_SOURCES gpu_tests.cpp) endif() include(${doctest_SOURCE_DIR}/scripts/cmake/doctest.cmake) diff --git a/tests/autograd_tests.cpp b/tests/autograd_tests.cpp index c992c3c6d..5b3454bfc 100644 --- a/tests/autograd_tests.cpp +++ b/tests/autograd_tests.cpp @@ -1133,26 +1133,48 @@ TEST_CASE("test complex gradients") { } { + auto multiply_fn = + [](const std::vector& inputs) -> std::vector { + return {multiply(inputs[0], inputs[1])}; + }; + // Compute jvp auto x = array(complex64_t{2.0, 4.0}); auto y = array(3.0f); - auto x_tan = array(complex64_t{1.0, 2.0}); auto y_tan = array(2.0f); + auto jvp_out = jvp(multiply_fn, {x, y}, {x_tan, y_tan}).second; + CHECK_EQ(jvp_out[0].item(), complex64_t{7.0, 14.0}); - auto out = jvp([x](array a) { return multiply(a, x); }, y, y_tan).second; - CHECK_EQ(out.item(), complex64_t{4.0, 8.0}); - - out = jvp([y](array a) { return multiply(a, y); }, x, x_tan).second; - CHECK_EQ(out.item(), complex64_t{3.0, 6.0}); - + // Compute vjp auto cotan = array(complex64_t{2.0, 3.0}); - out = vjp([x](array a) { return multiply(a, x); }, y, cotan).second; - CHECK_EQ(out.dtype(), float32); - CHECK_EQ(out.item(), -8.0); + auto vjp_out = vjp(multiply_fn, {x, y}, {cotan}).second; + CHECK_EQ(vjp_out[0].dtype(), complex64); + CHECK_EQ(vjp_out[0].item(), complex64_t{6.0, 9.0}); + CHECK_EQ(vjp_out[1].dtype(), float32); + CHECK_EQ(vjp_out[1].item(), 16); + } - out = vjp([y](array a) { return multiply(a, y); }, x, cotan).second; - CHECK_EQ(out.item(), complex64_t{6.0, 9.0}); + { + auto divide_fn = + [](const std::vector& inputs) -> std::vector { + return {divide(inputs[0], inputs[1])}; + }; + + // Compute jvp + auto x = array(complex64_t{2.0, 3.0}); + auto y = array(complex64_t{1.0, 2.0}); + auto x_tan = array(complex64_t{3.0, 4.0}); + auto y_tan = array(complex64_t{4.0, -2.0}); + auto jvp_out = jvp(divide_fn, {x, y}, {x_tan, y_tan}).second; + CHECK_EQ( + jvp_out[0].item(), doctest::Approx(complex64_t{2.6, 2.8})); + + // Compute vjp + auto cotan = array(complex64_t{2.0, -4.0}); + auto vjp_out = vjp(divide_fn, {x, y}, {cotan}).second; + CHECK_EQ(vjp_out[0].item(), complex64_t{2.0, 0.0}); + CHECK_EQ(vjp_out[1].item(), complex64_t{-3.2, -0.4}); } } diff --git a/tests/compile_tests.cpp b/tests/compile_tests.cpp index 66511682d..96552ef9d 100644 --- a/tests/compile_tests.cpp +++ b/tests/compile_tests.cpp @@ -795,3 +795,12 @@ TEST_CASE("test compile lambda") { out = cfun2({array(0)}); CHECK_EQ(out[0].item(), 3); } + +TEST_CASE("test compile with no-ops") { + auto fun = [](const std::vector& inputs) { + return std::vector{abs(stop_gradient(abs(inputs[0])))}; + }; + auto in = array(1.0); + auto out = compile(fun)({in})[0]; + CHECK_EQ(out.inputs()[0].id(), in.id()); +} diff --git a/tests/fft_tests.cpp b/tests/fft_tests.cpp index c04dda1d5..b9e2d1bcc 100644 --- a/tests/fft_tests.cpp +++ b/tests/fft_tests.cpp @@ -243,7 +243,7 @@ TEST_CASE("test fft grads") { auto fft_fn = [](array x) { return fft::fft(x); }; auto cotangent = astype(arange(10), complex64); auto vjp_out = vjp(fft_fn, zeros_like(cotangent), cotangent).second; - CHECK(array_equal(fft::fft(cotangent), vjp_out).item()); + CHECK(array_equal(fft::ifft(cotangent) * 10, vjp_out).item()); auto tangent = astype(arange(10), complex64); auto jvp_out = jvp(fft_fn, zeros_like(tangent), tangent).second; @@ -252,7 +252,7 @@ TEST_CASE("test fft grads") { // Inverse auto ifft_fn = [](array x) { return fft::ifft(x); }; vjp_out = vjp(ifft_fn, zeros_like(cotangent), cotangent).second; - CHECK(array_equal(fft::ifft(cotangent), vjp_out).item()); + CHECK(array_equal(fft::fft(cotangent) * 0.1, vjp_out).item()); jvp_out = jvp(ifft_fn, zeros_like(tangent), tangent).second; CHECK(array_equal(fft::ifft(tangent), jvp_out).item()); @@ -261,7 +261,8 @@ TEST_CASE("test fft grads") { auto rfft_fn = [](array x) { return fft::rfft(x); }; cotangent = astype(arange(6), complex64); vjp_out = vjp(rfft_fn, zeros({10}), cotangent).second; - auto expected = astype(fft::fft(cotangent, 10, 0), float32); + array mask({1.0, 0.5, 0.5, 0.5, 0.5, 1.0}, complex64); + auto expected = fft::irfft(cotangent * mask, 10, 0) * 10; CHECK(array_equal(expected, vjp_out).item()); tangent = astype(arange(10), float32); @@ -272,12 +273,9 @@ TEST_CASE("test fft grads") { auto irfft_fn = [](array x) { return fft::irfft(x); }; cotangent = astype(arange(10), float32); vjp_out = vjp(irfft_fn, astype(zeros({6}), complex64), cotangent).second; - expected = fft::fft(cotangent, 10, 0); - auto o_splits = split(vjp_out, {1, 5}); - auto e_splits = split(expected, {1, 5, 6}); - CHECK_EQ(e_splits[0].item(), o_splits[0].item()); - CHECK(array_equal(2 * e_splits[1], o_splits[1]).item()); - CHECK_EQ(e_splits[2].item(), o_splits[2].item()); + mask = array({0.1, 0.2, 0.2, 0.2, 0.2, 0.1}, float32); + expected = fft::rfft(cotangent) * mask; + CHECK(array_equal(expected, vjp_out).item()); tangent = astype(arange(10), complex64); jvp_out = jvp(irfft_fn, zeros_like(tangent), tangent).second; @@ -308,3 +306,61 @@ TEST_CASE("test fft grads") { .second; CHECK_EQ(vjp_out.shape(), Shape{5, 5}); } + +TEST_CASE("test fftshift and ifftshift") { + // Test 1D array with even length + auto x = arange(8); + auto y = fft::fftshift(x); + CHECK_EQ(y.shape(), x.shape()); + // print y + CHECK(array_equal(y, array({4, 5, 6, 7, 0, 1, 2, 3})).item()); + + // Test 1D array with odd length + x = arange(7); + y = fft::fftshift(x); + CHECK_EQ(y.shape(), x.shape()); + CHECK(array_equal(y, array({4, 5, 6, 0, 1, 2, 3})).item()); + + // Test 2D array + x = reshape(arange(16), {4, 4}); + y = fft::fftshift(x); + auto expected = + array({10, 11, 8, 9, 14, 15, 12, 13, 2, 3, 0, 1, 6, 7, 4, 5}, {4, 4}); + CHECK(array_equal(y, expected).item()); + + // Test with specific axes + y = fft::fftshift(x, {0}); + expected = + array({8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7}, {4, 4}); + CHECK(array_equal(y, expected).item()); + + y = fft::fftshift(x, {1}); + expected = + array({2, 3, 0, 1, 6, 7, 4, 5, 10, 11, 8, 9, 14, 15, 12, 13}, {4, 4}); + CHECK(array_equal(y, expected).item()); + + // Test ifftshift (inverse operation) + x = arange(8); + y = fft::ifftshift(x); + CHECK_EQ(y.shape(), x.shape()); + CHECK(array_equal(y, array({4, 5, 6, 7, 0, 1, 2, 3})).item()); + + // Test ifftshift with odd length (different from fftshift) + x = arange(7); + y = fft::ifftshift(x); + CHECK_EQ(y.shape(), x.shape()); + CHECK(array_equal(y, array({3, 4, 5, 6, 0, 1, 2})).item()); + + // Test 2D ifftshift + x = reshape(arange(16), {4, 4}); + y = fft::ifftshift(x); + expected = + array({10, 11, 8, 9, 14, 15, 12, 13, 2, 3, 0, 1, 6, 7, 4, 5}, {4, 4}); + CHECK(array_equal(y, expected).item()); + + // Test error cases + CHECK_THROWS_AS(fft::fftshift(x, {3}), std::invalid_argument); + CHECK_THROWS_AS(fft::fftshift(x, {-5}), std::invalid_argument); + CHECK_THROWS_AS(fft::ifftshift(x, {3}), std::invalid_argument); + CHECK_THROWS_AS(fft::ifftshift(x, {-5}), std::invalid_argument); +} diff --git a/tests/metal_tests.cpp b/tests/gpu_tests.cpp similarity index 95% rename from tests/metal_tests.cpp rename to tests/gpu_tests.cpp index 7aabdf36d..f0ef969cf 100644 --- a/tests/metal_tests.cpp +++ b/tests/gpu_tests.cpp @@ -1,11 +1,8 @@ // Copyright © 2023-2024 Apple Inc. #include -#include "doctest/doctest.h" -#include "mlx/backend/metal/allocator.h" -#include "mlx/backend/metal/device.h" -#include "mlx/backend/metal/metal.h" +#include "doctest/doctest.h" #include "mlx/mlx.h" using namespace mlx::core; @@ -13,13 +10,7 @@ using namespace mlx::core; static const std::array types = {bool_, uint32, int32, int64, float32}; -TEST_CASE("test metal device") { - // Make sure the device and library can load - CHECK(metal::is_available()); - auto& device = metal::device(Device::gpu); -} - -TEST_CASE("test metal arange") { +TEST_CASE("test gpu arange") { for (auto t : types) { if (t == bool_) { continue; @@ -34,7 +25,7 @@ TEST_CASE("test metal arange") { } } -TEST_CASE("test metal full") { +TEST_CASE("test gpu full") { for (auto t : types) { auto out_cpu = full({4, 4}, 2, t, Device::cpu); auto out_gpu = full({4, 4}, 2, t, Device::gpu); @@ -63,7 +54,7 @@ TEST_CASE("test metal full") { } } -TEST_CASE("test metal astype") { +TEST_CASE("test gpu astype") { array x = array({-4, -3, -2, -1, 0, 1, 2, 3}); // Check all types work for (auto t : types) { @@ -80,7 +71,7 @@ TEST_CASE("test metal astype") { } } -TEST_CASE("test metal reshape") { +TEST_CASE("test gpu reshape") { array x = array({0, 1, 2, 3, 4, 5, 6, 7}); auto out_cpu = reshape(x, {2, 2, 2}); auto out_gpu = reshape(x, {2, 2, 2}, Device::gpu); @@ -96,7 +87,7 @@ TEST_CASE("test metal reshape") { CHECK(array_equal(out_gpu, out_cpu, Device::cpu).item()); } -TEST_CASE("test metal reduce") { +TEST_CASE("test gpu reduce") { { array a(true); CHECK_EQ(all(a, Device::gpu).item(), true); @@ -190,7 +181,7 @@ TEST_CASE("test metal reduce") { } } -TEST_CASE("test metal binary ops") { +TEST_CASE("test gpu binary ops") { // scalar-scalar { array a(2.0f); @@ -338,7 +329,7 @@ TEST_CASE("test metal binary ops") { } } -TEST_CASE("test metal unary ops") { +TEST_CASE("test gpu unary ops") { // contiguous { array x({-1.0f, 0.0f, 1.0f}); @@ -392,7 +383,7 @@ TEST_CASE("test metal unary ops") { } } -TEST_CASE("test metal random") { +TEST_CASE("test gpu random") { { auto key = random::key(0); auto x = random::bits({}, 4, key, Device::gpu); @@ -415,7 +406,7 @@ TEST_CASE("test metal random") { } } -TEST_CASE("test metal matmul") { +TEST_CASE("test gpu matmul") { { auto a = ones({2, 2}); auto b = ones({2, 2}); @@ -440,7 +431,7 @@ TEST_CASE("test metal matmul") { } } -TEST_CASE("test metal validation") { +TEST_CASE("test gpu validation") { // Run this test with Metal validation enabled // METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 ./tests/tests \ // -tc="test metal validation" \ diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index de0f3352c..969bc2ba7 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -1024,6 +1024,10 @@ TEST_CASE("test reduction ops") { x = array({true, true, true, false, true, false}, {2, 3}); CHECK(array_equal(min(x, 1), array({true, false})).item()); CHECK(array_equal(min(x, 0), array({false, true, false})).item()); + + x = array({1.0f, NAN, 3.0f, 4.0f, 5.0f, 6.0f}, {2, 3}); + CHECK(array_equal(max(x, 0), array({4.0f, NAN, 6.0f}), true).item()); + CHECK(array_equal(max(x, 1), array({NAN, 6.0f}), true).item()); } // Test logsumexp @@ -1036,6 +1040,9 @@ TEST_CASE("test reduction ops") { x = array({-inf, -inf}); CHECK_EQ(logsumexp(x).item(), -inf); + x = repeat(array(-inf), 5000); + CHECK_EQ(logsumexp(x).item(), -inf); + x = array({0.0f, -inf}); CHECK_EQ(logsumexp(x).item(), 0.0f); @@ -1343,6 +1350,11 @@ TEST_CASE("test arithmetic unary ops") { x = split(array({0.0f, 1.0f, 2.0f, 3.0f}, {2, 2}), 2, 1)[0]; auto expected = array({std::exp(0.0f), std::exp(2.0f)}, {2, 1}); CHECK(allclose(exp(x), expected).item()); + + // Complex of -inf + constexpr float inf = std::numeric_limits::infinity(); + x = array(complex64_t{-inf, -inf}); + CHECK_EQ(exp(x).item(), complex64_t{0, 0}); } // Test expm1 @@ -1823,6 +1835,10 @@ TEST_CASE("test arithmetic binary ops") { x = array(-inf); y = array(inf); CHECK_EQ(logaddexp(x, y).item(), inf); + + x = array(complex64_t{1, 1}); + y = array(complex64_t{-inf, -inf}); + CHECK_EQ(logaddexp(x, y).item(), complex64_t{1, 1}); } TEST_CASE("test broadcast") { @@ -3859,6 +3875,9 @@ TEST_CASE("test roll") { y = roll(x, {1, 2}, {0, 1}); CHECK(array_equal(y, array({8, 9, 5, 6, 7, 3, 4, 0, 1, 2}, {2, 5})) .item()); + + y = roll(array({}), 0, 0); + CHECK(array_equal(y, array({})).item()); } TEST_CASE("test contiguous") { @@ -3911,4 +3930,70 @@ TEST_CASE("test bitwise shift operations") { CHECK_EQ(right_shift_bool_result.dtype(), uint8); CHECK(array_equal(right_shift_bool_result, full({4}, 0, uint8)).item()); -} \ No newline at end of file +} + +TEST_CASE("test conv_transpose1d with output_padding") { + auto in = array({1.0, 2.0, 3.0}, {1, 1, 3}); + auto wt = array({1.0, 1.0, 1.0}, {1, 1, 3}); + int stride = 2; + int padding = 0; + int dilation = 1; + int output_padding = 1; + int groups = 1; + + auto out = conv_transpose1d( + in, wt, stride, padding, dilation, output_padding, groups); + auto expected = array({6.0, 0.0}, {1, 2, 1}); + CHECK(array_equal(out, expected).item()); +} + +TEST_CASE("test conv_transpose2d with output_padding") { + auto in = array({1.0, 2.0, 3.0, 4.0}, {1, 1, 2, 2}); + auto wt = array({1.0, 1.0, 1.0, 1.0}, {2, 1, 1, 2}); + std::pair stride{2, 2}; + std::pair padding{0, 0}; + std::pair output_padding{1, 1}; + std::pair dilation{1, 1}; + int groups = 1; + + auto out = conv_transpose2d( + in, wt, stride, padding, dilation, output_padding, groups); + auto expected = array( + {3.0, + 3.0, + 0.0, + 0.0, + 7.0, + 7.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0}, + {1, 2, 4, 2}); + CHECK(array_equal(out, expected).item()); +} + +TEST_CASE("test conv_transpose3d with output_padding") { + auto in = array({1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0}, {1, 1, 2, 2, 2}); + auto wt = array({1.0, 1.0}, {1, 1, 1, 1, 2}); + std::tuple stride{2, 2, 2}; + std::tuple padding{0, 0, 0}; + std::tuple output_padding{1, 1, 1}; + std::tuple dilation{1, 1, 1}; + int groups = 1; + + auto out = conv_transpose3d( + in, wt, stride, padding, dilation, output_padding, groups); + auto expected = array( + {3.0, 0.0, 7.0, 0.0, 0.0, 0.0, 0.0, 0.0, 11.0, 0.0, 15.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}, + {1, 2, 4, 4, 1}); + CHECK(array_equal(out, expected).item()); +} diff --git a/tests/random_tests.cpp b/tests/random_tests.cpp index 49f1f300b..6ddd37104 100644 --- a/tests/random_tests.cpp +++ b/tests/random_tests.cpp @@ -350,7 +350,7 @@ TEST_CASE("test random uniform") { // Check float16 { auto key = random::key(0); - auto out = random::uniform({100}, float16, key); + auto out = random::uniform({1000}, float16, key); CHECK_EQ(out.dtype(), float16); CHECK(all(less(out, array(1.0f))).item()); CHECK(all(greater_equal(out, array(0.0f))).item()); @@ -360,7 +360,7 @@ TEST_CASE("test random uniform") { { auto key = random::key(0); - auto out = random::uniform({100}, bfloat16, key); + auto out = random::uniform({1000}, bfloat16, key); CHECK_EQ(out.dtype(), bfloat16); CHECK(all(less(out, array(1.0f))).item()); CHECK(all(greater_equal(out, array(0.0f))).item());