diff --git a/.circleci/config.yml b/.circleci/config.yml index 25fb71fb5..250f35faa 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -1,5 +1,8 @@ version: 2.1 +orbs: + apple: ml-explore/pr-approval@0.1.0 + parameters: nightly_build: type: boolean @@ -7,6 +10,9 @@ parameters: weekly_build: type: boolean default: false + test_release: + type: boolean + default: false jobs: linux_build_and_test: @@ -57,17 +63,18 @@ jobs: command: ./build/tests/tests mac_build_and_test: - machine: true - resource_class: ml-explore/m-builder + macos: + xcode: "15.2.0" + resource_class: macos.m1.large.gen1 steps: - checkout - run: name: Install dependencies command: | - eval "$(conda shell.bash hook)" - rm -r $CONDA_PREFIX/envs/runner-env - conda create -y -n runner-env python=3.9 - conda activate runner-env + brew install python@3.9 + python3.9 -m venv env + source env/bin/activate + pip install --upgrade pip pip install --upgrade cmake pip install --upgrade pybind11[global] pip install pybind11-stubgen @@ -78,203 +85,158 @@ jobs: - run: name: Install Python package command: | - eval "$(conda shell.bash hook)" - conda activate runner-env - CMAKE_BUILD_PARALLEL_LEVEL="" python setup.py build_ext --inplace - CMAKE_BUILD_PARALLEL_LEVEL="" python setup.py develop + source env/bin/activate + CMAKE_BUILD_PARALLEL_LEVEL="" pip install -e . -v - run: name: Generate package stubs command: | - eval "$(conda shell.bash hook)" - conda activate runner-env + source env/bin/activate python setup.py generate_stubs - run: name: Run Python tests command: | - eval "$(conda shell.bash hook)" - conda activate runner-env - DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu - DEVICE=gpu python -m xmlrunner discover -v python/tests -o test-results/gpu + source env/bin/activate + LOW_MEMORY=1 DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu + LOW_MEMORY=1 DEVICE=gpu python3.9 -m xmlrunner discover -v python/tests -o test-results/gpu # TODO: Reenable when extension api becomes stable # - run: # name: Build example extension # command: | - # eval "$(conda shell.bash hook)" - # conda activate runner-env - # cd examples/extensions && python -m pip install . + # cd examples/extensions && python3.11 -m pip install . - store_test_results: path: test-results - run: name: Build CPP only command: | + source env/bin/activate mkdir -p build && cd build && cmake .. && make -j - run: name: Run CPP tests - command: METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 ./build/tests/tests + command: | + DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 ./build/tests/tests + DEVICE=cpu ./build/tests/tests build_release: - machine: true - resource_class: ml-explore/m-builder parameters: python_version: type: string default: "3.9" - macos_version: + xcode_version: type: string - default: "14" + default: "15.2.0" + build_env: + type: string + default: "" + macos: + xcode: << parameters.xcode_version >> + resource_class: macos.m1.large.gen1 steps: - checkout - run: name: Install dependencies command: | - eval "$(conda shell.bash hook)" - rm -r $CONDA_PREFIX/envs/runner-env - conda create -y -n runner-env python=<< parameters.python_version >> - conda activate runner-env + brew install python@<< parameters.python_version >> + python<< parameters.python_version >> -m venv env + source env/bin/activate + pip install --upgrade pip pip install --upgrade cmake pip install --upgrade pybind11[global] + pip install --upgrade setuptools pip install pybind11-stubgen pip install numpy pip install twine - # TODO: Update build system to switch away from setup.py develop + pip install build - run: name: Install Python package command: | - eval "$(conda shell.bash hook)" - conda activate runner-env - DEVELOPER_DIR=$(developer_dir_macos_<< parameters.macos_version >>) \ - PYPI_RELEASE=1 \ + source env/bin/activate + DEV_RELEASE=1 \ CMAKE_BUILD_PARALLEL_LEVEL="" \ - python setup.py develop + pip install . -v - run: name: Generate package stubs command: | - eval "$(conda shell.bash hook)" - conda activate runner-env + source env/bin/activate python setup.py generate_stubs - run: - name: Publish Python package + name: Build Python package command: | - eval "$(conda shell.bash hook)" - conda activate runner-env - DEVELOPER_DIR=$(developer_dir_macos_<< parameters.macos_version >>) \ - PYPI_RELEASE=1 \ + source env/bin/activate + << parameters.build_env >> \ CMAKE_BUILD_PARALLEL_LEVEL="" \ - python setup.py bdist_wheel - twine upload dist/* --repository mlx + python -m build -w + - when: + condition: << parameters.build_env >> + steps: + - run: + name: Upload package + command: | + source env/bin/activate + twine upload dist/* - store_artifacts: path: dist/ - build_dev_release: - machine: true - resource_class: ml-explore/m-builder + build_linux_test_release: parameters: python_version: type: string default: "3.9" - macos_version: + extra_env: type: string - default: "14" + default: "DEV_RELEASE=1" + docker: + - image: ubuntu:20.04 steps: - checkout - run: - name: Install dependencies + name: Build wheel command: | - eval "$(conda shell.bash hook)" - rm -r $CONDA_PREFIX/envs/runner-env - conda create -y -n runner-env python=<< parameters.python_version >> - conda activate runner-env + PYTHON=python<< parameters.python_version >> + apt-get update + apt-get upgrade -y + DEBIAN_FRONTEND=noninteractive TZ=Etc/UTC apt-get -y install tzdata + apt-get install -y apt-utils + apt-get install -y software-properties-common + add-apt-repository -y ppa:deadsnakes/ppa + apt-get install -y $PYTHON $PYTHON-dev $PYTHON-full + apt-get install -y libblas-dev liblapack-dev liblapacke-dev + apt-get install -y build-essential git + $PYTHON -m venv env + source env/bin/activate + pip install --upgrade pip pip install --upgrade cmake pip install --upgrade pybind11[global] + pip install --upgrade setuptools pip install pybind11-stubgen pip install numpy - pip install twine - - run: - name: Install Python package - command: | - eval "$(conda shell.bash hook)" - conda activate runner-env - DEVELOPER_DIR=$(developer_dir_macos_<< parameters.macos_version >>) \ - DEV_RELEASE=1 \ + pip install auditwheel + pip install patchelf + pip install build + << parameters.extra_env >> \ CMAKE_BUILD_PARALLEL_LEVEL="" \ - python setup.py develop - - run: - name: Generate package stubs - command: | - eval "$(conda shell.bash hook)" - conda activate runner-env + pip install . -v python setup.py generate_stubs - - run: - name: Publish Python package - command: | - eval "$(conda shell.bash hook)" - conda activate runner-env - DEVELOPER_DIR=$(developer_dir_macos_<< parameters.macos_version >>) \ - DEV_RELEASE=1 \ + << parameters.extra_env >> \ CMAKE_BUILD_PARALLEL_LEVEL="" \ - python setup.py bdist_wheel - twine upload dist/* --repository mlx + python -m build --wheel + auditwheel show dist/* + auditwheel repair dist/* --plat manylinux_2_31_x86_64 - store_artifacts: - path: dist/ - - build_package: - machine: true - resource_class: ml-explore/m-builder - parameters: - python_version: - type: string - default: "3.9" - macos_version: - type: string - default: "14" - steps: - - checkout - - run: - name: Install dependencies - command: | - eval "$(conda shell.bash hook)" - rm -r $CONDA_PREFIX/envs/runner-env - conda create -y -n runner-env python=<< parameters.python_version >> - conda activate runner-env - pip install --upgrade cmake - pip install --upgrade pybind11[global] - pip install pybind11-stubgen - pip install numpy - pip install twine - - run: - name: Install Python package - command: | - eval "$(conda shell.bash hook)" - conda activate runner-env - DEVELOPER_DIR=$(developer_dir_macos_<< parameters.macos_version >>) \ - CMAKE_BUILD_PARALLEL_LEVEL="" \ - python setup.py develop - - run: - name: Generate package stubs - command: | - eval "$(conda shell.bash hook)" - conda activate runner-env - python setup.py generate_stubs - - run: - name: Build package distribution - command: | - eval "$(conda shell.bash hook)" - conda activate runner-env - DEVELOPER_DIR=$(developer_dir_macos_<< parameters.macos_version >>) \ - CMAKE_BUILD_PARALLEL_LEVEL="" \ - python setup.py bdist_wheel - - store_artifacts: - path: dist/ + path: wheelhouse/ workflows: build_and_test: when: and: + - matches: + pattern: "^(?!pull/)[-\\w]+$" + value: << pipeline.git.branch >> - not: << pipeline.parameters.nightly_build >> - not: << pipeline.parameters.weekly_build >> + - not: << pipeline.parameters.test_release >> jobs: - - linux_build_and_test - mac_build_and_test + - linux_build_and_test - build_release: filters: tags: @@ -284,20 +246,53 @@ workflows: matrix: parameters: python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"] - macos_version: ["13", "14"] + xcode_version: ["14.3.1", "15.2.0"] + build_env: ["PYPI_RELEASE=1"] + prb: + when: + matches: + pattern: "^pull/\\d+(/head)?$" + value: << pipeline.git.branch >> + jobs: + - hold: + type: approval + - apple/authenticate: + context: pr-approval + - mac_build_and_test: + requires: [ hold ] + - linux_build_and_test: + requires: [ hold ] nightly_build: - when: << pipeline.parameters.nightly_build >> + when: + and: + - equal: [ main, << pipeline.git.branch >> ] + - << pipeline.parameters.nightly_build >> jobs: - - build_package: + - build_release: matrix: parameters: python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"] - macos_version: ["13", "14"] + xcode_version: ["14.3.1", "15.2.0"] weekly_build: - when: << pipeline.parameters.weekly_build >> + when: + and: + - equal: [ main, << pipeline.git.branch >> ] + - << pipeline.parameters.weekly_build >> jobs: - - build_dev_release: + - build_release: matrix: parameters: python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"] - macos_version: ["13", "14"] + xcode_version: ["14.3.1", "15.2.0"] + build_env: ["DEV_RELEASE=1"] + linux_test_release: + when: + and: + - equal: [ main, << pipeline.git.branch >> ] + - << pipeline.parameters.test_release >> + jobs: + - build_linux_test_release: + matrix: + parameters: + python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"] + extra_env: ["PYPI_RELEASE=1"] diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d0ebc6d48..dd5ebec30 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -5,7 +5,7 @@ repos: - id: clang-format # Using this mirror lets us use mypyc-compiled black, which is about 2x faster - repo: https://github.com/psf/black-pre-commit-mirror - rev: 23.12.1 + rev: 24.2.0 hooks: - id: black - repo: https://github.com/pycqa/isort diff --git a/ACKNOWLEDGMENTS.md b/ACKNOWLEDGMENTS.md index 12fdd47bc..e15aafd5b 100644 --- a/ACKNOWLEDGMENTS.md +++ b/ACKNOWLEDGMENTS.md @@ -10,8 +10,9 @@ MLX was developed with contributions from the following individuals: - Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. Added `LogicalAnd` and `LogicalOR` ops. - Juarez Bochi: Fixed bug in cross attention. - Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example. -- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile` and safetensor support -- Gabrijel Boduljak: Added `mlx.core.linalg`, implemented `norm` method and `InstanceNorm` layer. +- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile`, `StreamContext`, `stream` and safetensor support. +- Gabrijel Boduljak: Added `mlx.core.linalg`, implemented `norm` method and `InstanceNorm` layer. Implemented ``MaxPool1d``, ``MaxPool2d``, ``AvgPool1d``, ``AvgPool2d``. +- Hinrik Snær Guðmundsson: Added `atleast_1d`, `atleast_2d`, `atleast_3d` ops. - Luca Arnaboldi: Added `Ceil` and `Floor` ops. Implemented pickling, copy and deepcopy for Python arrays. diff --git a/CMakeLists.txt b/CMakeLists.txt index 2d6bef9d2..8b1a4cf52 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -18,7 +18,7 @@ option(MLX_BUILD_METAL "Build metal backend" ON) option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF) if(NOT MLX_VERSION) - set(MLX_VERSION 0.1.0) + set(MLX_VERSION 0.3.0) endif() # --------------------- Processor tests ------------------------- @@ -123,8 +123,8 @@ else() /usr/include /usr/local/include $ENV{BLAS_HOME}/include) - message(STATUS "Blas lib" ${BLAS_LIBRARIES}) - message(STATUS "Blas include" ${BLAS_INCLUDE_DIRS}) + message(STATUS "Blas lib " ${BLAS_LIBRARIES}) + message(STATUS "Blas include " ${BLAS_INCLUDE_DIRS}) target_include_directories(mlx PRIVATE ${BLAS_INCLUDE_DIRS}) target_link_libraries(mlx ${BLAS_LIBRARIES}) find_package(LAPACK REQUIRED) @@ -134,7 +134,7 @@ else() find_path(LAPACK_INCLUDE_DIRS lapacke.h /usr/include /usr/local/include) - message(STATUS "Lapack lib" ${LAPACK_LIBRARIES}) + message(STATUS "Lapack lib " ${LAPACK_LIBRARIES}) message(STATUS "Lapack include " ${LAPACK_INCLUDE_DIRS}) target_include_directories(mlx PRIVATE ${LAPACK_INCLUDE_DIRS}) target_link_libraries(mlx ${LAPACK_LIBRARIES}) diff --git a/MANIFEST.in b/MANIFEST.in index d81234106..9faafee45 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,4 +1,4 @@ include CMakeLists.txt recursive-include mlx/ * include python/src/* -python/mlx/py.typed # support type hinting as in PEP-561 +include python/mlx/py.typed # support type hinting as in PEP-561 diff --git a/README.md b/README.md index a9abd58ad..118cc828e 100644 --- a/README.md +++ b/README.md @@ -6,8 +6,8 @@ [![CircleCI](https://circleci.com/gh/ml-explore/mlx.svg?style=svg)](https://circleci.com/gh/ml-explore/mlx) -MLX is an array framework for machine learning on Apple silicon, brought to you -by Apple machine learning research. +MLX is an array framework for machine learning research on Apple silicon, +brought to you by Apple machine learning research. Some key features of MLX include: diff --git a/benchmarks/python/comparative/compare.py b/benchmarks/python/comparative/compare.py index a9d3df22d..5b71cf583 100644 --- a/benchmarks/python/comparative/compare.py +++ b/benchmarks/python/comparative/compare.py @@ -80,10 +80,8 @@ if __name__ == "__main__": _filter = make_predicate(args.filter, args.negative_filter) if args.mlx_dtypes: - compare_filtered = ( - lambda x: compare_mlx_dtypes( - x.split() + rest, args.mlx_dtypes[0], args.mlx_dtypes[1] - ) + compare_filtered = lambda x: ( + compare_mlx_dtypes(x.split() + rest, args.mlx_dtypes[0], args.mlx_dtypes[1]) if _filter(x) else None ) diff --git a/benchmarks/python/compile_bench.py b/benchmarks/python/compile_bench.py new file mode 100644 index 000000000..0d5d9f61d --- /dev/null +++ b/benchmarks/python/compile_bench.py @@ -0,0 +1,109 @@ +# Copyright © 2023-2024 Apple Inc. + +import argparse +import math +import random + +import mlx.core as mx +from time_utils import time_fn + + +def bench_gelu(): + + def gelu(x): + return x * (1 + mx.erf(x / math.sqrt(2))) / 2 + + x = mx.random.uniform(shape=(1000, 1024)) + + def gen_fun(fun): + def bench_fun(x): + for _ in range(10): + x = fun(x) + return x + + return bench_fun + + time_fn(gen_fun(gelu), x, msg="fixed gelu") + time_fn(gen_fun(mx.compile(gelu)), x, msg="compiled fixed gelu") + + def randint(): + return random.randint(1, x.shape[0]) + + def gen_fun(fun): + def bench_fun(x, y): + x = x[: randint()] + for _ in range(10): + x = fun(x) + y = fun(y) + return x, y + + return bench_fun + + y = mx.random.uniform(shape=(1000, 1024)) + time_fn(gen_fun(gelu), x, y, msg="variable gelu") + time_fn(gen_fun(mx.compile(gelu)), x, y, msg="compiled variable gelu") + time_fn( + gen_fun(mx.compile(gelu, shapeless=True)), + x, + y, + msg="shapeless variable gelu", + ) + + +def bench_layernorm(): + + weight = mx.random.uniform(shape=(4096,)).astype(mx.float16) + bias = mx.random.uniform(shape=(4096,)).astype(mx.float16) + mx.eval(weight, bias) + + def layernorm(x): + x = x.astype(mx.float32) + means = mx.mean(x, axis=-1, keepdims=True) + var = mx.var(x, axis=-1, keepdims=True) + x = (x - means) * mx.rsqrt(var + 1e-4) + x = x.astype(mx.float16) + return weight * x + bias + + x = mx.random.uniform(shape=(1000, 4096)).astype(mx.float16) + + def gen_fun(fun): + def bench_fun(x): + for _ in range(10): + x = fun(x) + return x + + return bench_fun + + time_fn(gen_fun(layernorm), x, msg="fixed layernorm") + time_fn(gen_fun(mx.compile(layernorm)), x, msg="compiled fixed layernorm") + + def randint(): + return random.randint(1, x.shape[0]) + + def gen_fun(fun): + def bench_fun(x): + x = x[: randint()] + for _ in range(10): + x = fun(x) + return x + + return bench_fun + + random.seed(0) + time_fn(gen_fun(layernorm), x, msg="variable layernorm") + random.seed(0) + time_fn(gen_fun(mx.compile(layernorm)), x, msg="compiled variable layernorm") + random.seed(0) + time_fn( + gen_fun(mx.compile(layernorm, shapeless=True)), + x, + msg="shapeless variable layernorm", + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("Compile benchmarks.") + args = parser.parse_args() + + bench_gelu() + bench_layernorm() diff --git a/benchmarks/python/gather_bench.py b/benchmarks/python/gather_bench.py index 8ce6dcb8f..e000841d2 100644 --- a/benchmarks/python/gather_bench.py +++ b/benchmarks/python/gather_bench.py @@ -5,18 +5,7 @@ from time import time import mlx.core as mx import torch - - -def measure_runtime(fn, **kwargs): - # Warmup - for _ in range(5): - fn(**kwargs) - - tic = time() - iters = 10 - for _ in range(iters): - fn(**kwargs) - return (time() - tic) * 1000 / iters +from time_utils import measure_runtime def benchmark_gather_mlx(x_shape, idx_shape): diff --git a/benchmarks/python/rope_bench.py b/benchmarks/python/rope_bench.py new file mode 100644 index 000000000..62f01648e --- /dev/null +++ b/benchmarks/python/rope_bench.py @@ -0,0 +1,35 @@ +# Copyright © 2023-2024 Apple Inc. + +import mlx.core as mx +import mlx.nn as nn +from time_utils import time_fn + + +def time_rope(): + rope = nn.RoPE(4096) + + # vec + x = mx.random.uniform(shape=(1, 4096)).astype(mx.float16) + mx.eval(x) + + def rope_vec(x): + for _ in range(32): + x = rope(x) + return x + + time_fn(rope_vec, x) + + # matrix + x = mx.random.uniform(shape=(1024, 4096)).astype(mx.float16) + mx.eval(x) + + def rope_mat(x): + for _ in range(32): + x = rope(x) + return x + + time_fn(rope_mat, x) + + +if __name__ == "__main__": + time_rope() diff --git a/benchmarks/python/scatter_bench.py b/benchmarks/python/scatter_bench.py new file mode 100644 index 000000000..2d63d8bf1 --- /dev/null +++ b/benchmarks/python/scatter_bench.py @@ -0,0 +1,56 @@ +# Copyright © 2023-2024 Apple Inc. + +import argparse + +import mlx.core as mx +import torch +from time_utils import measure_runtime + + +def benchmark_scatter_mlx(dst_shape, x_shape, idx_shape): + def scatter(dst, x, idx): + dst[idx] = x + mx.eval(dst) + + idx = mx.random.randint(0, dst_shape[0] - 1, idx_shape) + x = mx.random.normal(x_shape).astype(mx.float32) + dst = mx.random.normal(dst_shape).astype(mx.float32) + + runtime = measure_runtime(scatter, dst=dst, x=x, idx=idx) + print(f"MLX: {runtime:.3f}ms") + + +def benchmark_scatter_torch(dst_shape, x_shape, idx_shape, device): + def gather(dst, x, idx, device): + dst[idx] = x + if device == torch.device("mps"): + torch.mps.synchronize() + + idx = torch.randint(0, dst_shape[0] - 1, idx_shape).to(device) + x = torch.randn(x_shape, dtype=torch.float32).to(device) + dst = torch.randn(dst_shape, dtype=torch.float32).to(device) + + runtime = measure_runtime(gather, dst=dst, x=x, idx=idx, device=device) + print(f"PyTorch: {runtime:.3f}ms") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("Gather benchmarks.") + parser.add_argument("--cpu", action="store_true", help="Use the CPU.") + args = parser.parse_args() + + if args.cpu: + mx.set_default_device(mx.cpu) + device = torch.device("cpu") + else: + device = torch.device("mps") + + dst_shapes = [(10, 64), (100_000, 64), (1_000_000, 64)] + idx_shapes = [(1_000_000,), (1_000_000,), (100_000,)] + x_shapes = [(1_000_000, 64), (1_000_000, 64), (100_000, 64)] + + for dst_shape, x_shape, idx_shape in zip(dst_shapes, x_shapes, idx_shapes): + print("=" * 20) + print(f"X {x_shape}, Indices {idx_shape}") + benchmark_scatter_mlx(dst_shape, x_shape, idx_shape) + benchmark_scatter_torch(dst_shape, x_shape, idx_shape, device=device) diff --git a/benchmarks/python/time_utils.py b/benchmarks/python/time_utils.py index 73266cb7a..2903c3293 100644 --- a/benchmarks/python/time_utils.py +++ b/benchmarks/python/time_utils.py @@ -1,4 +1,4 @@ -# Copyright © 2023 Apple Inc. +# Copyright © 2023-2024 Apple Inc. import time @@ -6,7 +6,11 @@ import mlx.core as mx def time_fn(fn, *args, **kwargs): - print(f"Timing {fn.__name__} ...", end=" ") + msg = kwargs.pop("msg", None) + if msg: + print(f"Timing {msg} ...", end=" ") + else: + print(f"Timing {fn.__name__} ...", end=" ") # warmup for _ in range(5): @@ -20,3 +24,15 @@ def time_fn(fn, *args, **kwargs): msec = 1e3 * (toc - tic) / num_iters print(f"{msec:.5f} msec") + + +def measure_runtime(fn, **kwargs): + # Warmup + for _ in range(5): + fn(**kwargs) + + tic = time.time() + iters = 100 + for _ in range(iters): + fn(**kwargs) + return (time.time() - tic) * 1000 / iters diff --git a/docs/.gitignore b/docs/.gitignore index 5c2693cb6..fa80a135e 100644 --- a/docs/.gitignore +++ b/docs/.gitignore @@ -1,2 +1,3 @@ src/python/_autosummary*/ src/python/nn/_autosummary*/ +src/python/optimizers/_autosummary*/ diff --git a/docs/src/_templates/nn-module-template.rst b/docs/src/_templates/nn-module-template.rst deleted file mode 100644 index 49f018eb5..000000000 --- a/docs/src/_templates/nn-module-template.rst +++ /dev/null @@ -1,19 +0,0 @@ -{{ fullname | escape | underline}} - -.. currentmodule:: {{ module }} - -.. autoclass:: {{ objname }} - - {#{% block methods %} - - {% if methods %} - .. rubric:: {{ _('Methods') }} - - .. autosummary:: - {% for item in methods %} - {%- if item not in inherited_members and item != '__init__' %} - ~{{ name }}.{{ item }} - {%- endif %} - {%- endfor %} - {% endif %} - {% endblock %}#} diff --git a/docs/src/conf.py b/docs/src/conf.py index bec2c976c..0654cf53c 100644 --- a/docs/src/conf.py +++ b/docs/src/conf.py @@ -26,6 +26,7 @@ extensions = [ python_use_unqualified_type_names = True autosummary_generate = True +autosummary_filename_map = {"mlx.core.Stream": "stream_class"} intersphinx_mapping = { "https://docs.python.org/3": None, diff --git a/docs/src/dev/extensions.rst b/docs/src/dev/extensions.rst index 3563305bf..5f63c6337 100644 --- a/docs/src/dev/extensions.rst +++ b/docs/src/dev/extensions.rst @@ -35,7 +35,7 @@ However, you work with vector math libraries often and realize that the You would really like the part of your applications that does this operation on the CPU to be very fast - so you decide that you want it to rely on the ``axpby`` routine provided by the Accelerate_ framework. Continuing to impose -our assumptions on to you, let's also assume that you want to learn how add +our assumptions on to you, let's also assume that you want to learn how to add your own implementation for the gradients of your new operation while going over the ins-and-outs of the MLX framework. diff --git a/docs/src/index.rst b/docs/src/index.rst index 4f4411758..50dfe9083 100644 --- a/docs/src/index.rst +++ b/docs/src/index.rst @@ -41,6 +41,7 @@ are the CPU and GPU. usage/indexing usage/saving_and_loading usage/function_transforms + usage/compile usage/numpy usage/using_streams diff --git a/docs/src/python/devices_and_streams.rst b/docs/src/python/devices_and_streams.rst index bb9dfae2f..e16ab9875 100644 --- a/docs/src/python/devices_and_streams.rst +++ b/docs/src/python/devices_and_streams.rst @@ -9,9 +9,10 @@ Devices and Streams :toctree: _autosummary Device + Stream default_device set_default_device - Stream default_stream new_stream set_default_stream + stream diff --git a/docs/src/python/nn/layers.rst b/docs/src/python/nn/layers.rst index ef099ef2f..0f5fca9db 100644 --- a/docs/src/python/nn/layers.rst +++ b/docs/src/python/nn/layers.rst @@ -10,6 +10,8 @@ Layers :template: nn-module-template.rst ALiBi + AvgPool1d + AvgPool2d BatchNorm Conv1d Conv2d @@ -22,6 +24,8 @@ Layers InstanceNorm LayerNorm Linear + MaxPool1d + MaxPool2d Mish MultiHeadAttention PReLU diff --git a/docs/src/python/nn/module.rst b/docs/src/python/nn/module.rst index 042a88028..c3a4dfa62 100644 --- a/docs/src/python/nn/module.rst +++ b/docs/src/python/nn/module.rst @@ -11,6 +11,7 @@ Module :toctree: _autosummary Module.training + Module.state .. rubric:: Methods diff --git a/docs/src/python/ops.rst b/docs/src/python/ops.rst index 09e2d5f71..7ec7defc9 100644 --- a/docs/src/python/ops.rst +++ b/docs/src/python/ops.rst @@ -25,6 +25,9 @@ Operations argpartition argsort array_equal + atleast_1d + atleast_2d + atleast_3d broadcast_to ceil clip diff --git a/docs/src/python/optimizers.rst b/docs/src/python/optimizers.rst index fe8632a7e..f437ddc15 100644 --- a/docs/src/python/optimizers.rst +++ b/docs/src/python/optimizers.rst @@ -29,20 +29,8 @@ model's parameters and the **optimizer state**. # Compute the new parameters but also the optimizer state. mx.eval(model.parameters(), optimizer.state) -.. currentmodule:: mlx.optimizers +.. toctree:: -.. autosummary:: - :toctree: _autosummary - :template: optimizers-template.rst - - OptimizerState - Optimizer - SGD - RMSprop - Adagrad - Adafactor - AdaDelta - Adam - AdamW - Adamax - Lion + optimizers/optimizer + optimizers/common_optimizers + optimizers/schedulers diff --git a/docs/src/python/optimizers/common_optimizers.rst b/docs/src/python/optimizers/common_optimizers.rst new file mode 100644 index 000000000..41b3fba03 --- /dev/null +++ b/docs/src/python/optimizers/common_optimizers.rst @@ -0,0 +1,20 @@ +.. _common_optimizers: + +Common Optimizers +================= + +.. currentmodule:: mlx.optimizers + +.. autosummary:: + :toctree: _autosummary + :template: optimizers-template.rst + + SGD + RMSprop + Adagrad + Adafactor + AdaDelta + Adam + AdamW + Adamax + Lion diff --git a/docs/src/python/optimizers/optimizer.rst b/docs/src/python/optimizers/optimizer.rst new file mode 100644 index 000000000..cf6034dee --- /dev/null +++ b/docs/src/python/optimizers/optimizer.rst @@ -0,0 +1,23 @@ +Optimizer +========= + +.. currentmodule:: mlx.optimizers + +.. autoclass:: Optimizer + + + .. rubric:: Attributes + + .. autosummary:: + :toctree: _autosummary + + Optimizer.state + + .. rubric:: Methods + + .. autosummary:: + :toctree: _autosummary + + Optimizer.apply_gradients + Optimizer.init + Optimizer.update diff --git a/docs/src/python/optimizers/schedulers.rst b/docs/src/python/optimizers/schedulers.rst new file mode 100644 index 000000000..a83883ddb --- /dev/null +++ b/docs/src/python/optimizers/schedulers.rst @@ -0,0 +1,13 @@ +.. _schedulers: + +Schedulers +========== + +.. currentmodule:: mlx.optimizers + +.. autosummary:: + :toctree: _autosummary + + step_decay + exponential_decay + cosine_decay diff --git a/docs/src/python/transforms.rst b/docs/src/python/transforms.rst index cc8d681d5..ad9ba579b 100644 --- a/docs/src/python/transforms.rst +++ b/docs/src/python/transforms.rst @@ -9,6 +9,9 @@ Transforms :toctree: _autosummary eval + compile + disable_compile + enable_compile grad value_and_grad jvp diff --git a/docs/src/usage/compile.rst b/docs/src/usage/compile.rst new file mode 100644 index 000000000..97d5503a3 --- /dev/null +++ b/docs/src/usage/compile.rst @@ -0,0 +1,430 @@ +.. _compile: + +Compilation +=========== + +.. currentmodule:: mlx.core + +MLX has a :func:`compile` function transformation which compiles computation +graphs. Function compilation results in smaller graphs by merging common work +and fusing certain operations. In many cases this can lead to big improvements +in run-time and memory use. + +Getting started with :func:`compile` is simple, but there are some edge cases +that are good to be aware of for more complex graphs and advanced usage. + +Basics of Compile +----------------- + +Let's start with a simple example: + +.. code-block:: python + + def fun(x, y): + return mx.exp(-x) + y + + x = mx.array(1.0) + y = mx.array(2.0) + + # Regular call, no compilation + # Prints: array(2.36788, dtype=float32) + print(fun(x, y)) + + # Compile the function + compiled_fun = mx.compile(fun) + + # Prints: array(2.36788, dtype=float32) + print(compiled_fun(x, y)) + +The output of both the regular function and the compiled function is the same +up to numerical precision. + +The first time you call a compiled function, MLX will build the compute +graph, optimize it, and generate and compile code. This can be relatively +slow. However, MLX will cache compiled functions, so calling a compiled +function multiple times will not initiate a new compilation. This means you +should typically compile functions that you plan to use more than once. + +.. code-block:: python + + def fun(x, y): + return mx.exp(-x) + y + + x = mx.array(1.0) + y = mx.array(2.0) + + compiled_fun = mx.compile(fun) + + # Compiled here + compiled_fun(x, y) + + # Not compiled again + compiled_fun(x, y) + + # Not compiled again + mx.compile(fun)(x, y) + +There are some important cases to be aware of that can cause a function to +be recompiled: + +* Changing the shape or number of dimensions +* Changing the type of any of the inputs +* Changing the number of inputs to the function + +In certain cases only some of the compilation stack will be rerun (for +example when changing the shapes) and in other cases the full compilation +stack will be rerun (for example when changing the types). In general you +should avoid compiling functions too frequently. + +Another idiom to watch out for is compiling functions which get created and +destroyed frequently. This can happen, for example, when compiling an anonymous +function in a loop: + +.. code-block:: python + + a = mx.array(1.0) + # Don't do this, compiles lambda at each iteration + for _ in range(5): + mx.compile(lambda x: mx.exp(mx.abs(x)))(a) + +Example Speedup +--------------- + +The :func:`mlx.nn.gelu` is a nonlinear activation function commonly used with +Transformer-based models. The implementation involves several unary and binary +element-wise operations: + +.. code-block:: python + + def gelu(x): + return x * (1 + mx.erf(x / math.sqrt(2))) / 2 + +If you use this function with small arrays, it will be overhead bound. If you +use it with large arrays it will be memory bandwidth bound. However, all of +the operations in the ``gelu`` are fusible into a single kernel with +:func:`compile`. This can speedup both cases considerably. + +Let's compare the runtime of the regular function versus the compiled +function. We'll use the following timing helper which does a warm up and +handles synchronization: + +.. code-block:: python + + import time + + def timeit(fun, x): + # warm up + for _ in range(10): + mx.eval(fun(x)) + + tic = time.perf_counter() + for _ in range(100): + mx.eval(fun(x)) + toc = time.perf_counter() + tpi = 1e3 * (toc - tic) / 100 + print(f"Time per iteration {tpi:.3f} (ms)") + + +Now make an array, and benchmark both functions: + +.. code-block:: python + + x = mx.random.uniform(shape=(32, 1000, 4096)) + timeit(nn.gelu, x) + timeit(mx.compile(nn.gelu), x) + +On an M1 Max the times are 15.5 and 3.1 milliseconds. The compiled ``gelu`` is +five times faster. + +.. note:: + + As of the latest MLX, CPU functions are not fully compiled. Compiling CPU + functions can still be helpful, but won't typically result in as large a + speedup as compiling operations that run on the GPU. + + +Debugging +--------- + +When a compiled function is first called, it is traced with placeholder +inputs. This means you can't evaluate arrays (for example to print their +contents) inside compiled functions. + +.. code-block:: python + + @mx.compile + def fun(x): + z = -x + print(z) # Crash + return mx.exp(z) + + fun(mx.array(5.0)) + +For debugging, inspecting arrays can be helpful. One way to do that is to +globally disable compilation using the :func:`disable_compile` function or +``MLX_DISABLE_COMPILE`` flag. For example the following is okay even though +``fun`` is compiled: + +.. code-block:: python + + @mx.compile + def fun(x): + z = -x + print(z) # Okay + return mx.exp(z) + + mx.disable_compile() + fun(mx.array(5.0)) + + +Pure Functions +-------------- + +Compiled functions are intended to be *pure*; that is they should not have side +effects. For example: + +.. code-block:: python + + state = [] + + @mx.compile + def fun(x, y): + z = x + y + state.append(z) + return mx.exp(z) + + fun(mx.array(1.0), mx.array(2.0)) + # Crash! + print(state) + +After the first call of ``fun``, the ``state`` list will hold a placeholder +array. The placeholder does not have any data; it is only used to build the +computation graph. Printing such an array results in a crash. + +You have two options to deal with this. The first option is to simply return +``state`` as an output: + +.. code-block:: python + + state = [] + + @mx.compile + def fun(x, y): + z = x + y + state.append(z) + return mx.exp(z), state + + _, state = fun(mx.array(1.0), mx.array(2.0)) + # Prints [array(3, dtype=float32)] + print(state) + +In some cases returning updated state can be pretty inconvenient. Hence, +:func:`compile` has a parameter to capture implicit outputs: + +.. code-block:: python + + from functools import partial + + state = [] + + # Tell compile to capture state as an output + @partial(mx.compile, outputs=state) + def fun(x, y): + z = x + y + state.append(z) + return mx.exp(z), state + + fun(mx.array(1.0), mx.array(2.0)) + # Prints [array(3, dtype=float32)] + print(state) + +This is particularly useful for compiling a function which includes an update +to a container of arrays, as is commonly done when training the parameters of a +:class:`mlx.nn.Module`. + +Compiled functions will also treat any inputs not in the parameter list as +constants. For example: + +.. code-block:: python + + state = [mx.array(1.0)] + + @mx.compile + def fun(x): + return x + state[0] + + # Prints array(2, dtype=float32) + print(fun(mx.array(1.0))) + + # Update state + state[0] = mx.array(5.0) + + # Still prints array(2, dtype=float32) + print(fun(mx.array(1.0))) + +In order to have the change of state reflected in the outputs of ``fun`` you +again have two options. The first option is to simply pass ``state`` as input +to the function. In some cases this can be pretty inconvenient. Hence, +:func:`compile` also has a parameter to capture implicit inputs: + +.. code-block:: python + + from functools import partial + state = [mx.array(1.0)] + + # Tell compile to capture state as an input + @partial(mx.compile, inputs=state) + def fun(x): + return x + state[0] + + # Prints array(2, dtype=float32) + print(fun(mx.array(1.0))) + + # Update state + state[0] = mx.array(5.0) + + # Prints array(6, dtype=float32) + print(fun(mx.array(1.0))) + + +Compiling Training Graphs +------------------------- + +This section will step through how to use :func:`compile` with a simple example +of a common setup: training a model with :obj:`mlx.nn.Module` using an +:obj:`mlx.optimizers.Optimizer` with state. We will show how to compile the +full forward, backward, and update with :func:`compile`. + +To start, here is the simple example without any compilation: + +.. code-block:: python + + import mlx.core as mx + import mlx.nn as nn + import mlx.optimizers as optim + + # 4 examples with 10 features each + x = mx.random.uniform(shape=(4, 10)) + + # 0, 1 targets + y = mx.array([0, 1, 0, 1]) + + # Simple linear model + model = nn.Linear(10, 1) + + # SGD with momentum + optimizer = optim.SGD(learning_rate=0.1, momentum=0.8) + + def loss_fn(model, x, y): + logits = model(x).squeeze() + return nn.losses.binary_cross_entropy(logits, y) + + loss_and_grad_fn = nn.value_and_grad(model, loss_fn) + + # Perform 10 steps of gradient descent + for it in range(10): + loss, grads = loss_and_grad_fn(model, x, y) + optimizer.update(model, grads) + mx.eval(model.parameters(), optimizer.state) + +To compile the update we can put it all in a function and compile it with the +appropriate input and output captures. Here's the same example but compiled: + +.. code-block:: python + + import mlx.core as mx + import mlx.nn as nn + import mlx.optimizers as optim + from functools import partial + + # 4 examples with 10 features each + x = mx.random.uniform(shape=(4, 10)) + + # 0, 1 targets + y = mx.array([0, 1, 0, 1]) + + # Simple linear model + model = nn.Linear(10, 1) + + # SGD with momentum + optimizer = optim.SGD(learning_rate=0.1, momentum=0.8) + + def loss_fn(model, x, y): + logits = model(x).squeeze() + return nn.losses.binary_cross_entropy(logits, y) + + # The state that will be captured as input and output + state = [model.state, optimizer.state] + + @partial(mx.compile, inputs=state, outputs=state) + def step(x, y): + loss_and_grad_fn = nn.value_and_grad(model, loss_fn) + loss, grads = loss_and_grad_fn(model, x, y) + optimizer.update(model, grads) + return loss + + # Perform 10 steps of gradient descent + for it in range(10): + loss = step(x, y) + # Evaluate the model and optimizer state + mx.eval(state) + print(loss) + + +.. note:: + + If you are using a module which performs random sampling such as + :func:`mlx.nn.Dropout`, make sure you also include ``mx.random.state`` in the + ``state`` captured by :func:`compile`, i.e. ``state = [model.state, + optimizer.state, mx.random.state]``. + + +.. note:: + + For more examples of compiling full training graphs checkout the `MLX + Examples `_ GitHub repo. + +Transformations with Compile +---------------------------- + +In MLX function transformations are composable. You can apply any function +transformation to the output of any other function transformation. For more on +this, see the documentation on :ref:`function transforms +`. + +Compiling transformed functions works just as expected: + +.. code-block:: python + + grad_fn = mx.grad(mx.exp) + + compiled_grad_fn = mx.compile(grad_fn) + + # Prints: array(2.71828, dtype=float32) + print(grad_fn(mx.array(1.0))) + + # Also prints: array(2.71828, dtype=float32) + print(compiled_grad_fn(mx.array(1.0))) + +.. note:: + + In order to compile as much as possible, a transformation of a compiled + function will not by default be compiled. To compile the transformed + function simply pass it through :func:`compile`. + +You can also compile functions which themselves call compiled functions. A +good practice is to compile the outer most function to give :func:`compile` +the most opportunity to optimize the computation graph: + +.. code-block:: python + + @mx.compile + def inner(x): + return mx.exp(-mx.abs(x)) + + def outer(x): + inner(inner(x)) + + # Compiling the outer function is good to do as it will likely + # be faster even though the inner functions are compiled + fun = mx.compile(outer) diff --git a/docs/src/usage/function_transforms.rst b/docs/src/usage/function_transforms.rst index 72a313f97..02c5dec48 100644 --- a/docs/src/usage/function_transforms.rst +++ b/docs/src/usage/function_transforms.rst @@ -5,9 +5,12 @@ Function Transforms .. currentmodule:: mlx.core -MLX uses composable function transformations for automatic differentiation and -vectorization. The key idea behind composable function transformations is that -every transformation returns a function which can be further transformed. +MLX uses composable function transformations for automatic differentiation, +vectorization, and compute graph optimizations. To see the complete list of +function transformations check-out the :ref:`API documentation `. + +The key idea behind composable function transformations is that every +transformation returns a function which can be further transformed. Here is a simple example: @@ -36,10 +39,10 @@ Using :func:`grad` on the output of :func:`grad` is always ok. You keep getting higher order derivatives. Any of the MLX function transformations can be composed in any order to any -depth. To see the complete list of function transformations check-out the -:ref:`API documentation `. See the following sections for more -information on :ref:`automatic differentiaion ` and -:ref:`automatic vectorization `. +depth. See the following sections for more information on :ref:`automatic +differentiaion ` and :ref:`automatic vectorization `. +For more information on :func:`compile` see the :ref:`compile documentation `. + Automatic Differentiation ------------------------- diff --git a/mlx/CMakeLists.txt b/mlx/CMakeLists.txt index cb65e518b..d2f021af5 100644 --- a/mlx/CMakeLists.txt +++ b/mlx/CMakeLists.txt @@ -3,9 +3,10 @@ target_sources( PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp ${CMAKE_CURRENT_SOURCE_DIR}/array.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp ${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp ${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp ${CMAKE_CURRENT_SOURCE_DIR}/graph_utils.cpp diff --git a/mlx/array.cpp b/mlx/array.cpp index 7f3dd854b..83c2fe6d7 100644 --- a/mlx/array.cpp +++ b/mlx/array.cpp @@ -82,6 +82,13 @@ array::array(std::initializer_list data) init(data.begin()); } +array::array(std::initializer_list data, Dtype dtype) + : array_desc_(std::make_shared( + std::vector{static_cast(data.size())}, + dtype)) { + init(data.begin()); +} + /* Build an array from a shared buffer */ array::array( allocator::Buffer data, diff --git a/mlx/array.h b/mlx/array.h index 2b849a7ae..fe01cbfd7 100644 --- a/mlx/array.h +++ b/mlx/array.h @@ -41,6 +41,9 @@ class array { /* Special case so empty lists default to float32. */ array(std::initializer_list data); + /* Special case so array({}, type) is an empty array. */ + array(std::initializer_list data, Dtype dtype); + template array( std::initializer_list data, @@ -121,6 +124,9 @@ class array { template T item(); + template + T item() const; + struct ArrayIterator { using iterator_category = std::random_access_iterator_tag; using difference_type = size_t; @@ -454,6 +460,18 @@ T array::item() { return *data(); } +template +T array::item() const { + if (size() != 1) { + throw std::invalid_argument("item can only be called on arrays of size 1."); + } + if (!is_evaled()) { + throw std::invalid_argument( + "item() const can only be called on evaled arrays"); + } + return *data(); +} + template void array::init(It src) { set_data(allocator::malloc(size() * size_of(dtype()))); diff --git a/mlx/backend/accelerate/primitives.cpp b/mlx/backend/accelerate/primitives.cpp index 6cd851111..e147b5888 100644 --- a/mlx/backend/accelerate/primitives.cpp +++ b/mlx/backend/accelerate/primitives.cpp @@ -33,7 +33,6 @@ DEFAULT(ArgSort) DEFAULT(AsStrided) DEFAULT(Broadcast) DEFAULT(Ceil) -DEFAULT_MULTI(Compiled) DEFAULT(Concatenate) DEFAULT(Copy) DEFAULT_MULTI(CustomVJP) @@ -62,6 +61,7 @@ DEFAULT(Partition) DEFAULT_MULTI(QRF) DEFAULT(RandomBits) DEFAULT(Reshape) +DEFAULT(Remainder) DEFAULT(Round) DEFAULT(Scatter) DEFAULT(Sigmoid) @@ -81,11 +81,8 @@ void Abs::eval_cpu(const std::vector& inputs, array& out) { } else if (in.dtype() == int32 && in.flags().contiguous) { set_unary_output_data(in, out); vDSP_vabsi(in.data(), 1, out.data(), 1, in.data_size()); - } else if (is_unsigned(in.dtype())) { - // No-op for unsigned types - out.copy_shared_buffer(in); } else { - unary(in, out, AbsOp()); + eval(inputs, out); } } @@ -292,45 +289,6 @@ void Divide::eval_cpu(const std::vector& inputs, array& out) { } } -// TODO: Avoid code duplication with the common backend. -struct RemainderFn { - template - std::enable_if_t, T> operator()( - T numerator, - T denominator) { - return std::fmod(numerator, denominator); - } - - template - std::enable_if_t, T> operator()( - T numerator, - T denominator) { - return numerator % denominator; - } -}; - -void Remainder::eval_cpu(const std::vector& inputs, array& out) { - assert(inputs.size() == 2); - auto& a = inputs[0]; - auto& b = inputs[1]; - - if (a.dtype() == float32) { - binary( - a, - b, - out, - RemainderFn{}, - UseDefaultBinaryOp(), - UseDefaultBinaryOp(), - [](const auto* a, const auto* b, auto* o, auto n) { - int num_el = n; - vvremainderf((float*)o, (const float*)a, (const float*)b, &num_el); - }); - } else { - binary(a, b, out, RemainderFn{}); - } -} - void Exp::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; diff --git a/mlx/backend/accelerate/quantized.cpp b/mlx/backend/accelerate/quantized.cpp index cc31c88a3..e9fec1303 100644 --- a/mlx/backend/accelerate/quantized.cpp +++ b/mlx/backend/accelerate/quantized.cpp @@ -24,8 +24,6 @@ void _qmm_t_4_64( constexpr int bitmask = (1 << bits) - 1; constexpr int pack_factor = 32 / bits; constexpr int packs_in_group = group_size / pack_factor; - const int Kg = K / group_size; - const int Kw = K / pack_factor; for (int m = 0; m < M; m++) { const uint32_t* w_local = w; diff --git a/mlx/backend/accelerate/softmax.cpp b/mlx/backend/accelerate/softmax.cpp index fcd8fbe50..8b95e32d4 100644 --- a/mlx/backend/accelerate/softmax.cpp +++ b/mlx/backend/accelerate/softmax.cpp @@ -274,7 +274,12 @@ void Softmax::eval_cpu(const std::vector& inputs, array& out) { // Make sure that the last dimension is contiguous auto check_input = [](array x) { - if (x.strides()[x.ndim() - 1] == 1) { + bool no_copy = x.strides()[x.ndim() - 1] == 1; + if (x.ndim() > 1) { + auto s = x.strides()[x.ndim() - 2]; + no_copy &= (s == 0 || s == x.shape().back()); + } + if (no_copy) { return x; } else { array x_copy(x.shape(), x.dtype(), nullptr, {}); diff --git a/mlx/backend/common/CMakeLists.txt b/mlx/backend/common/CMakeLists.txt index 0263dff9b..38a9819e5 100644 --- a/mlx/backend/common/CMakeLists.txt +++ b/mlx/backend/common/CMakeLists.txt @@ -1,3 +1,33 @@ + +if (${CMAKE_SYSTEM_NAME} MATCHES "Darwin") + set(CLANG TRUE) +endif() + +add_custom_command( + OUTPUT compiled_preamble.cpp + COMMAND /bin/bash + ${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.sh + ${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp + ${CMAKE_CXX_COMPILER} + ${CMAKE_SOURCE_DIR} + ${CLANG} + + DEPENDS make_compiled_preamble.sh + compiled_preamble.h + ${CMAKE_SOURCE_DIR}/mlx/types/half_types.h + ${CMAKE_SOURCE_DIR}/mlx/types/fp16.h + ${CMAKE_SOURCE_DIR}/mlx/types/bf16.h + ${CMAKE_SOURCE_DIR}/mlx/types/complex.h + ops.h +) + +add_custom_target( + cpu_compiled_preamble + DEPENDS compiled_preamble.cpp +) + +add_dependencies(mlx cpu_compiled_preamble) + target_sources( mlx PRIVATE @@ -11,6 +41,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp ${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp ${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/rope.cpp ${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp ${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp ${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp @@ -18,4 +49,5 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/load.cpp ${CMAKE_CURRENT_SOURCE_DIR}/qrf.cpp + ${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp ) diff --git a/mlx/backend/common/binary.cpp b/mlx/backend/common/binary.cpp index a51d22d0f..ec7097797 100644 --- a/mlx/backend/common/binary.cpp +++ b/mlx/backend/common/binary.cpp @@ -7,6 +7,7 @@ #include "mlx/allocator.h" #include "mlx/backend/common/binary.h" #include "mlx/backend/common/binary_two.h" +#include "mlx/backend/common/ops.h" #include "mlx/primitives.h" #include "mlx/utils.h" @@ -73,7 +74,7 @@ void Add::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 2); auto& a = inputs[0]; auto& b = inputs[1]; - binary(a, b, out, [](auto x, auto y) { return x + y; }); + binary(a, b, out, detail::Add()); } void DivMod::eval( @@ -135,88 +136,56 @@ void Divide::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 2); auto& a = inputs[0]; auto& b = inputs[1]; - binary(a, b, out, [](auto x, auto y) { return x / y; }); + binary(a, b, out, detail::Divide()); } -struct RemainderFn { - template - std::enable_if_t, T> operator()( - T numerator, - T denominator) { - return std::fmod(numerator, denominator); - } - - template - std::enable_if_t, T> operator()( - T numerator, - T denominator) { - return numerator % denominator; - } -}; - void Remainder::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 2); auto& a = inputs[0]; auto& b = inputs[1]; - binary(a, b, out, RemainderFn{}); + binary(a, b, out, detail::Remainder()); } void Equal::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 2); if (equal_nan_) { - comparison_op(inputs[0], inputs[1], out, [](auto x, auto y) { - return x == y || (std::isnan(x) && std::isnan(y)); - }); + comparison_op(inputs[0], inputs[1], out, detail::NaNEqual()); } else { - comparison_op( - inputs[0], inputs[1], out, [](auto x, auto y) { return x == y; }); + comparison_op(inputs[0], inputs[1], out, detail::Equal()); } } void Greater::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 2); - comparison_op( - inputs[0], inputs[1], out, [](auto x, auto y) { return x > y; }); + comparison_op(inputs[0], inputs[1], out, detail::Greater()); } void GreaterEqual::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 2); - comparison_op( - inputs[0], inputs[1], out, [](auto x, auto y) { return x >= y; }); + comparison_op(inputs[0], inputs[1], out, detail::GreaterEqual()); } void Less::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 2); - comparison_op( - inputs[0], inputs[1], out, [](auto x, auto y) { return x < y; }); + comparison_op(inputs[0], inputs[1], out, detail::Less()); } void LessEqual::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 2); - comparison_op( - inputs[0], inputs[1], out, [](auto x, auto y) { return x <= y; }); + comparison_op(inputs[0], inputs[1], out, detail::LessEqual()); } void LogAddExp::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 2); auto& a = inputs[0]; auto& b = inputs[1]; - auto op = [](auto x, auto y) { - constexpr float inf = std::numeric_limits::infinity(); - auto maxval = (x > y) ? x : y; - auto minval = (x > y) ? y : x; - return (minval == -inf || maxval == inf) - ? maxval - : static_cast( - maxval + std::log1p(std::exp(minval - maxval))); - }; if (is_floating_point(out.dtype())) { if (out.dtype() == float32) { - binary_op(a, b, out, op); + binary_op(a, b, out, detail::LogAddExp()); } else if (out.dtype() == float16) { - binary_op(a, b, out, op); + binary_op(a, b, out, detail::LogAddExp()); } else if (out.dtype() == bfloat16) { - binary_op(a, b, out, op); + binary_op(a, b, out, detail::LogAddExp()); } else { std::ostringstream err; err << "[logaddexp] Does not support " << out.dtype(); @@ -233,84 +202,40 @@ void Maximum::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 2); auto& a = inputs[0]; auto& b = inputs[1]; - - if (is_floating_point(out.dtype())) { - binary(a, b, out, [](auto x, auto y) { - if (std::isnan(x)) { - return x; - } - return (x > y) ? x : y; - }); - } else { - binary(a, b, out, [](auto x, auto y) { return (x > y) ? x : y; }); - } + binary(a, b, out, detail::Maximum()); } void Minimum::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 2); auto& a = inputs[0]; auto& b = inputs[1]; - if (is_floating_point(out.dtype())) { - binary(a, b, out, [](auto x, auto y) { - if (std::isnan(x)) { - return x; - } - return (x < y) ? x : y; - }); - } else { - binary(a, b, out, [](auto x, auto y) { return (x < y) ? x : y; }); - } + binary(a, b, out, detail::Minimum()); } void Multiply::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 2); auto& a = inputs[0]; auto& b = inputs[1]; - binary(a, b, out, [](auto x, auto y) { return x * y; }); + binary(a, b, out, detail::Multiply()); } void NotEqual::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 2); - comparison_op( - inputs[0], inputs[1], out, [](auto x, auto y) { return x != y; }); + comparison_op(inputs[0], inputs[1], out, detail::NotEqual()); } -struct PowerFn { - template - std::enable_if_t, T> operator()(T base, T exp) { - return std::pow(base, exp); - } - - template - std::enable_if_t, T> operator()(T base, T exp) { - if (exp < 0) { - throw std::invalid_argument( - "Integers cannot be raise to negative powers"); - } - T res = 1; - while (exp) { - if (exp & 1) { - res *= base; - } - exp >>= 1; - base *= base; - } - return res; - } -}; - void Power::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 2); auto& a = inputs[0]; auto& b = inputs[1]; - binary(a, b, out, PowerFn{}); + binary(a, b, out, detail::Power()); } void Subtract::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 2); auto& a = inputs[0]; auto& b = inputs[1]; - binary(a, b, out, [](auto x, auto y) { return x - y; }); + binary(a, b, out, detail::Subtract()); } } // namespace mlx::core diff --git a/mlx/backend/common/compiled.cpp b/mlx/backend/common/compiled.cpp index 149556530..529ad2fa5 100644 --- a/mlx/backend/common/compiled.cpp +++ b/mlx/backend/common/compiled.cpp @@ -1,59 +1,507 @@ // Copyright © 2023-2024 Apple Inc. -#include +#include +#include +#include +#include "mlx/backend/common/compiled.h" +#include "mlx/backend/common/compiled_preamble.h" +#include "mlx/backend/common/utils.h" +#include "mlx/graph_utils.h" #include "mlx/primitives.h" +#include "mlx/utils.h" namespace mlx::core { -// Build the real tape -std::pair, std::vector> trace_to_real( - const std::vector& trace_tape, - const std::vector& trace_inputs, - const std::vector& trace_outputs, - const std::vector& inputs) { - std::unordered_map trace_to_real; - for (int i = 0; i < inputs.size(); ++i) { - trace_to_real.insert({trace_inputs[i].id(), inputs[i]}); - } - std::queue tape; - for (auto& a : trace_tape) { - // Find real inputs - std::vector real_inputs; - for (auto& in : a.inputs()) { - real_inputs.push_back(trace_to_real.at(in.id())); - } - tape.push( - array(a.shape(), a.dtype(), a.primitive_ptr(), std::move(real_inputs))); - trace_to_real.insert({a.id(), tape.back()}); - } - - std::vector outputs; - for (auto& o : trace_outputs) { - outputs.push_back(trace_to_real.at(o.id())); - } - return {tape, outputs}; +std::string get_temp_file(const std::string& name) { + return std::filesystem::temp_directory_path().append(name); } -void Compiled::eval( +std::string build_lib_name( + const std::vector& inputs, + const std::vector& outputs, + const std::vector& tape, + const std::unordered_set& constant_ids) { + std::ostringstream os; + std::ostringstream constant_hasher; + + // The primitives describing the tape. For unary and binary primitives this + // must be enough to describe the full computation. + for (auto& a : tape) { + a.primitive().print(os); + } + os << "_"; + + for (auto& x : inputs) { + if (constant_ids.find(x.id()) != constant_ids.end()) { + os << "C"; + print_constant(constant_hasher, x); + } else { + os << (is_scalar(x) ? "S" : "V"); + } + } + os << "_"; + for (auto& x : inputs) { + if (constant_ids.find(x.id()) != constant_ids.end()) { + continue; + } + os << kindof(x.dtype()) << x.itemsize(); + } + os << "_" << std::hash{}(constant_hasher.str()); + + return os.str(); +} + +void print_constant(std::ostream& os, const array& x) { + switch (x.dtype()) { + case float32: + return print_float_constant(os, x); + case float16: + return print_float_constant(os, x); + case bfloat16: + return print_float_constant(os, x); + case complex64: + return print_complex_constant(os, x); + case int8: + return print_int_constant(os, x); + case int16: + return print_int_constant(os, x); + case int32: + return print_int_constant(os, x); + case int64: + return print_int_constant(os, x); + case uint8: + return print_int_constant(os, x); + case uint16: + return print_int_constant(os, x); + case uint32: + return print_int_constant(os, x); + case uint64: + return print_int_constant(os, x); + case bool_: + os << std::boolalpha << x.item(); + return; + default: + throw std::runtime_error("Unsupported constant type"); + } +} + +std::string get_type_string(Dtype d) { + switch (d) { + case float32: + return "float"; + case float16: + return "float16_t"; + case bfloat16: + return "bfloat16_t"; + case complex64: + return "complex64_t"; + case bool_: + return "bool"; + case int8: + return "int8_t"; + case int16: + return "int16_t"; + case int32: + return "int32_t"; + case int64: + return "int64_t"; + case uint8: + return "uint8_t"; + case uint16: + return "uint16_t"; + case uint32: + return "uint32_t"; + case uint64: + return "uint64_t"; + default: { + std::ostringstream msg; + msg << "Unsupported compilation type " << d; + throw std::runtime_error(msg.str()); + } + } +} + +// Return a pointer to a compiled function +void* compile( + const std::string& kernel_name, + const std::string& source_code = "") { + struct DLib { + DLib(const std::string& libname) { + lib = dlopen(libname.c_str(), RTLD_NOW); + if (!lib) { + std::ostringstream msg; + msg << "Could not load C++ shared library " << dlerror(); + throw std::runtime_error(msg.str()); + } + } + + ~DLib() { + dlclose(lib); + } + void* lib; + }; + // Statics to cache compiled libraries and functions + static std::list libs; + static std::unordered_map kernels; + if (auto it = kernels.find(kernel_name); it != kernels.end()) { + return it->second; + } + if (source_code.empty()) { + return nullptr; + } + + std::ostringstream shared_lib_name; + shared_lib_name << "lib" << kernel_name << ".so"; + auto shared_lib_path = get_temp_file(shared_lib_name.str()); + bool lib_exists = false; + { + std::ifstream f(shared_lib_path.c_str()); + lib_exists = f.good(); + } + + if (!lib_exists) { + // Open source file and write source code to it + std::ostringstream source_file_name; + source_file_name << kernel_name << ".cpp"; + auto source_file_path = get_temp_file(source_file_name.str()); + + std::ofstream source_file(source_file_path); + source_file << source_code; + source_file.close(); + + std::ostringstream build_command; + build_command << "g++ -std=c++17 -O2 -Wall -fPIC -shared " + << source_file_path << " -o " << shared_lib_path; + std::string build_command_str = build_command.str(); + auto return_code = system(build_command_str.c_str()); + if (return_code) { + std::ostringstream msg; + msg << "[Compile::eval_cpu] Failed to compile function " << kernel_name + << " with error code " << return_code << "." << std::endl; + throw std::runtime_error(msg.str()); + } + } + + // load library + libs.emplace_back(shared_lib_path); + + // Load function + void* fun = dlsym(libs.back().lib, kernel_name.c_str()); + if (!fun) { + std::ostringstream msg; + msg << "[Compile::eval_cpu] Failed to load compiled function " + << kernel_name << std::endl + << dlerror(); + throw std::runtime_error(msg.str()); + } + kernels.insert({kernel_name, fun}); + return fun; +} + +inline void build_kernel( + std::ostream& os, + const std::string& kernel_name, + const std::vector& inputs, + const std::vector& outputs, + const std::vector& tape, + const std::unordered_set& constant_ids, + bool contiguous, + int ndim) { + // All outputs should have the exact same shape and will be row contiguous + auto output_shape = outputs[0].shape(); + auto output_strides = outputs[0].strides(); + + // Constants are scalars that are captured by value and cannot change + auto is_constant = [&constant_ids](const array& x) { + return constant_ids.find(x.id()) != constant_ids.end(); + }; + + NodeNamer namer; + + // Start the kernel + os << "void " << kernel_name << "(void** args) {" << std::endl; + + // Add the input arguments + int cnt = 0; + for (auto& x : inputs) { + auto& xname = namer.get_name(x); + + // Skip constants from the input list + if (is_constant(x)) { + continue; + } + + auto tstr = get_type_string(x.dtype()); + os << " " << tstr << "* " << xname << " = (" << tstr << "*)args[" << cnt++ + << "];" << std::endl; + // Scalars and contiguous need no strides + if (!is_scalar(x) && !contiguous) { + os << " const size_t* " << xname << "_strides = (size_t*)args[" << cnt++ + << "];" << std::endl; + } + } + + // Add the output arguments + for (auto& x : outputs) { + auto tstr = get_type_string(x.dtype()); + os << " " << tstr << "* " << namer.get_name(x) << " = (" << tstr + << "*)args[" << cnt++ << "];" << std::endl; + } + // Add output strides and shape to extract the indices. + if (!contiguous) { + os << " const int* shape = (int*)args[" << cnt++ << "];" << std::endl; + } else { + os << " const size_t size = (size_t)args[" << cnt++ << "];" << std::endl; + } + + if (contiguous) { + os << " for (size_t i = 0; i < size; ++i) {" << std::endl; + } else { + for (int d = 0; d < ndim; ++d) { + os << " for (int i" << d << " = 0; i" << d << " < shape[" << d + << "]; ++i" << d << ") {" << std::endl; + } + } + + // Read the inputs in tmps + for (auto& x : inputs) { + auto& xname = namer.get_name(x); + + if (is_constant(x)) { + os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "; + print_constant(os, x); + os << ";" << std::endl; + } else if (is_scalar(x)) { + os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = " + << xname << "[0];" << std::endl; + } else if (contiguous) { + os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = " + << xname << "[i];" << std::endl; + } else { + os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = *" + << xname << ";" << std::endl; + } + } + + // Actually write the computation + for (auto& x : tape) { + os << " " << get_type_string(x.dtype()) << " tmp_" << namer.get_name(x) + << " = "; + if (is_static_cast(x.primitive())) { + os << "static_cast<" << get_type_string(x.dtype()) << ">(tmp_" + << namer.get_name(x.inputs()[0]) << ");" << std::endl; + } else { + x.primitive().print(os); + os << "()("; + for (int i = 0; i < x.inputs().size() - 1; i++) { + os << "tmp_" << namer.get_name(x.inputs()[i]) << ", "; + } + os << "tmp_" << namer.get_name(x.inputs().back()) << ");" << std::endl; + } + } + + // Write the outputs from tmps + for (auto& x : outputs) { + if (contiguous) { + os << " " << namer.get_name(x) << "[i] = tmp_" << namer.get_name(x) + << ";" << std::endl; + } else { + os << " *" << namer.get_name(x) << "++ = tmp_" << namer.get_name(x) + << ";" << std::endl; + } + } + + // Close loops + if (contiguous) { + os << " }" << std::endl; + } else { + for (int d = ndim - 1; d >= 0; --d) { + // Update pointers + for (auto& x : inputs) { + if (is_constant(x) || is_scalar(x)) { + continue; + } + auto& xname = namer.get_name(x); + os << " " << xname << " += " << xname << "_strides[" << d << "];" + << std::endl; + if (d < ndim - 1) { + os << " " << xname << " -= " << xname << "_strides[" << d + 1 << "]" + << " * shape[" << d + 1 << "];" << std::endl; + } + } + os << " }" << std::endl; + } + } + + // Finish the kernel + os << "}" << std::endl; +} + +void Compiled::eval_cpu( const std::vector& inputs, std::vector& outputs) { - // Make the a real tape from the tracers - auto [tape, real_outputs] = trace_to_real(tape_, inputs_, outputs_, inputs); - - // Run the tape - while (!tape.empty()) { - auto a = std::move(tape.front()); - tape.pop(); - auto outputs = a.outputs(); - a.primitive().eval_cpu(a.inputs(), outputs); - a.detach(); + if (kernel_lib_.empty()) { + kernel_lib_ = build_lib_name(inputs_, outputs_, tape_, constant_ids_); } - // Copy results into outputs - for (int o = 0; o < real_outputs.size(); ++o) { - outputs[o].copy_shared_buffer(real_outputs[o]); + // Figure out which kernel we are using + auto& shape = outputs[0].shape(); + bool contiguous = true; + { + bool all_contig = true; + bool all_row_contig = true; + bool all_col_contig = true; + int non_scalar_inputs = 0; + for (auto& x : inputs) { + if (is_scalar(x)) { + continue; + } + non_scalar_inputs++; + bool shape_eq = x.shape() == shape; + all_contig &= (x.flags().contiguous && shape_eq); + all_row_contig &= (x.flags().row_contiguous && shape_eq); + all_col_contig &= (x.flags().col_contiguous && shape_eq); + } + if (non_scalar_inputs > 1 && !all_row_contig && !all_col_contig) { + contiguous = false; + } else if (non_scalar_inputs == 1 && !all_contig) { + contiguous = false; + } } + + // Handle all broadcasting and collect function input arguments + std::vector args; + std::vector> strides; + for (int i = 0; i < inputs.size(); i++) { + // Skip constants. + if (constant_ids_.find(inputs_[i].id()) != constant_ids_.end()) { + continue; + } + auto& x = inputs[i]; + args.push_back((void*)x.data()); + + if (contiguous || is_scalar(x)) { + continue; + } + + // Broadcast the input to the output shape. + std::vector xstrides; + int j = 0; + for (; j < shape.size() - x.ndim(); j++) { + if (shape[j] == 1) { + xstrides.push_back(outputs[0].strides()[j]); + } else { + xstrides.push_back(0); + } + } + for (int i = 0; i < x.ndim(); i++, j++) { + if (x.shape(i) == 1) { + if (shape[j] == 1) { + xstrides.push_back(outputs[0].strides()[j]); + } else { + xstrides.push_back(0); + } + } else { + xstrides.push_back(x.strides()[i]); + } + } + strides.push_back(std::move(xstrides)); + args.push_back(strides.back().data()); + } + + // Get the kernel name from the lib + int ndim = shape.size(); + auto kernel_name = kernel_lib_ + (contiguous ? "_contiguous" : "_strided_"); + if (!contiguous) { + kernel_name += std::to_string(shape.size()); + } + + // Get the function + auto fn_ptr = compile(kernel_name); + + // If it doesn't exist, compile it + if (fn_ptr == nullptr) { + std::ostringstream kernel; + kernel << get_kernel_preamble() << std::endl; + kernel << "extern \"C\" {" << std::endl; + build_kernel( + kernel, + kernel_name, + inputs_, + outputs_, + tape_, + constant_ids_, + contiguous, + ndim); + // Close extern "C" + kernel << "}" << std::endl; + + // Compile and get function pointer + fn_ptr = compile(kernel_name, kernel.str()); + } + + // Allocate space for the outputs possibly with input donation + if (contiguous) { + int o = 0; + std::vector strides; + size_t data_size; + array::Flags flags; + for (int i = 0; i < inputs.size() && o < outputs.size(); ++i) { + auto& in = inputs[i]; + // Conditions for donation + // - Contiguous + // - Donatable + // - Correct size + // - Not a constant + if (in.flags().contiguous && !is_scalar(in) && in.is_donatable() && + constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) { + outputs[o++].copy_shared_buffer(in); + } + // Get representative input flags to properly set non-donated outputs + if (strides.empty() && in.size() == outputs[0].size()) { + strides = in.strides(); + flags = in.flags(); + data_size = in.data_size(); + } + } + for (; o < outputs.size(); ++o) { + outputs[o].set_data( + allocator::malloc_or_wait(data_size * outputs[o].itemsize()), + data_size, + strides, + flags); + } + } else { + int o = 0; + for (int i = 0; i < inputs.size() && o < outputs.size(); ++i) { + auto& in = inputs[i]; + // Conditions for donation + // - Row contiguous + // - Donatable + // - Correct size + // - Not a constant + if (in.flags().row_contiguous && in.nbytes() == outputs[o].nbytes() && + in.is_donatable() && + constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) { + outputs[o++].copy_shared_buffer(in); + } + } + for (; o < outputs.size(); ++o) { + outputs[o].set_data(allocator::malloc_or_wait(outputs[o].nbytes())); + } + } + + for (auto& x : outputs) { + args.push_back(x.data()); + } + if (!contiguous) { + args.push_back((void*)outputs[0].shape().data()); + } else { + args.push_back((void*)outputs[0].data_size()); + } + auto fun = (void (*)(void**))fn_ptr; + fun(args.data()); } } // namespace mlx::core diff --git a/mlx/backend/common/compiled.h b/mlx/backend/common/compiled.h new file mode 100644 index 000000000..d01fe4fdc --- /dev/null +++ b/mlx/backend/common/compiled.h @@ -0,0 +1,56 @@ +// Copyright © 2023-2024 Apple Inc. +#pragma once + +#include +#include +#include + +#include "mlx/array.h" +#include "mlx/primitives.h" + +namespace mlx::core { + +inline bool is_static_cast(const Primitive& p) { + return ( + typeid(p) == typeid(Broadcast) || typeid(p) == typeid(Copy) || + typeid(p) == typeid(StopGradient) || typeid(p) == typeid(AsType)); +} + +std::string build_lib_name( + const std::vector& inputs, + const std::vector& outputs, + const std::vector& tape, + const std::unordered_set& constant_ids); + +std::string get_type_string(Dtype d); + +template +void print_float_constant(std::ostream& os, const array& x) { + auto old_precision = os.precision(); + os << std::setprecision(std::numeric_limits::digits10 + 1) + << x.item() << std::setprecision(old_precision); +} + +template +void print_int_constant(std::ostream& os, const array& x) { + os << x.item(); +} + +template +void print_complex_constant(std::ostream& os, const array& x) { + auto old_precision = os.precision(); + T constant = x.item(); + + os << get_type_string(x.dtype()) << "(" + << std::setprecision(std::numeric_limits::digits10 + 1) + << constant.real() << ", " << constant.imag() << ")" + << std::setprecision(old_precision); +} + +void print_constant(std::ostream& os, const array& x); + +inline bool is_scalar(const array& x) { + return x.ndim() == 0; +} + +} // namespace mlx::core diff --git a/mlx/backend/common/compiled_preamble.h b/mlx/backend/common/compiled_preamble.h new file mode 100644 index 000000000..84b77d29d --- /dev/null +++ b/mlx/backend/common/compiled_preamble.h @@ -0,0 +1,11 @@ +// Copyright © 2023-24 Apple Inc. + +#pragma once + +// clang-format off +#include "mlx/types/half_types.h" +#include "mlx/types/complex.h" +#include "mlx/backend/common/ops.h" +// clang-format on + +const char* get_kernel_preamble(); diff --git a/mlx/backend/common/default_primitives.cpp b/mlx/backend/common/default_primitives.cpp index 6befc8eb9..c65028d95 100644 --- a/mlx/backend/common/default_primitives.cpp +++ b/mlx/backend/common/default_primitives.cpp @@ -43,7 +43,6 @@ DEFAULT(AsStrided) DEFAULT(Broadcast) DEFAULT_MULTI(DivMod) DEFAULT(Ceil) -DEFAULT_MULTI(Compiled) DEFAULT(Concatenate) DEFAULT(Convolution) DEFAULT(Copy) diff --git a/mlx/backend/common/erf.h b/mlx/backend/common/erf.h deleted file mode 100644 index a175a0c43..000000000 --- a/mlx/backend/common/erf.h +++ /dev/null @@ -1,11 +0,0 @@ -// Copyright © 2023 Apple Inc. - -namespace mlx::core { - -/* Approximation to the inverse error function. - * Based on code from: - * https://stackoverflow.com/questions/27229371/inverse-error-function-in-c#answer-49743348 - */ -float erfinv(float a); - -} // namespace mlx::core diff --git a/mlx/backend/common/make_compiled_preamble.sh b/mlx/backend/common/make_compiled_preamble.sh new file mode 100644 index 000000000..687f4cfc7 --- /dev/null +++ b/mlx/backend/common/make_compiled_preamble.sh @@ -0,0 +1,34 @@ +#!/bin/bash +# +# This script generates a C++ function that provides the CPU +# code for use with kernel generation. +# +# Copyright © 2023-24 Apple Inc. + + +OUTPUT_FILE=$1 +GCC=$2 +SRCDIR=$3 +CLANG=$4 + +if [ $CLANG = "TRUE" ]; then + read -r -d '' INCLUDES <<- EOM + #include + #include + #include + #include +EOM + +fi + +CONTENT=$($GCC -I $SRCDIR -E $SRCDIR/mlx/backend/common/compiled_preamble.h 2>/dev/null) + +cat << EOF > "$OUTPUT_FILE" +const char* get_kernel_preamble() { +return R"preamble( +$INCLUDES +$CONTENT +using namespace mlx::core::detail; +)preamble"; +} +EOF diff --git a/mlx/backend/common/ops.h b/mlx/backend/common/ops.h new file mode 100644 index 000000000..8b2d7ab58 --- /dev/null +++ b/mlx/backend/common/ops.h @@ -0,0 +1,591 @@ +// Copyright © 2023-2024 Apple Inc. + +#pragma once +#include +#include +#include + +namespace mlx::core::detail { + +typedef union { + int i; + float f; +} IntOrFloat; + +inline float fast_exp(float x) { + if (x == -std::numeric_limits::infinity()) { + return 0.0f; + } else if (x == std::numeric_limits::infinity() || std::isnan(x)) { + return x; + } + x *= 1.442695; // multiply with log_2(e) + float ipart, fpart; + IntOrFloat epart; + x = std::max(-80.f, std::min(x, 80.f)); + ipart = std::floor(x + 0.5); + fpart = x - ipart; + + x = 1.535336188319500e-4f; + x = x * fpart + 1.339887440266574e-3f; + x = x * fpart + 9.618437357674640e-3f; + x = x * fpart + 5.550332471162809e-2f; + x = x * fpart + 2.402264791363012e-1f; + x = x * fpart + 6.931472028550421e-1f; + x = x * fpart + 1.000000000000000f; + + // generate 2**ipart in the floating point representation using integer + // bitshifting + epart.i = (int(ipart) + 127) << 23; + + return epart.f * x; +} + +inline float fast_erf(float a) { + float r, s, t, u; + t = std::abs(a); + s = a * a; + if (t > 0.927734375f) { + // maximum error 0.99527 ulp + r = std::fma( + -1.72853470e-5f, t, 3.83197126e-4f); // -0x1.220000p-16,0x1.91cfb2p-12 + u = std::fma( + -3.88396438e-3f, t, 2.42546219e-2f); // -0x1.fd1438p-9, 0x1.8d6342p-6 + r = std::fma(r, s, u); + r = std::fma(r, t, -1.06777877e-1f); // -0x1.b55cb8p-4 + r = std::fma(r, t, -6.34846687e-1f); // -0x1.450aa0p-1 + r = std::fma(r, t, -1.28717512e-1f); // -0x1.079d0cp-3 + r = std::fma(r, t, -t); + // TODO, replace with expm1 when implemented + r = 1.0f - std::exp(r); + r = std::copysign(r, a); + } else { + // maximum error 0.98929 ulp + r = -5.96761703e-4f; // -0x1.38e000p-11 + r = std::fma(r, s, 4.99119423e-3f); // 0x1.471a58p-8 + r = std::fma(r, s, -2.67681349e-2f); // -0x1.b691b2p-6 + r = std::fma(r, s, 1.12819925e-1f); // 0x1.ce1c44p-4 + r = std::fma(r, s, -3.76125336e-1f); // -0x1.812700p-2 + r = std::fma(r, s, 1.28379166e-1f); // 0x1.06eba8p-3 + r = std::fma(r, a, a); + } + return r; +} + +inline float fast_erfinv(float a) { + auto t = std::fma(a, 0.0f - a, 1.0f); + t = std::log(t); + float p; + if (std::abs(t) > 6.125f) { // maximum ulp error = 2.35793 + p = 3.03697567e-10f; // 0x1.4deb44p-32 + p = std::fma(p, t, 2.93243101e-8f); // 0x1.f7c9aep-26 + p = std::fma(p, t, 1.22150334e-6f); // 0x1.47e512p-20 + p = std::fma(p, t, 2.84108955e-5f); // 0x1.dca7dep-16 + p = std::fma(p, t, 3.93552968e-4f); // 0x1.9cab92p-12 + p = std::fma(p, t, 3.02698812e-3f); // 0x1.8cc0dep-9 + p = std::fma(p, t, 4.83185798e-3f); // 0x1.3ca920p-8 + p = std::fma(p, t, -2.64646143e-1f); // -0x1.0eff66p-2 + p = std::fma(p, t, 8.40016484e-1f); // 0x1.ae16a4p-1 + } else { // maximum ulp error = 2.35002 + p = 5.43877832e-9f; // 0x1.75c000p-28 + p = std::fma(p, t, 1.43285448e-7f); // 0x1.33b402p-23 + p = std::fma(p, t, 1.22774793e-6f); // 0x1.499232p-20 + p = std::fma(p, t, 1.12963626e-7f); // 0x1.e52cd2p-24 + p = std::fma(p, t, -5.61530760e-5f); // -0x1.d70bd0p-15 + p = std::fma(p, t, -1.47697632e-4f); // -0x1.35be90p-13 + p = std::fma(p, t, 2.31468678e-3f); // 0x1.2f6400p-9 + p = std::fma(p, t, 1.15392581e-2f); // 0x1.7a1e50p-7 + p = std::fma(p, t, -2.32015476e-1f); // -0x1.db2aeep-3 + p = std::fma(p, t, 8.86226892e-1f); // 0x1.c5bf88p-1 + } + return a * p; +} + +struct Abs { + template + T operator()(T x) { + return std::abs(x); + }; + uint8_t operator()(uint8_t x) { + return x; + }; + uint16_t operator()(uint16_t x) { + return x; + }; + uint32_t operator()(uint32_t x) { + return x; + }; + uint64_t operator()(uint64_t x) { + return x; + }; + bool operator()(bool x) { + return x; + }; +}; + +struct ArcCos { + template + T operator()(T x) { + return std::acos(x); + }; +}; + +struct ArcCosh { + template + T operator()(T x) { + return std::acosh(x); + }; +}; + +struct ArcSin { + template + T operator()(T x) { + return std::asin(x); + }; +}; + +struct ArcSinh { + template + T operator()(T x) { + return std::asinh(x); + }; +}; + +struct ArcTan { + template + T operator()(T x) { + return std::atan(x); + }; +}; + +struct ArcTanh { + template + T operator()(T x) { + return std::atanh(x); + }; +}; + +struct Ceil { + template + T operator()(T x) { + return std::ceil(x); + }; + int8_t operator()(int8_t x) { + return x; + }; + int16_t operator()(int16_t x) { + return x; + }; + int32_t operator()(int32_t x) { + return x; + }; + int64_t operator()(int64_t x) { + return x; + }; + uint8_t operator()(uint8_t x) { + return x; + }; + uint16_t operator()(uint16_t x) { + return x; + }; + uint32_t operator()(uint32_t x) { + return x; + }; + uint64_t operator()(uint64_t x) { + return x; + }; + bool operator()(bool x) { + return x; + }; +}; + +struct Cos { + template + T operator()(T x) { + return std::cos(x); + }; +}; + +struct Cosh { + template + T operator()(T x) { + return std::cosh(x); + }; +}; + +struct Erf { + template + T operator()(T x) { + return static_cast(fast_erf(static_cast(x))); + }; +}; + +struct ErfInv { + template + T operator()(T x) { + return static_cast(fast_erfinv(static_cast(x))); + }; +}; + +struct Exp { + template + T operator()(T x) { + return fast_exp(x); + }; + + complex64_t operator()(complex64_t x) { + return std::exp(x); + } +}; + +struct Floor { + template + T operator()(T x) { + return std::floor(x); + }; + int8_t operator()(int8_t x) { + return x; + }; + int16_t operator()(int16_t x) { + return x; + }; + int32_t operator()(int32_t x) { + return x; + }; + int64_t operator()(int64_t x) { + return x; + }; + uint8_t operator()(uint8_t x) { + return x; + }; + uint16_t operator()(uint16_t x) { + return x; + }; + uint32_t operator()(uint32_t x) { + return x; + }; + uint64_t operator()(uint64_t x) { + return x; + }; + bool operator()(bool x) { + return x; + }; +}; + +struct Log { + template + T operator()(T x) { + return std::log(x); + }; +}; + +struct Log2 { + template + T operator()(T x) { + return std::log2(x); + }; +}; + +struct Log10 { + template + T operator()(T x) { + return std::log10(x); + }; +}; + +struct Log1p { + template + T operator()(T x) { + return log1p(x); + }; +}; + +struct LogicalNot { + template + T operator()(T x) { + return !x; + }; +}; + +struct Negative { + template + T operator()(T x) { + return -x; + }; +}; + +struct Round { + template + T operator()(T x) { + return std::rint(x); + } + + complex64_t operator()(complex64_t x) { + return {std::rint(x.real()), std::rint(x.imag())}; + } +}; + +struct Sigmoid { + template + T operator()(T x) { + auto one = static_cast(1.0); + return one / (one + fast_exp(-x)); + } +}; + +struct Sign { + template + T operator()(T x) { + return (x > T(0)) - (x < T(0)); + } + uint8_t operator()(uint8_t x) { + return x != 0; + } + uint16_t operator()(uint16_t x) { + return x != 0; + } + uint32_t operator()(uint32_t x) { + return x != 0; + } + uint64_t operator()(uint64_t x) { + return x != 0; + } +}; + +struct Sin { + template + T operator()(T x) { + return std::sin(x); + }; +}; + +struct Sinh { + template + T operator()(T x) { + return std::sinh(x); + }; +}; + +struct Square { + template + T operator()(T x) { + return x * x; + }; +}; + +struct Sqrt { + template + T operator()(T x) { + return std::sqrt(x); + }; +}; + +struct Rsqrt { + template + T operator()(T x) { + return static_cast(1.0) / std::sqrt(x); + }; +}; + +struct Tan { + template + T operator()(T x) { + return std::tan(x); + }; +}; + +struct Tanh { + template + T operator()(T x) { + return std::tanh(x); + }; +}; + +struct Add { + template + T operator()(T x, T y) { + return x + y; + } +}; + +struct Divide { + template + T operator()(T x, T y) { + return x / y; + } +}; + +struct Remainder { + template + std::enable_if_t & !std::is_signed_v, T> operator()( + T numerator, + T denominator) { + return numerator % denominator; + } + + template + std::enable_if_t & std::is_signed_v, T> operator()( + T numerator, + T denominator) { + auto r = numerator % denominator; + if (r != 0 && (r < 0 != denominator < 0)) + r += denominator; + return r; + } + + template + std::enable_if_t, T> operator()( + T numerator, + T denominator) { + auto r = std::fmod(numerator, denominator); + if (r != 0 && (r < 0 != denominator < 0)) { + r += denominator; + } + return r; + } + + complex64_t operator()(complex64_t numerator, complex64_t denominator) { + return numerator % denominator; + } +}; + +struct Equal { + template + bool operator()(T x, T y) { + return x == y; + } +}; + +struct NaNEqual { + template + bool operator()(T x, T y) { + return x == y || (std::isnan(x) && std::isnan(y)); + } +}; + +struct Greater { + template + bool operator()(T x, T y) { + return x > y; + } +}; + +struct GreaterEqual { + template + bool operator()(T x, T y) { + return x >= y; + } +}; + +struct Less { + template + bool operator()(T x, T y) { + return x < y; + } +}; + +struct LessEqual { + template + bool operator()(T x, T y) { + return x <= y; + } +}; + +struct Maximum { + template + std::enable_if_t, T> operator()(T x, T y) { + return (x > y) ? x : y; + } + + template + std::enable_if_t, T> operator()(T x, T y) { + if (std::isnan(x)) { + return x; + } + return (x > y) ? x : y; + } +}; + +struct Minimum { + template + std::enable_if_t, T> operator()(T x, T y) { + return x < y ? x : y; + } + + template + std::enable_if_t, T> operator()(T x, T y) { + if (std::isnan(x)) { + return x; + } + return x < y ? x : y; + } +}; + +struct LogAddExp { + template + T operator()(T x, T y) { + constexpr float inf = std::numeric_limits::infinity(); + auto maxval = Maximum()(x, y); + auto minval = Minimum()(x, y); + return (minval == -inf || maxval == inf) + ? maxval + : static_cast( + maxval + std::log1p(fast_exp(minval - maxval))); + }; +}; + +struct Multiply { + template + T operator()(T x, T y) { + return x * y; + } +}; + +struct NotEqual { + template + bool operator()(T x, T y) { + return x != y; + } +}; + +struct Power { + template + std::enable_if_t, T> operator()(T base, T exp) { + return std::pow(base, exp); + } + + template + std::enable_if_t, T> operator()(T base, T exp) { + T res = 1; + while (exp) { + if (exp & 1) { + res *= base; + } + exp >>= 1; + base *= base; + } + return res; + } +}; + +struct Subtract { + template + T operator()(T x, T y) { + return x - y; + } +}; + +struct LogicalAnd { + template + T operator()(T x, T y) { + return x && y; + }; +}; + +struct LogicalOr { + template + T operator()(T x, T y) { + return x || y; + }; +}; + +} // namespace mlx::core::detail diff --git a/mlx/backend/common/primitives.cpp b/mlx/backend/common/primitives.cpp index 37c61761a..a1e99d7c7 100644 --- a/mlx/backend/common/primitives.cpp +++ b/mlx/backend/common/primitives.cpp @@ -10,7 +10,7 @@ #include "mlx/backend/common/arange.h" #include "mlx/backend/common/binary.h" #include "mlx/backend/common/copy.h" -#include "mlx/backend/common/erf.h" +#include "mlx/backend/common/ops.h" #include "mlx/backend/common/threefry.h" #include "mlx/backend/common/unary.h" #include "mlx/backend/common/utils.h" @@ -26,7 +26,7 @@ void Abs::eval(const std::vector& inputs, array& out) { // No-op for unsigned types out.copy_shared_buffer(in); } else { - unary(in, out, AbsOp()); + unary(in, out, detail::Abs()); } } @@ -38,7 +38,7 @@ void ArcCos::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; if (is_floating_point(out.dtype())) { - unary_fp(in, out, [](auto x) { return std::acos(x); }); + unary_fp(in, out, detail::ArcCos()); } else { throw std::invalid_argument( "[arccos] Cannot compute inverse cosine of elements in array" @@ -50,7 +50,7 @@ void ArcCosh::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; if (is_floating_point(out.dtype())) { - unary_fp(in, out, [](auto x) { return std::acosh(x); }); + unary_fp(in, out, detail::ArcCosh()); } else { throw std::invalid_argument( "[arccosh] Cannot compute inverse hyperbolic cosine of elements in" @@ -62,7 +62,7 @@ void ArcSin::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; if (is_floating_point(out.dtype())) { - unary_fp(in, out, [](auto x) { return std::asin(x); }); + unary_fp(in, out, detail::ArcSin()); } else { throw std::invalid_argument( "[arcsin] Cannot compute inverse sine of elements in array" @@ -74,7 +74,7 @@ void ArcSinh::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; if (is_floating_point(out.dtype())) { - unary_fp(in, out, [](auto x) { return std::asinh(x); }); + unary_fp(in, out, detail::ArcSinh()); } else { throw std::invalid_argument( "[arcsinh] Cannot compute inverse hyperbolic sine of elements in" @@ -86,7 +86,7 @@ void ArcTan::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; if (is_floating_point(out.dtype())) { - unary_fp(in, out, [](auto x) { return std::atan(x); }); + unary_fp(in, out, detail::ArcTan()); } else { throw std::invalid_argument( "[arctan] Cannot compute inverse tangent of elements in array" @@ -98,7 +98,7 @@ void ArcTanh::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; if (is_floating_point(out.dtype())) { - unary_fp(in, out, [](auto x) { return std::atanh(x); }); + unary_fp(in, out, detail::ArcTanh()); } else { throw std::invalid_argument( "[arctanh] Cannot compute inverse hyperbolic tangent of elements in" @@ -172,7 +172,7 @@ void Ceil::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& in = inputs[0]; if (not is_integral(in.dtype())) { - unary_fp(in, out, [](auto x) { return std::ceil(x); }); + unary_fp(in, out, detail::Ceil()); } else { // No-op integer types out.copy_shared_buffer(in); @@ -212,7 +212,7 @@ void Cos::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; if (is_floating_point(out.dtype())) { - unary_fp(in, out, [](auto x) { return std::cos(x); }); + unary_fp(in, out, detail::Cos()); } else { throw std::invalid_argument( "[cos] Cannot compute cosine of elements in array" @@ -224,7 +224,7 @@ void Cosh::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; if (is_floating_point(out.dtype())) { - unary_fp(in, out, [](auto x) { return std::cosh(x); }); + unary_fp(in, out, detail::Cosh()); } else { throw std::invalid_argument( "[cosh] Cannot compute hyperbolic cosine of elements in array" @@ -256,17 +256,13 @@ void Erf::eval(const std::vector& inputs, array& out) { const auto& in = inputs[0]; switch (out.dtype()) { case float32: - unary_op(in, out, [](auto x) { return std::erf(x); }); + unary_op(in, out, detail::Erf()); break; case float16: - unary_op(in, out, [](auto x) { - return static_cast(std::erf(static_cast(x))); - }); + unary_op(in, out, detail::Erf()); break; case bfloat16: - unary_op(in, out, [](auto x) { - return static_cast(std::erf(static_cast(x))); - }); + unary_op(in, out, detail::Erf()); break; default: throw std::invalid_argument( @@ -280,17 +276,13 @@ void ErfInv::eval(const std::vector& inputs, array& out) { const auto& in = inputs[0]; switch (out.dtype()) { case float32: - unary_op(in, out, [](auto x) { return erfinv(x); }); + unary_op(in, out, detail::ErfInv()); break; case float16: - unary_op(in, out, [](auto x) { - return static_cast(erfinv(static_cast(x))); - }); + unary_op(in, out, detail::ErfInv()); break; case bfloat16: - unary_op(in, out, [](auto x) { - return static_cast(erfinv(static_cast(x))); - }); + unary_op(in, out, detail::ErfInv()); break; default: throw std::invalid_argument( @@ -302,9 +294,8 @@ void ErfInv::eval(const std::vector& inputs, array& out) { void Exp::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; - if (is_floating_point(out.dtype())) { - unary_fp(in, out, [](auto x) { return std::exp(x); }); + unary_fp(in, out, detail::Exp()); } else { throw std::invalid_argument( "[exp] Cannot exponentiate elements in array" @@ -316,7 +307,7 @@ void Floor::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& in = inputs[0]; if (not is_integral(in.dtype())) { - unary_fp(in, out, [](auto x) { return std::floor(x); }); + unary_fp(in, out, detail::Floor()); } else { // No-op integer types out.copy_shared_buffer(in); @@ -344,13 +335,13 @@ void Log::eval(const std::vector& inputs, array& out) { if (is_floating_point(out.dtype())) { switch (base_) { case Base::e: - unary_fp(in, out, [](auto x) { return std::log(x); }); + unary_fp(in, out, detail::Log()); break; case Base::two: - unary_fp(in, out, [](auto x) { return std::log2(x); }); + unary_fp(in, out, detail::Log2()); break; case Base::ten: - unary_fp(in, out, [](auto x) { return std::log10(x); }); + unary_fp(in, out, detail::Log10()); break; } } else { @@ -364,7 +355,7 @@ void Log1p::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; if (is_floating_point(out.dtype())) { - unary_fp(in, out, [](auto x) { return std::log1p(x); }); + unary_fp(in, out, detail::Log1p()); } else { throw std::invalid_argument( "[log1p] Cannot compute log of elements in array with" @@ -375,27 +366,27 @@ void Log1p::eval(const std::vector& inputs, array& out) { void LogicalNot::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& in = inputs[0]; - unary(in, out, [](auto x) { return !x; }); + unary(in, out, detail::LogicalNot()); } void LogicalAnd::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 2); // LogicalAnd requires two input arrays auto& in1 = inputs[0]; auto& in2 = inputs[1]; - binary(in1, in2, out, [](auto x, auto y) { return x && y; }); + binary(in1, in2, out, detail::LogicalAnd()); } void LogicalOr::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 2); // LogicalOr requires two input arrays auto& in1 = inputs[0]; auto& in2 = inputs[1]; - binary(in1, in2, out, [](auto x, auto y) { return x || y; }); + binary(in1, in2, out, detail::LogicalOr()); } void Negative::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& in = inputs[0]; - unary(in, out, [](auto x) { return -x; }); + unary(in, out, detail::Negative()); } void Pad::eval(const std::vector& inputs, array& out) { @@ -498,7 +489,7 @@ void Round::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& in = inputs[0]; if (not is_integral(in.dtype())) { - unary_fp(in, out, RoundOp()); + unary_fp(in, out, detail::Round()); } else { // No-op integer types out.copy_shared_buffer(in); @@ -509,11 +500,7 @@ void Sigmoid::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; if (is_floating_point(out.dtype())) { - auto sigmoid_op = [](auto x) { - auto one = static_cast(1.0); - return one / (one + std::exp(-x)); - }; - unary_fp(in, out, sigmoid_op); + unary_fp(in, out, detail::Sigmoid()); } else { throw std::invalid_argument( "[sigmoid] Cannot sigmoid of elements in array with" @@ -527,7 +514,7 @@ void Sign::eval(const std::vector& inputs, array& out) { if (in.dtype() == bool_) { out.copy_shared_buffer(in); } else { - unary(in, out, SignOp()); + unary(in, out, detail::Sign()); } } @@ -535,7 +522,7 @@ void Sin::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; if (is_floating_point(out.dtype())) { - unary_fp(in, out, [](auto x) { return std::sin(x); }); + unary_fp(in, out, detail::Sin()); } else { throw std::invalid_argument( "[sin] Cannot compute sine of elements in array" @@ -547,7 +534,7 @@ void Sinh::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; if (is_floating_point(out.dtype())) { - unary_fp(in, out, [](auto x) { return std::sinh(x); }); + unary_fp(in, out, detail::Sinh()); } else { throw std::invalid_argument( "[sinh] Cannot compute hyperbolic sine of elements in array" @@ -656,18 +643,16 @@ void Split::eval( void Square::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& in = inputs[0]; - unary(in, out, [](auto x) { return x * x; }); + unary(in, out, detail::Square()); } void Sqrt::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& in = inputs[0]; if (recip_) { - unary_fp(in, out, [](auto x) { - return static_cast(1.0) / sqrt(x); - }); + unary_fp(in, out, detail::Rsqrt()); } else { - unary_fp(in, out, [](auto x) { return sqrt(x); }); + unary_fp(in, out, detail::Sqrt()); } } @@ -680,7 +665,7 @@ void Tan::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; if (is_floating_point(out.dtype())) { - unary_fp(in, out, [](auto x) { return std::tan(x); }); + unary_fp(in, out, detail::Tan()); } else { throw std::invalid_argument( "[tan] Cannot compute tangent of elements in array" @@ -692,7 +677,7 @@ void Tanh::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; if (is_floating_point(out.dtype())) { - unary_fp(in, out, [](auto x) { return std::tanh(x); }); + unary_fp(in, out, detail::Tanh()); } else { throw std::invalid_argument( "[tanh] Cannot compute hyperbolic tangent of elements in array" diff --git a/mlx/backend/common/rope.cpp b/mlx/backend/common/rope.cpp new file mode 100644 index 000000000..15b5de7e5 --- /dev/null +++ b/mlx/backend/common/rope.cpp @@ -0,0 +1,13 @@ +// Copyright © 2023-2024 Apple Inc. + +#include "mlx/fast_primitives.h" + +namespace mlx::core::fast { + +void RoPE::eval_cpu( + const std::vector& inputs, + std::vector& outputs) { + throw std::runtime_error("NYI"); +} + +} // namespace mlx::core::fast diff --git a/mlx/backend/common/softmax.cpp b/mlx/backend/common/softmax.cpp index 90874c72d..87ce748c8 100644 --- a/mlx/backend/common/softmax.cpp +++ b/mlx/backend/common/softmax.cpp @@ -53,7 +53,12 @@ void Softmax::eval(const std::vector& inputs, array& out) { // Make sure that the last dimension is contiguous auto check_input = [](array x) { - if (x.strides().back() == 1) { + bool no_copy = x.strides()[x.ndim() - 1] == 1; + if (x.ndim() > 1) { + auto s = x.strides()[x.ndim() - 2]; + no_copy &= (s == 0 || s == x.shape().back()); + } + if (no_copy) { return x; } else { array x_copy(x.shape(), x.dtype(), nullptr, {}); diff --git a/mlx/backend/common/unary.h b/mlx/backend/common/unary.h index 7fdcbeb77..28c4f0f4a 100644 --- a/mlx/backend/common/unary.h +++ b/mlx/backend/common/unary.h @@ -11,59 +11,6 @@ namespace mlx::core { namespace { -struct AbsOp { - template - T operator()(T x) { - return std::abs(x); - } - uint8_t operator()(uint8_t x) { - return x; - } - uint16_t operator()(uint16_t x) { - return x; - } - uint32_t operator()(uint32_t x) { - return x; - } - uint64_t operator()(uint64_t x) { - return x; - } - bool operator()(bool x) { - return x; - } -}; - -struct SignOp { - template - T operator()(T x) { - return (x > T(0)) - (x < T(0)); - } - - uint8_t operator()(uint8_t x) { - return x != 0; - } - uint16_t operator()(uint16_t x) { - return x != 0; - } - uint32_t operator()(uint32_t x) { - return x != 0; - } - uint64_t operator()(uint64_t x) { - return x != 0; - } -}; - -struct RoundOp { - template - T operator()(T x) { - return std::rint(x); - } - - complex64_t operator()(complex64_t x) { - return {std::rint(x.real()), std::rint(x.imag())}; - } -}; - void set_unary_output_data(const array& in, array& out) { if (in.is_donatable() && in.itemsize() == out.itemsize()) { out.copy_shared_buffer(in); diff --git a/mlx/backend/metal/CMakeLists.txt b/mlx/backend/metal/CMakeLists.txt index fd1a47f01..063c283fe 100644 --- a/mlx/backend/metal/CMakeLists.txt +++ b/mlx/backend/metal/CMakeLists.txt @@ -1,3 +1,23 @@ +add_custom_command( + OUTPUT compiled_preamble.cpp + COMMAND /bin/bash + ${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.sh + ${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp + ${CMAKE_C_COMPILER} + ${CMAKE_SOURCE_DIR} + DEPENDS make_compiled_preamble.sh + kernels/compiled_preamble.h + kernels/unary.h + kernels/binary.h +) + +add_custom_target( + compiled_preamble + DEPENDS compiled_preamble.cpp +) + +add_dependencies(mlx compiled_preamble) + target_sources( mlx PRIVATE @@ -12,10 +32,12 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp ${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/rope.cpp ${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp ${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp ${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp ${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp + ${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp ) if (NOT MLX_METAL_PATH) diff --git a/mlx/backend/metal/compiled.cpp b/mlx/backend/metal/compiled.cpp index 4ce9b0c85..3b1ee116a 100644 --- a/mlx/backend/metal/compiled.cpp +++ b/mlx/backend/metal/compiled.cpp @@ -1,44 +1,378 @@ // Copyright © 2023-2024 Apple Inc. +#include + +#include "mlx/backend/common/compiled.h" +#include "mlx/backend/metal/compiled_preamble.h" #include "mlx/backend/metal/device.h" +#include "mlx/backend/metal/utils.h" +#include "mlx/graph_utils.h" #include "mlx/primitives.h" +#include "mlx/utils.h" namespace mlx::core { +inline void build_kernel( + std::ostream& os, + const std::string& kernel_name, + const std::vector& inputs, + const std::vector& outputs, + const std::vector& tape, + const std::unordered_set& constant_ids, + bool contiguous, + int ndim, + bool dynamic_dims) { + // All outputs should have the exact same shape and will be row contiguous + auto output_shape = outputs[0].shape(); + auto output_strides = outputs[0].strides(); + + // Constants are scalars that are captured by value and cannot change + auto is_constant = [&constant_ids](const array& x) { + return constant_ids.find(x.id()) != constant_ids.end(); + }; + + NodeNamer namer; + bool add_indices = false; + int cnt = 0; + + // Start the kernel + os << "[[host_name(\"" << kernel_name << "\")]]" << std::endl + << "[[kernel]] void " << kernel_name << "(" << std::endl; + + // Add the input arguments + for (auto& x : inputs) { + auto& xname = namer.get_name(x); + + // Skip constants from the input list + if (is_constant(x)) { + continue; + } + + // Scalars and contiguous need no strides + if (is_scalar(x) || contiguous) { + os << " device const " << get_type_string(x.dtype()) << "* " << xname + << " [[buffer(" << cnt++ << ")]]," << std::endl; + } else { + add_indices = true; + os << " device const " << get_type_string(x.dtype()) << "* " << xname + << " [[buffer(" << cnt++ << ")]]," << std::endl + << " constant const size_t* " << xname << "_strides [[buffer(" + << cnt++ << ")]]," << std::endl; + } + } + + // Add the output arguments + for (auto& x : outputs) { + os << " device " << get_type_string(x.dtype()) << "* " + << namer.get_name(x) << " [[buffer(" << cnt++ << ")]]," << std::endl; + } + // Add output strides and shape to extract the indices. + if (!contiguous) { + os << " constant const size_t* output_strides [[buffer(" << cnt++ + << ")]]," << std::endl + << " constant const int* output_shape [[buffer(" << cnt++ << ")]]," + << std::endl; + } + if (dynamic_dims) { + os << " constant const int& ndim [[buffer(" << cnt++ << ")]]," + << std::endl; + } + + // The thread index in the whole grid + os << " uint3 pos [[thread_position_in_grid]]," << std::endl + << " uint3 grid [[threads_per_grid]]) {" << std::endl + << " uint index = pos.x + grid.x * (pos.y + grid.y * pos.z);" + << std::endl; + + // Extract the indices per axis to individual uints if we have arrays that + // are broadcasted or transposed + if (add_indices) { + if (!dynamic_dims) { + if (ndim == 1) { + os << " uint index_0 = pos.x;" << std::endl; + } else if (ndim == 2) { + os << " uint index_0 = pos.y;" << std::endl + << " uint index_1 = pos.x;" << std::endl; + } else if (ndim == 3) { + os << " uint index_0 = pos.z;" << std::endl + << " uint index_1 = pos.y;" << std::endl + << " uint index_2 = pos.x;" << std::endl; + } else { + for (int i = 0; i < ndim - 2; i++) { + os << " uint index_" << i << " = (index / uint(output_strides[" << i + << "])) % output_shape[" << i << "];" << std::endl; + } + os << " uint index_" << ndim - 2 << " = pos.y;" << std::endl + << " uint index_" << ndim - 1 << " = pos.x;" << std::endl; + } + } + } + + // Read the inputs in tmps + for (auto& x : inputs) { + auto& xname = namer.get_name(x); + + if (is_constant(x)) { + os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "; + print_constant(os, x); + os << ";" << std::endl; + } else if (is_scalar(x)) { + os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = " + << xname << "[0];" << std::endl; + } else if (contiguous) { + os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = " + << xname << "[index];" << std::endl; + } else if (!dynamic_dims) { + os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = " + << xname << "["; + os << "index_0 * " << xname << "_strides[0]"; + for (int i = 1; i < ndim; i++) { + os << " + index_" << i << " * " << xname << "_strides[" << i << "]"; + } + os << "];" << std::endl; + } else { + os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = " + << xname << "[elem_to_loc(index, output_shape, " << xname + << "_strides, ndim)];" << std::endl; + } + } + + // Actually write the computation + for (auto& x : tape) { + os << " " << get_type_string(x.dtype()) << " tmp_" << namer.get_name(x) + << " = "; + if (is_static_cast(x.primitive())) { + os << "static_cast<" << get_type_string(x.dtype()) << ">(tmp_" + << namer.get_name(x.inputs()[0]) << ");" << std::endl; + } else { + x.primitive().print(os); + os << "()("; + for (int i = 0; i < x.inputs().size() - 1; i++) { + os << "tmp_" << namer.get_name(x.inputs()[i]) << ", "; + } + os << "tmp_" << namer.get_name(x.inputs().back()) << ");" << std::endl; + } + } + + // Write the outputs from tmps + for (auto& x : outputs) { + os << " " << namer.get_name(x) << "[index] = tmp_" << namer.get_name(x) + << ";" << std::endl; + } + + // Finish the kernel + os << "}" << std::endl; + + if (cnt > 31) { + std::ostringstream msg; + msg << "[compile] Too many inputs/outputs fused in the Metal Compiled " + << "primitive which exhausted the available argument buffers for " + << "the kernel. Please file an issue with the function that results " + << "in this error. The name of the kernel is '" << kernel_name << "'"; + throw std::runtime_error(msg.str()); + } +} + void Compiled::eval_gpu( const std::vector& inputs, std::vector& outputs) { - // Just a fall-back to the original tape for now - std::unordered_map trace_to_real; - for (int i = 0; i < inputs.size(); ++i) { - trace_to_real.insert({inputs_[i].id(), inputs[i]}); - } - for (int i = 0; i < outputs.size(); ++i) { - trace_to_real.insert({outputs_[i].id(), outputs[i]}); + // Make the name for the kernel library + if (kernel_lib_.empty()) { + kernel_lib_ = build_lib_name(inputs_, outputs_, tape_, constant_ids_); } - for (auto& a : tape_) { - std::vector p_inputs; - for (auto& in : a.inputs()) { - p_inputs.push_back(trace_to_real.at(in.id())); - } - // If a is an output get it from the map, otherwise create it - // NB this is safe as long as no multi-output sub primitves are allowed - // in Compiled - std::vector p_outputs; - if (auto it = trace_to_real.find(a.id()); it != trace_to_real.end()) { - p_outputs.push_back(it->second); - } else { - p_outputs.push_back(array(a.shape(), a.dtype(), a.primitive_ptr(), {})); - trace_to_real.insert({a.id(), p_outputs[0]}); - } - a.primitive().eval_gpu(p_inputs, p_outputs); - } + // Get the kernel if someone else built it already auto& s = stream(); auto& d = metal::device(s.device); - auto command_buffer = d.get_command_buffer(s.index); - command_buffer->addCompletedHandler( - [trace_to_real](MTL::CommandBuffer*) mutable {}); + auto lib = d.get_library(kernel_lib_); + + // If not we have to build it ourselves + if (lib == nullptr) { + std::ostringstream kernel; + kernel << metal::get_kernel_preamble() << std::endl; + build_kernel( + kernel, + kernel_lib_ + "_contiguous", + inputs_, + outputs_, + tape_, + constant_ids_, + /* contiguous = */ true, + /* ndim = */ 0, + /* dynamic_dims = */ false); + for (int i = 1; i < 8; i++) { + build_kernel( + kernel, + kernel_lib_ + "_strided_" + std::to_string(i), + inputs_, + outputs_, + tape_, + constant_ids_, + /* contiguous = */ false, + /* ndim = */ i, + /* dynamic_dims = */ false); + } + build_kernel( + kernel, + kernel_lib_ + "_strided_dynamic", + inputs_, + outputs_, + tape_, + constant_ids_, + /* contiguous = */ false, + /* ndim = */ 0, + /* dynamic_dims = */ true); + + lib = d.get_library(kernel_lib_, kernel.str()); + } + + // Figure out which kernel we are using + auto& output_shape = outputs[0].shape(); + bool contiguous = true; + for (auto& x : inputs) { + if ((!x.flags().row_contiguous || x.shape() != output_shape) && + !is_scalar(x)) { + contiguous = false; + break; + } + } + + // Collapse contiguous dims to route to a faster kernel if possible. Also + // handle all broadcasting. + std::vector> initial_strides; + initial_strides.push_back(outputs[0].strides()); + std::vector shape; + std::vector> strides; + if (!contiguous) { + for (int i = 0; i < inputs.size(); i++) { + // Skip constants. + if (constant_ids_.find(inputs_[i].id()) != constant_ids_.end()) { + continue; + } + auto& x = inputs[i]; + + // Skip scalar inputs. + if (is_scalar(x)) { + continue; + } + + // Broadcast the inputs to the output shape. + std::vector xstrides; + int j = 0; + for (; j < output_shape.size() - x.ndim(); j++) { + if (output_shape[j] == 1) { + xstrides.push_back(outputs[0].strides()[j]); + } else { + xstrides.push_back(0); + } + } + for (int i = 0; i < x.ndim(); i++, j++) { + if (x.shape(i) == 1) { + if (output_shape[j] == 1) { + xstrides.push_back(outputs[0].strides()[j]); + } else { + xstrides.push_back(0); + } + } else { + xstrides.push_back(x.strides()[i]); + } + } + initial_strides.push_back(std::move(xstrides)); + } + std::tie(shape, strides) = + collapse_contiguous_dims(output_shape, initial_strides); + } + + // Get the kernel from the lib + int ndim = shape.size(); + bool dynamic = ndim >= 8; + auto kernel_name = kernel_lib_ + (contiguous ? "_contiguous" : "_strided_"); + if (!contiguous) { + if (dynamic) { + kernel_name += "dynamic"; + } else { + kernel_name += std::to_string(shape.size()); + } + } + auto kernel = d.get_kernel(kernel_name, lib); + auto compute_encoder = d.get_command_encoder(s.index); + compute_encoder->setComputePipelineState(kernel); + + // Put the inputs in + int cnt = 0; + int stride_idx = 1; // idx 0 is the output strides + for (int i = 0; i < inputs.size(); i++) { + if (constant_ids_.find(inputs_[i].id()) != constant_ids_.end()) { + continue; + } + auto& x = inputs[i]; + set_array_buffer(compute_encoder, x, cnt++); + if (!contiguous && !is_scalar(x)) { + compute_encoder->setBytes( + strides[stride_idx].data(), + strides[stride_idx].size() * sizeof(size_t), + cnt++); + stride_idx++; + } + } + + // Allocate space for the outputs possibly with input donation + { + int o = 0; + for (int i = 0; i < inputs.size() && o < outputs.size(); ++i) { + auto& in = inputs[i]; + // Conditions for donation + // - Row contiguous + // - Donatable + // - Correct size + // - Not a constant + if (in.flags().row_contiguous && in.nbytes() == outputs[o].nbytes() && + in.is_donatable() && + constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) { + outputs[o++].move_shared_buffer(in); + } + } + for (; o < outputs.size(); ++o) { + outputs[o].set_data(allocator::malloc_or_wait(outputs[o].nbytes())); + } + } + + // Put the outputs in + for (auto& x : outputs) { + set_array_buffer(compute_encoder, x, cnt++); + } + + // Put the output shape and strides in + if (!contiguous) { + compute_encoder->setBytes( + strides[0].data(), strides[0].size() * sizeof(size_t), cnt++); + compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), cnt++); + } + + // Put the number of dims in if it is dynamic + if (dynamic) { + compute_encoder->setBytes(&ndim, sizeof(int), cnt++); + } + + // Launch the kernel + if (contiguous) { + size_t nthreads = outputs[0].size(); + MTL::Size grid_dims(nthreads, 1, 1); + MTL::Size group_dims( + std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1); + compute_encoder->dispatchThreads(grid_dims, group_dims); + } else { + size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1; + size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1; + size_t rest = outputs[0].size() / (dim0 * dim1); + NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); + if (thread_group_size != 1024) { + throw std::runtime_error("[Metal::binary] Must use 1024 sized block"); + } + auto group_dims = get_block_dims(dim0, dim1, rest); + MTL::Size grid_dims = MTL::Size(dim0, dim1, rest); + compute_encoder->dispatchThreads(grid_dims, group_dims); + } } } // namespace mlx::core diff --git a/mlx/backend/metal/compiled_preamble.h b/mlx/backend/metal/compiled_preamble.h new file mode 100644 index 000000000..9122d3d54 --- /dev/null +++ b/mlx/backend/metal/compiled_preamble.h @@ -0,0 +1,9 @@ +// Copyright © 2023-24 Apple Inc. + +#pragma once + +namespace mlx::core::metal { + +const char* get_kernel_preamble(); + +} diff --git a/mlx/backend/metal/conv.cpp b/mlx/backend/metal/conv.cpp index 9fbbffb33..4ade8da17 100644 --- a/mlx/backend/metal/conv.cpp +++ b/mlx/backend/metal/conv.cpp @@ -182,7 +182,6 @@ void implicit_gemm_conv_2D_gpu( int implicit_M = conv_params.N * conv_params.oS[0] * conv_params.oS[1]; int implicit_N = conv_params.O; - int implicit_K = conv_params.wS[0] * conv_params.wS[1] * conv_params.C; size_t grid_dim_x = (implicit_N + bn - 1) / bn; size_t grid_dim_y = (implicit_M + bm - 1) / bm; diff --git a/mlx/backend/metal/device.cpp b/mlx/backend/metal/device.cpp index 7c61e68ae..1fb2bd46f 100644 --- a/mlx/backend/metal/device.cpp +++ b/mlx/backend/metal/device.cpp @@ -215,15 +215,6 @@ MTL::ComputeCommandEncoder* Device::get_command_encoder(int index) { return eit->second; } -MTL::ArgumentEncoder* Device::argument_encoder( - const std::vector& arg_descs) const { - // NB array here is already autoreleased but the returned argument - // encoder is owned by the caller and must be released/autoreleased - NS::Array* arg_desc_arr = NS::Array::array( - reinterpret_cast(arg_descs.data()), arg_descs.size()); - return device_->newArgumentEncoder(arg_desc_arr); -} - void Device::register_library( const std::string& lib_name, const std::string& lib_path) { @@ -414,6 +405,11 @@ MTL::ComputePipelineState* Device::get_kernel_( return kernel; } +MTL::Library* Device::get_library(const std::string& name) { + auto it = library_map_.find(name); + return (it != library_map_.end()) ? it->second : nullptr; +} + MTL::Library* Device::get_library( const std::string& name, const std::string& source, diff --git a/mlx/backend/metal/device.h b/mlx/backend/metal/device.h index 6acfe9332..8312084ce 100644 --- a/mlx/backend/metal/device.h +++ b/mlx/backend/metal/device.h @@ -62,6 +62,8 @@ class Device { const std::function& lib_path_func = get_colocated_mtllib_path); + MTL::Library* get_library(const std::string& name); + MTL::Library* get_library( const std::string& name, const std::string& source_string, diff --git a/mlx/backend/metal/indexing.cpp b/mlx/backend/metal/indexing.cpp index 28edb9126..4eeb8858e 100644 --- a/mlx/backend/metal/indexing.cpp +++ b/mlx/backend/metal/indexing.cpp @@ -51,6 +51,7 @@ void Gather::eval_gpu(const std::vector& inputs, array& out) { auto compute_encoder = d.get_command_encoder(s.index); auto kernel = d.get_kernel(kname.str()); + compute_encoder->setComputePipelineState(kernel); size_t slice_size = 1; for (auto s : slice_sizes_) { @@ -63,91 +64,50 @@ void Gather::eval_gpu(const std::vector& inputs, array& out) { auto group_dims = get_block_dims(dim0, dim1, 1); MTL::Size grid_dims = MTL::Size(dim0, dim1, 1); - compute_encoder->setComputePipelineState(kernel); + // Collect all idx shapes and strides into one place + std::vector idx_shapes; + std::vector idx_strides; - // Make the argument buffer to store the indices for the - // `Indices` struct in kernels/indexing.metal - std::vector arg_descs(4); - arg_descs[0] = MTL::ArgumentDescriptor::argumentDescriptor(); - arg_descs[0]->setIndex(0); - arg_descs[0]->setDataType(MTL::DataType::DataTypePointer); - arg_descs[0]->setArrayLength(nidx); - - // Shapes - arg_descs[1] = MTL::ArgumentDescriptor::argumentDescriptor(); - arg_descs[1]->setDataType(MTL::DataType::DataTypePointer); - arg_descs[1]->setIndex(nidx + 1); - - // Strides - arg_descs[2] = MTL::ArgumentDescriptor::argumentDescriptor(); - arg_descs[2]->setDataType(MTL::DataType::DataTypePointer); - arg_descs[2]->setIndex(nidx + 2); - - // Indices ndim - arg_descs[3] = MTL::ArgumentDescriptor::argumentDescriptor(); - arg_descs[3]->setDataType(MTL::DataType::DataTypeInt); - arg_descs[3]->setIndex(nidx + 3); - - // Get the argument encoder - auto arg_enc = d.argument_encoder(arg_descs); - - // Allocate and fill buffers for shapes and strides - auto idx_shapes_buf = allocator::malloc_or_wait(sizeof(int) * idx_ndim); - auto idx_strides_buf = allocator::malloc_or_wait(sizeof(size_t) * idx_ndim); for (int i = 0; i < nidx; ++i) { - std::copy( + idx_shapes.insert( + idx_shapes.end(), inputs[i + 1].shape().begin(), - inputs[i + 1].shape().end(), - static_cast(idx_shapes_buf.raw_ptr()) + i * idx_ndim); - std::copy( + inputs[i + 1].shape().end()); + + idx_strides.insert( + idx_strides.end(), inputs[i + 1].strides().begin(), - inputs[i + 1].strides().end(), - static_cast(idx_strides_buf.raw_ptr()) + i * idx_ndim); + inputs[i + 1].strides().end()); } - // Allocate the argument buffer - auto arg_buf = allocator::malloc_or_wait(arg_enc->encodedLength()); - - // Register data with the encoder - arg_enc->setArgumentBuffer(static_cast(arg_buf.ptr()), 0); - for (int i = 0; i < nidx; ++i) { - set_array_buffer(compute_encoder, arg_enc, inputs[i + 1], i); - } - if (idx_ndim > 0) { - arg_enc->setBuffer( - static_cast(idx_shapes_buf.ptr()), 0, nidx + 1); - compute_encoder->useResource( - static_cast(idx_shapes_buf.ptr()), - MTL::ResourceUsageRead); - arg_enc->setBuffer( - static_cast(idx_strides_buf.ptr()), 0, nidx + 2); - compute_encoder->useResource( - static_cast(idx_strides_buf.ptr()), - MTL::ResourceUsageRead); - } - *static_cast(arg_enc->constantData(nidx + 3)) = idx_ndim; - // Set all the buffers set_array_buffer(compute_encoder, src, 0); - compute_encoder->setBuffer(static_cast(arg_buf.ptr()), 0, 1); - set_array_buffer(compute_encoder, out, 2); + set_array_buffer(compute_encoder, out, 1); - compute_encoder->setBytes(src.shape().data(), ndim * sizeof(int), 3); - compute_encoder->setBytes(src.strides().data(), ndim * sizeof(size_t), 4); - compute_encoder->setBytes(&ndim, sizeof(size_t), 5); - compute_encoder->setBytes(slice_sizes_.data(), ndim * sizeof(int), 6); - compute_encoder->setBytes(axes_.data(), nidx * sizeof(int), 7); + // Set source info + compute_encoder->setBytes(src.shape().data(), ndim * sizeof(int), 2); + compute_encoder->setBytes(src.strides().data(), ndim * sizeof(size_t), 3); + compute_encoder->setBytes(&ndim, sizeof(size_t), 4); + compute_encoder->setBytes(slice_sizes_.data(), ndim * sizeof(int), 5); + compute_encoder->setBytes(axes_.data(), nidx * sizeof(int), 6); + // Set index info + // + // We don't need to check for empty idx_shapes because gather has a + // idx_ndim == 0 specialization + compute_encoder->setBytes( + idx_shapes.data(), idx_shapes.size() * sizeof(int), 7); + compute_encoder->setBytes( + idx_strides.data(), idx_strides.size() * sizeof(size_t), 8); + compute_encoder->setBytes(&idx_ndim, sizeof(int), 9); + + // Set index buffers + for (int i = 1; i < nidx + 1; ++i) { + set_array_buffer(compute_encoder, inputs[i], 20 + i); + } + + // Launch grid compute_encoder->dispatchThreads(grid_dims, group_dims); - - // Cleanup temporaries - arg_enc->release(); - d.get_command_buffer(s.index)->addCompletedHandler( - [arg_buf, idx_shapes_buf, idx_strides_buf](MTL::CommandBuffer*) { - allocator::free(arg_buf); - allocator::free(idx_shapes_buf); - allocator::free(idx_strides_buf); - }); } void Scatter::eval_gpu(const std::vector& inputs, array& out) { @@ -207,87 +167,36 @@ void Scatter::eval_gpu(const std::vector& inputs, array& out) { auto& upd = inputs.back(); size_t nthreads = upd.size(); - NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); - if (thread_group_size > nthreads) { - thread_group_size = nthreads; - } - - MTL::Size grid_dims = MTL::Size(nthreads, 1, 1); - MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); compute_encoder->setComputePipelineState(kernel); - // Make the argument buffer to store the indices for the - // `Indices` struct in kernels/indexing.metal - std::vector arg_descs(4); - arg_descs[0] = MTL::ArgumentDescriptor::argumentDescriptor(); - arg_descs[0]->setIndex(0); - arg_descs[0]->setDataType(MTL::DataType::DataTypePointer); - arg_descs[0]->setArrayLength(nidx); - - // Shapes - arg_descs[1] = MTL::ArgumentDescriptor::argumentDescriptor(); - arg_descs[1]->setDataType(MTL::DataType::DataTypePointer); - arg_descs[1]->setIndex(nidx + 1); - - // Strides - arg_descs[2] = MTL::ArgumentDescriptor::argumentDescriptor(); - arg_descs[2]->setDataType(MTL::DataType::DataTypePointer); - arg_descs[2]->setIndex(nidx + 2); - - // Indices ndim - arg_descs[3] = MTL::ArgumentDescriptor::argumentDescriptor(); - arg_descs[3]->setDataType(MTL::DataType::DataTypeInt); - arg_descs[3]->setIndex(nidx + 3); - - // Get the argument encoder - auto arg_enc = d.argument_encoder(arg_descs); - - // Allocate and fill buffers for shapes and strides + // Collect all idx shapes and strides into one place int idx_ndim = nidx ? inputs[1].ndim() : 0; - auto idx_shapes_buf = allocator::malloc_or_wait(sizeof(int) * idx_ndim); - auto idx_strides_buf = allocator::malloc_or_wait(sizeof(size_t) * idx_ndim); + std::vector idx_shapes; + std::vector idx_strides; + for (int i = 0; i < nidx; ++i) { - std::copy( + idx_shapes.insert( + idx_shapes.end(), inputs[i + 1].shape().begin(), - inputs[i + 1].shape().end(), - static_cast(idx_shapes_buf.raw_ptr()) + i * idx_ndim); - std::copy( + inputs[i + 1].shape().end()); + + idx_strides.insert( + idx_strides.end(), inputs[i + 1].strides().begin(), - inputs[i + 1].strides().end(), - static_cast(idx_strides_buf.raw_ptr()) + i * idx_ndim); + inputs[i + 1].strides().end()); } - // Allocate the argument buffer - auto arg_buf = allocator::malloc_or_wait(arg_enc->encodedLength()); + // Set all the buffers + set_array_buffer(compute_encoder, upd, 1); + set_array_buffer(compute_encoder, out, 2); - // Register data with the encoder - arg_enc->setArgumentBuffer(static_cast(arg_buf.ptr()), 0); - for (int i = 0; i < nidx; ++i) { - set_array_buffer(compute_encoder, arg_enc, inputs[i + 1], i); - } - if (idx_ndim > 0) { - arg_enc->setBuffer( - static_cast(idx_shapes_buf.ptr()), 0, nidx + 1); - compute_encoder->useResource( - static_cast(idx_shapes_buf.ptr()), - MTL::ResourceUsageRead); - arg_enc->setBuffer( - static_cast(idx_strides_buf.ptr()), 0, nidx + 2); - compute_encoder->useResource( - static_cast(idx_strides_buf.ptr()), - MTL::ResourceUsageRead); - } - *static_cast(arg_enc->constantData(nidx + 3)) = idx_ndim; - - compute_encoder->setBuffer(static_cast(arg_buf.ptr()), 0, 0); + // Set update info size_t upd_ndim = upd.ndim(); size_t upd_size = 1; for (int i = idx_ndim; i < upd.ndim(); ++i) { upd_size *= upd.shape(i); } - set_array_buffer(compute_encoder, upd, 1); - set_array_buffer(compute_encoder, out, 2); if (upd_ndim == 0) { // Need placeholders so Metal doesn't compalain int shape_ = 0; @@ -302,6 +211,7 @@ void Scatter::eval_gpu(const std::vector& inputs, array& out) { compute_encoder->setBytes(&upd_ndim, sizeof(size_t), 5); compute_encoder->setBytes(&upd_size, sizeof(size_t), 6); + // Set output info size_t out_ndim = out.ndim(); if (out_ndim == 0) { // Need placeholders so Metal doesn't compalain @@ -317,16 +227,28 @@ void Scatter::eval_gpu(const std::vector& inputs, array& out) { compute_encoder->setBytes(&out_ndim, sizeof(size_t), 9); compute_encoder->setBytes(axes_.data(), axes_.size() * sizeof(int), 10); - compute_encoder->dispatchThreads(grid_dims, group_dims); + // Set index info + if (idx_ndim == 0) { + // Add a 0 in idx_shapes and strides to avoid the missing buffer binding + // error in the metal API. + idx_shapes.push_back(0); + idx_strides.push_back(0); + } + compute_encoder->setBytes( + idx_shapes.data(), idx_shapes.size() * sizeof(int), 11); + compute_encoder->setBytes( + idx_strides.data(), idx_strides.size() * sizeof(size_t), 12); + compute_encoder->setBytes(&idx_ndim, sizeof(int), 13); - // Cleanup temporaries - arg_enc->release(); - d.get_command_buffer(s.index)->addCompletedHandler( - [arg_buf, idx_shapes_buf, idx_strides_buf](MTL::CommandBuffer*) { - allocator::free(arg_buf); - allocator::free(idx_shapes_buf); - allocator::free(idx_strides_buf); - }); + // Set index buffers + for (int i = 1; i < nidx + 1; ++i) { + set_array_buffer(compute_encoder, inputs[i], 20 + i); + } + + // Launch grid + MTL::Size grid_dims = MTL::Size(upd_size, nthreads / upd_size, 1); + MTL::Size group_dims = get_block_dims(upd_size, nthreads / upd_size, 1); + compute_encoder->dispatchThreads(grid_dims, group_dims); } } // namespace mlx::core diff --git a/mlx/backend/metal/kernels/CMakeLists.txt b/mlx/backend/metal/kernels/CMakeLists.txt index 2d271abb4..afd2fbc8a 100644 --- a/mlx/backend/metal/kernels/CMakeLists.txt +++ b/mlx/backend/metal/kernels/CMakeLists.txt @@ -6,6 +6,7 @@ set( ${CMAKE_CURRENT_SOURCE_DIR}/complex.h ${CMAKE_CURRENT_SOURCE_DIR}/defines.h ${CMAKE_CURRENT_SOURCE_DIR}/erf.h + ${CMAKE_CURRENT_SOURCE_DIR}/indexing.h ${CMAKE_CURRENT_SOURCE_DIR}/reduce.h ${CMAKE_CURRENT_SOURCE_DIR}/utils.h ) @@ -22,11 +23,13 @@ set( "quantized" "random" "reduce" + "rope" "scan" "softmax" "sort" "unary" - "indexing" + "gather" + "scatter" ) function(build_kernel_base TARGET SRCFILE DEPS) diff --git a/mlx/backend/metal/kernels/binary.h b/mlx/backend/metal/kernels/binary.h new file mode 100644 index 000000000..006f2ff0e --- /dev/null +++ b/mlx/backend/metal/kernels/binary.h @@ -0,0 +1,231 @@ +// Copyright © 2023-2024 Apple Inc. + +#pragma once + +#include +#include + +#include "mlx/backend/metal/kernels/bf16.h" +#include "mlx/backend/metal/kernels/utils.h" + +struct Add { + template + T operator()(T x, T y) { + return x + y; + } +}; + +struct Divide { + template + T operator()(T x, T y) { + return x / y; + } +}; + +struct Remainder { + template + metal::enable_if_t & !metal::is_signed_v, T> + operator()(T x, T y) { + return x % y; + } + template + metal::enable_if_t & metal::is_signed_v, T> + operator()(T x, T y) { + auto r = x % y; + if (r != 0 && (r < 0 != y < 0)) { + r += y; + } + return r; + } + template + metal::enable_if_t, T> operator()(T x, T y) { + T r = fmod(x, y); + if (r != 0 && (r < 0 != y < 0)) { + r += y; + } + return r; + } + template <> + complex64_t operator()(complex64_t x, complex64_t y) { + return x % y; + } +}; + +struct Equal { + template + bool operator()(T x, T y) { + return x == y; + } +}; + +struct NaNEqual { + template + bool operator()(T x, T y) { + return x == y || (metal::isnan(x) && metal::isnan(y)); + } + template <> + bool operator()(complex64_t x, complex64_t y) { + return x == y || + (metal::isnan(x.real) && metal::isnan(y.real) && metal::isnan(x.imag) && + metal::isnan(y.imag)) || + (x.real == y.real && metal::isnan(x.imag) && metal::isnan(y.imag)) || + (metal::isnan(x.real) && metal::isnan(y.real) && x.imag == y.imag); + } +}; + +struct Greater { + template + bool operator()(T x, T y) { + return x > y; + } +}; + +struct GreaterEqual { + template + bool operator()(T x, T y) { + return x >= y; + } +}; + +struct Less { + template + bool operator()(T x, T y) { + return x < y; + } +}; + +struct LessEqual { + template + bool operator()(T x, T y) { + return x <= y; + } +}; + +struct LogAddExp { + template + T operator()(T x, T y) { + if (metal::isnan(x) || metal::isnan(y)) { + return metal::numeric_limits::quiet_NaN(); + } + constexpr T inf = metal::numeric_limits::infinity(); + T maxval = metal::max(x, y); + T minval = metal::min(x, y); + return (minval == -inf || maxval == inf) + ? maxval + : (maxval + log1p(metal::exp(minval - maxval))); + }; +}; + +struct Maximum { + template + metal::enable_if_t, T> operator()(T x, T y) { + return metal::max(x, y); + } + + template + metal::enable_if_t, T> operator()(T x, T y) { + if (metal::isnan(x)) { + return x; + } + return x > y ? x : y; + } + + template <> + complex64_t operator()(complex64_t x, complex64_t y) { + if (metal::isnan(x.real) || metal::isnan(x.imag)) { + return x; + } + return x > y ? x : y; + } +}; + +struct Minimum { + template + metal::enable_if_t, T> operator()(T x, T y) { + return metal::min(x, y); + } + + template + metal::enable_if_t, T> operator()(T x, T y) { + if (metal::isnan(x)) { + return x; + } + return x < y ? x : y; + } + + template <> + complex64_t operator()(complex64_t x, complex64_t y) { + if (metal::isnan(x.real) || metal::isnan(x.imag)) { + return x; + } + return x < y ? x : y; + } +}; + +struct Multiply { + template + T operator()(T x, T y) { + return x * y; + } +}; + +struct NotEqual { + template + bool operator()(T x, T y) { + return x != y; + } + template <> + bool operator()(complex64_t x, complex64_t y) { + return x.real != y.real || x.imag != y.imag; + } +}; + +struct Power { + template + metal::enable_if_t, T> operator()(T base, T exp) { + return metal::pow(base, exp); + } + + template + metal::enable_if_t, T> operator()(T base, T exp) { + T res = 1; + while (exp) { + if (exp & 1) { + res *= base; + } + exp >>= 1; + base *= base; + } + return res; + } + + template <> + complex64_t operator()(complex64_t x, complex64_t y) { + auto x_theta = metal::atan(x.imag / x.real); + auto x_ln_r = 0.5 * metal::log(x.real * x.real + x.imag * x.imag); + auto mag = metal::exp(y.real * x_ln_r - y.imag * x_theta); + auto phase = y.imag * x_ln_r + y.real * x_theta; + return {mag * metal::cos(phase), mag * metal::sin(phase)}; + } +}; + +struct Subtract { + template + T operator()(T x, T y) { + return x - y; + } +}; + +struct LogicalAnd { + template + T operator()(T x, T y) { + return x && y; + }; +}; + +struct LogicalOr { + template + T operator()(T x, T y) { + return x || y; + }; +}; diff --git a/mlx/backend/metal/kernels/binary.metal b/mlx/backend/metal/kernels/binary.metal index 1b84c70a5..4d449ab69 100644 --- a/mlx/backend/metal/kernels/binary.metal +++ b/mlx/backend/metal/kernels/binary.metal @@ -1,176 +1,6 @@ -// Copyright © 2023 Apple Inc. +// Copyright © 2023-2024 Apple Inc. -#include -#include - -#include "mlx/backend/metal/kernels/utils.h" -#include "mlx/backend/metal/kernels/bf16.h" - -struct Add { - template T operator()(T x, T y) { return x + y; } -}; - -struct Divide { - template T operator()(T x, T y) { return x / y; } -}; - -struct Remainder { - template T operator()(T x, T y) { return x % y; } - template <> float operator()(float x, float y) { return fmod(x, y); } - template <> half operator()(half x, half y) { return fmod(x, y); } - template <> bfloat16_t operator()(bfloat16_t x, bfloat16_t y) { return fmod(x, y); } -}; - -struct Equal { - template bool operator()(T x, T y) { return x == y; } -}; - -struct NaNEqual { - template bool operator()(T x, T y) { - return x == y || (metal::isnan(x) && metal::isnan(y)); - } - template <> - bool operator()(complex64_t x, complex64_t y) { - return x == y || - (metal::isnan(x.real) && metal::isnan(y.real) - && metal::isnan(x.imag) && metal::isnan(y.imag)) || - (x.real == y.real && metal::isnan(x.imag) && metal::isnan(y.imag)) || - (metal::isnan(x.real) && metal::isnan(y.real) && x.imag == y.imag); - } -}; - -struct Greater { - template bool operator()(T x, T y) { return x > y; } -}; - -struct GreaterEqual { - template bool operator()(T x, T y) { return x >= y; } -}; - -struct Less { - template bool operator()(T x, T y) { return x < y; } -}; - -struct LessEqual { - template bool operator()(T x, T y) { return x <= y; } -}; - -struct LogAddExp { - template - T operator()(T x, T y) { - if (metal::isnan(x) || metal::isnan(y)) { - return metal::numeric_limits::quiet_NaN(); - } - constexpr T inf = metal::numeric_limits::infinity(); - T maxval = metal::max(x, y); - T minval = metal::min(x, y); - return (minval == -inf || maxval == inf) ? maxval : - (maxval + log1p(metal::exp(minval - maxval))); - }; -}; - -struct Maximum { - template - metal::enable_if_t, T> operator()(T x, T y) { - return metal::max(x, y); - } - - template - metal::enable_if_t, T> operator()(T x, T y) { - if (metal::isnan(x)) { - return x; - } - return x > y ? x : y; - } - - template <> - complex64_t operator()(complex64_t x, complex64_t y) { - if (metal::isnan(x.real) || metal::isnan(x.imag)) { - return x; - } - return x > y ? x : y; - } -}; - -struct Minimum { - template - metal::enable_if_t, T> operator()(T x, T y) { - return metal::min(x, y); - } - - template - metal::enable_if_t, T> operator()(T x, T y) { - if (metal::isnan(x)) { - return x; - } - return x < y ? x : y; - } - - template <> - complex64_t operator()(complex64_t x, complex64_t y) { - if (metal::isnan(x.real) || metal::isnan(x.imag)) { - return x; - } - return x < y ? x : y; - } -}; - -struct Multiply { - template T operator()(T x, T y) { return x * y; } -}; - -struct NotEqual { - template bool operator()(T x, T y) { return x != y; } - template <> - bool operator()(complex64_t x, complex64_t y) { - return x.real != y.real || x.imag != y.imag; - } -}; - -struct Power { - - template - metal::enable_if_t, T> operator()(T base, T exp) { - return metal::pow(base, exp); - } - - template - metal::enable_if_t, T> operator()(T base, T exp) { - T res = 1; - while (exp) { - if (exp & 1) { - res *= base; - } - exp >>= 1; - base *= base; - } - return res; - } - - template <> - complex64_t operator()(complex64_t x, complex64_t y) { - auto x_theta = metal::atan(x.imag / x.real); - auto x_ln_r = 0.5 * metal::log(x.real * x.real + x.imag * x.imag); - auto mag = metal::exp(y.real * x_ln_r - y.imag * x_theta); - auto phase = y.imag * x_ln_r + y.real * x_theta; - return {mag * metal::cos(phase), mag * metal::sin(phase)}; - } -}; - - -struct Subtract { - template T operator()(T x, T y) { return x - y; } -}; - -struct LogicalAnd { - template - T operator()(T x, T y) { return x && y; }; -}; - -struct LogicalOr { - template - T operator()(T x, T y) { return x || y; }; -}; +#include "mlx/backend/metal/kernels/binary.h" template [[kernel]] void binary_op_s2s( diff --git a/mlx/backend/metal/kernels/binary_two.metal b/mlx/backend/metal/kernels/binary_two.metal index 3e4dbc8f2..245ced024 100644 --- a/mlx/backend/metal/kernels/binary_two.metal +++ b/mlx/backend/metal/kernels/binary_two.metal @@ -14,10 +14,29 @@ struct FloorDivide { }; struct Remainder { - template T operator()(T x, T y) { return x % y; } - template <> float operator()(float x, float y) { return fmod(x, y); } - template <> half operator()(half x, half y) { return fmod(x, y); } - template <> bfloat16_t operator()(bfloat16_t x, bfloat16_t y) { return fmod(x, y); } + template + metal::enable_if_t & !metal::is_signed_v, T> operator()(T x, T y) { + return x % y; + } + template + metal::enable_if_t & metal::is_signed_v, T> operator()(T x, T y) { + auto r = x % y; + if (r != 0 && (r < 0 != y < 0)) { + r += y; + } + return r; + } + template + metal::enable_if_t, T> operator()(T x, T y) { + T r = fmod(x, y); + if (r != 0 && (r < 0 != y < 0)) { + r += y; + } + return r; + } + template <> complex64_t operator()(complex64_t x, complex64_t y) { + return x % y; + } }; template diff --git a/mlx/backend/metal/kernels/compiled_preamble.h b/mlx/backend/metal/kernels/compiled_preamble.h new file mode 100644 index 000000000..d5bf33696 --- /dev/null +++ b/mlx/backend/metal/kernels/compiled_preamble.h @@ -0,0 +1,6 @@ +// Copyright © 2023-2024 Apple Inc. + +#include "mlx/backend/metal/kernels/binary.h" +#include "mlx/backend/metal/kernels/unary.h" + +typedef half float16_t; diff --git a/mlx/backend/metal/kernels/complex.h b/mlx/backend/metal/kernels/complex.h index ac966a293..9cb27c5a3 100644 --- a/mlx/backend/metal/kernels/complex.h +++ b/mlx/backend/metal/kernels/complex.h @@ -121,5 +121,11 @@ constexpr complex64_t operator/(complex64_t a, complex64_t b) { constexpr complex64_t operator%(complex64_t a, complex64_t b) { auto real = a.real - (b.real * static_cast(a.real / b.real)); auto imag = a.imag - (b.imag * static_cast(a.imag / b.imag)); + if (real != 0 && (real < 0 != b.real < 0)) { + real += b.real; + } + if (imag != 0 && (imag < 0 != b.imag < 0)) { + imag += b.imag; + } return {real, imag}; } diff --git a/mlx/backend/metal/kernels/gather.metal b/mlx/backend/metal/kernels/gather.metal new file mode 100644 index 000000000..793b2af62 --- /dev/null +++ b/mlx/backend/metal/kernels/gather.metal @@ -0,0 +1,187 @@ +// Copyright © 2023-2024 Apple Inc. + +#include + +#include "mlx/backend/metal/kernels/bf16.h" +#include "mlx/backend/metal/kernels/indexing.h" +#include "mlx/backend/metal/kernels/utils.h" + +using namespace metal; + +///////////////////////////////////////////////////////////////////// +// Gather kernel +///////////////////////////////////////////////////////////////////// + +template +METAL_FUNC void gather_impl( + const device T *src [[buffer(0)]], + device T *out [[buffer(1)]], + const constant int *src_shape [[buffer(2)]], + const constant size_t *src_strides [[buffer(3)]], + const constant size_t& src_ndim [[buffer(4)]], + const constant int *slice_sizes [[buffer(5)]], + const constant int *axes [[buffer(6)]], + const thread Indices& indices, + uint2 index [[thread_position_in_grid]], + uint2 grid_dim [[threads_per_grid]]) { + + auto ind_idx = index.x; + auto ind_offset = index.y; + + size_t src_idx = 0; + for (int i = 0; i < NIDX; ++i) { + size_t idx_loc; + if (IDX_NDIM == 0) { + idx_loc = 0; + } else if (IDX_NDIM == 1) { + idx_loc = ind_idx * indices.strides[indices.ndim * i]; + } else { + idx_loc = elem_to_loc( + ind_idx, + &indices.shapes[indices.ndim * i], + &indices.strides[indices.ndim * i], + indices.ndim); + } + auto ax = axes[i]; + auto idx_val = offset_neg_idx( + indices.buffers[i][idx_loc], src_shape[ax]); + src_idx += idx_val * src_strides[ax]; + } + + auto src_offset = elem_to_loc( + ind_offset, slice_sizes, src_strides, src_ndim); + + size_t out_idx = index.y + static_cast(grid_dim.y) * index.x; + out[out_idx] = src[src_offset + src_idx]; + +} + +#define make_gather_impl(IDX_ARG, IDX_ARR) \ +template \ +[[kernel]] void gather( \ + const device T *src [[buffer(0)]], \ + device T *out [[buffer(1)]], \ + const constant int *src_shape [[buffer(2)]], \ + const constant size_t *src_strides [[buffer(3)]], \ + const constant size_t& src_ndim [[buffer(4)]], \ + const constant int *slice_sizes [[buffer(5)]], \ + const constant int *axes [[buffer(6)]], \ + const constant int *idx_shapes [[buffer(7)]], \ + const constant size_t *idx_strides [[buffer(8)]], \ + const constant int& idx_ndim [[buffer(9)]], \ + IDX_ARG(IdxT) \ + uint2 index [[thread_position_in_grid]], \ + uint2 grid_dim [[threads_per_grid]]) { \ + \ + Indices idxs{ \ + {{IDX_ARR()}}, \ + idx_shapes, \ + idx_strides, \ + idx_ndim}; \ + \ + return gather_impl( \ + src, \ + out, \ + src_shape, \ + src_strides, \ + src_ndim, \ + slice_sizes, \ + axes, \ + idxs, \ + index, \ + grid_dim); \ +} + +#define make_gather(n) make_gather_impl(IDX_ARG_ ##n, IDX_ARR_ ##n) + +make_gather(0) +make_gather(1) +make_gather(2) +make_gather(3) +make_gather(4) +make_gather(5) +make_gather(6) +make_gather(7) +make_gather(8) +make_gather(9) +make_gather(10) + +///////////////////////////////////////////////////////////////////// +// Gather instantiations +///////////////////////////////////////////////////////////////////// + +#define instantiate_gather6(name, src_t, idx_t, nidx, IDX_ARG, nd, nd_name) \ +template [[host_name("gather" name "_" #nidx "" #nd_name)]] \ +[[kernel]] void gather( \ + const device src_t *src [[buffer(0)]], \ + device src_t *out [[buffer(1)]], \ + const constant int *src_shape [[buffer(2)]], \ + const constant size_t *src_strides [[buffer(3)]], \ + const constant size_t& src_ndim [[buffer(4)]], \ + const constant int *slice_sizes [[buffer(5)]], \ + const constant int *axes [[buffer(6)]], \ + const constant int *idx_shapes [[buffer(7)]], \ + const constant size_t *idx_strides [[buffer(8)]], \ + const constant int& idx_ndim [[buffer(9)]], \ + IDX_ARG(idx_t) \ + uint2 index [[thread_position_in_grid]], \ + uint2 grid_dim [[threads_per_grid]]); + +#define instantiate_gather5(name, src_t, idx_t, nidx, nd, nd_name) \ + instantiate_gather6(name, src_t, idx_t, nidx, IDX_ARG_ ##nidx, nd, nd_name) + +#define instantiate_gather4(name, src_t, idx_t, nidx) \ + instantiate_gather5(name, src_t, idx_t, nidx, 0, _0) \ + instantiate_gather5(name, src_t, idx_t, nidx, 1, _1) \ + instantiate_gather5(name, src_t, idx_t, nidx, 2, ) + + +// Special for case NIDX=0 +instantiate_gather4("bool_", bool, bool, 0) +instantiate_gather4("uint8", uint8_t, bool, 0) +instantiate_gather4("uint16", uint16_t, bool, 0) +instantiate_gather4("uint32", uint32_t, bool, 0) +instantiate_gather4("uint64", uint64_t, bool, 0) +instantiate_gather4("int8", int8_t, bool, 0) +instantiate_gather4("int16", int16_t, bool, 0) +instantiate_gather4("int32", int32_t, bool, 0) +instantiate_gather4("int64", int64_t, bool, 0) +instantiate_gather4("float16", half, bool, 0) +instantiate_gather4("float32", float, bool, 0) +instantiate_gather4("bfloat16", bfloat16_t, bool, 0) + +#define instantiate_gather3(name, src_type, ind_type) \ + instantiate_gather4(name, src_type, ind_type, 1) \ + instantiate_gather4(name, src_type, ind_type, 2) \ + instantiate_gather4(name, src_type, ind_type, 3) \ + instantiate_gather4(name, src_type, ind_type, 4) \ + instantiate_gather4(name, src_type, ind_type, 5) \ + instantiate_gather4(name, src_type, ind_type, 6) \ + instantiate_gather4(name, src_type, ind_type, 7) \ + instantiate_gather4(name, src_type, ind_type, 8) \ + instantiate_gather4(name, src_type, ind_type, 9) \ + instantiate_gather4(name, src_type, ind_type, 10) + +#define instantiate_gather(name, src_type) \ + instantiate_gather3(#name "bool_", src_type, bool) \ + instantiate_gather3(#name "uint8", src_type, uint8_t) \ + instantiate_gather3(#name "uint16", src_type, uint16_t) \ + instantiate_gather3(#name "uint32", src_type, uint32_t) \ + instantiate_gather3(#name "uint64", src_type, uint64_t) \ + instantiate_gather3(#name "int8", src_type, int8_t) \ + instantiate_gather3(#name "int16", src_type, int16_t) \ + instantiate_gather3(#name "int32", src_type, int32_t) \ + instantiate_gather3(#name "int64", src_type, int64_t) + +instantiate_gather(bool_, bool) +instantiate_gather(uint8, uint8_t) +instantiate_gather(uint16, uint16_t) +instantiate_gather(uint32, uint32_t) +instantiate_gather(uint64, uint64_t) +instantiate_gather(int8, int8_t) +instantiate_gather(int16, int16_t) +instantiate_gather(int32, int32_t) +instantiate_gather(int64, int64_t) +instantiate_gather(float16, half) +instantiate_gather(float32, float) +instantiate_gather(bfloat16, bfloat16_t) \ No newline at end of file diff --git a/mlx/backend/metal/kernels/indexing.h b/mlx/backend/metal/kernels/indexing.h new file mode 100644 index 000000000..c2b37f3ff --- /dev/null +++ b/mlx/backend/metal/kernels/indexing.h @@ -0,0 +1,54 @@ +// Copyright © 2023-2024 Apple Inc. + +#include + +using namespace metal; + +///////////////////////////////////////////////////////////////////// +// Indexing utils +///////////////////////////////////////////////////////////////////// + +template +struct Indices { + const array buffers; + const constant int* shapes; + const constant size_t* strides; + const int ndim; +}; + +template +METAL_FUNC size_t offset_neg_idx(IdxT idx, size_t size) { + if (is_unsigned_v) { + return idx; + } else { + return (idx < 0) ? idx + size : idx; + } +} + +#define IDX_ARG_N(idx_t, n) const device idx_t *idx##n [[buffer(n)]], + +#define IDX_ARG_0(idx_t) +#define IDX_ARG_1(idx_t) IDX_ARG_0(idx_t) IDX_ARG_N(idx_t, 21) +#define IDX_ARG_2(idx_t) IDX_ARG_1(idx_t) IDX_ARG_N(idx_t, 22) +#define IDX_ARG_3(idx_t) IDX_ARG_2(idx_t) IDX_ARG_N(idx_t, 23) +#define IDX_ARG_4(idx_t) IDX_ARG_3(idx_t) IDX_ARG_N(idx_t, 24) +#define IDX_ARG_5(idx_t) IDX_ARG_4(idx_t) IDX_ARG_N(idx_t, 25) +#define IDX_ARG_6(idx_t) IDX_ARG_5(idx_t) IDX_ARG_N(idx_t, 26) +#define IDX_ARG_7(idx_t) IDX_ARG_6(idx_t) IDX_ARG_N(idx_t, 27) +#define IDX_ARG_8(idx_t) IDX_ARG_7(idx_t) IDX_ARG_N(idx_t, 28) +#define IDX_ARG_9(idx_t) IDX_ARG_8(idx_t) IDX_ARG_N(idx_t, 29) +#define IDX_ARG_10(idx_t) IDX_ARG_9(idx_t) IDX_ARG_N(idx_t, 30) + +#define IDX_ARR_N(n) idx##n, + +#define IDX_ARR_0() +#define IDX_ARR_1() IDX_ARR_0() IDX_ARR_N(21) +#define IDX_ARR_2() IDX_ARR_1() IDX_ARR_N(22) +#define IDX_ARR_3() IDX_ARR_2() IDX_ARR_N(23) +#define IDX_ARR_4() IDX_ARR_3() IDX_ARR_N(24) +#define IDX_ARR_5() IDX_ARR_4() IDX_ARR_N(25) +#define IDX_ARR_6() IDX_ARR_5() IDX_ARR_N(26) +#define IDX_ARR_7() IDX_ARR_6() IDX_ARR_N(27) +#define IDX_ARR_8() IDX_ARR_7() IDX_ARR_N(28) +#define IDX_ARR_9() IDX_ARR_8() IDX_ARR_N(29) +#define IDX_ARR_10() IDX_ARR_9() IDX_ARR_N(30) \ No newline at end of file diff --git a/mlx/backend/metal/kernels/indexing.metal b/mlx/backend/metal/kernels/indexing.metal deleted file mode 100644 index 395bc7819..000000000 --- a/mlx/backend/metal/kernels/indexing.metal +++ /dev/null @@ -1,290 +0,0 @@ -// Copyright © 2023-2024 Apple Inc. - -#include -#include - -#include "mlx/backend/metal/kernels/bf16.h" -#include "mlx/backend/metal/kernels/reduce.h" -#include "mlx/backend/metal/kernels/utils.h" - -using namespace metal; - -///////////////////////////////////////////////////////////////////// -// Gather kernel -///////////////////////////////////////////////////////////////////// - -template -struct Indices { - const array buffers [[id(0)]]; - device int* shapes [[id(NIDX + 1)]]; - device size_t* strides [[id(NIDX + 2)]]; - const int ndim [[id(NIDX + 3)]]; -}; - -template -inline size_t offset_neg_idx(IdxT idx, size_t size) { - return (idx < 0) ? idx + size : idx; -} - -template <> -inline size_t offset_neg_idx(bool idx, size_t) { - return idx; -} - -template <> -inline size_t offset_neg_idx(uint32_t idx, size_t) { - return idx; -} - -// IDX_NDIM is the number of dimensions of the indices arrays. Compile-time -// special case for 0 and 1. Anything >= 2 uses the general case -template -[[kernel]] void gather( - const device T *src [[buffer(0)]], - const constant Indices& indices [[buffer(1)]], - device T *out [[buffer(2)]], - const constant int *src_shape [[buffer(3)]], - const constant size_t *src_strides [[buffer(4)]], - const constant size_t& src_ndim [[buffer(5)]], - const constant int *slice_sizes [[buffer(6)]], - const constant int *axes [[buffer(7)]], - uint2 index [[thread_position_in_grid]], - uint2 grid_dim [[threads_per_grid]]) { - - auto ind_idx = index.x; - auto ind_offset = index.y; - - size_t src_idx = 0; - for (int i = 0; i < NIDX; ++i) { - size_t idx_loc; - if (IDX_NDIM == 0) { - idx_loc = 0; - } else if (IDX_NDIM == 1) { - idx_loc = ind_idx * indices.strides[indices.ndim * i]; - } else { - idx_loc = elem_to_loc( - ind_idx, - &indices.shapes[indices.ndim * i], - &indices.strides[indices.ndim * i], - indices.ndim); - } - auto ax = axes[i]; - auto idx_val = offset_neg_idx( - indices.buffers[i][idx_loc], src_shape[ax]); - src_idx += idx_val * src_strides[ax]; - } - - auto src_offset = elem_to_loc( - ind_offset, slice_sizes, src_strides, src_ndim); - - size_t out_idx = index.y + static_cast(grid_dim.y) * index.x; - out[out_idx] = src[src_offset + src_idx]; -} - -#define instantiate_gather4(name, src_type, ind_type, nindex) \ -template [[host_name("gather" name "_" #nindex "_0")]] \ -[[kernel]] void gather( \ - const device src_type *src [[buffer(0)]], \ - const constant Indices& indices [[buffer(1)]], \ - device src_type *out [[buffer(2)]], \ - const constant int *src_shape [[buffer(3)]], \ - const constant size_t *src_strides [[buffer(4)]], \ - const constant size_t& src_ndim [[buffer(5)]], \ - const constant int *slice_sizes [[buffer(6)]], \ - const constant int* axes [[buffer(7)]], \ - uint2 index [[thread_position_in_grid]], \ - uint2 grid_dim [[threads_per_grid]]); \ -template [[host_name("gather" name "_" #nindex "_1")]] \ -[[kernel]] void gather( \ - const device src_type *src [[buffer(0)]], \ - const constant Indices& indices [[buffer(1)]], \ - device src_type *out [[buffer(2)]], \ - const constant int *src_shape [[buffer(3)]], \ - const constant size_t *src_strides [[buffer(4)]], \ - const constant size_t& src_ndim [[buffer(5)]], \ - const constant int *slice_sizes [[buffer(6)]], \ - const constant int* axes [[buffer(7)]], \ - uint2 index [[thread_position_in_grid]], \ - uint2 grid_dim [[threads_per_grid]]); \ -template [[host_name("gather" name "_" #nindex)]] \ -[[kernel]] void gather( \ - const device src_type *src [[buffer(0)]], \ - const constant Indices& indices [[buffer(1)]], \ - device src_type *out [[buffer(2)]], \ - const constant int *src_shape [[buffer(3)]], \ - const constant size_t *src_strides [[buffer(4)]], \ - const constant size_t& src_ndim [[buffer(5)]], \ - const constant int *slice_sizes [[buffer(6)]], \ - const constant int* axes [[buffer(7)]], \ - uint2 index [[thread_position_in_grid]], \ - uint2 grid_dim [[threads_per_grid]]); - - -// Special for case NIDX=0 -instantiate_gather4("bool_", bool, bool, 0) -instantiate_gather4("uint8", uint8_t, bool, 0) -instantiate_gather4("uint16", uint16_t, bool, 0) -instantiate_gather4("uint32", uint32_t, bool, 0) -instantiate_gather4("uint64", uint64_t, bool, 0) -instantiate_gather4("int8", int8_t, bool, 0) -instantiate_gather4("int16", int16_t, bool, 0) -instantiate_gather4("int32", int32_t, bool, 0) -instantiate_gather4("int64", int64_t, bool, 0) -instantiate_gather4("float16", half, bool, 0) -instantiate_gather4("float32", float, bool, 0) -instantiate_gather4("bfloat16", bfloat16_t, bool, 0) - -#define instantiate_gather3(name, src_type, ind_type) \ - instantiate_gather4(name, src_type, ind_type, 1) \ - instantiate_gather4(name, src_type, ind_type, 2) \ - instantiate_gather4(name, src_type, ind_type, 3) \ - instantiate_gather4(name, src_type, ind_type, 4) \ - instantiate_gather4(name, src_type, ind_type, 5) \ - instantiate_gather4(name, src_type, ind_type, 6) \ - instantiate_gather4(name, src_type, ind_type, 7) \ - instantiate_gather4(name, src_type, ind_type, 8) \ - instantiate_gather4(name, src_type, ind_type, 9) \ - instantiate_gather4(name, src_type, ind_type, 10) - -#define instantiate_gather(name, src_type) \ - instantiate_gather3(#name "bool_", src_type, bool) \ - instantiate_gather3(#name "uint8", src_type, uint8_t) \ - instantiate_gather3(#name "uint16", src_type, uint16_t) \ - instantiate_gather3(#name "uint32", src_type, uint32_t) \ - instantiate_gather3(#name "uint64", src_type, uint64_t) \ - instantiate_gather3(#name "int8", src_type, int8_t) \ - instantiate_gather3(#name "int16", src_type, int16_t) \ - instantiate_gather3(#name "int32", src_type, int32_t) \ - instantiate_gather3(#name "int64", src_type, int64_t) - -instantiate_gather(bool_, bool) -instantiate_gather(uint8, uint8_t) -instantiate_gather(uint16, uint16_t) -instantiate_gather(uint32, uint32_t) -instantiate_gather(uint64, uint64_t) -instantiate_gather(int8, int8_t) -instantiate_gather(int16, int16_t) -instantiate_gather(int32, int32_t) -instantiate_gather(int64, int64_t) -instantiate_gather(float16, half) -instantiate_gather(float32, float) -instantiate_gather(bfloat16, bfloat16_t) - -///////////////////////////////////////////////////////////////////// -// Scatter kernel -///////////////////////////////////////////////////////////////////// - -template -[[kernel]] void scatter( - const device Indices& indices [[buffer(0)]], - const device T *updates [[buffer(1)]], - device mlx_atomic *out [[buffer(2)]], - const device int *upd_shape [[buffer(3)]], - const device size_t *upd_strides [[buffer(4)]], - const device size_t& upd_ndim [[buffer(5)]], - const device size_t& upd_size [[buffer(6)]], - const device int *out_shape [[buffer(7)]], - const device size_t *out_strides [[buffer(8)]], - const device size_t& out_ndim [[buffer(9)]], - const device int* axes [[buffer(10)]], - uint gid [[thread_position_in_grid]]) { - - Op op; - auto ind_idx = gid / upd_size; - auto ind_offset = gid % upd_size; - - size_t out_idx = 0; - for (int i = 0; i < NIDX; ++i) { - auto idx_loc = elem_to_loc( - ind_idx, - &indices.shapes[indices.ndim * i], - &indices.strides[indices.ndim * i], - indices.ndim); - auto ax = axes[i]; - auto idx_val = offset_neg_idx( - indices.buffers[i][idx_loc], out_shape[ax]); - out_idx += idx_val * out_strides[ax]; - } - - auto out_offset = elem_to_loc( - ind_offset, upd_shape + indices.ndim, out_strides, out_ndim); - auto upd_idx = elem_to_loc(gid, upd_shape, upd_strides, upd_ndim); - op.atomic_update(out, updates[upd_idx], out_idx + out_offset); -} - -#define instantiate_scatter4(name, type, ind_type, op_type, nindex) \ -template [[host_name("scatter" name "_" #nindex)]] \ -[[kernel]] void scatter( \ - const device Indices& indices [[buffer(0)]], \ - const device type *updates [[buffer(1)]], \ - device mlx_atomic *out [[buffer(2)]], \ - const device int *upd_shape [[buffer(3)]], \ - const device size_t *upd_strides [[buffer(4)]], \ - const device size_t& upd_ndim [[buffer(5)]], \ - const device size_t& upd_size [[buffer(6)]], \ - const device int *out_shape [[buffer(7)]], \ - const device size_t *out_strides [[buffer(8)]], \ - const device size_t& out_ndim [[buffer(9)]], \ - const device int* axes [[buffer(10)]], \ - uint gid [[thread_position_in_grid]]); - -// Special case NINDEX=0 -#define instantiate_scatter_nd0(name, type) \ - instantiate_scatter4(#name "none", type, bool, None, 0) \ - instantiate_scatter4(#name "_sum", type, bool, Sum, 0) \ - instantiate_scatter4(#name "_prod", type, bool, Prod, 0) \ - instantiate_scatter4(#name "_max", type, bool, Max, 0) \ - instantiate_scatter4(#name "_min", type, bool, Min, 0) - -#define instantiate_scatter3(name, type, ind_type, op_type) \ - instantiate_scatter4(name, type, ind_type, op_type, 1) \ - instantiate_scatter4(name, type, ind_type, op_type, 2) \ - instantiate_scatter4(name, type, ind_type, op_type, 3) \ - instantiate_scatter4(name, type, ind_type, op_type, 4) \ - instantiate_scatter4(name, type, ind_type, op_type, 5) \ - instantiate_scatter4(name, type, ind_type, op_type, 6) \ - instantiate_scatter4(name, type, ind_type, op_type, 7) \ - instantiate_scatter4(name, type, ind_type, op_type, 8) \ - instantiate_scatter4(name, type, ind_type, op_type, 9) \ - instantiate_scatter4(name, type, ind_type, op_type, 10) - -#define instantiate_scatter2(name, type, ind_type) \ - instantiate_scatter3(name "_none", type, ind_type, None) \ - instantiate_scatter3(name "_sum", type, ind_type, Sum) \ - instantiate_scatter3(name "_prod", type, ind_type, Prod) \ - instantiate_scatter3(name "_max", type, ind_type, Max) \ - instantiate_scatter3(name "_min", type, ind_type, Min) - -#define instantiate_scatter(name, type) \ - instantiate_scatter2(#name "bool_", type, bool) \ - instantiate_scatter2(#name "uint8", type, uint8_t) \ - instantiate_scatter2(#name "uint16", type, uint16_t) \ - instantiate_scatter2(#name "uint32", type, uint32_t) \ - instantiate_scatter2(#name "uint64", type, uint64_t) \ - instantiate_scatter2(#name "int8", type, int8_t) \ - instantiate_scatter2(#name "int16", type, int16_t) \ - instantiate_scatter2(#name "int32", type, int32_t) \ - instantiate_scatter2(#name "int64", type, int64_t) - -// TODO uint64 and int64 unsupported -instantiate_scatter_nd0(bool_, bool) -instantiate_scatter_nd0(uint8, uint8_t) -instantiate_scatter_nd0(uint16, uint16_t) -instantiate_scatter_nd0(uint32, uint32_t) -instantiate_scatter_nd0(int8, int8_t) -instantiate_scatter_nd0(int16, int16_t) -instantiate_scatter_nd0(int32, int32_t) -instantiate_scatter_nd0(float16, half) -instantiate_scatter_nd0(float32, float) -instantiate_scatter_nd0(bfloat16, bfloat16_t) - -instantiate_scatter(bool_, bool) -instantiate_scatter(uint8, uint8_t) -instantiate_scatter(uint16, uint16_t) -instantiate_scatter(uint32, uint32_t) -instantiate_scatter(int8, int8_t) -instantiate_scatter(int16, int16_t) -instantiate_scatter(int32, int32_t) -instantiate_scatter(float16, half) -instantiate_scatter(float32, float) -instantiate_scatter(bfloat16, bfloat16_t) diff --git a/mlx/backend/metal/kernels/quantized.metal b/mlx/backend/metal/kernels/quantized.metal index 5bf3142d4..c2bfba9f9 100644 --- a/mlx/backend/metal/kernels/quantized.metal +++ b/mlx/backend/metal/kernels/quantized.metal @@ -15,6 +15,14 @@ using namespace metal; MLX_MTL_CONST int SIMD_SIZE = 32; +template struct AccT { + typedef T acc_t; +}; + +template <> struct AccT { + typedef float acc_t; +}; + template [[kernel]] void qmv( const device uint32_t* w [[buffer(0)]], @@ -31,21 +39,23 @@ template ::acc_t U; + threadgroup U scales_block[BM * groups_per_block]; + threadgroup U biases_block[BM * groups_per_block]; + threadgroup U x_block[colgroup]; thread uint32_t w_local; - thread T result = 0; - thread T scale = 1; - thread T bias = 0; - thread T x_thread[el_per_thread]; + thread U result = 0; + thread U scale = 1; + thread U bias = 0; + thread U x_thread[el_per_thread]; // Adjust positions const int in_vec_size_w = in_vec_size / el_per_thread; @@ -57,12 +67,19 @@ template = out_vec_size) { + return; + } + // Loop over in_vec in blocks of colgroup for (int i=0; i(w_local & bitmask) + bias) * x_thread[k]; + result += (scale * static_cast(w_local & bitmask) + bias) * x_thread[k]; w_local >>= bits; } } @@ -100,7 +117,7 @@ template (result); } } @@ -129,15 +146,16 @@ template ::acc_t U; + threadgroup U scales_block[BM * groups_per_block]; + threadgroup U biases_block[BM * groups_per_block]; + threadgroup U x_block[BM]; thread uint32_t w_local; - thread T result[el_per_int] = {0}; - thread T scale = 1; - thread T bias = 0; - thread T x_local = 0; + thread U result[el_per_int] = {0}; + thread U scale = 1; + thread U bias = 0; + thread U x_local = 0; // Adjust positions const int out_vec_size_w = out_vec_size / el_per_int; @@ -186,7 +204,7 @@ template (w_local & bitmask) + bias) * x_local; + result[k] += (scale * static_cast(w_local & bitmask) + bias) * x_local; w_local >>= bits; } } @@ -201,7 +219,7 @@ template (result[k]); } } } @@ -240,7 +258,6 @@ template ; using loader_x_t = mlx::steel::BlockLoader; - threadgroup T scales_block[BN * groups_per_block]; threadgroup T biases_block[BN * groups_per_block]; threadgroup T Xs[BM * BK]; @@ -303,7 +320,7 @@ template = K) { + if (num_k < BK) { for (int wo=0; wo + +#include "mlx/backend/metal/kernels/bf16.h" +#include "mlx/backend/metal/kernels/utils.h" + +template +[[kernel]] void rope( + const device T *in [[buffer(0)]], + device T * out [[buffer(1)]], + constant const size_t strides[3], + constant const int& offset, + constant const float& base, + constant const float& scale, + uint3 pos [[thread_position_in_grid]], + uint3 grid [[threads_per_grid]]) { + // Compute the input and output indices + uint in_index_1, in_index_2; + uint out_index_1, out_index_2; + if (traditional) { + out_index_1 = 2 * (pos.x + grid.x * (pos.y + grid.y * pos.z)); + out_index_2 = out_index_1 + 1; + in_index_1 = 2 * pos.x * strides[2] + pos.y * strides[1] + pos.z * strides[0]; + in_index_2 = in_index_1 + strides[2]; + } else { + out_index_1 = pos.x + 2*(grid.x * (pos.y + grid.y * pos.z)); + out_index_2 = out_index_1 + grid.x; + in_index_1 = pos.x * strides[2] + pos.y * strides[1] + pos.z * strides[0]; + in_index_2 = in_index_1 + grid.x * strides[2]; + } + + // Figure out L and d. + float L = scale * static_cast(pos.y + offset); + float d = static_cast(pos.x) / static_cast(grid.x); + + // Compute costheta, sintheta + float theta = L * metal::exp2(-d * base); + float costheta = metal::fast::cos(theta); + float sintheta = metal::fast::sin(theta); + + // Read and write the output + float x1 = static_cast(in[in_index_1]); + float x2 = static_cast(in[in_index_2]); + float rx1 = x1 * costheta - x2 * sintheta; + float rx2 = x1 * sintheta + x2 * costheta; + out[out_index_1] = static_cast(rx1); + out[out_index_2] = static_cast(rx2); +} + +#define instantiate_rope(name, type, traditional) \ + template [[host_name("rope_" #name)]] \ + [[kernel]] void rope( \ + const device type* in [[buffer(0)]], \ + device type* out [[buffer(1)]], \ + constant const size_t strides[3], \ + constant const int& offset, \ + constant const float& base, \ + constant const float& scale, \ + uint3 pos [[thread_position_in_grid]], \ + uint3 grid [[threads_per_grid]]); + +instantiate_rope(traditional_float16, half, true) +instantiate_rope(traditional_bfloat16, bfloat16_t, true) +instantiate_rope(traditional_float32, float, true) +instantiate_rope(float16, half, false) +instantiate_rope(bfloat16, bfloat16_t, false) +instantiate_rope(float32, float, false) diff --git a/mlx/backend/metal/kernels/scatter.metal b/mlx/backend/metal/kernels/scatter.metal new file mode 100644 index 000000000..7a94be7da --- /dev/null +++ b/mlx/backend/metal/kernels/scatter.metal @@ -0,0 +1,194 @@ +// Copyright © 2023-2024 Apple Inc. + +#include + +#include "mlx/backend/metal/kernels/bf16.h" +#include "mlx/backend/metal/kernels/indexing.h" +#include "mlx/backend/metal/kernels/reduce.h" +#include "mlx/backend/metal/kernels/utils.h" + +using namespace metal; + +///////////////////////////////////////////////////////////////////// +// Scatter kernel +///////////////////////////////////////////////////////////////////// + + +template +METAL_FUNC void scatter_impl( + const device T *updates [[buffer(1)]], + device mlx_atomic *out [[buffer(2)]], + const constant int *upd_shape [[buffer(3)]], + const constant size_t *upd_strides [[buffer(4)]], + const constant size_t& upd_ndim [[buffer(5)]], + const constant size_t& upd_size [[buffer(6)]], + const constant int *out_shape [[buffer(7)]], + const constant size_t *out_strides [[buffer(8)]], + const constant size_t& out_ndim [[buffer(9)]], + const constant int* axes [[buffer(10)]], + const thread Indices& indices, + uint2 gid [[thread_position_in_grid]]) { + + Op op; + auto ind_idx = gid.y; + auto ind_offset = gid.x; + + size_t out_idx = 0; + for (int i = 0; i < NIDX; ++i) { + auto idx_loc = elem_to_loc( + ind_idx, + &indices.shapes[indices.ndim * i], + &indices.strides[indices.ndim * i], + indices.ndim); + auto ax = axes[i]; + auto idx_val = offset_neg_idx( + indices.buffers[i][idx_loc], out_shape[ax]); + out_idx += idx_val * out_strides[ax]; + } + + auto out_offset = elem_to_loc( + ind_offset, upd_shape + indices.ndim, out_strides, out_ndim); + auto upd_idx = elem_to_loc(gid.y * upd_size + gid.x, upd_shape, upd_strides, upd_ndim); + op.atomic_update(out, updates[upd_idx], out_idx + out_offset); +} + +#define make_scatter_impl(IDX_ARG, IDX_ARR) \ +template \ +[[kernel]] void scatter( \ + const device T *updates [[buffer(1)]], \ + device mlx_atomic *out [[buffer(2)]], \ + const constant int *upd_shape [[buffer(3)]], \ + const constant size_t *upd_strides [[buffer(4)]], \ + const constant size_t& upd_ndim [[buffer(5)]], \ + const constant size_t& upd_size [[buffer(6)]], \ + const constant int *out_shape [[buffer(7)]], \ + const constant size_t *out_strides [[buffer(8)]], \ + const constant size_t& out_ndim [[buffer(9)]], \ + const constant int* axes [[buffer(10)]], \ + const constant int *idx_shapes [[buffer(11)]], \ + const constant size_t *idx_strides [[buffer(12)]], \ + const constant int& idx_ndim [[buffer(13)]], \ + IDX_ARG(IdxT) \ + uint2 gid [[thread_position_in_grid]]) { \ + \ + Indices idxs{ \ + {{IDX_ARR()}}, \ + idx_shapes, \ + idx_strides, \ + idx_ndim}; \ + \ + return scatter_impl( \ + updates, \ + out, \ + upd_shape, \ + upd_strides, \ + upd_ndim, \ + upd_size, \ + out_shape, \ + out_strides, \ + out_ndim, \ + axes, \ + idxs, \ + gid); \ +} + +#define make_scatter(n) make_scatter_impl(IDX_ARG_ ##n, IDX_ARR_ ##n) + +make_scatter(0) +make_scatter(1) +make_scatter(2) +make_scatter(3) +make_scatter(4) +make_scatter(5) +make_scatter(6) +make_scatter(7) +make_scatter(8) +make_scatter(9) +make_scatter(10) + +///////////////////////////////////////////////////////////////////// +// Scatter instantiations +///////////////////////////////////////////////////////////////////// + +#define instantiate_scatter5(name, src_t, idx_t, op_t, nidx, IDX_ARG) \ +template [[host_name("scatter" name "_" #nidx)]] \ +[[kernel]] void scatter( \ + const device src_t *updates [[buffer(1)]], \ + device mlx_atomic *out [[buffer(2)]], \ + const constant int *upd_shape [[buffer(3)]], \ + const constant size_t *upd_strides [[buffer(4)]], \ + const constant size_t& upd_ndim [[buffer(5)]], \ + const constant size_t& upd_size [[buffer(6)]], \ + const constant int *out_shape [[buffer(7)]], \ + const constant size_t *out_strides [[buffer(8)]], \ + const constant size_t& out_ndim [[buffer(9)]], \ + const constant int* axes [[buffer(10)]], \ + const constant int *idx_shapes [[buffer(11)]], \ + const constant size_t *idx_strides [[buffer(12)]], \ + const constant int& idx_ndim [[buffer(13)]], \ + IDX_ARG(idx_t) \ + uint2 gid [[thread_position_in_grid]]); + +#define instantiate_scatter4(name, src_t, idx_t, op_t, nidx) \ + instantiate_scatter5(name, src_t, idx_t, op_t, nidx, IDX_ARG_ ##nidx) + +// Special case NINDEX=0 +#define instantiate_scatter_nd0(name, type) \ + instantiate_scatter4(#name "none", type, bool, None, 0) \ + instantiate_scatter4(#name "_sum", type, bool, Sum, 0) \ + instantiate_scatter4(#name "_prod", type, bool, Prod, 0) \ + instantiate_scatter4(#name "_max", type, bool, Max, 0) \ + instantiate_scatter4(#name "_min", type, bool, Min, 0) + +#define instantiate_scatter3(name, type, ind_type, op_type) \ + instantiate_scatter4(name, type, ind_type, op_type, 1) \ + instantiate_scatter4(name, type, ind_type, op_type, 2) \ + instantiate_scatter4(name, type, ind_type, op_type, 3) \ + instantiate_scatter4(name, type, ind_type, op_type, 4) \ + instantiate_scatter4(name, type, ind_type, op_type, 5) \ + instantiate_scatter4(name, type, ind_type, op_type, 6) \ + instantiate_scatter4(name, type, ind_type, op_type, 7) \ + instantiate_scatter4(name, type, ind_type, op_type, 8) \ + instantiate_scatter4(name, type, ind_type, op_type, 9) \ + instantiate_scatter4(name, type, ind_type, op_type, 10) + +#define instantiate_scatter2(name, type, ind_type) \ + instantiate_scatter3(name "_none", type, ind_type, None) \ + instantiate_scatter3(name "_sum", type, ind_type, Sum) \ + instantiate_scatter3(name "_prod", type, ind_type, Prod) \ + instantiate_scatter3(name "_max", type, ind_type, Max) \ + instantiate_scatter3(name "_min", type, ind_type, Min) + +#define instantiate_scatter(name, type) \ + instantiate_scatter2(#name "bool_", type, bool) \ + instantiate_scatter2(#name "uint8", type, uint8_t) \ + instantiate_scatter2(#name "uint16", type, uint16_t) \ + instantiate_scatter2(#name "uint32", type, uint32_t) \ + instantiate_scatter2(#name "uint64", type, uint64_t) \ + instantiate_scatter2(#name "int8", type, int8_t) \ + instantiate_scatter2(#name "int16", type, int16_t) \ + instantiate_scatter2(#name "int32", type, int32_t) \ + instantiate_scatter2(#name "int64", type, int64_t) + +// TODO uint64 and int64 unsupported +instantiate_scatter_nd0(bool_, bool) +instantiate_scatter_nd0(uint8, uint8_t) +instantiate_scatter_nd0(uint16, uint16_t) +instantiate_scatter_nd0(uint32, uint32_t) +instantiate_scatter_nd0(int8, int8_t) +instantiate_scatter_nd0(int16, int16_t) +instantiate_scatter_nd0(int32, int32_t) +instantiate_scatter_nd0(float16, half) +instantiate_scatter_nd0(float32, float) +instantiate_scatter_nd0(bfloat16, bfloat16_t) + +instantiate_scatter(bool_, bool) +instantiate_scatter(uint8, uint8_t) +instantiate_scatter(uint16, uint16_t) +instantiate_scatter(uint32, uint32_t) +instantiate_scatter(int8, int8_t) +instantiate_scatter(int16, int16_t) +instantiate_scatter(int32, int32_t) +instantiate_scatter(float16, half) +instantiate_scatter(float32, float) +instantiate_scatter(bfloat16, bfloat16_t) diff --git a/mlx/backend/metal/kernels/unary.h b/mlx/backend/metal/kernels/unary.h new file mode 100644 index 000000000..6d086b775 --- /dev/null +++ b/mlx/backend/metal/kernels/unary.h @@ -0,0 +1,376 @@ +// Copyright © 2023-2024 Apple Inc. + +#pragma once + +#include +#include + +#include "mlx/backend/metal/kernels/bf16.h" +#include "mlx/backend/metal/kernels/erf.h" +#include "mlx/backend/metal/kernels/utils.h" + +struct Abs { + template + T operator()(T x) { + return metal::abs(x); + }; + template <> + uint8_t operator()(uint8_t x) { + return x; + }; + template <> + uint16_t operator()(uint16_t x) { + return x; + }; + template <> + uint32_t operator()(uint32_t x) { + return x; + }; + template <> + uint64_t operator()(uint64_t x) { + return x; + }; + template <> + bool operator()(bool x) { + return x; + }; + template <> + complex64_t operator()(complex64_t x) { + return {metal::precise::sqrt(x.real * x.real + x.imag * x.imag), 0}; + }; +}; + +struct ArcCos { + template + T operator()(T x) { + return metal::precise::acos(x); + }; +}; + +struct ArcCosh { + template + T operator()(T x) { + return metal::precise::acosh(x); + }; +}; + +struct ArcSin { + template + T operator()(T x) { + return metal::precise::asin(x); + }; +}; + +struct ArcSinh { + template + T operator()(T x) { + return metal::precise::asinh(x); + }; +}; + +struct ArcTan { + template + T operator()(T x) { + return metal::precise::atan(x); + }; +}; + +struct ArcTanh { + template + T operator()(T x) { + return metal::precise::atanh(x); + }; +}; + +struct Ceil { + template + T operator()(T x) { + return metal::ceil(x); + }; + template <> + int8_t operator()(int8_t x) { + return x; + }; + template <> + int16_t operator()(int16_t x) { + return x; + }; + template <> + int32_t operator()(int32_t x) { + return x; + }; + template <> + int64_t operator()(int64_t x) { + return x; + }; + template <> + uint8_t operator()(uint8_t x) { + return x; + }; + template <> + uint16_t operator()(uint16_t x) { + return x; + }; + template <> + uint32_t operator()(uint32_t x) { + return x; + }; + template <> + uint64_t operator()(uint64_t x) { + return x; + }; + template <> + bool operator()(bool x) { + return x; + }; +}; + +struct Cos { + template + T operator()(T x) { + return metal::precise::cos(x); + }; + + template <> + complex64_t operator()(complex64_t x) { + return { + metal::precise::cos(x.real) * metal::precise::cosh(x.imag), + -metal::precise::sin(x.real) * metal::precise::sinh(x.imag)}; + }; +}; + +struct Cosh { + template + T operator()(T x) { + return metal::precise::cosh(x); + }; + + template <> + complex64_t operator()(complex64_t x) { + return { + metal::precise::cosh(x.real) * metal::precise::cos(x.imag), + metal::precise::sinh(x.real) * metal::precise::sin(x.imag)}; + }; +}; + +struct Erf { + template + T operator()(T x) { + return static_cast(erf(static_cast(x))); + }; +}; + +struct ErfInv { + template + T operator()(T x) { + return static_cast(erfinv(static_cast(x))); + }; +}; + +struct Exp { + template + T operator()(T x) { + return metal::precise::exp(x); + }; + template <> + complex64_t operator()(complex64_t x) { + auto m = metal::precise::exp(x.real); + return {m * metal::precise::cos(x.imag), m * metal::precise::sin(x.imag)}; + } +}; + +struct Floor { + template + T operator()(T x) { + return metal::floor(x); + }; + template <> + int8_t operator()(int8_t x) { + return x; + }; + template <> + int16_t operator()(int16_t x) { + return x; + }; + template <> + int32_t operator()(int32_t x) { + return x; + }; + template <> + int64_t operator()(int64_t x) { + return x; + }; + template <> + uint8_t operator()(uint8_t x) { + return x; + }; + template <> + uint16_t operator()(uint16_t x) { + return x; + }; + template <> + uint32_t operator()(uint32_t x) { + return x; + }; + template <> + uint64_t operator()(uint64_t x) { + return x; + }; + template <> + bool operator()(bool x) { + return x; + }; +}; + +struct Log { + template + T operator()(T x) { + return metal::precise::log(x); + }; +}; + +struct Log2 { + template + T operator()(T x) { + return metal::precise::log2(x); + }; +}; + +struct Log10 { + template + T operator()(T x) { + return metal::precise::log10(x); + }; +}; + +struct Log1p { + template + T operator()(T x) { + return log1p(x); + }; +}; + +struct LogicalNot { + template + T operator()(T x) { + return !x; + }; +}; + +struct Negative { + template + T operator()(T x) { + return -x; + }; +}; + +struct Round { + template + T operator()(T x) { + return metal::rint(x); + }; + template <> + complex64_t operator()(complex64_t x) { + return {metal::rint(x.real), metal::rint(x.imag)}; + }; +}; + +struct Sigmoid { + template + T operator()(T x) { + auto y = 1 / (1 + metal::exp(-metal::abs(x))); + return (x < 0) ? 1 - y : y; + } +}; + +struct Sign { + template + T operator()(T x) { + return (x > T(0)) - (x < T(0)); + }; + template <> + uint32_t operator()(uint32_t x) { + return x != 0; + }; +}; + +struct Sin { + template + T operator()(T x) { + return metal::precise::sin(x); + }; + + template <> + complex64_t operator()(complex64_t x) { + return { + metal::precise::sin(x.real) * metal::precise::cosh(x.imag), + metal::precise::cos(x.real) * metal::precise::sinh(x.imag)}; + }; +}; + +struct Sinh { + template + T operator()(T x) { + return metal::precise::sinh(x); + }; + + template <> + complex64_t operator()(complex64_t x) { + return { + metal::precise::sinh(x.real) * metal::precise::cos(x.imag), + metal::precise::cosh(x.real) * metal::precise::sin(x.imag)}; + }; +}; + +struct Square { + template + T operator()(T x) { + return x * x; + }; +}; + +struct Sqrt { + template + T operator()(T x) { + return metal::precise::sqrt(x); + }; +}; + +struct Rsqrt { + template + T operator()(T x) { + return metal::precise::rsqrt(x); + }; +}; + +struct Tan { + template + T operator()(T x) { + return metal::precise::tan(x); + }; + + template <> + complex64_t operator()(complex64_t x) { + float tan_a = metal::precise::tan(x.real); + float tanh_b = metal::precise::tanh(x.imag); + float t1 = tan_a * tanh_b; + float denom = 1. + t1 * t1; + return {(tan_a - tanh_b * t1) / denom, (tanh_b + tan_a * t1) / denom}; + }; +}; + +struct Tanh { + template + T operator()(T x) { + return metal::precise::tanh(x); + }; + + template <> + complex64_t operator()(complex64_t x) { + float tanh_a = metal::precise::tanh(x.real); + float tan_b = metal::precise::tan(x.imag); + float t1 = tanh_a * tan_b; + float denom = 1. + t1 * t1; + return {(tanh_a + tan_b * t1) / denom, (tan_b - tanh_a * t1) / denom}; + }; +}; diff --git a/mlx/backend/metal/kernels/unary.metal b/mlx/backend/metal/kernels/unary.metal index 681d7707f..154db0520 100644 --- a/mlx/backend/metal/kernels/unary.metal +++ b/mlx/backend/metal/kernels/unary.metal @@ -1,223 +1,6 @@ -// Copyright © 2023 Apple Inc. +// Copyright © 2023-2024 Apple Inc. -#include -#include - -#include "mlx/backend/metal/kernels/utils.h" -#include "mlx/backend/metal/kernels/erf.h" -#include "mlx/backend/metal/kernels/bf16.h" - -struct Abs { - template T operator()(T x) { return metal::abs(x); }; - template <> uint8_t operator()(uint8_t x) { return x; }; - template <> uint16_t operator()(uint16_t x) { return x; }; - template <> uint32_t operator()(uint32_t x) { return x; }; - template <> uint64_t operator()(uint64_t x) { return x; }; - template <> bool operator()(bool x) { return x; }; - template <> complex64_t operator()(complex64_t x) { - return {metal::precise::sqrt(x.real * x.real + x.imag * x.imag), 0}; - }; -}; - -struct ArcCos { - template T operator()(T x) { return metal::precise::acos(x); }; -}; - -struct ArcCosh { - template T operator()(T x) { return metal::precise::acosh(x); }; -}; - -struct ArcSin { - template T operator()(T x) { return metal::precise::asin(x); }; -}; - -struct ArcSinh { - template T operator()(T x) { return metal::precise::asinh(x); }; -}; - -struct ArcTan { - template T operator()(T x) { return metal::precise::atan(x); }; -}; - -struct ArcTanh { - template T operator()(T x) { return metal::precise::atanh(x); }; -}; - -struct Ceil { - template T operator()(T x) { return metal::ceil(x); }; - template <> int8_t operator()(int8_t x) { return x; }; - template <> int16_t operator()(int16_t x) { return x; }; - template <> int32_t operator()(int32_t x) { return x; }; - template <> int64_t operator()(int64_t x) { return x; }; - template <> uint8_t operator()(uint8_t x) { return x; }; - template <> uint16_t operator()(uint16_t x) { return x; }; - template <> uint32_t operator()(uint32_t x) { return x; }; - template <> uint64_t operator()(uint64_t x) { return x; }; - template <> bool operator()(bool x) { return x; }; -}; - -struct Cos { - template T operator()(T x) { return metal::precise::cos(x); }; - - template <> - complex64_t operator()(complex64_t x) { - return { - metal::precise::cos(x.real) * metal::precise::cosh(x.imag), - -metal::precise::sin(x.real) * metal::precise::sinh(x.imag) - }; - }; -}; - -struct Cosh { - template T operator()(T x) { return metal::precise::cosh(x); }; - - template <> - complex64_t operator()(complex64_t x) { - return { - metal::precise::cosh(x.real) * metal::precise::cos(x.imag), - metal::precise::sinh(x.real) * metal::precise::sin(x.imag) - }; - }; -}; - -struct Erf { - template T operator()(T x) { return static_cast(erf(static_cast(x))); }; -}; - -struct ErfInv { - template T operator()(T x) { return static_cast(erfinv(static_cast(x))); }; -}; - -struct Exp { - template T operator()(T x) { return metal::precise::exp(x); }; - template <> complex64_t operator()(complex64_t x) { - auto m = metal::precise::exp(x.real); - return {m * metal::precise::cos(x.imag), m * metal::precise::sin(x.imag)}; - } -}; - -struct Floor { - template T operator()(T x) { return metal::floor(x); }; - template <> int8_t operator()(int8_t x) { return x; }; - template <> int16_t operator()(int16_t x) { return x; }; - template <> int32_t operator()(int32_t x) { return x; }; - template <> int64_t operator()(int64_t x) { return x; }; - template <> uint8_t operator()(uint8_t x) { return x; }; - template <> uint16_t operator()(uint16_t x) { return x; }; - template <> uint32_t operator()(uint32_t x) { return x; }; - template <> uint64_t operator()(uint64_t x) { return x; }; - template <> bool operator()(bool x) { return x; }; -}; - -struct Log { - template T operator()(T x) { return metal::precise::log(x); }; -}; - -struct Log2 { - template T operator()(T x) { return metal::precise::log2(x); }; -}; - -struct Log10 { - template T operator()(T x) { return metal::precise::log10(x); }; -}; - -struct Log1p { - template T operator()(T x) { return log1p(x); }; -}; - -struct LogicalNot { - template T operator()(T x) { return !x; }; -}; - -struct Negative { - template T operator()(T x) { return -x; }; -}; - -struct Round { - template T operator()(T x) { return metal::rint(x); }; - template <> complex64_t operator()(complex64_t x) { return {metal::rint(x.real), metal::rint(x.imag)}; }; -}; - -struct Sigmoid { - template - T operator()(T x) { - auto y = 1 / (1 + metal::exp(-metal::abs(x))); - return (x < 0) ? 1 - y : y; - } -}; - -struct Sign { - template T operator()(T x) { return (x > T(0)) - (x < T(0)); }; - template <> uint32_t operator()(uint32_t x) { return x != 0; }; -}; - -struct Sin { - template T operator()(T x) { return metal::precise::sin(x); }; - - template <> - complex64_t operator()(complex64_t x) { - return { - metal::precise::sin(x.real) * metal::precise::cosh(x.imag), - metal::precise::cos(x.real) * metal::precise::sinh(x.imag) - }; - }; -}; - -struct Sinh { - template T operator()(T x) { return metal::precise::sinh(x); }; - - template <> - complex64_t operator()(complex64_t x) { - return { - metal::precise::sinh(x.real) * metal::precise::cos(x.imag), - metal::precise::cosh(x.real) * metal::precise::sin(x.imag) - }; - }; -}; - -struct Square { - template T operator()(T x) { return x * x; }; -}; - -struct Sqrt { - template T operator()(T x) { return metal::precise::sqrt(x); }; -}; - -struct Rsqrt { - template T operator()(T x) { return metal::precise::rsqrt(x); }; -}; - -struct Tan { - template T operator()(T x) { return metal::precise::tan(x); }; - - template <> - complex64_t operator()(complex64_t x) { - float tan_a = metal::precise::tan(x.real); - float tanh_b = metal::precise::tanh(x.imag); - float t1 = tan_a * tanh_b; - float denom = 1. + t1 * t1; - return { - (tan_a - tanh_b * t1) / denom, - (tanh_b + tan_a * t1) / denom - }; - }; -}; - -struct Tanh { - template T operator()(T x) { return metal::precise::tanh(x); }; - - template <> - complex64_t operator()(complex64_t x) { - float tanh_a = metal::precise::tanh(x.real); - float tan_b = metal::precise::tan(x.imag); - float t1 = tanh_a * tan_b; - float denom = 1. + t1 * t1; - return { - (tanh_a + tan_b * t1) / denom, - (tan_b - tanh_a * t1) / denom - }; - }; -}; +#include "mlx/backend/metal/kernels/unary.h" template [[kernel]] void unary_op_v( diff --git a/mlx/backend/metal/kernels/utils.h b/mlx/backend/metal/kernels/utils.h index 634a9d6df..8ef1127b6 100644 --- a/mlx/backend/metal/kernels/utils.h +++ b/mlx/backend/metal/kernels/utils.h @@ -12,10 +12,10 @@ template struct Limits { - static const constant U max; - static const constant U min; - static const constant U finite_max; - static const constant U finite_min; + static const constant U max = metal::numeric_limits::max(); + static const constant U min = metal::numeric_limits::min(); + static const constant U finite_max = metal::numeric_limits::max(); + static const constant U finite_min = metal::numeric_limits::min(); }; #define instantiate_default_limit(type) \ @@ -71,7 +71,7 @@ inline size_t elem_to_loc( device const size_t* strides, int ndim) { size_t loc = 0; - for (int i = ndim - 1; i >= 0; --i) { + for (int i = ndim - 1; i >= 0 && elem > 0; --i) { loc += (elem % shape[i]) * strides[i]; elem /= shape[i]; } @@ -84,7 +84,7 @@ inline size_t elem_to_loc( constant const size_t* strides, int ndim) { size_t loc = 0; - for (int i = ndim - 1; i >= 0; --i) { + for (int i = ndim - 1; i >= 0 && elem > 0; --i) { loc += (elem % shape[i]) * strides[i]; elem /= shape[i]; } @@ -273,4 +273,4 @@ inline int64_t simd_shuffle_down(int64_t data, uint16_t delta) { inline bool simd_shuffle_down(bool data, uint16_t delta) { return simd_shuffle_down(static_cast(data), delta); -} \ No newline at end of file +} diff --git a/mlx/backend/metal/make_compiled_preamble.sh b/mlx/backend/metal/make_compiled_preamble.sh new file mode 100644 index 000000000..26b575de4 --- /dev/null +++ b/mlx/backend/metal/make_compiled_preamble.sh @@ -0,0 +1,28 @@ +#!/bin/bash +# +# This script generates a C++ function that provides the Metal unary and binary +# ops at runtime for use with kernel generation. +# +# Copyright © 2023-24 Apple Inc. + + +OUTPUT_FILE=$1 +CC=$2 +SRCDIR=$3 + +CONTENT=$($CC -I $SRCDIR -E $SRCDIR/mlx/backend/metal/kernels/compiled_preamble.h 2>/dev/null) + +cat << EOF > "$OUTPUT_FILE" +// Copyright © 2023-24 Apple Inc. + +namespace mlx::core::metal { + +const char* get_kernel_preamble() { + return R"preamble( +$CONTENT +)preamble"; + +} + +} // namespace mlx::core::metal +EOF diff --git a/mlx/backend/metal/primitives.cpp b/mlx/backend/metal/primitives.cpp index 757888a66..38ec5993c 100644 --- a/mlx/backend/metal/primitives.cpp +++ b/mlx/backend/metal/primitives.cpp @@ -691,7 +691,6 @@ void RandomBits::eval_gpu(const std::vector& inputs, array& out) { // organize into grid nkeys x elem_per_key MTL::Size grid_dims = MTL::Size(num_keys, half_size + odd, 1); NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); - auto nthreads = std::min(num_keys * (half_size + odd), thread_group_size); MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); auto compute_encoder = d.get_command_encoder(s.index); compute_encoder->setComputePipelineState(kernel); diff --git a/mlx/backend/metal/quantized.cpp b/mlx/backend/metal/quantized.cpp index 7754711e3..839d0b336 100644 --- a/mlx/backend/metal/quantized.cpp +++ b/mlx/backend/metal/quantized.cpp @@ -55,7 +55,7 @@ void QuantizedMatmul::eval_gpu(const std::vector& inputs, array& out) { int bo = std::min(32, O); int bd = 32; MTL::Size group_dims = MTL::Size(bd, bo, 1); - MTL::Size grid_dims = MTL::Size(1, O / bo, B); + MTL::Size grid_dims = MTL::Size(1, (O + bo - 1) / bo, B); set_array_buffer(compute_encoder, w, 0); set_array_buffer(compute_encoder, scales, 1); diff --git a/mlx/backend/metal/rope.cpp b/mlx/backend/metal/rope.cpp new file mode 100644 index 000000000..fdea57985 --- /dev/null +++ b/mlx/backend/metal/rope.cpp @@ -0,0 +1,54 @@ +// Copyright © 2023-2024 Apple Inc. + +#include "mlx/backend/metal/utils.h" +#include "mlx/fast_primitives.h" + +namespace mlx::core::fast { + +void RoPE::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + assert(inputs.size() == 1); + assert(outputs.size() == 1); + auto& in = inputs[0]; + auto& out = outputs[0]; + + if (in.ndim() != 3) { + throw std::runtime_error( + "[RoPE] Only 3 dimensions are supported (batch x sequence x dims)"); + } + if (dims_ != in.shape(-1)) { + throw std::runtime_error("[RoPE] Partial RoPE application not supported"); + } + if (in.flags().row_contiguous && in.is_donatable()) { + out.move_shared_buffer(in); + } else { + out.set_data(allocator::malloc_or_wait(out.nbytes())); + } + + auto& s = out.primitive().stream(); + auto& d = metal::device(s.device); + std::ostringstream kname; + kname << "rope_" << (traditional_ ? "traditional_" : "") << type_to_name(in); + auto kernel = d.get_kernel(kname.str()); + auto compute_encoder = d.get_command_encoder(s.index); + + bool donated = in.data_shared_ptr() == nullptr; + float base = std::log2(base_); + compute_encoder->setComputePipelineState(kernel); + set_array_buffer(compute_encoder, donated ? out : in, 0); + set_array_buffer(compute_encoder, out, 1); + compute_encoder->setBytes(in.strides().data(), 3 * sizeof(size_t), 2); + compute_encoder->setBytes(&offset_, sizeof(int), 3); + compute_encoder->setBytes(&base, sizeof(float), 4); + compute_encoder->setBytes(&scale_, sizeof(float), 5); + + int dim0 = in.shape(2) / 2; + int dim1 = in.shape(1); + int dim2 = in.shape(0); + auto group_dims = get_block_dims(dim0, dim1, dim2); + auto grid_dims = MTL::Size(dim0, dim1, dim2); + compute_encoder->dispatchThreads(grid_dims, group_dims); +} + +} // namespace mlx::core::fast diff --git a/mlx/backend/metal/softmax.cpp b/mlx/backend/metal/softmax.cpp index 33ec8014c..be25bc032 100644 --- a/mlx/backend/metal/softmax.cpp +++ b/mlx/backend/metal/softmax.cpp @@ -22,7 +22,12 @@ void Softmax::eval_gpu(const std::vector& inputs, array& out) { // Make sure that the last dimension is contiguous std::vector copies; auto check_input = [&copies, &s](const array& x) { - if (x.strides()[x.ndim() - 1] == 1) { + bool no_copy = x.strides()[x.ndim() - 1] == 1; + if (x.ndim() > 1) { + auto s = x.strides()[x.ndim() - 2]; + no_copy &= (s == 0 || s == x.shape().back()); + } + if (no_copy) { return x; } else { array x_copy(x.shape(), x.dtype(), nullptr, {}); diff --git a/mlx/backend/metal/utils.h b/mlx/backend/metal/utils.h index 378850802..363632a30 100644 --- a/mlx/backend/metal/utils.h +++ b/mlx/backend/metal/utils.h @@ -9,20 +9,6 @@ namespace mlx::core { namespace { -void set_array_buffer( - MTL::ComputeCommandEncoder* compute_encoder, - MTL::ArgumentEncoder* enc, - const array& a, - int idx) { - auto a_buf = static_cast(a.buffer().ptr()); - auto offset = a.data() - - static_cast(const_cast(a_buf)->contents()); - enc->setBuffer(a_buf, offset, idx); - // MTL::Resource usage through argument buffer needs to be explicitly - // flagged to enable hazard tracking - compute_encoder->useResource(a_buf, MTL::ResourceUsageRead); -} - void set_array_buffer( MTL::ComputeCommandEncoder* enc, const array& a, @@ -117,16 +103,18 @@ MTL::Size get_block_dims(int dim0, int dim1, int dim2) { // When multiple arrays are passed they should all have the same shape. The // collapsed axes are also the same so one shape is returned. std::tuple, std::vector>> -collapse_contiguous_dims(const std::vector& xs) { +collapse_contiguous_dims( + const std::vector& shape, + const std::vector> strides) { // Make a vector that has axes separated with -1. Collapse all axes between // -1. std::vector to_collapse; - if (xs[0].ndim() > 0) { + if (shape.size() > 0) { to_collapse.push_back(0); - for (int i = 1; i < xs[0].ndim(); i++) { + for (int i = 1; i < shape.size(); i++) { bool contiguous = true; - for (auto& x : xs) { - if (x.strides()[i] * x.shape()[i] != x.strides()[i - 1]) { + for (const std::vector& st : strides) { + if (st[i] * shape[i] != st[i - 1]) { contiguous = false; } if (!contiguous) { @@ -142,21 +130,31 @@ collapse_contiguous_dims(const std::vector& xs) { } std::vector out_shape; - std::vector> out_strides(xs.size()); + std::vector> out_strides(strides.size()); for (int i = 0; i < to_collapse.size(); i++) { - int current_shape = xs[0].shape()[to_collapse[i]]; + int current_shape = shape[to_collapse[i]]; while (to_collapse[++i] != -1) { - current_shape *= xs[0].shape()[to_collapse[i]]; + current_shape *= shape[to_collapse[i]]; } out_shape.push_back(current_shape); - for (int j = 0; j < xs.size(); j++) { - out_strides[j].push_back(xs[j].strides()[to_collapse[i - 1]]); + for (int j = 0; j < strides.size(); j++) { + const std::vector& st = strides[j]; + out_strides[j].push_back(st[to_collapse[i - 1]]); } } return std::make_tuple(out_shape, out_strides); } +std::tuple, std::vector>> +collapse_contiguous_dims(const std::vector& xs) { + std::vector> strides; + for (auto& x : xs) { + strides.emplace_back(x.strides()); + } + return collapse_contiguous_dims(xs[0].shape(), strides); +} + template std::tuple, std::vector>> collapse_contiguous_dims(Arrays... xs) { diff --git a/mlx/backend/no_metal/primitives.cpp b/mlx/backend/no_metal/primitives.cpp index dd4edc2ed..8e66f56b3 100644 --- a/mlx/backend/no_metal/primitives.cpp +++ b/mlx/backend/no_metal/primitives.cpp @@ -1,6 +1,7 @@ // Copyright © 2023-2024 Apple Inc. #include "mlx/primitives.h" +#include "mlx/fast_primitives.h" #define NO_GPU_MULTI(func) \ void func::eval_gpu( \ @@ -95,4 +96,8 @@ NO_GPU(Tan) NO_GPU(Tanh) NO_GPU(Transpose) +namespace fast { +NO_GPU_MULTI(RoPE) +} // namespace fast + } // namespace mlx::core diff --git a/mlx/compile.cpp b/mlx/compile.cpp index fa9e0a987..700c07ced 100644 --- a/mlx/compile.cpp +++ b/mlx/compile.cpp @@ -13,7 +13,7 @@ namespace mlx::core { -constexpr int max_compile_depth = 6; +constexpr int max_compile_depth = 11; bool is_unary(const Primitive& p) { return ( @@ -55,19 +55,20 @@ bool is_noop(const Primitive& p) { return typeid(p) == typeid(Copy) || typeid(p) == typeid(StopGradient); } +bool is_reduction(const Primitive& p) { + return typeid(p) == typeid(Reduce) || typeid(p) == typeid(ArgReduce); +} + bool is_fusable(const Primitive& p) { return is_unary(p) || is_binary(p) || is_broadcast(p) || is_noop(p); } -namespace detail { - -std::vector compile_replace( - const std::vector& tape, - const std::vector& trace_inputs, - const std::vector& trace_outputs, - const std::vector& inputs); - -} // namespace detail +bool allows_shapeless(const Primitive& p) { + return typeid(p) == typeid(Compiled) || is_unary(p) || is_binary(p) || + is_noop(p) || is_reduction(p) || typeid(p) == typeid(Softmax) || + typeid(p) == typeid(Sort) || typeid(p) == typeid(ArgSort) || + typeid(p) == typeid(ArgPartition) || typeid(p) == typeid(Partition); +} Compiled::Compiled( Stream stream, @@ -123,6 +124,23 @@ void Compiled::print(std::ostream& os) { } } +std::vector> Compiled::output_shapes( + const std::vector& inputs) { + size_t nd = 0; + for (auto& in : inputs) { + nd = std::max(nd, in.ndim()); + } + std::vector out_shape(nd, 0); + for (auto& in : inputs) { + auto dd = nd - in.ndim(); + for (auto i = dd; i < nd; ++i) { + out_shape[i] = std::max(out_shape[i], in.shape()[i - dd]); + } + } + // All outputs have the same shape + return std::vector>(outputs_.size(), out_shape); +} + namespace detail { CompileMode& compile_mode() { @@ -180,24 +198,30 @@ struct CompilerCache { std::vector outputs; std::vector tape; bool empty{true}; + std::vector constants; }; // Returns a reference to a CacheEntry which can be updated // by the caller to avoid copying large tapes / inputs / outputs - CacheEntry& find(size_t fun_id, const std::vector& inputs) { + CacheEntry& find( + size_t fun_id, + const std::vector& inputs, + bool shapeless, + const std::vector& constants) { // Try to find the entry auto [entry_it, inserted] = cache_.insert({fun_id, {}}); auto& entries = entry_it->second; - auto is_match = [](const std::vector& in1, - const std::vector& in2) { + auto is_match = [shapeless]( + const std::vector& in1, + const std::vector& in2) { if (in1.size() != in2.size()) { - std::ostringstream msg; - msg << "[compiler] Unexpected number of inputs to compiled function:" - << " expected " << in2.size() << " got " << in1.size() << "."; - throw std::invalid_argument(msg.str()); + return false; } for (int i = 0; i < in1.size(); ++i) { - if (in1[i].shape() != in2[i].shape()) { + if (in1[i].ndim() != in2[i].ndim()) { + return false; + } + if (!shapeless && in1[i].shape() != in2[i].shape()) { return false; } if (in1[i].dtype() != in2[i].dtype()) { @@ -213,7 +237,7 @@ struct CompilerCache { // more easily searchable structure. for (auto& entry : entries) { // Check the inputs match and return if so - if (is_match(inputs, entry.inputs)) { + if (is_match(inputs, entry.inputs) && constants == entry.constants) { return entry; } } @@ -322,6 +346,9 @@ void compile_simplify( case 1: v = *a.data(); break; + case 2: + v = *a.data(); + break; case 4: v = *a.data(); break; @@ -603,7 +630,7 @@ void compile_fuse( shapes, types, std::make_shared( - outputs.back().primitive().stream(), + old_outputs.back().primitive().stream(), inputs, old_outputs, std::move(fused_tape), @@ -651,7 +678,8 @@ std::vector compile_replace( const std::vector& tape, const std::vector& trace_inputs, const std::vector& trace_outputs, - const std::vector& inputs) { + const std::vector& inputs, + bool shapeless) { std::unordered_map trace_to_real; for (int i = 0; i < inputs.size(); ++i) { trace_to_real.insert({trace_inputs[i].id(), inputs[i]}); @@ -669,18 +697,29 @@ std::vector compile_replace( real_inputs.push_back(trace_to_real.at(in.id())); } if (a.siblings().empty()) { + auto shape = + shapeless ? a.primitive().output_shapes(real_inputs)[0] : a.shape(); auto real_a = array( - a.shape(), a.dtype(), a.primitive_ptr(), std::move(real_inputs)); + std::move(shape), + a.dtype(), + a.primitive_ptr(), + std::move(real_inputs)); trace_to_real.insert({a.id(), std::move(real_a)}); } else { // Ensure the order is correct for multi-output primitives - std::vector> shapes; std::vector types; auto trace_out = a.outputs(); for (auto& o : trace_out) { - shapes.push_back(o.shape()); types.push_back(o.dtype()); } + std::vector> shapes; + if (shapeless) { + shapes = a.primitive().output_shapes(real_inputs); + } else { + for (auto& o : trace_out) { + shapes.push_back(o.shape()); + } + } auto real_out = array::make_arrays(shapes, types, a.primitive_ptr(), real_inputs); for (int i = 0; i < trace_out.size(); ++i) { @@ -697,13 +736,34 @@ std::vector compile_replace( return outputs; } +void compile_validate_shapeless(const std::vector& tape) { + for (auto& t : tape) { + if (!t.has_primitive()) { + continue; + } + auto& p = t.primitive(); + if (allows_shapeless(p)) { + continue; + } + + std::ostringstream msg; + msg << "[compile] Cannot compile primitive "; + p.print(msg); + msg << " with shapeless enabled."; + throw std::invalid_argument(msg.str()); + } +} + std::function(const std::vector&)> compile( const std::function(const std::vector&)>& fun, - size_t fun_id) { + size_t fun_id, + bool shapeless /* = false */, + std::vector constants /* = {} */) { if (compile_mode() == CompileMode::disabled) { return fun; } - return [fun, fun_id](const std::vector& inputs) { + return [fun, fun_id, shapeless, constants = std::move(constants)]( + const std::vector& inputs) { // If the inputs are tracers, trace the original graph if (std::any_of(inputs.begin(), inputs.end(), [](auto& in) { return in.is_tracer(); @@ -712,12 +772,14 @@ std::function(const std::vector&)> compile( } // Find a cache entry with the correct inputs - auto& entry = compiler_cache().find(fun_id, inputs); + auto& entry = compiler_cache().find(fun_id, inputs, shapeless, constants); // No matching cache entry existed, so compile if (entry.empty) { // Mark the entry as not empty since we are about to fill it entry.empty = false; + // Set the constants + entry.constants = std::move(constants); // Trace to build the graph std::tie(entry.inputs, entry.outputs) = compile_trace(fun, inputs); @@ -739,11 +801,16 @@ std::function(const std::vector&)> compile( if (compile_mode() != CompileMode::no_fuse) { compile_fuse(entry.tape, parents_map, entry.inputs, entry.outputs); } + + if (shapeless) { + compile_validate_shapeless(entry.tape); + } } // At this point we must have a tape, now replace the placeholders // with real arrays that can be evaluated - return compile_replace(entry.tape, entry.inputs, entry.outputs, inputs); + return compile_replace( + entry.tape, entry.inputs, entry.outputs, inputs, shapeless); }; } @@ -754,12 +821,13 @@ void compile_erase(size_t fun_id) { } // namespace detail std::function(const std::vector&)> compile( - const std::function(const std::vector&)>& fun) { + const std::function(const std::vector&)>& fun, + bool shapeless /* false */) { if (detail::compile_mode() == CompileMode::disabled) { return fun; } auto fun_id = detail::getAddress(fun); - return detail::compile(fun, fun_id); + return detail::compile(fun, fun_id, shapeless); } void disable_compile() { diff --git a/mlx/compile.h b/mlx/compile.h index fb3115d61..1134c20dc 100644 --- a/mlx/compile.h +++ b/mlx/compile.h @@ -8,9 +8,10 @@ namespace mlx::core { enum class CompileMode { disabled, no_simplify, no_fuse, enabled }; -// Compile takes a function and returns a new function +/** Compile takes a function and returns a compiled function. */ std::function(const std::vector&)> compile( - const std::function(const std::vector&)>& fun); + const std::function(const std::vector&)>& fun, + bool shapeless = false); /** Globally disable compilation. * Setting the environment variable ``MLX_DISABLE_COMPILE`` can also diff --git a/mlx/fast.cpp b/mlx/fast.cpp new file mode 100644 index 000000000..ee28138f1 --- /dev/null +++ b/mlx/fast.cpp @@ -0,0 +1,130 @@ +// Copyright © 2023-2024 Apple Inc. + +#include "mlx/fast.h" +#include "mlx/fast_primitives.h" +#include "mlx/ops.h" +#include "mlx/transforms.h" + +namespace mlx::core::fast { + +std::vector Custom::vjp( + const std::vector& primals, + const std::vector& cotangents, + const std::vector& argnums, + const std::vector& outputs) { + auto [_, vjps] = mlx::core::vjp(fallback_, primals, cotangents); + std::vector vjp_outs; + for (int i = 0, j = 0; i < vjps.size(); ++i) { + if (i < argnums.size() && i == argnums[j]) { + vjp_outs.push_back(vjps[i]); + j++; + } + } + return vjp_outs; +} + +std::vector Custom::jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) { + auto [_, jvps] = mlx::core::jvp(fallback_, primals, tangents); + std::vector jvp_outs; + for (int i = 0, j = 0; i < jvps.size(); ++i) { + if (i < argnums.size() && i == argnums[j]) { + jvp_outs.push_back(jvps[i]); + j++; + } + } + return jvp_outs; +} + +std::pair, std::vector> Custom::vmap( + const std::vector& inputs, + const std::vector& axes) { + auto outputs = mlx::core::vmap(fallback_, axes)(inputs); + auto out_axes = std::vector(outputs.size(), 0); + return {outputs, out_axes}; +} + +array rope( + const array& x, + int dims, + bool traditional, + float base, + float scale, + int offset, + StreamOrDevice s /* = {} */) { + if (x.ndim() != 3) { + std::ostringstream msg; + msg << "[rope] Input must have 3 dimensions but got input with " << x.ndim() + << " dimensions."; + throw std::invalid_argument(msg.str()); + } + if (traditional && x.shape(-1) != dims) { + throw std::invalid_argument( + "[rope] Does not support partial traditional application."); + } + + auto fallback = [dims, traditional, base, scale, offset, s]( + const std::vector& inputs) { + auto& x = inputs[0]; + auto t = x.dtype(); + auto N = x.shape(1) + offset; + // Compute sines and cosines + auto half_dims = dims / 2; + auto positions = multiply(arange(offset, N, t, s), array(scale, t), s); + auto freqs = negative(arange(0, half_dims, t, s), s); + freqs = exp(multiply(freqs, array(std::log(base) / half_dims, t), s), s); + auto theta = + multiply(expand_dims(positions, 1, s), expand_dims(freqs, 0, s), s); + auto coss = cos(theta, s); + auto sins = sin(theta, s); + + if (traditional) { + auto x1 = slice(x, {0, 0, 0}, x.shape(), {1, 1, 2}, s); + auto x2 = slice(x, {0, 0, 1}, x.shape(), {1, 1, 2}, s); + std::vector outs; + outs.push_back(subtract(multiply(x1, coss, s), multiply(x2, sins, s), s)); + outs.push_back(add(multiply(x1, sins, s), multiply(x2, coss, s), s)); + for (auto& o : outs) { + o = expand_dims(o, 3, s); + } + return std::vector{reshape(concatenate(outs, 3, s), x.shape(), s)}; + } else { + auto out_s = x.shape(); + out_s.back() = half_dims; + auto x1 = slice(x, {0, 0, 0}, out_s, s); + out_s.back() = dims; + auto x2 = slice(x, {0, 0, half_dims}, out_s, s); + + std::vector outs; + outs.push_back(subtract(multiply(x1, coss, s), multiply(x2, sins, s), s)); + outs.push_back(add(multiply(x1, sins, s), multiply(x2, coss, s), s)); + if (dims < x.shape(-1)) { + outs.push_back(slice(x, {0, 0, dims}, x.shape(), s)); + } + return std::vector{concatenate(outs, 2, s)}; + } + }; + // TODO change to condition for using custom prim + auto stream = to_stream(s); + if (stream.device == Device::gpu && x.shape(-1) == dims) { + return array( + x.shape(), + x.dtype(), + std::make_unique( + stream, fallback, dims, traditional, base, scale, offset), + {x}); + } + return fallback({x})[0]; +} + +bool RoPE::is_equivalent(const Primitive& other) const { + const RoPE& a_other = static_cast(other); + return ( + dims_ == a_other.dims_ && base_ == a_other.base_ && + scale_ == a_other.scale_ && traditional_ == a_other.traditional_ && + offset_ == a_other.offset_); +} + +} // namespace mlx::core::fast diff --git a/mlx/fast.h b/mlx/fast.h new file mode 100644 index 000000000..48ac90a5a --- /dev/null +++ b/mlx/fast.h @@ -0,0 +1,18 @@ +// Copyright © 2023-2024 Apple Inc. + +#pragma once + +#include "mlx/utils.h" + +namespace mlx::core::fast { + +array rope( + const array& x, + int dims, + bool traditional, + float base, + float scale, + int offset, + StreamOrDevice s /* = {} */); + +} // namespace mlx::core::fast diff --git a/mlx/fast_primitives.h b/mlx/fast_primitives.h new file mode 100644 index 000000000..2b854960b --- /dev/null +++ b/mlx/fast_primitives.h @@ -0,0 +1,68 @@ +#include "mlx/primitives.h" + +namespace mlx::core::fast { + +// Custom primitive accepts a fallback function which it uses for +// transformations. Transformations are virtual so that derived classes may +// override the default behavior. +class Custom : public Primitive { + public: + explicit Custom( + Stream stream, + std::function(std::vector)> fallback) + : Primitive(stream), fallback_(fallback){}; + + virtual std::pair, std::vector> vmap( + const std::vector& inputs, + const std::vector& axes) override; + + virtual std::vector jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) override; + + virtual std::vector vjp( + const std::vector& primals, + const std::vector& cotangents, + const std::vector& argnums, + const std::vector& outputs) override; + + private: + std::function(std::vector)> fallback_; +}; + +class RoPE : public Custom { + public: + RoPE( + Stream stream, + std::function(std::vector)> fallback, + int dims, + bool traditional, + float base, + float scale, + int offset) + : Custom(stream, fallback), + dims_(dims), + traditional_(traditional), + base_(base), + scale_(scale), + offset_(offset){}; + + void eval_cpu(const std::vector& inputs, std::vector& outputs) + override; + void eval_gpu(const std::vector& inputs, std::vector& outputs) + override; + + DEFINE_PRINT(RoPE) + bool is_equivalent(const Primitive& other) const override; + + private: + std::function(std::vector)> fallback_; + int dims_; + bool traditional_; + float base_; + float scale_; + int offset_; +}; + +} // namespace mlx::core::fast diff --git a/mlx/graph_utils.cpp b/mlx/graph_utils.cpp index 7c1c17740..ba031c441 100644 --- a/mlx/graph_utils.cpp +++ b/mlx/graph_utils.cpp @@ -12,27 +12,24 @@ namespace mlx::core { -struct NodeNamer { - std::unordered_map names; - - std::string get_name(const array& x) { - auto it = names.find(x.id()); - if (it == names.end()) { - // Get the next name in the sequence - // [A, B, ..., Z, AA, AB, ...] - std::vector letters; - auto var_num = names.size() + 1; - while (var_num > 0) { - letters.push_back('A' + (var_num - 1) % 26); - var_num = (var_num - 1) / 26; - } - std::string name(letters.rbegin(), letters.rend()); - names.insert({x.id(), name}); - return name; +const std::string& NodeNamer::get_name(const array& x) { + auto it = names.find(x.id()); + if (it == names.end()) { + // Get the next name in the sequence + // [A, B, ..., Z, AA, AB, ...] + std::vector letters; + auto var_num = names.size() + 1; + while (var_num > 0) { + letters.push_back('A' + (var_num - 1) % 26); + var_num = (var_num - 1) / 26; } - return it->second; + std::string name(letters.rbegin(), letters.rend()); + names.insert({x.id(), name}); + + return get_name(x); } -}; + return it->second; +} void depth_first_traversal( std::function callback, diff --git a/mlx/graph_utils.h b/mlx/graph_utils.h index 3bd373bec..5e024704e 100644 --- a/mlx/graph_utils.h +++ b/mlx/graph_utils.h @@ -6,6 +6,12 @@ namespace mlx::core { +struct NodeNamer { + std::unordered_map names; + + const std::string& get_name(const array& x); +}; + void print_graph(std::ostream& os, const std::vector& outputs); template diff --git a/mlx/io.h b/mlx/io.h index c58e1959e..59866ea27 100644 --- a/mlx/io.h +++ b/mlx/io.h @@ -10,6 +10,14 @@ #include "mlx/stream.h" namespace mlx::core { +using GGUFMetaData = + std::variant>; +using GGUFLoad = std::pair< + std::unordered_map, + std::unordered_map>; +using SafetensorsLoad = std::pair< + std::unordered_map, + std::unordered_map>; /** Save array to out stream in .npy format */ void save(std::shared_ptr out_stream, array a); @@ -24,32 +32,29 @@ array load(std::shared_ptr in_stream, StreamOrDevice s = {}); array load(const std::string& file, StreamOrDevice s = {}); /** Load array map from .safetensors file format */ -std::unordered_map load_safetensors( +SafetensorsLoad load_safetensors( std::shared_ptr in_stream, StreamOrDevice s = {}); -std::unordered_map load_safetensors( +SafetensorsLoad load_safetensors( const std::string& file, StreamOrDevice s = {}); void save_safetensors( std::shared_ptr in_stream, - std::unordered_map); + std::unordered_map, + std::unordered_map metadata = {}); void save_safetensors( const std::string& file, - std::unordered_map); - -using MetaData = - std::variant>; + std::unordered_map, + std::unordered_map metadata = {}); /** Load array map and metadata from .gguf file format */ -std::pair< - std::unordered_map, - std::unordered_map> -load_gguf(const std::string& file, StreamOrDevice s = {}); + +GGUFLoad load_gguf(const std::string& file, StreamOrDevice s = {}); void save_gguf( std::string file, std::unordered_map array_map, - std::unordered_map meta_data = {}); + std::unordered_map meta_data = {}); } // namespace mlx::core diff --git a/mlx/io/gguf.cpp b/mlx/io/gguf.cpp index f4047d1a0..9e7953d6e 100644 --- a/mlx/io/gguf.cpp +++ b/mlx/io/gguf.cpp @@ -82,7 +82,7 @@ void set_mx_value_from_gguf( gguf_ctx* ctx, uint32_t type, gguf_value* val, - MetaData& value) { + GGUFMetaData& value) { switch (type) { case GGUF_VALUE_TYPE_UINT8: value = array(val->uint8, uint8); @@ -191,12 +191,12 @@ void set_mx_value_from_gguf( } } -std::unordered_map load_metadata(gguf_ctx* ctx) { - std::unordered_map metadata; +std::unordered_map load_metadata(gguf_ctx* ctx) { + std::unordered_map metadata; gguf_key key; while (gguf_get_key(ctx, &key)) { std::string key_name = std::string(key.name, key.namelen); - auto& val = metadata.insert({key_name, MetaData{}}).first->second; + auto& val = metadata.insert({key_name, GGUFMetaData{}}).first->second; set_mx_value_from_gguf(ctx, key.type, key.val, val); } return metadata; @@ -230,10 +230,7 @@ std::unordered_map load_arrays(gguf_ctx* ctx) { return array_map; } -std::pair< - std::unordered_map, - std::unordered_map> -load_gguf(const std::string& file, StreamOrDevice s) { +GGUFLoad load_gguf(const std::string& file, StreamOrDevice s) { gguf_ctx* ctx = gguf_open(file.c_str()); if (!ctx) { throw std::runtime_error("[load_gguf] gguf_init failed"); @@ -280,7 +277,7 @@ void append_kv_array( void save_gguf( std::string file, std::unordered_map array_map, - std::unordered_map metadata /* = {} */) { + std::unordered_map metadata /* = {} */) { // Add .gguf to file name if it is not there if (file.length() < 5 || file.substr(file.length() - 5, 5) != ".gguf") { file += ".gguf"; diff --git a/mlx/io/gguf_quants.cpp b/mlx/io/gguf_quants.cpp index 636648bc7..b9fe1e3bf 100644 --- a/mlx/io/gguf_quants.cpp +++ b/mlx/io/gguf_quants.cpp @@ -114,7 +114,6 @@ void gguf_load_quantized( << "has incompatible last dim shape: " << shape[shape.size() - 1]; throw std::runtime_error(msg.str()); } - const uint64_t num_blocks = tensor.num_weights / weights_per_block; std::vector weights_shape = shape; weights_shape.back() /= (weights_per_byte * 4); diff --git a/mlx/io/safetensor.cpp b/mlx/io/safetensor.cpp index 7e7868d49..1dd59f444 100644 --- a/mlx/io/safetensor.cpp +++ b/mlx/io/safetensor.cpp @@ -93,7 +93,7 @@ Dtype dtype_from_safetensor_str(std::string str) { } /** Load array from reader in safetensor format */ -std::unordered_map load_safetensors( +SafetensorsLoad load_safetensors( std::shared_ptr in_stream, StreamOrDevice s) { //////////////////////////////////////////////////////// @@ -121,9 +121,12 @@ std::unordered_map load_safetensors( size_t offset = jsonHeaderLength + 8; // Load the arrays using metadata std::unordered_map res; + std::unordered_map metadata_map; for (const auto& item : metadata.items()) { if (item.key() == "__metadata__") { - // ignore metadata for now + for (const auto& meta_item : item.value().items()) { + metadata_map.insert({meta_item.key(), meta_item.value()}); + } continue; } std::string dtype = item.value().at("dtype"); @@ -138,19 +141,18 @@ std::unordered_map load_safetensors( std::vector{}); res.insert({item.key(), loaded_array}); } - return res; + return {res, metadata_map}; } -std::unordered_map load_safetensors( - const std::string& file, - StreamOrDevice s) { +SafetensorsLoad load_safetensors(const std::string& file, StreamOrDevice s) { return load_safetensors(std::make_shared(file), s); } /** Save array to out stream in .npy format */ void save_safetensors( std::shared_ptr out_stream, - std::unordered_map a) { + std::unordered_map a, + std::unordered_map metadata /* = {} */) { //////////////////////////////////////////////////////// // Check file if (!out_stream->good() || !out_stream->is_open()) { @@ -161,9 +163,11 @@ void save_safetensors( //////////////////////////////////////////////////////// // Check array map json parent; - parent["__metadata__"] = json::object({ - {"format", "mlx"}, - }); + json _metadata; + for (auto& [key, value] : metadata) { + _metadata[key] = value; + } + parent["__metadata__"] = _metadata; size_t offset = 0; for (auto& [key, arr] : a) { arr.eval(); @@ -204,7 +208,8 @@ void save_safetensors( void save_safetensors( const std::string& file_, - std::unordered_map a) { + std::unordered_map a, + std::unordered_map metadata /* = {} */) { // Open and check file std::string file = file_; @@ -214,7 +219,7 @@ void save_safetensors( file += ".safetensors"; // Serialize array - save_safetensors(std::make_shared(file), a); + save_safetensors(std::make_shared(file), a, metadata); } } // namespace mlx::core diff --git a/mlx/mlx.h b/mlx/mlx.h index 7b33faba7..1963a4c50 100644 --- a/mlx/mlx.h +++ b/mlx/mlx.h @@ -6,6 +6,7 @@ #include "mlx/backend/metal/metal.h" #include "mlx/compile.h" #include "mlx/device.h" +#include "mlx/fast.h" #include "mlx/fft.h" #include "mlx/io.h" #include "mlx/linalg.h" diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 01ee6d388..97d4a3a2d 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -59,16 +59,6 @@ Dtype at_least_float(const Dtype& d) { } // namespace -Stream to_stream(StreamOrDevice s) { - if (std::holds_alternative(s)) { - return default_stream(default_device()); - } else if (std::holds_alternative(s)) { - return default_stream(std::get(s)); - } else { - return std::get(s); - } -} - array arange( double start, double stop, @@ -632,6 +622,13 @@ std::vector split( std::vector split(const array& a, int num_splits, int axis, StreamOrDevice s /* = {} */) { + auto ax = axis < 0 ? axis + a.ndim() : axis; + if (ax < 0 || ax >= a.ndim()) { + std::ostringstream msg; + msg << "Invalid axis " << axis << " passed to split" + << " for array with shape " << a.shape() << "."; + throw std::invalid_argument(msg.str()); + } auto q_and_r = std::ldiv(a.shape(axis), num_splits); if (q_and_r.rem) { std::ostringstream msg; @@ -3384,4 +3381,34 @@ std::vector depends( shapes, dtypes, std::make_shared(to_stream(s)), all_inputs); } +array atleast_1d(const array& a, StreamOrDevice s /* = {} */) { + if (a.ndim() == 0) { + return reshape(a, {1}, s); + } + return a; +} + +array atleast_2d(const array& a, StreamOrDevice s /* = {} */) { + switch (a.ndim()) { + case 0: + return reshape(a, {1, 1}, s); + case 1: + return reshape(a, {1, static_cast(a.size())}, s); + default: + return a; + } +} + +array atleast_3d(const array& a, StreamOrDevice s /* = {} */) { + switch (a.ndim()) { + case 0: + return reshape(a, {1, 1, 1}, s); + case 1: + return reshape(a, {1, static_cast(a.size()), 1}, s); + case 2: + return reshape(a, {a.shape(0), a.shape(1), 1}, s); + default: + return a; + } +} } // namespace mlx::core diff --git a/mlx/ops.h b/mlx/ops.h index a4b1dd1ef..b61224d65 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -3,18 +3,14 @@ #pragma once #include -#include #include "mlx/array.h" #include "mlx/device.h" #include "mlx/stream.h" +#include "mlx/utils.h" namespace mlx::core { -using StreamOrDevice = std::variant; - -Stream to_stream(StreamOrDevice s); - /** Creation operations */ /** @@ -1125,4 +1121,9 @@ std::vector depends( const std::vector& inputs, const std::vector& dependencies); +/** convert an array to an atleast ndim array */ +array atleast_1d(const array& a, StreamOrDevice s = {}); +array atleast_2d(const array& a, StreamOrDevice s = {}); +array atleast_3d(const array& a, StreamOrDevice s = {}); + } // namespace mlx::core diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index a7e1d205d..b2daaa2f8 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -71,6 +71,15 @@ std::pair, std::vector> Primitive::vmap( throw std::invalid_argument("Primitive's vmap not implemented."); }; +std::vector> Primitive::output_shapes( + const std::vector&) { + std::ostringstream msg; + msg << "[Primitive::output_shapes] "; + this->print(msg); + msg << " cannot infer output shapes."; + throw std::invalid_argument(msg.str()); +}; + std::vector Abs::vjp( const std::vector& primals, const std::vector& cotangents, @@ -383,6 +392,13 @@ std::pair, std::vector> ArgSort::vmap( return {{argsort(inputs[0], axis_ + (axes[0] <= axis_), stream())}, axes}; } +std::vector> ArgReduce::output_shapes( + const std::vector& inputs) { + auto out_shape = inputs[0].shape(); + out_shape[axis_] = 1; + return {out_shape}; +} + bool ArgSort::is_equivalent(const Primitive& other) const { const ArgSort& r_other = static_cast(other); return axis_ == r_other.axis_; @@ -628,7 +644,6 @@ std::vector Convolution::vjp( auto& wt = primals[1]; auto cotan = cotangents[0]; - int N = in.shape(0); int O = wt.shape(0); // Resolve Padded input shapes and strides @@ -2203,6 +2218,15 @@ bool Reduce::is_equivalent(const Primitive& other) const { return reduce_type_ == r_other.reduce_type_ && axes_ == r_other.axes_; } +std::vector> Reduce::output_shapes( + const std::vector& inputs) { + std::vector out_shape = inputs[0].shape(); + for (auto i : axes_) { + out_shape[i] = 1; + } + return {out_shape}; +} + std::vector Round::vjp( const std::vector& primals, const std::vector& cotangents, diff --git a/mlx/primitives.h b/mlx/primitives.h index 5bdee12cf..73e4394a5 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -36,6 +36,12 @@ return true; \ } +#define DEFINE_INPUT_OUTPUT_SHAPE() \ + std::vector> output_shapes( \ + const std::vector& inputs) override { \ + return {inputs[0].shape()}; \ + }; + namespace mlx::core { // Abstract base class @@ -102,6 +108,11 @@ class Primitive { return false; } + /** Get the output shapes of the primitive. This is not required to be + * implemented by derived classes, in which case it will throw. */ + virtual std::vector> output_shapes( + const std::vector& inputs); + virtual ~Primitive() = default; Primitive(const Primitive& other) = delete; Primitive(Primitive&& other) = delete; @@ -152,6 +163,7 @@ class Abs : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(Abs) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -168,6 +180,7 @@ class Add : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(Add) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -226,6 +239,7 @@ class ArcCos : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(ArcCos) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -242,6 +256,7 @@ class ArcCosh : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(ArcCosh) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -258,6 +273,7 @@ class ArcSin : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(ArcSin) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -274,6 +290,7 @@ class ArcSinh : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(ArcSinh) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -290,6 +307,7 @@ class ArcTan : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(ArcTan) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -306,6 +324,7 @@ class ArcTanh : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(ArcTanh) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -321,6 +340,7 @@ class ArgPartition : public UnaryPrimitive { DEFINE_VMAP() DEFINE_PRINT(ArgPartition) + DEFINE_INPUT_OUTPUT_SHAPE() bool is_equivalent(const Primitive& other) const override; private: @@ -346,6 +366,8 @@ class ArgReduce : public UnaryPrimitive { DEFINE_VMAP() DEFINE_PRINT(ArgReduce) bool is_equivalent(const Primitive& other) const override; + std::vector> output_shapes( + const std::vector& inputs) override; private: ReduceType reduce_type_; @@ -364,6 +386,7 @@ class ArgSort : public UnaryPrimitive { DEFINE_VMAP() DEFINE_PRINT(ArgSort) + DEFINE_INPUT_OUTPUT_SHAPE() bool is_equivalent(const Primitive& other) const override; private: @@ -383,6 +406,7 @@ class AsType : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() DEFINE_PRINT(AsType) + DEFINE_INPUT_OUTPUT_SHAPE() bool is_equivalent(const Primitive& other) const override; private: @@ -448,6 +472,7 @@ class Ceil : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(Ceil) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -473,24 +498,27 @@ class Compiled : public Primitive { void eval_cpu(const std::vector& inputs, std::vector& outputs) override; - void eval_gpu(const std::vector& inputs, std::vector& outputs) override; DEFINE_VMAP() DEFINE_GRADS() - + std::vector> output_shapes( + const std::vector& inputs) override; void print(std::ostream& os) override; - bool is_equivalent(const Primitive& other) const override; + std::string lib_name() const { + return kernel_lib_; + } + private: const std::vector inputs_; const std::vector outputs_; const std::vector tape_; const std::unordered_set constant_ids_; - void eval(const std::vector& inputs, std::vector& out); + std::string kernel_lib_; }; class Concatenate : public UnaryPrimitive { @@ -558,6 +586,7 @@ class Copy : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(Copy) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -574,6 +603,7 @@ class Cos : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(Cos) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -590,6 +620,7 @@ class Cosh : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(Cosh) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -660,6 +691,7 @@ class Divide : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(Divide) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -678,6 +710,10 @@ class DivMod : public Primitive { DEFINE_GRADS() DEFINE_PRINT(DivMod) DEFINE_DEFAULT_IS_EQUIVALENT() + std::vector> output_shapes( + const std::vector& inputs) override { + return std::vector{inputs[0].shape(), inputs[0].shape()}; + }; private: void eval(const std::vector& inputs, std::vector& outputs); @@ -694,6 +730,7 @@ class Remainder : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(Remainder) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -709,8 +746,16 @@ class Equal : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Equal) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() + + void print(std::ostream& os) override { + if (equal_nan_) { + os << "NanEqual"; + } else { + os << "Equal"; + } + } private: void eval(const std::vector& inputs, array& out); @@ -728,6 +773,7 @@ class Erf : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(Erf) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -744,6 +790,7 @@ class ErfInv : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(ErfInv) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -760,6 +807,7 @@ class Exp : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(Exp) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -802,6 +850,7 @@ class Floor : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(Floor) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -856,6 +905,7 @@ class Greater : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(Greater) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -872,6 +922,7 @@ class GreaterEqual : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(GreaterEqual) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -888,6 +939,7 @@ class Less : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(Less) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -904,6 +956,7 @@ class LessEqual : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(LessEqual) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -945,8 +998,22 @@ class Log : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Log) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() + + void print(std::ostream& os) override { + switch (base_) { + case e: + os << "Log"; + break; + case two: + os << "Log2"; + break; + case ten: + os << "Log10"; + break; + } + } private: Base base_; @@ -963,6 +1030,7 @@ class Log1p : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() DEFINE_PRINT(Log1p) + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -979,6 +1047,7 @@ class LogicalNot : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(LogicalNot) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -995,6 +1064,7 @@ class LogicalAnd : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(LogicalAnd) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -1011,6 +1081,7 @@ class LogicalOr : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(LogicalOr) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -1027,6 +1098,7 @@ class LogAddExp : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(LogAddExp) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -1060,6 +1132,7 @@ class Maximum : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(Maximum) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -1076,6 +1149,7 @@ class Minimum : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(Minimum) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -1092,6 +1166,7 @@ class Multiply : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(Multiply) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -1108,6 +1183,7 @@ class Negative : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(Negative) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -1124,6 +1200,7 @@ class NotEqual : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(NotEqual) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -1168,6 +1245,7 @@ class Partition : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() DEFINE_PRINT(Partition) + DEFINE_INPUT_OUTPUT_SHAPE() bool is_equivalent(const Primitive& other) const override; private: @@ -1188,6 +1266,7 @@ class Power : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(Power) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -1280,6 +1359,9 @@ class Reduce : public UnaryPrimitive { const std::vector& argnums, const std::vector& outputs) override; + std::vector> output_shapes( + const std::vector& inputs) override; + void print(std::ostream& os) override { switch (reduce_type_) { case And: @@ -1322,6 +1404,7 @@ class Round : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(Round) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -1430,6 +1513,7 @@ class Sigmoid : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(Sigmoid) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -1446,6 +1530,7 @@ class Sign : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(Sign) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -1462,6 +1547,7 @@ class Sin : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(Sin) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -1478,6 +1564,7 @@ class Sinh : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(Sinh) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -1522,6 +1609,7 @@ class Softmax : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(Softmax) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -1538,6 +1626,7 @@ class Sort : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() DEFINE_PRINT(Sort) + DEFINE_INPUT_OUTPUT_SHAPE() bool is_equivalent(const Primitive& other) const override; private: @@ -1579,6 +1668,7 @@ class Square : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(Square) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -1594,9 +1684,17 @@ class Sqrt : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Sqrt) + DEFINE_INPUT_OUTPUT_SHAPE() bool is_equivalent(const Primitive& other) const override; + void print(std::ostream& os) override { + if (recip_) { + os << "Rsqrt"; + } else { + os << "Sqrt"; + } + } + private: void eval(const std::vector& inputs, array& out); bool recip_; @@ -1612,6 +1710,7 @@ class StopGradient : public UnaryPrimitive { DEFINE_VMAP() DEFINE_PRINT(StopGradient) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -1628,6 +1727,7 @@ class Subtract : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(Subtract) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -1644,6 +1744,7 @@ class Tan : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(Tan) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); @@ -1660,6 +1761,7 @@ class Tanh : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(Tanh) DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() private: void eval(const std::vector& inputs, array& out); diff --git a/mlx/random.cpp b/mlx/random.cpp index 63e39cdcc..5e0682d32 100644 --- a/mlx/random.cpp +++ b/mlx/random.cpp @@ -153,14 +153,23 @@ array uniform( array normal( const std::vector& shape, Dtype dtype, + const float loc /* = 0.0 */, + const float scale /* = 1.0 */, const std::optional& key /*= nullopt */, StreamOrDevice s /* = {} */) { auto stream = to_stream(s); auto low = array(std::nextafter(-1.0f, 0.0f), dtype); auto high = array(1.0f, dtype); auto samples = uniform(low, high, shape, dtype, key, stream); - return multiply( - array(std::sqrt(2.0), dtype), erfinv(samples, stream), stream); + samples = + multiply(array(std::sqrt(2.0), dtype), erfinv(samples, stream), stream); + if (scale != 1.0) { + samples = multiply(array(scale, dtype), samples, stream); + } + if (loc != 0.0) { + samples = add(array(loc, dtype), samples, stream); + } + return samples; } array randint( diff --git a/mlx/random.h b/mlx/random.h index ab75eb488..1397b32d7 100644 --- a/mlx/random.h +++ b/mlx/random.h @@ -95,13 +95,30 @@ inline array uniform( array normal( const std::vector& shape, Dtype dtype, + const float loc, + const float scale, const std::optional& key = std::nullopt, StreamOrDevice s = {}); inline array normal( const std::vector& shape, + const float loc, + const float scale, const std::optional& key = std::nullopt, StreamOrDevice s = {}) { - return normal(shape, float32, key, s); + return normal(shape, float32, loc, scale, key, s); +} +inline array normal( + const std::vector& shape, + const Dtype dtype, + const std::optional& key = std::nullopt, + StreamOrDevice s = {}) { + return normal(shape, dtype, 0.0, 1.0, key, s); +} +inline array normal( + const std::vector& shape, + const std::optional& key = std::nullopt, + StreamOrDevice s = {}) { + return normal(shape, float32, 0.0, 1.0, key, s); } /** Generate integer samples uniformly at random */ diff --git a/mlx/transforms_impl.h b/mlx/transforms_impl.h index a1b3461ab..6c7959426 100644 --- a/mlx/transforms_impl.h +++ b/mlx/transforms_impl.h @@ -18,7 +18,9 @@ std::vector vmap_replace( // idea. std::function(const std::vector&)> compile( const std::function(const std::vector&)>& fun, - size_t fun_id); + size_t fun_id, + bool shapeless = false, + std::vector constants = {}); // Erase cached compile functions void compile_erase(size_t fun_id); diff --git a/mlx/types/complex.h b/mlx/types/complex.h index 55cbe447a..f8a607766 100644 --- a/mlx/types/complex.h +++ b/mlx/types/complex.h @@ -35,6 +35,16 @@ inline bool operator>(const complex64_t& a, const complex64_t& b) { return (a.real() > b.real()) || (a.real() == b.real() && a.imag() > b.imag()); } +inline complex64_t operator%(complex64_t a, complex64_t b) { + auto real = a.real() - (b.real() * static_cast(a.real() / b.real())); + auto imag = a.imag() - (b.imag() * static_cast(a.imag() / b.imag())); + if (real != 0 && ((real < 0) != (b.real() < 0))) + real += b.real(); + if (imag != 0 && ((imag < 0) != (b.imag() < 0))) + imag += b.imag(); + return {real, imag}; +} + inline bool operator<=(const complex64_t& a, const complex64_t& b) { return operator>=(b, a); } @@ -50,25 +60,30 @@ inline complex64_t operator-(const complex64_t& v) { // clang-format off #define complex_binop_helper(_op_, _operator_, itype) \ inline complex64_t _operator_(itype x, const complex64_t& y) { \ - return x _op_ static_cast>(y); \ + return static_cast(x) _op_ y; \ } \ inline complex64_t _operator_(const complex64_t& x, itype y) { \ - return static_cast>(x) _op_ y; \ + return x _op_ static_cast(y); \ } -#define complex_binop(_op_, _operator_) \ - inline complex64_t _operator_(const complex64_t& x, const complex64_t& y) { \ - return static_cast>(x) \ - _op_ static_cast>(y); \ - } \ - complex_binop_helper(_op_, _operator_, bool) \ - complex_binop_helper(_op_, _operator_, uint32_t) \ - complex_binop_helper(_op_, _operator_, uint64_t) \ - complex_binop_helper(_op_, _operator_, int32_t) \ - complex_binop_helper(_op_, _operator_, int64_t) \ - complex_binop_helper(_op_, _operator_, float16_t) \ - complex_binop_helper(_op_, _operator_, bfloat16_t) \ - complex_binop_helper(_op_, _operator_, const std::complex&) \ +#define complex_binop(_op_, _operator_) \ + inline complex64_t _operator_(const std::complex& x, const complex64_t& y) { \ + return x _op_ static_cast>(y); \ + } \ + inline complex64_t _operator_(const complex64_t& x, const std::complex& y) { \ + return static_cast>(x) _op_ y; \ + } \ + inline complex64_t _operator_(const complex64_t& x, const complex64_t& y) { \ + return static_cast>(x) \ + _op_ static_cast>(y); \ + } \ + complex_binop_helper(_op_, _operator_, bool) \ + complex_binop_helper(_op_, _operator_, uint32_t) \ + complex_binop_helper(_op_, _operator_, uint64_t) \ + complex_binop_helper(_op_, _operator_, int32_t) \ + complex_binop_helper(_op_, _operator_, int64_t) \ + complex_binop_helper(_op_, _operator_, float16_t) \ + complex_binop_helper(_op_, _operator_, bfloat16_t) \ complex_binop_helper(_op_, _operator_, float) // clang-format on diff --git a/mlx/utils.cpp b/mlx/utils.cpp index eece43717..c6365beb9 100644 --- a/mlx/utils.cpp +++ b/mlx/utils.cpp @@ -7,6 +7,16 @@ namespace mlx::core { +Stream to_stream(StreamOrDevice s) { + if (std::holds_alternative(s)) { + return default_stream(default_device()); + } else if (std::holds_alternative(s)) { + return default_stream(std::get(s)); + } else { + return std::get(s); + } +} + void PrintFormatter::print(std::ostream& os, bool val) { if (capitalize_bool) { os << (val ? "True" : "False"); diff --git a/mlx/utils.h b/mlx/utils.h index f28970369..ebcca3a1e 100644 --- a/mlx/utils.h +++ b/mlx/utils.h @@ -2,6 +2,8 @@ #pragma once +#include + #include "array.h" #include "device.h" #include "dtype.h" @@ -9,6 +11,30 @@ namespace mlx::core { +using StreamOrDevice = std::variant; +Stream to_stream(StreamOrDevice s); + +struct StreamContext { + public: + StreamContext(StreamOrDevice s) : _stream(default_stream(default_device())) { + if (std::holds_alternative(s)) { + throw std::runtime_error( + "[StreamContext] Invalid argument, please specify a stream or device."); + } + auto _s = to_stream(s); + set_default_device(_s.device); + set_default_stream(_s); + } + + ~StreamContext() { + set_default_device(_stream.device); + set_default_stream(_stream); + } + + private: + Stream _stream; +}; + struct PrintFormatter { inline void print(std::ostream& os, bool val); inline void print(std::ostream& os, int16_t val); @@ -51,7 +77,7 @@ std::ostream& operator<<(std::ostream& os, array a); std::ostream& operator<<(std::ostream& os, const std::vector& v); std::ostream& operator<<(std::ostream& os, const std::vector& v); inline std::ostream& operator<<(std::ostream& os, const complex64_t& v) { - return os << v.real() << (v.imag() > 0 ? "+" : "") << v.imag() << "j"; + return os << v.real() << (v.imag() >= 0 ? "+" : "") << v.imag() << "j"; } inline std::ostream& operator<<(std::ostream& os, const float16_t& v) { return os << static_cast(v); diff --git a/python/mlx/nn/init.py b/python/mlx/nn/init.py index 5afc6170e..6596ba741 100644 --- a/python/mlx/nn/init.py +++ b/python/mlx/nn/init.py @@ -60,7 +60,7 @@ def normal( """ def initializer(a: mx.array) -> mx.array: - return std * mx.random.normal(shape=a.shape, dtype=dtype) + mean + return mx.random.normal(shape=a.shape, scale=std, loc=mean, dtype=dtype) return initializer @@ -184,7 +184,7 @@ def glorot_normal( def initializer(a: mx.array, gain: float = 1.0) -> mx.array: fan_in, fan_out = _calculate_fan_in_fan_out(a) std = gain * math.sqrt(2.0 / (fan_in + fan_out)) - return mx.random.normal(shape=a.shape, dtype=dtype) * std + return mx.random.normal(shape=a.shape, scale=std, dtype=dtype) return initializer @@ -285,7 +285,7 @@ def he_normal( raise ValueError(f"Invalid mode: {mode}. Valid modes are: fan_in, fan_out") std = gain / math.sqrt(fan) - return mx.random.normal(shape=a.shape, dtype=dtype) * std + return mx.random.normal(shape=a.shape, scale=std, dtype=dtype) return initializer diff --git a/python/mlx/nn/layers/__init__.py b/python/mlx/nn/layers/__init__.py index f5092418b..207cb01b2 100644 --- a/python/mlx/nn/layers/__init__.py +++ b/python/mlx/nn/layers/__init__.py @@ -58,6 +58,7 @@ from mlx.nn.layers.normalization import ( LayerNorm, RMSNorm, ) +from mlx.nn.layers.pooling import AvgPool1d, AvgPool2d, MaxPool1d, MaxPool2d from mlx.nn.layers.positional_encoding import ALiBi, RoPE, SinusoidalPositionalEncoding from mlx.nn.layers.quantized import QuantizedLinear from mlx.nn.layers.transformer import ( diff --git a/python/mlx/nn/layers/activations.py b/python/mlx/nn/layers/activations.py index db07ce190..dfd435cfd 100644 --- a/python/mlx/nn/layers/activations.py +++ b/python/mlx/nn/layers/activations.py @@ -1,6 +1,7 @@ # Copyright © 2023 Apple Inc. import math +from functools import partial from typing import Any import mlx.core as mx @@ -9,13 +10,13 @@ from mlx.nn.layers.base import Module def _make_activation_module(f): def decorator(klass): - klass.__doc__ = f.__doc__ - klass.__call__ = lambda self, x: f(x) + klass.__call__ = lambda _, x: f(x) return klass return decorator +@partial(mx.compile, shapeless=True) def sigmoid(x): r"""Applies the element-wise function: @@ -25,6 +26,7 @@ def sigmoid(x): return mx.sigmoid(x) +@partial(mx.compile, shapeless=True) def relu(x): r"""Applies the Rectified Linear Unit. @@ -33,6 +35,7 @@ def relu(x): return mx.maximum(x, 0) +@partial(mx.compile, shapeless=True) def leaky_relu(x, negative_slope=0.01): r"""Applies the Leaky Rectified Linear Unit. @@ -41,6 +44,7 @@ def leaky_relu(x, negative_slope=0.01): return mx.maximum(negative_slope * x, x) +@partial(mx.compile, shapeless=True) def log_softmax(x, axis=-1): r"""Applies the Log Softmax function. @@ -49,6 +53,7 @@ def log_softmax(x, axis=-1): return x - mx.logsumexp(x, axis=axis, keepdims=True) +@partial(mx.compile, shapeless=True) def elu(x, alpha=1.0): r"""Applies the Exponential Linear Unit. @@ -57,6 +62,7 @@ def elu(x, alpha=1.0): return mx.where(x > 0, x, alpha * (mx.exp(x) - 1)) +@partial(mx.compile, shapeless=True) def relu6(x): r"""Applies the Rectified Linear Unit 6. @@ -65,6 +71,7 @@ def relu6(x): return mx.minimum(mx.maximum(x, 0), 6.0) +@partial(mx.compile, shapeless=True) def softmax(x, axis=-1): r"""Applies the Softmax function. @@ -73,6 +80,7 @@ def softmax(x, axis=-1): return mx.softmax(x, axis=axis) +@partial(mx.compile, shapeless=True) def softplus(x): r"""Applies the Softplus function. @@ -81,6 +89,7 @@ def softplus(x): return mx.logaddexp(x, 0) +@partial(mx.compile, shapeless=True) def softsign(x): r"""Applies the Softsign function. @@ -89,6 +98,7 @@ def softsign(x): return mx.divide(x, 1 + mx.abs(x)) +@partial(mx.compile, shapeless=True) def softshrink(x, lambd: float = 0.5): r"""Applies the Softshrink activation function. @@ -102,6 +112,7 @@ def softshrink(x, lambd: float = 0.5): return mx.where(mx.abs(x) > lambd, x - mx.sign(x) * lambd, 0) +@partial(mx.compile, shapeless=True) def celu(x, alpha=1.0): r"""Applies the Continuously Differentiable Exponential Linear Unit. @@ -111,6 +122,7 @@ def celu(x, alpha=1.0): return mx.maximum(x, 0.0) + alpha * (mx.exp(mx.minimum(x, 0.0) / alpha) - 1) +@partial(mx.compile, shapeless=True) def silu(x): r"""Applies the Sigmoid Linear Unit. Also known as Swish. @@ -120,6 +132,7 @@ def silu(x): return x * mx.sigmoid(x) +@partial(mx.compile, shapeless=True) def log_sigmoid(x): r"""Applies the Log Sigmoid function. @@ -128,6 +141,7 @@ def log_sigmoid(x): return -softplus(-x) +@partial(mx.compile, shapeless=True) def gelu(x): r"""Applies the Gaussian Error Linear Units function. @@ -142,6 +156,7 @@ def gelu(x): return x * (1 + mx.erf(x / math.sqrt(2))) / 2 +@partial(mx.compile, shapeless=True) def gelu_approx(x): r"""An approximation to Gaussian Error Linear Unit. @@ -159,6 +174,7 @@ def gelu_approx(x): return x * mx.sigmoid(1.60033 * x * (1 + 0.0433603 * x.square())) +@partial(mx.compile, shapeless=True) def gelu_fast_approx(x): r"""A fast approximation to Gaussian Error Linear Unit. @@ -192,27 +208,7 @@ def glu(x: mx.array, axis: int = -1) -> mx.array: return a * mx.sigmoid(b) -class GLU(Module): - r"""Applies the gated linear unit function. - - This function splits the ``axis`` dimension of the input into two halves - (:math:`a` and :math:`b`) and applies :math:`a * \sigma(b)`. - - .. math:: - textrm{GLU}(x) = a * \sigma(b) - - Args: - axis (int): The dimension to split along. Default: ``-1`` - """ - - def __init__(self, axis: int = -1): - super().__init__() - self.axis = axis - - def __call__(self, x) -> Any: - return glu(x=x, axis=self.axis) - - +@partial(mx.compile, shapeless=True) def step(x: mx.array, threshold: float = 0.0): r"""Applies the Step Activation Function. @@ -232,6 +228,7 @@ def step(x: mx.array, threshold: float = 0.0): return mx.where(x > threshold, 1, 0) +@partial(mx.compile, shapeless=True) def selu(x): r"""Applies the Scaled Exponential Linear Unit. @@ -248,6 +245,7 @@ def selu(x): return elu(x, 1.67326) * 1.0507 +@partial(mx.compile, shapeless=True) def prelu(x: mx.array, alpha: mx.array) -> mx.array: r"""Applies the element-wise parametric ReLU. @@ -259,6 +257,7 @@ def prelu(x: mx.array, alpha: mx.array) -> mx.array: return mx.maximum(0, x) + alpha * mx.minimum(0, x) +@partial(mx.compile, shapeless=True) def mish(x: mx.array) -> mx.array: r"""Applies the Mish function, element-wise. Mish: A Self Regularized Non-Monotonic Neural Activation Function. @@ -272,6 +271,7 @@ def mish(x: mx.array) -> mx.array: return x * mx.tanh(softplus(x)) +@partial(mx.compile, shapeless=True) def hardswish(x): r"""Applies the hardswish function, element-wise. @@ -282,6 +282,35 @@ def hardswish(x): return x * mx.minimum(max_x_3, 6) / 6 +def tanh(x): + """Applies the hyperbolic tangent function. + + Simply ``mx.tanh(x)``. + """ + return mx.tanh(x) + + +class GLU(Module): + r"""Applies the gated linear unit function. + + This function splits the ``axis`` dimension of the input into two halves + (:math:`a` and :math:`b`) and applies :math:`a * \sigma(b)`. + + .. math:: + textrm{GLU}(x) = a * \sigma(b) + + Args: + axis (int): The dimension to split along. Default: ``-1`` + """ + + def __init__(self, axis: int = -1): + super().__init__() + self.axis = axis + + def __call__(self, x) -> Any: + return glu(x=x, axis=self.axis) + + @_make_activation_module(mx.sigmoid) class Sigmoid(Module): r"""Applies the sigmoid function, element-wise. @@ -500,14 +529,6 @@ class GELU(Module): return self._act(x) -def tanh(x): - """Applies the hyperbolic tangent function. - - Simply ``mx.tanh(x)``. - """ - return mx.tanh(x) - - @_make_activation_module(tanh) class Tanh(Module): r"""Applies the hyperbolic tangent function. diff --git a/python/mlx/nn/layers/base.py b/python/mlx/nn/layers/base.py index 3da1993ec..de7097673 100644 --- a/python/mlx/nn/layers/base.py +++ b/python/mlx/nn/layers/base.py @@ -66,6 +66,19 @@ class Module(dict): """Boolean indicating if the model is in training mode.""" return self._training + @property + def state(self): + """The module's state dictionary + + The module's state dictionary contains any attribute set on the + module including parameters in :meth:`Module.parameters` + + Unlike :meth:`Module.parameters`, the :attr:`Module.state` property is + a reference to the module's state. Updates to it will be reflected in + the original module. + """ + return self + def _extra_repr(self): return "" @@ -312,7 +325,7 @@ class Module(dict): elif isinstance(current_value, (dict, list)): apply(current_value, new_value) elif isinstance(parameters, list): - for i in range(len(dst)): + for i in range(len(parameters)): current_value = dst[i] new_value = parameters[i] if isinstance(current_value, mx.array): diff --git a/python/mlx/nn/layers/embedding.py b/python/mlx/nn/layers/embedding.py index c62b1206f..18482eddc 100644 --- a/python/mlx/nn/layers/embedding.py +++ b/python/mlx/nn/layers/embedding.py @@ -21,7 +21,7 @@ class Embedding(Module): def __init__(self, num_embeddings: int, dims: int): super().__init__() scale = math.sqrt(1 / dims) - self.weight = mx.random.normal((num_embeddings, dims)) * scale + self.weight = mx.random.normal(shape=(num_embeddings, dims), scale=scale) def _extra_repr(self): return f"{self.weight.shape[0]}, {self.weight.shape[1]}" diff --git a/python/mlx/nn/layers/pooling.py b/python/mlx/nn/layers/pooling.py new file mode 100644 index 000000000..ffa05f5d2 --- /dev/null +++ b/python/mlx/nn/layers/pooling.py @@ -0,0 +1,308 @@ +# Copyright © 2023-2024 Apple Inc. + +import operator +from itertools import accumulate +from typing import Optional, Tuple, Union + +import mlx.core as mx +from mlx.nn.layers.base import Module + + +def _value_or_list(x, n, msg): + if isinstance(x, (list, tuple)): + if len(x) != n: + raise ValueError(msg) + return list(x) + + if not isinstance(x, int): + raise ValueError(msg) + + return [x] * n + + +def _sliding_windows(x, window_shape, window_strides): + if x.ndim < 3: + raise ValueError( + f"To extract sliding windows at least 1 spatial dimension " + f"(3 total) is needed but the input only has {x.ndim} dimensions." + ) + + spatial_dims = x.shape[1:-1] + if not (len(spatial_dims) == len(window_shape) == len(window_strides)): + raise ValueError( + f"To extract sliding windows the window shapes and strides must have " + f"the same number of spatial dimensions as the signal but the signal " + f"has {len(spatial_dims)} dims and the window shape has {len(window_shape)} " + f"and strides have {len(window_strides)}." + ) + + shape = x.shape + strides = list(reversed(list(accumulate(reversed(shape + (1,)), operator.mul))))[1:] + + # Compute the output shape + final_shape = [shape[0]] + final_shape += [ + (size - window) // stride + 1 + for size, window, stride in zip(spatial_dims, window_shape, window_strides) + ] + final_shape += window_shape + final_shape += [shape[-1]] + + # Compute the output strides + final_strides = strides[:1] + final_strides += [ + og_stride * stride for og_stride, stride in zip(strides[1:-1], window_strides) + ] + final_strides += strides[1:-1] + final_strides += strides[-1:] # should always be [1] + + return mx.as_strided(x, final_shape, final_strides) + + +class _Pool(Module): + def __init__(self, pooling_function, kernel_size, stride, padding, padding_value): + super().__init__() + + self._pooling_function = pooling_function + self._kernel_size = kernel_size + self._stride = stride + self._padding = padding + self._padding_value = padding_value + self._axes = tuple(range(-len(self._kernel_size) - 1, -1, 1)) + + def _extra_repr(self): + ks = tuple(self._kernel_size) + st = tuple(self._stride) + pd = tuple(p[0] for p in self._padding) + + return f"kernel_size={ks}, stride={st}, padding={pd}" + + def __call__(self, x): + if any(p[0] > 0 for p in self._padding): + x = mx.pad(x, [(0, 0)] + self._padding + [(0, 0)], self._padding_value) + x = _sliding_windows(x, self._kernel_size, self._stride) + return self._pooling_function(x, self._axes) + + +class _Pool1d(_Pool): + def __init__( + self, + pooling_function, + padding_value, + kernel_size: Union[int, Tuple[int]], + stride: Optional[Union[int, Tuple[int]]] = None, + padding: Union[int, Tuple[int]] = 0, + ): + class_name = type(self).__name__ + msg = "[{}] '{}' must be an integer or a tuple containing 1 integer" + kernel_size = _value_or_list( + kernel_size, 1, msg.format(class_name, "kernel_size") + ) + if stride is not None: + stride = _value_or_list(stride, 1, msg.format(class_name, "stride")) + else: + stride = kernel_size + padding = _value_or_list(padding, 1, msg.format(class_name, "padding")) + padding = [(p, p) for p in padding] + + super().__init__(pooling_function, kernel_size, stride, padding, padding_value) + + +class _Pool2d(_Pool): + def __init__( + self, + pooling_function, + padding_value, + kernel_size: Union[int, Tuple[int, int]], + stride: Optional[Union[int, Tuple[int, int]]] = None, + padding: Optional[Union[int, Tuple[int, int]]] = 0, + ): + class_name = type(self).__name__ + msg = "[{}] '{}' must be an integer or a tuple containing 2 integers" + kernel_size = _value_or_list( + kernel_size, 2, msg.format(class_name, "kernel_size") + ) + if stride is not None: + stride = _value_or_list(stride, 2, msg.format(class_name, "stride")) + else: + stride = kernel_size + padding = _value_or_list(padding, 2, msg.format(class_name, "padding")) + padding = [(p, p) for p in padding] + + super().__init__(pooling_function, kernel_size, stride, padding, padding_value) + + +class MaxPool1d(_Pool1d): + r"""Applies 1-dimensional max pooling. + + Assuming an input of shape :math:`(N, L, C)` and ``kernel_size`` is + :math:`k`, the output is a tensor of shape :math:`(N, L_{out}, C)`, given + by: + + .. math:: + \text{out}(N_i, t, C_j) = \max_{m=0, \ldots, k - 1} + \text{input}(N_i, \text{stride} \times t + m, C_j), + + where :math:`L_{out} = \left\lfloor \frac{L + 2 \times \text{padding} - + \text{kernel_size}}{\text{stride}}\right\rfloor + 1`. + + Args: + kernel_size (int or tuple(int)): The size of the pooling window kernel. + stride (int or tuple(int), optional): The stride of the pooling window. + Default: ``kernel_size``. + padding (int or tuple(int), optional): How much negative infinity + padding to apply to the input. The padding amount is applied to + both sides of the spatial axis. Default: ``0``. + + Examples: + >>> import mlx.core as mx + >>> import mlx.nn.layers as nn + >>> x = mx.random.normal(shape=(4, 16, 5)) + >>> pool = nn.MaxPool1d(kernel_size=2, stride=2) + >>> pool(x) + """ + + def __init__( + self, + kernel_size: Union[int, Tuple[int, int]], + stride: Optional[Union[int, Tuple[int, int]]] = None, + padding: Optional[Union[int, Tuple[int, int]]] = 0, + ): + super().__init__(mx.max, -float("inf"), kernel_size, stride, padding) + + +class AvgPool1d(_Pool1d): + r"""Applies 1-dimensional average pooling. + + Assuming an input of shape :math:`(N, L, C)` and ``kernel_size`` is + :math:`k`, the output is a tensor of shape :math:`(N, L_{out}, C)`, given + by: + + .. math:: + \text{out}(N_i, t, C_j) = \frac{1}{k} \sum_{m=0, \ldots, k - 1} + \text{input}(N_i, \text{stride} \times t + m, C_j), + + where :math:`L_{out} = \left\lfloor \frac{L + 2 \times \text{padding} - + \text{kernel_size}}{\text{stride}}\right\rfloor + 1`. + + Args: + kernel_size (int or tuple(int)): The size of the pooling window kernel. + stride (int or tuple(int), optional): The stride of the pooling window. + Default: ``kernel_size``. + padding (int or tuple(int), optional): How much zero padding to apply to + the input. The padding amount is applied to both sides of the spatial + axis. Default: ``0``. + + Examples: + >>> import mlx.core as mx + >>> import mlx.nn.layers as nn + >>> x = mx.random.normal(shape=(4, 16, 5)) + >>> pool = nn.AvgPool1d(kernel_size=2, stride=2) + >>> pool(x) + """ + + def __init__( + self, + kernel_size: Union[int, Tuple[int, int]], + stride: Optional[Union[int, Tuple[int, int]]] = None, + padding: Optional[Union[int, Tuple[int, int]]] = 0, + ): + super().__init__(mx.mean, 0, kernel_size, stride, padding) + + +class MaxPool2d(_Pool2d): + r"""Applies 2-dimensional max pooling. + + Assuming an input of shape :math:`(N, H, W, C)` and ``kernel_size`` is + :math:`(k_H, k_W)`, the output is a tensor of shape :math:`(N, H_{out}, + W_{out}, C)`, given by: + + .. math:: + \begin{aligned} + \text{out}(N_i, h, w, C_j) = & \max_{m=0, \ldots, k_H-1} \max_{n=0, \ldots, k_W-1} \\ + & \text{input}(N_i, \text{stride[0]} \times h + m, + \text{stride[1]} \times w + n, C_j), + \end{aligned} + + where :math:`H_{out} = \left\lfloor\frac{H + 2 * \text{padding[0]} - \text{kernel_size[0]}}{\text{stride[0]}}\right\rfloor + 1`, + :math:`W_{out} = \left\lfloor\frac{W + 2 * \text{padding[1]} - \text{kernel_size[1]}}{\text{stride[1]}}\right\rfloor + 1`. + + The parameters ``kernel_size``, ``stride``, ``padding``, can either be: + + - a single ``int`` -- in which case the same value is used for both the + height and width axis; + - a ``tuple`` of two ``int`` s -- in which case, the first ``int`` is + used for the height axis, the second ``int`` for the width axis. + + Args: + kernel_size (int or tuple(int, int)): The size of the pooling window. + stride (int or tuple(int, int), optional): The stride of the pooling + window. Default: ``kernel_size``. + padding (int or tuple(int, int), optional): How much negative infinity + padding to apply to the input. The padding is applied on both sides + of the height and width axis. Default: ``0``. + + Examples: + >>> import mlx.core as mx + >>> import mlx.nn.layers as nn + >>> x = mx.random.normal(shape=(8, 32, 32, 4)) + >>> pool = nn.MaxPool2d(kernel_size=2, stride=2) + >>> pool(x) + """ + + def __init__( + self, + kernel_size: Union[int, Tuple[int, int]], + stride: Optional[Union[int, Tuple[int, int]]] = None, + padding: Optional[Union[int, Tuple[int, int]]] = 0, + ): + super().__init__(mx.max, -float("inf"), kernel_size, stride, padding) + + +class AvgPool2d(_Pool2d): + r"""Applies 2-dimensional average pooling. + + Assuming an input of shape :math:`(N, H, W, C)` and ``kernel_size`` is + :math:`(k_H, k_W)`, the output is a tensor of shape :math:`(N, H_{out}, + W_{out}, C)`, given by: + + .. math:: + \begin{aligned} + \text{out}(N_i, h, w, C_j) = & \frac{1}{k_H k_W} \sum_{m=0, \ldots, k_H-1} \sum_{n=0, \ldots, k_W-1} \\ + & \text{input}(N_i, \text{stride[0]} \times h + m, + \text{stride[1]} \times w + n, C_j), + \end{aligned} + + where :math:`H_{out} = \left\lfloor\frac{H + 2 * \text{padding[0]} - \text{kernel_size[0]}}{\text{stride[0]}}\right\rfloor + 1`, + :math:`W_{out} = \left\lfloor\frac{W + 2 * \text{padding[1]} - \text{kernel_size[1]}}{\text{stride[1]}}\right\rfloor + 1`. + + The parameters ``kernel_size``, ``stride``, ``padding``, can either be: + + - a single ``int`` -- in which case the same value is used for both the + height and width axis; + - a ``tuple`` of two ``int`` s -- in which case, the first ``int`` is + used for the height axis, the second ``int`` for the width axis. + + Args: + kernel_size (int or tuple(int, int)): The size of the pooling window. + stride (int or tuple(int, int), optional): The stride of the pooling + window. Default: ``kernel_size``. + padding (int or tuple(int, int), optional): How much zero + padding to apply to the input. The padding is applied on both sides + of the height and width axis. Default: ``0``. + + Examples: + >>> import mlx.core as mx + >>> import mlx.nn.layers as nn + >>> x = mx.random.normal(shape=(8, 32, 32, 4)) + >>> pool = nn.MaxPool2d(kernel_size=2, stride=2) + >>> pool(x) + """ + + def __init__( + self, + kernel_size: Union[int, Tuple[int, int]], + stride: Optional[Union[int, Tuple[int, int]]] = None, + padding: Optional[Union[int, Tuple[int, int]]] = 0, + ): + super().__init__(mx.mean, 0, kernel_size, stride, padding) diff --git a/python/mlx/nn/layers/positional_encoding.py b/python/mlx/nn/layers/positional_encoding.py index a8024f0a4..f0bb92863 100644 --- a/python/mlx/nn/layers/positional_encoding.py +++ b/python/mlx/nn/layers/positional_encoding.py @@ -1,4 +1,4 @@ -# Copyright © 2023 Apple Inc. +# Copyright © 2023-2024 Apple Inc. import math from typing import Optional @@ -20,20 +20,13 @@ class RoPE(Module): Args: dims (int): The feature dimensions to be rotated. If the input feature is larger than dims then the rest is left unchanged. - traditional (bool, optional): If set to True choose the traditional + traditional (bool, optional): If set to ``True`` choose the traditional implementation which is slightly less efficient. Default: ``False``. base (float, optional): The base used to compute angular frequency for each dimension in the positional encodings. Default: ``10000``. scale (float, optional): The scale used to scale the positions. Default: ``1.0``. - - Attributes: - _cos_sin_theta_key (tuple): Cached key for the precomputed cosine and sine values. - _cos_sin_theta_value (tuple): Cached cosine and sine values. """ - _cos_sin_theta_key = None - _cos_sin_theta_value = None - def __init__( self, dims: int, @@ -50,69 +43,18 @@ class RoPE(Module): def _extra_repr(self): return f"{self.dims}, traditional={self.traditional}" - def _compute_rope(self, costheta, sintheta, x): - x1 = x[..., : self.dims // 2] - x2 = x[..., self.dims // 2 : self.dims] - rx1 = x1 * costheta - x2 * sintheta - rx2 = x1 * sintheta + x2 * costheta - - if self.dims < x.shape[-1]: - rx = mx.concatenate([rx1, rx2, x[..., self.dims :]], axis=-1) - else: - rx = mx.concatenate([rx1, rx2], axis=-1) - - return rx - - def _compute_traditional_rope(self, costheta, sintheta, x): - x1 = x[..., ::2] - x2 = x[..., 1::2] - rx1 = x1 * costheta - x2 * sintheta - rx2 = x1 * sintheta + x2 * costheta - - if self.dims < x.shape[-1]: - raise NotImplementedError( - "RoPE doesn't implement partial traditional application" - ) - - rx = mx.concatenate([rx1[..., None], rx2[..., None]], axis=-1) - - return rx - def __call__(self, x, offset: int = 0): shape = x.shape x = mx.reshape(x, (-1, shape[-2], shape[-1])) - N = x.shape[1] + offset - costheta, sintheta = RoPE.create_cos_sin_theta( - N, self.dims, offset=offset, base=self.base, scale=self.scale, dtype=x.dtype + x = mx.fast.rope( + x, + self.dims, + traditional=self.traditional, + base=self.base, + scale=self.scale, + offset=offset, ) - - rope = ( - self._compute_traditional_rope if self.traditional else self._compute_rope - ) - rx = rope(costheta, sintheta, x) - - return mx.reshape(rx, shape) - - @classmethod - def create_cos_sin_theta( - cls, - N: int, - D: int, - offset: int = 0, - base: float = 10000, - scale: float = 1.0, - dtype=mx.float32, - ): - if (N, D, offset, base, scale, dtype) != cls._cos_sin_theta_key: - half_D = D // 2 - positions = mx.arange(offset, N, dtype=dtype) * scale - freqs = mx.exp( - -mx.arange(0.0, half_D, dtype=dtype) * (math.log(base) / half_D) - ) - theta = mx.reshape(positions, (-1, 1)) * mx.reshape(freqs, (1, -1)) - cls._cos_sin_theta_key = (N, D, offset, base, scale, dtype) - cls._cos_sin_theta_value = (mx.cos(theta), mx.sin(theta)) - return cls._cos_sin_theta_value + return mx.reshape(x, shape) class SinusoidalPositionalEncoding(Module): diff --git a/python/mlx/nn/losses.py b/python/mlx/nn/losses.py index a466c10ed..ee33fde3e 100644 --- a/python/mlx/nn/losses.py +++ b/python/mlx/nn/losses.py @@ -117,6 +117,7 @@ def cross_entropy( def binary_cross_entropy( inputs: mx.array, targets: mx.array, + weights: mx.array = None, with_logits: bool = True, reduction: Reduction = "mean", ) -> mx.array: @@ -128,6 +129,7 @@ def binary_cross_entropy( ``inputs`` are unnormalized logits. Otherwise, ``inputs`` are probabilities. targets (array): The binary target values in {0, 1}. with_logits (bool, optional): Whether ``inputs`` are logits. Default: ``True``. + weights (array, optional): Optional weights for each target. Default: ``None``. reduction (str, optional): Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'mean'``. @@ -159,6 +161,15 @@ def binary_cross_entropy( else: loss = -(targets * mx.log(inputs) + (1 - targets) * mx.log(1 - inputs)) + # Apply weights if provided + if weights is not None: + if weights.shape != loss.shape: + raise ValueError( + f"Weights with shape {weights.shape} is not the same as " + f"output loss with shape {loss.shape}." + ) + loss *= weights + return _reduce(loss, reduction) diff --git a/python/mlx/optimizers/__init__.py b/python/mlx/optimizers/__init__.py new file mode 100644 index 000000000..6e8e0ccd4 --- /dev/null +++ b/python/mlx/optimizers/__init__.py @@ -0,0 +1,4 @@ +# Copyright © 2023-2024 Apple Inc. + +from mlx.optimizers.optimizers import * +from mlx.optimizers.schedulers import * diff --git a/python/mlx/optimizers.py b/python/mlx/optimizers/optimizers.py similarity index 69% rename from python/mlx/optimizers.py rename to python/mlx/optimizers/optimizers.py index b659ec5cf..16928625f 100644 --- a/python/mlx/optimizers.py +++ b/python/mlx/optimizers/optimizers.py @@ -1,45 +1,21 @@ -# Copyright © 2023 Apple Inc. +# Copyright © 2023-2024 Apple Inc. import math -from typing import List, Optional, Tuple +from typing import Callable, List, Optional, Tuple, Union import mlx.core as mx from mlx.utils import tree_map -class OptimizerState(dict): - """The optimizer state implements a recursively defined - :class:`collections.defaultdict`, namely a missing key in an optimizer - state is an :class:`OptimizerState`. - - .. note:: - :meth:`OptimizerState.get` in contrast to a normal dictionary also sets - the key to the ``default`` value if the ``key`` was not present in the - dictionary. - """ - - def __getitem__(self, key): - if key not in self: - self[key] = OptimizerState() - return super().__getitem__(key) - - def get(self, key, default): - """If ``key`` doesn't exist set its value to ``default`` and then return it.""" - if key not in self: - self[key] = default - return super().__getitem__(key) - - class Optimizer: """The base class for all optimizers. It allows us to implement an optimizer on a per-parameter basis and apply it to a parameter tree. - - Attributes: - state (OptimizerState): It holds the optimizer's state dictionary. """ - def __init__(self): - self.state = OptimizerState() + def __init__(self, schedulers=None): + self._initialized = False + self._state = {"step": mx.array(0, mx.uint64)} + self._schedulers = {k: v for k, v in (schedulers or {}).items()} def update(self, model: "mlx.nn.Module", gradients: dict): """Apply the gradients to the parameters of the model and update the @@ -52,7 +28,40 @@ class Optimizer: """ model.update(self.apply_gradients(gradients, model)) - def apply_gradients(self, gradients: dict, model: dict): + def init(self, parameters: dict): + """Initialize the optimizer's state + + This function can be used to initialize optimizers which have state + (like momentum in :class:`SGD`). Using this method is optional as the + optimizer will initialize itself if the state is not yet set. However, + there are some cases where explicit initialization is useful in order + to have access to the :attr:`Optimizer.state` before the first call to + :meth:`Optimizer.update`. + + Args: + model (dict): A Python tree of parameters. + + Example: + >>> optimizer = optim.SGD(learning_rate=1e-1, momentum=0.9) + >>> model = nn.Linear(2, 2) + >>> optimizer.init(model.trainable_parameters()) + >>> optimizer.state.keys() + dict_keys(['step', 'learning_rate', 'weight', 'bias']) + """ + self._state.update(tree_map(lambda x: {}, parameters)) + tree_map(self.init_single, parameters, self._state) + self._initialized = True + + def init_single(self, parameter: mx.array, state: dict): + """To be extended by the children classes to implement each optimizer's + state initialization. + + Args: + parameter (mx.array): A single parameter that will be optimized. + """ + raise NotImplementedError() + + def apply_gradients(self, gradients: dict, parameters: dict): """Apply the gradients to the parameters and return the updated parameters. Can be used to update a model via @@ -61,19 +70,67 @@ class Optimizer: Args: gradients (dict): A Python tree of gradients. - model (dict): A Python tree of parameters. It can be a superset of - the gradients. In that case the returned python tree - will be of the same structure as the gradients. + parameters (dict): A Python tree of parameters. It can be a + superset of the gradients. In that case the returned python + tree will be of the same structure as the gradients. """ - return tree_map(self.apply_single, gradients, model, self.state) + if not self._initialized: + self.init(gradients) - def apply_single( - self, gradient: mx.array, parameter: mx.array, state: OptimizerState - ): - """To be extended by the children classes to implement each optimizer's - update.""" + # Update any scheduled variables + for param, scheduler in self._schedulers.items(): + self.state[param] = scheduler(self.step) + + # Increment the step + self.state["step"] = self.step + 1 + + # Apply the update + return tree_map(self.apply_single, gradients, parameters, self.state) + + def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict): + """To be extended by derived classes to implement the optimizer's update. + + Args: + gradient (mx.array): The ``parameter`` gradient. + parameter (mx.array): The ``parameter`` to update. + state (dict): The optimizer's state. + """ raise NotImplementedError() + @property + def state(self): + """The optimizer's state dictionary.""" + return self._state + + @state.setter + def state(self, state: dict): + self._state = state + + @property + def step(self): + return self.state["step"] + + @property + def learning_rate(self): + return self.state["learning_rate"] + + @learning_rate.setter + def learning_rate(self, learning_rate: Union[float, mx.array]): + self.state["learning_rate"] = mx.array(learning_rate) + + def _maybe_schedule( + self, name: str, param: Union[float, Callable[[mx.array], mx.array]] + ): + """ + To be used by derived classes to optionally put a parameter on a schedule. + """ + if isinstance(param, Callable): + self._schedulers[name] = param + param = param(self.step) + else: + param = mx.array(param) + self.state[name] = param + class SGD(Optimizer): r"""The stochastic gradient descent optimizer. @@ -86,7 +143,7 @@ class SGD(Optimizer): w_{t+1} &= w_t - \lambda v_{t+1} Args: - learning_rate (float): The learning rate :math:`\lambda`. + learning_rate (float or callable): The learning rate :math:`\lambda`. momentum (float, optional): The momentum strength :math:`\mu`. Default: ``0`` weight_decay (float, optional): The weight decay (L2 penalty). Default: ``0`` dampening (float, optional): Dampening for momentum :math:`\tau`. Default: ``0`` @@ -95,7 +152,7 @@ class SGD(Optimizer): def __init__( self, - learning_rate: float, + learning_rate: Union[float, Callable[[mx.array], mx.array]], momentum: float = 0.0, weight_decay: float = 0.0, dampening: float = 0.0, @@ -107,15 +164,17 @@ class SGD(Optimizer): ) super().__init__() - self.learning_rate = learning_rate + self._maybe_schedule("learning_rate", learning_rate) self.momentum = momentum self.weight_decay = weight_decay self.dampening = dampening self.nesterov = nesterov - def apply_single( - self, gradient: mx.array, parameter: mx.array, state: OptimizerState - ): + def init_single(self, parameter: mx.array, state: dict): + """Initialize optimizer state""" + state["v"] = mx.zeros_like(parameter) + + def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict): """Performs the SGD parameter update and stores :math:`v` in the optimizer state.""" @@ -123,24 +182,21 @@ class SGD(Optimizer): gradient += self.weight_decay * parameter if self.momentum <= 0: - return parameter - self.learning_rate * gradient + return parameter - self.learning_rate.astype(gradient.dtype) * gradient + v = self.momentum * state.get("v") if self.dampening > 0: - v = ( - state.get("v", (self.dampening / self.momentum) * gradient) - * self.momentum - ) v += (1 - self.dampening) * gradient else: - v = state.get("v", mx.zeros_like(gradient)) * self.momentum v += gradient if self.nesterov: update = gradient + self.momentum * v else: update = v + state["v"] = v - return parameter - self.learning_rate * update + return parameter - self.learning_rate.astype(gradient.dtype) * update class RMSprop(Optimizer): @@ -164,7 +220,7 @@ class RMSprop(Optimizer): def __init__(self, learning_rate: float, alpha: float = 0.99, eps: float = 1e-8): super().__init__() - self.learning_rate = learning_rate + self._maybe_schedule("learning_rate", learning_rate) self.alpha = alpha self.eps = eps @@ -177,15 +233,17 @@ class RMSprop(Optimizer): f"RMSprop epsilon should be >0, {self.eps} was provided instead" ) - def apply_single( - self, gradient: mx.array, parameter: mx.array, state: OptimizerState - ): + def init_single(self, parameter: mx.array, state: dict): + """Initialize optimizer state""" + state["v"] = mx.zeros_like(parameter) + + def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict): """Performs the RMSprop parameter update and stores :math:`v` in the optimizer state.""" - lr = self.learning_rate + lr = self.learning_rate.astype(gradient.dtype) alpha = self.alpha eps = self.eps - v = state.get("v", mx.zeros_like(gradient)) + v = state["v"] v = alpha * v + (1 - alpha) * mx.square(gradient) state["v"] = v @@ -214,7 +272,7 @@ class Adagrad(Optimizer): def __init__(self, learning_rate: float, eps: float = 1e-8): super().__init__() - self.learning_rate = learning_rate + self._maybe_schedule("learning_rate", learning_rate) self.eps = eps if self.eps < 0.0: @@ -222,16 +280,17 @@ class Adagrad(Optimizer): f"Adagrad epsilon should be >0, {self.eps} was provided instead" ) - def apply_single( - self, gradient: mx.array, parameter: mx.array, state: OptimizerState - ): + def init_single(self, parameter: mx.array, state: dict): + """Initialize optimizer state""" + state["v"] = mx.zeros_like(parameter) + + def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict): """Performs the Adagrad parameter update and stores :math:`v` in the optimizer state.""" - lr = self.learning_rate + lr = self.learning_rate.astype(gradient.dtype) eps = self.eps - v = state.get("v", mx.zeros_like(gradient)) - v = v + mx.square(gradient) + v = state["v"] + mx.square(gradient) state["v"] = v return parameter - lr * gradient / (mx.sqrt(v) + eps) @@ -262,7 +321,7 @@ class AdaDelta(Optimizer): def __init__(self, learning_rate: float, rho: float = 0.9, eps: float = 1e-6): super().__init__() - self.learning_rate = learning_rate + self._maybe_schedule("learning_rate", learning_rate) self.rho = rho self.eps = eps if self.rho < 0.0: @@ -274,17 +333,20 @@ class AdaDelta(Optimizer): f"AdaDelta epsilon should be >0, {self.eps} was provided instead" ) - def apply_single( - self, gradient: mx.array, parameter: mx.array, state: OptimizerState - ): + def init_single(self, parameter: mx.array, state: dict): + """Initialize optimizer state""" + state["v"] = mx.zeros_like(parameter) + state["u"] = mx.zeros_like(parameter) + + def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict): """Performs the AdaDelta parameter update and stores :math:`v` and :math:`u` in the optimizer state.""" - lr = self.learning_rate + lr = self.learning_rate.astype(gradient.dtype) rho = self.rho eps = self.eps - v = state.get("v", mx.zeros_like(gradient)) - u = state.get("u", mx.zeros_like(gradient)) + v = state["v"] + u = state["u"] v = rho * v + (1 - rho) * mx.square(gradient) d = mx.sqrt(u + eps) / mx.sqrt(v + eps) * gradient @@ -325,21 +387,24 @@ class Adam(Optimizer): ): super().__init__() - self.learning_rate = learning_rate + self._maybe_schedule("learning_rate", learning_rate) self.betas = betas self.eps = eps - def apply_single( - self, gradient: mx.array, parameter: mx.array, state: OptimizerState - ): + def init_single(self, parameter: mx.array, state: dict): + """Initialize optimizer state""" + state["m"] = mx.zeros_like(parameter) + state["v"] = mx.zeros_like(parameter) + + def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict): """Performs the Adam parameter update and stores :math:`v` and :math:`m` in the optimizer state.""" - lr = self.learning_rate + lr = self.learning_rate.astype(gradient.dtype) b1, b2 = self.betas eps = self.eps - m = state.get("m", gradient) - v = state.get("v", mx.square(gradient)) + m = state["m"] + v = state["v"] m = b1 * m + (1 - b1) * gradient v = b2 * v + (1 - b2) * mx.square(gradient) state["m"] = m @@ -385,15 +450,14 @@ class AdamW(Adam): super().__init__(learning_rate=learning_rate, betas=betas, eps=eps) self.weight_decay = weight_decay - def apply_single( - self, gradient: mx.array, parameter: mx.array, state: OptimizerState - ): + def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict): """Performs the AdamW parameter update by modifying the parameters passed into Adam. """ + lr = self.learning_rate.astype(gradient.dtype) return super().apply_single( - gradient, parameter * (1 - self.learning_rate * self.weight_decay), state + gradient, parameter * (1 - lr * self.weight_decay), state ) @@ -430,17 +494,20 @@ class Adamax(Adam): f"Epsilon value should be >=0, {self.eps} was provided instead" ) - def apply_single( - self, gradient: mx.array, parameter: mx.array, state: OptimizerState - ): + def init_single(self, parameter: mx.array, state: dict): + """Initialize optimizer state""" + state["m"] = mx.zeros_like(parameter) + state["v"] = mx.zeros_like(parameter) + + def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict): """Performs the Adamax parameter update and stores :math:`v` and :math:`m` in the optimizer state.""" - lr = self.learning_rate + lr = self.learning_rate.astype(gradient.dtype) b1, b2 = self.betas eps = self.eps - m = state.get("m", mx.zeros_like(gradient)) - v = state.get("v", mx.zeros_like(gradient)) + m = state["m"] + v = state["v"] m = b1 * m + (1 - b1) * gradient v = mx.maximum(b2 * v, mx.abs(gradient)) @@ -485,20 +552,22 @@ class Lion(Optimizer): ): super().__init__() - self.learning_rate = learning_rate + self._maybe_schedule("learning_rate", learning_rate) self.betas = betas self.weight_decay = weight_decay - def apply_single( - self, gradient: mx.array, parameter: mx.array, state: OptimizerState - ): + def init_single(self, parameter: mx.array, state: dict): + """Initialize optimizer state""" + state["m"] = mx.zeros_like(parameter) + + def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict): """Performs the Lion parameter update and stores :math:`m` in the optimizer state.""" - lr = self.learning_rate + lr = self.learning_rate.astype(gradient.dtype) b1, b2 = self.betas weight_decay = self.weight_decay - m = state.get("m", gradient) + m = state["m"] c = b1 * m + (1 - b1) * gradient state["m"] = b2 * m + (1 - b2) * gradient if weight_decay > 0: @@ -552,7 +621,8 @@ class Adafactor(Optimizer): warmup_init: bool = False, ): super().__init__() - self.learning_rate = learning_rate + if learning_rate is not None: + self._maybe_schedule("learning_rate", learning_rate) self.eps = eps self.clip_threshold = clip_threshold self.decay_rate = decay_rate @@ -562,15 +632,30 @@ class Adafactor(Optimizer): self.relative_step = relative_step self.warmup_init = warmup_init + def init_single(self, parameter: mx.array, state: dict): + """Initialize optimizer state""" + if parameter.ndim >= 2: + shape = parameter.shape + dtype = parameter.dtype + state["exp_avg_sq_row"] = mx.zeros(shape[:-1], dtype=dtype) + state["exp_avg_sq_col"] = mx.zeros(shape[:-2] + shape[-1:], dtype=dtype) + else: + state["exp_avg_sq"] = mx.zeros_like(parameter) + + if self.beta_1 is not None: + state["exp_avg"] = mx.zeros_like(parameter) + def _compute_rms(self, inputs): return mx.sqrt(mx.mean(mx.square(inputs))) def _compute_learning_rate(self, step, parameter_rms): - relative_step_size = self.learning_rate if self.relative_step: min_step = 1e-6 * step if self.warmup_init else 1e-2 - relative_step_size = min(min_step, 1 / math.sqrt(step)) + relative_step_size = mx.minimum(min_step, mx.rsqrt(step)) + else: + relative_step_size = self.learning_rate + relative_step_size = relative_step_size.astype(parameter_rms.dtype) parameter_scale = 1.0 if self.scale_parameter: parameter_scale = mx.maximum(self.eps[1], parameter_rms) @@ -585,31 +670,21 @@ class Adafactor(Optimizer): mx.expand_dims(r_factor, axis=-1), mx.expand_dims(c_factor, axis=0) ) - def apply_single( - self, gradient: mx.array, parameter: mx.array, state: OptimizerState - ): + def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict): """Performs the Adafactor parameter and state update.""" - gradient_shape = gradient.shape - factored = len(gradient_shape) >= 2 - step = state.get("step", 0) + 1 - state["step"] = step + factored = gradient.ndim >= 2 + + step = self.step use_first_moment = self.beta_1 is not None parameter_rms = self._compute_rms(parameter) learning_rate = self._compute_learning_rate(step, parameter_rms) - beta_2 = 1.0 - math.pow(step, self.decay_rate) + beta_2 = 1.0 - (step**self.decay_rate).astype(parameter_rms.dtype) update = mx.square(gradient) + self.eps[0] if factored: - exp_avg_sq_row = state.get( - "exp_avg_sq_row", mx.zeros(gradient_shape[:-1], dtype=gradient.dtype) - ) - exp_avg_sq_col = state.get( - "exp_avg_sq_col", - mx.zeros( - gradient_shape[:-2] + gradient_shape[-1:], dtype=gradient.dtype - ), - ) + exp_avg_sq_row = state["exp_avg_sq_row"] + exp_avg_sq_col = state["exp_avg_sq_col"] exp_avg_sq_row = (beta_2 * exp_avg_sq_row) + ( (1 - beta_2) * mx.mean(update, axis=-1) ) @@ -621,7 +696,7 @@ class Adafactor(Optimizer): update = self._approximate_exp_moving_avg(exp_avg_sq_row, exp_avg_sq_col) update = update * gradient else: - exp_avg_sq = state.get("exp_avg_sq", mx.zeros_like(gradient)) + exp_avg_sq = state["exp_avg_sq"] exp_avg_sq = (beta_2 * exp_avg_sq) + ((1 - beta_2) * update) state["exp_avg_sq"] = exp_avg_sq update = mx.rsqrt(exp_avg_sq) * gradient @@ -632,7 +707,7 @@ class Adafactor(Optimizer): update = learning_rate * update if use_first_moment: - exp_avg = state.get("exp_avg", mx.zeros_like(gradient)) + exp_avg = state["exp_avg"] exp_avg = (self.beta_1 * exp_avg) + ((1 - self.beta_1) * update) state["exp_avg"] = exp_avg update = exp_avg diff --git a/python/mlx/optimizers/schedulers.py b/python/mlx/optimizers/schedulers.py new file mode 100644 index 000000000..da058c03a --- /dev/null +++ b/python/mlx/optimizers/schedulers.py @@ -0,0 +1,86 @@ +# Copyright © 2023-2024 Apple Inc. + +import math + +import mlx.core as mx + + +def exponential_decay(init: float, decay_rate: float): + r"""Make an exponential decay scheduler. + + Args: + init (float): Initial value. + decay_rate (float): Multiplicative factor to decay by. + + Example: + >>> lr_schedule = optim.exponential_decay(1e-1, 0.9) + >>> optimizer = optim.SGD(learning_rate=lr_schedule) + >>> optimizer.learning_rate + array(0.1, dtype=float32) + >>> + >>> for _ in range(5): optimizer.update({}, {}) + ... + >>> optimizer.learning_rate + array(0.06561, dtype=float32) + """ + + def schedule(step): + return init * decay_rate**step + + return schedule + + +def step_decay(init: float, decay_rate: float, step_size: int): + r"""Make a step decay scheduler. + + Args: + init (float): Initial value. + decay_rate (float): Multiplicative factor to decay by. + step_size (int): Decay every ``step_size`` steps. + + Example: + + >>> lr_schedule = optim.step_decay(1e-1, 0.9, 10) + >>> optimizer = optim.SGD(learning_rate=lr_schedule) + >>> optimizer.learning_rate + array(0.1, dtype=float32) + >>> + >>> for _ in range(21): optimizer.update({}, {}) + ... + >>> optimizer.learning_rate + array(0.081, dtype=float32) + """ + + def schedule(step): + return init * (decay_rate ** (step // step_size)) + + return schedule + + +def cosine_decay(init: float, decay_steps: int): + r"""Make a cosine decay scheduler. + + Args: + init (float): Initial value. + decay_steps (int): Number of steps to decay over. The decayed + value is constant for steps beyond ``decay_steps``. + + Example: + + >>> lr_schedule = optim.cosine_decay(1e-1, 1000) + >>> optimizer = optim.SGD(learning_rate=lr_schedule) + >>> optimizer.learning_rate + array(0.1, dtype=float32) + >>> + >>> for _ in range(5): optimizer.update({}, {}) + ... + >>> optimizer.learning_rate + array(0.0999961, dtype=float32) + """ + + def scheduler(step): + s = mx.minimum(step, decay_steps) + decay = 0.5 * (1.0 + mx.cos((math.pi / decay_steps) * s)) + return init * decay + + return scheduler diff --git a/python/mlx/utils.py b/python/mlx/utils.py index 137a8aae4..802b03831 100644 --- a/python/mlx/utils.py +++ b/python/mlx/utils.py @@ -1,5 +1,4 @@ # Copyright © 2023 Apple Inc. - from collections import defaultdict diff --git a/python/src/CMakeLists.txt b/python/src/CMakeLists.txt index 1ba037fdc..4df503a4a 100644 --- a/python/src/CMakeLists.txt +++ b/python/src/CMakeLists.txt @@ -3,6 +3,7 @@ pybind11_add_module( ${CMAKE_CURRENT_SOURCE_DIR}/mlx.cpp ${CMAKE_CURRENT_SOURCE_DIR}/array.cpp ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/load.cpp @@ -13,6 +14,7 @@ pybind11_add_module( ${CMAKE_CURRENT_SOURCE_DIR}/random.cpp ${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp ${CMAKE_CURRENT_SOURCE_DIR}/constants.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp ) if (NOT MLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY) diff --git a/python/src/array.cpp b/python/src/array.cpp index 4d8c22748..0b34e2f71 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -23,15 +23,15 @@ enum PyScalarT { pycomplex = 3, }; -template +template py::list to_list(array& a, size_t index, int dim) { py::list pl; auto stride = a.strides()[dim]; for (int i = 0; i < a.shape(dim); ++i) { if (dim == a.ndim() - 1) { - pl.append((a.data()[index])); + pl.append(static_cast(a.data()[index])); } else { - pl.append(to_list(a, index, dim + 1)); + pl.append(to_list(a, index, dim + 1)); } index += stride; } @@ -102,11 +102,11 @@ py::object tolist(array& a) { case int64: return to_list(a, 0, 0); case float16: - return to_list(a, 0, 0); + return to_list(a, 0, 0); case float32: return to_list(a, 0, 0); case bfloat16: - return to_list(a, 0, 0); + return to_list(a, 0, 0); case complex64: return to_list>(a, 0, 0); } @@ -990,6 +990,12 @@ void init_array(py::module_& m) { return power(a, to_array(v, a.dtype())); }, "other"_a) + .def( + "__rpow__", + [](const array& a, const ScalarOrArray v) { + return power(to_array(v, a.dtype()), a); + }, + "other"_a) .def( "__ipow__", [](array& a, const ScalarOrArray v) { diff --git a/python/src/device.cpp b/python/src/device.cpp index 8c36f0f85..c88144520 100644 --- a/python/src/device.cpp +++ b/python/src/device.cpp @@ -12,7 +12,8 @@ using namespace py::literals; using namespace mlx::core; void init_device(py::module_& m) { - auto device_class = py::class_(m, "Device"); + auto device_class = py::class_( + m, "Device", R"pbdoc(A device to run operations on.)pbdoc"); py::enum_(m, "DeviceType") .value("cpu", Device::DeviceType::cpu) .value("gpu", Device::DeviceType::gpu) @@ -39,6 +40,13 @@ void init_device(py::module_& m) { py::implicitly_convertible(); - m.def("default_device", &default_device); - m.def("set_default_device", &set_default_device, "device"_a); + m.def( + "default_device", + &default_device, + R"pbdoc(Get the default device.)pbdoc"); + m.def( + "set_default_device", + &set_default_device, + "device"_a, + R"pbdoc(Set the default device.)pbdoc"); } diff --git a/python/src/fast.cpp b/python/src/fast.cpp new file mode 100644 index 000000000..115ea37ec --- /dev/null +++ b/python/src/fast.cpp @@ -0,0 +1,59 @@ +// Copyright © 2023-2024 Apple Inc. + +#include +#include + +#include "mlx/fast.h" +#include "mlx/ops.h" +#include "python/src/utils.h" + +namespace py = pybind11; +using namespace py::literals; +using namespace mlx::core; + +void init_extensions(py::module_& parent_module) { + py::options options; + options.disable_function_signatures(); + + auto m = + parent_module.def_submodule("fast", "mlx.core.fast: fast operations"); + + m.def( + "rope", + [](const array& a, + int dims, + bool traditional, + float base, + float scale, + int offset, + const StreamOrDevice& s /* = {} */) { + return fast::rope(a, dims, traditional, base, scale, offset, s); + }, + "a"_a, + "dims"_a, + py::kw_only(), + "traditional"_a, + "base"_a, + "scale"_a, + "offset"_a, + "stream"_a = none, + R"pbdoc( + rope(a: array, dims: int, *, traditinoal: bool, base: float, scale: float, offset: int, stream: Union[None, Stream, Device] = None) -> array + + Apply rotary positional encoding to the input. + + Args: + a (array): Input array. + dims (int): The feature dimensions to be rotated. If the input feature + is larger than dims then the rest is left unchanged. + traditional (bool): If set to ``True`` choose the traditional + implementation which rotates consecutive dimensions. + base (float): The base used to compute angular frequency for + each dimension in the positional encodings. + scale (float): The scale used to scale the positions. + offset (int): The position offset to start at. + + Returns: + array: The output array. + )pbdoc"); +} diff --git a/python/src/load.cpp b/python/src/load.cpp index 18e89c7fb..9b6a6861e 100644 --- a/python/src/load.cpp +++ b/python/src/load.cpp @@ -160,31 +160,29 @@ class PyFileReader : public io::Reader { py::object tell_func_; }; -std::unordered_map mlx_load_safetensor_helper( - py::object file, - StreamOrDevice s) { +std::pair< + std::unordered_map, + std::unordered_map> +mlx_load_safetensor_helper(py::object file, StreamOrDevice s) { if (py::isinstance(file)) { // Assume .safetensors file path string return load_safetensors(py::cast(file), s); } else if (is_istream_object(file)) { // If we don't own the stream and it was passed to us, eval immediately - auto arr = load_safetensors(std::make_shared(file), s); + auto res = load_safetensors(std::make_shared(file), s); { py::gil_scoped_release gil; - for (auto& [key, arr] : arr) { + for (auto& [key, arr] : std::get<0>(res)) { arr.eval(); } } - return arr; + return res; } throw std::invalid_argument( "[load_safetensors] Input must be a file-like object, or string"); } -std::pair< - std::unordered_map, - std::unordered_map> -mlx_load_gguf_helper(py::object file, StreamOrDevice s) { +GGUFLoad mlx_load_gguf_helper(py::object file, StreamOrDevice s) { if (py::isinstance(file)) { // Assume .gguf file path string return load_gguf(py::cast(file), s); } @@ -274,12 +272,16 @@ LoadOutputTypes mlx_load_helper( format.emplace(fname.substr(ext + 1)); } - if (return_metadata && format.value() != "gguf") { + if (return_metadata && (format.value() == "npy" || format.value() == "npz")) { throw std::invalid_argument( "[load] metadata not supported for format " + format.value()); } if (format.value() == "safetensors") { - return mlx_load_safetensor_helper(file, s); + auto [dict, metadata] = mlx_load_safetensor_helper(file, s); + if (return_metadata) { + return std::make_pair(dict, metadata); + } + return dict; } else if (format.value() == "npz") { return mlx_load_npz_helper(file, s); } else if (format.value() == "npy") { @@ -444,18 +446,33 @@ void mlx_savez_helper( return; } -void mlx_save_safetensor_helper(py::object file, py::dict d) { +void mlx_save_safetensor_helper( + py::object file, + py::dict d, + std::optional m) { + std::unordered_map metadata_map; + if (m) { + try { + metadata_map = + m.value().cast>(); + } catch (const py::cast_error& e) { + throw std::invalid_argument( + "[save_safetensors] Metadata must be a dictionary with string keys and values"); + } + } else { + metadata_map = std::unordered_map(); + } auto arrays_map = d.cast>(); if (py::isinstance(file)) { { py::gil_scoped_release nogil; - save_safetensors(py::cast(file), arrays_map); + save_safetensors(py::cast(file), arrays_map, metadata_map); } } else if (is_ostream_object(file)) { auto writer = std::make_shared(file); { py::gil_scoped_release nogil; - save_safetensors(writer, arrays_map); + save_safetensors(writer, arrays_map, metadata_map); } } else { throw std::invalid_argument( @@ -471,7 +488,7 @@ void mlx_save_gguf_helper( if (py::isinstance(file)) { if (m) { auto metadata_map = - m.value().cast>(); + m.value().cast>(); { py::gil_scoped_release nogil; save_gguf(py::cast(file), arrays_map, metadata_map); diff --git a/python/src/load.h b/python/src/load.h index dbe0f9cd6..21f0cff32 100644 --- a/python/src/load.h +++ b/python/src/load.h @@ -15,19 +15,17 @@ using namespace mlx::core; using LoadOutputTypes = std::variant< array, std::unordered_map, - std::pair< - std::unordered_map, - std::unordered_map>>; + SafetensorsLoad, + GGUFLoad>; -std::unordered_map mlx_load_safetensor_helper( +SafetensorsLoad mlx_load_safetensor_helper(py::object file, StreamOrDevice s); +void mlx_save_safetensor_helper( py::object file, - StreamOrDevice s); -void mlx_save_safetensor_helper(py::object file, py::dict d); + py::dict d, + std::optional m); + +GGUFLoad mlx_load_gguf_helper(py::object file, StreamOrDevice s); -std::pair< - std::unordered_map, - std::unordered_map> -mlx_load_gguf_helper(py::object file, StreamOrDevice s); void mlx_save_gguf_helper( py::object file, py::dict d, diff --git a/python/src/mlx.cpp b/python/src/mlx.cpp index 81626e565..5fb9e74e2 100644 --- a/python/src/mlx.cpp +++ b/python/src/mlx.cpp @@ -17,6 +17,8 @@ void init_random(py::module_&); void init_fft(py::module_&); void init_linalg(py::module_&); void init_constants(py::module_&); +void init_extensions(py::module_&); +void init_utils(py::module_&); PYBIND11_MODULE(core, m) { m.doc() = "mlx: A framework for machine learning on Apple silicon."; @@ -33,5 +35,8 @@ PYBIND11_MODULE(core, m) { init_fft(m); init_linalg(m); init_constants(m); + init_extensions(m); + init_utils(m); + m.attr("__version__") = TOSTRING(_VERSION_); } diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 02a401543..2c2dcecfd 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -3214,8 +3214,9 @@ void init_ops(py::module_& m) { &mlx_save_safetensor_helper, "file"_a, "arrays"_a, + "metadata"_a = none, R"pbdoc( - save_safetensors(file: str, arrays: Dict[str, array]) + save_safetensors(file: str, arrays: Dict[str, array], metadata: Optional[Dict[str, str]] = None) Save array(s) to a binary file in ``.safetensors`` format. @@ -3225,6 +3226,7 @@ void init_ops(py::module_& m) { Args: file (file, str): File in which the array is saved. arrays (dict(str, array)): The dictionary of names to arrays to be saved. + metadata (dict(str, str), optional): The dictionary of metadata to be saved. )pbdoc"); m.def( "save_gguf", @@ -3634,4 +3636,64 @@ void init_ops(py::module_& m) { Returns: array: The extracted diagonal or the constructed diagonal matrix. )pbdoc"); + m.def( + "atleast_1d", + &atleast_1d, + "a"_a, + py::pos_only(), + py::kw_only(), + "stream"_a = none, + R"pbdoc( + atleast_1d(a: array, stream: Union[None, Stream, Device] = None) -> array + + Convert array to have at least one dimension. + + args: + a (array): Input array + stream (Union[None, Stream, Device], optional): The stream to execute the operation on. + + Returns: + array: An array with at least one dimension. + + )pbdoc"); + m.def( + "atleast_2d", + &atleast_2d, + "a"_a, + py::pos_only(), + py::kw_only(), + "stream"_a = none, + R"pbdoc( + atleast_2d(a: array, stream: Union[None, Stream, Device] = None) -> array + + Convert array to have at least two dimensions. + + args: + a (array): Input array + stream (Union[None, Stream, Device], optional): The stream to execute the operation on. + + Returns: + array: An array with at least two dimensions. + + )pbdoc"); + m.def( + "atleast_3d", + &atleast_3d, + "a"_a, + py::pos_only(), + py::kw_only(), + "stream"_a = none, + R"pbdoc( + atleast_3d(a: array, stream: Union[None, Stream, Device] = None) -> array + + Convert array to have at least three dimensions. + + args: + a (array): Input array + stream (Union[None, Stream, Device], optional): The stream to execute the operation on. + + Returns: + array: An array with at least three dimensions. + + )pbdoc"); } diff --git a/python/src/random.cpp b/python/src/random.cpp index 6e9f38d97..442d81fee 100644 --- a/python/src/random.cpp +++ b/python/src/random.cpp @@ -2,6 +2,7 @@ #include #include +#include #include "python/src/utils.h" @@ -13,13 +14,55 @@ using namespace py::literals; using namespace mlx::core; using namespace mlx::core::random; +class PyKeySequence { + public: + explicit PyKeySequence(uint64_t seed) { + state_.append(key(seed)); + } + + void seed(uint64_t seed) { + state_[0] = key(seed); + } + + array next() { + auto out = split(py::cast(state_[0])); + state_[0] = out.first; + return out.second; + } + + py::list state() { + return state_; + } + + void release() { + py::gil_scoped_acquire gil; + state_.release().dec_ref(); + } + + private: + py::list state_; +}; + +PyKeySequence& default_key() { + auto get_current_time_seed = []() { + auto now = std::chrono::system_clock::now(); + return std::chrono::duration_cast( + now.time_since_epoch()) + .count(); + }; + static PyKeySequence ks(get_current_time_seed()); + return ks; +} + void init_random(py::module_& parent_module) { auto m = parent_module.def_submodule( "random", "mlx.core.random: functionality related to random number generation"); + + m.attr("state") = default_key().state(); m.def( "seed", - &seed, + [](uint64_t seed) { default_key().seed(seed); }, "seed"_a, R"pbdoc( Seed the global PRNG. @@ -62,8 +105,9 @@ void init_random(py::module_& parent_module) { const ScalarOrArray& high, const std::vector& shape, std::optional type, - const std::optional& key, + const std::optional& key_, StreamOrDevice s) { + auto key = key_ ? key_.value() : default_key().next(); return uniform( to_array(low), to_array(high), @@ -89,7 +133,7 @@ void init_random(py::module_& parent_module) { low (scalar or array, optional): Lower bound of the distribution. Default is ``0``. high (scalar or array, optional): Upper bound of the distribution. Default is ``1``. shape (list(int), optional): Shape of the output. Default is ``()``. - key (array, optional): A PRNG key. Default: None. + key (array, optional): A PRNG key. Default: ``None``. dtype (Dtype, optional): Type of the output. Default is ``float32``. Returns: @@ -99,13 +143,17 @@ void init_random(py::module_& parent_module) { "normal", [](const std::vector& shape, std::optional type, - const std::optional& key, + float loc, + float scale, + const std::optional& key_, StreamOrDevice s) { - return normal(shape, type.value_or(float32), key, s); + auto key = key_ ? key_.value() : default_key().next(); + return normal(shape, type.value_or(float32), loc, scale, key, s); }, - "shape"_a = std::vector{}, "dtype"_a = std::optional{float32}, + "loc"_a = 0.0, + "scale"_a = 1.0, "key"_a = none, "stream"_a = none, R"pbdoc( @@ -114,6 +162,8 @@ void init_random(py::module_& parent_module) { Args: shape (list(int), optional): Shape of the output. Default is ``()``. dtype (Dtype, optional): Type of the output. Default is ``float32``. + loc (float, optional): Mean of the distribution. Default is ``0.0``. + scale (float, optional): Standard deviation of the distribution. Default is ``1.0``. key (array, optional): A PRNG key. Default: None. Returns: @@ -125,8 +175,9 @@ void init_random(py::module_& parent_module) { const ScalarOrArray& high, const std::vector& shape, std::optional type, - const std::optional& key, + const std::optional& key_, StreamOrDevice s) { + auto key = key_ ? key_.value() : default_key().next(); return randint( to_array(low), to_array(high), shape, type.value_or(int32), key, s); }, @@ -157,8 +208,9 @@ void init_random(py::module_& parent_module) { "bernoulli", [](const ScalarOrArray& p_, const std::optional> shape, - const std::optional& key, + const std::optional& key_, StreamOrDevice s) { + auto key = key_ ? key_.value() : default_key().next(); auto p = to_array(p_); if (shape.has_value()) { return bernoulli(p, shape.value(), key, s); @@ -193,8 +245,9 @@ void init_random(py::module_& parent_module) { const ScalarOrArray& upper_, const std::optional> shape_, std::optional type, - const std::optional& key, + const std::optional& key_, StreamOrDevice s) { + auto key = key_ ? key_.value() : default_key().next(); auto lower = to_array(lower_); auto upper = to_array(upper_); auto t = type.value_or(float32); @@ -233,8 +286,9 @@ void init_random(py::module_& parent_module) { "gumbel", [](const std::vector& shape, std::optional type, - const std::optional& key, + const std::optional& key_, StreamOrDevice s) { + auto key = key_ ? key_.value() : default_key().next(); return gumbel(shape, type.value_or(float32), key, s); }, "shape"_a = std::vector{}, @@ -261,8 +315,9 @@ void init_random(py::module_& parent_module) { int axis, const std::optional> shape, const std::optional num_samples, - const std::optional& key, + const std::optional& key_, StreamOrDevice s) { + auto key = key_ ? key_.value() : default_key().next(); if (shape.has_value() && num_samples.has_value()) { throw std::invalid_argument( "[categorical] At most one of shape or num_samples can be specified."); @@ -303,4 +358,7 @@ void init_random(py::module_& parent_module) { Returns: array: The ``shape``-sized output array with type ``uint32``. )pbdoc"); + // Register static Python object cleanup before the interpreter exits + auto atexit = py::module_::import("atexit"); + atexit.attr("register")(py::cpp_function([]() { default_key().release(); })); } diff --git a/python/src/stream.cpp b/python/src/stream.cpp index 7b1b2f55d..768795fc1 100644 --- a/python/src/stream.cpp +++ b/python/src/stream.cpp @@ -12,7 +12,12 @@ using namespace py::literals; using namespace mlx::core; void init_stream(py::module_& m) { - py::class_(m, "Stream") + py::class_( + m, + "Stream", + R"pbdoc( + A stream for running operations on a given device. + )pbdoc") .def(py::init(), "index"_a, "device"_a) .def_readonly("device", &Stream::device) .def( @@ -28,7 +33,27 @@ void init_stream(py::module_& m) { py::implicitly_convertible(); - m.def("default_stream", &default_stream, "device"_a); - m.def("set_default_stream", &set_default_stream, "stream"_a); - m.def("new_stream", &new_stream, "device"_a); + m.def( + "default_stream", + &default_stream, + "device"_a, + R"pbdoc(Get the device's default stream.)pbdoc"); + m.def( + "set_default_stream", + &set_default_stream, + "stream"_a, + R"pbdoc( + Set the default stream. + + This will make the given stream the default for the + streams device. It will not change the default device. + + Args: + stream (stream): Stream to make the default. + )pbdoc"); + m.def( + "new_stream", + &new_stream, + "device"_a, + R"pbdoc(Make a new stream on the given device.)pbdoc"); } diff --git a/python/src/transforms.cpp b/python/src/transforms.cpp index 78f867876..cda1d6316 100644 --- a/python/src/transforms.cpp +++ b/python/src/transforms.cpp @@ -135,6 +135,64 @@ py::object tree_map( }); } +void tree_visit_update( + py::object tree, + std::function visitor) { + std::function recurse; + recurse = [&](py::handle subtree) { + if (py::isinstance(subtree)) { + auto l = py::cast(subtree); + for (int i = 0; i < l.size(); ++i) { + l[i] = recurse(l[i]); + } + return py::cast(l); + } else if (py::isinstance(subtree)) { + for (auto item : subtree) { + recurse(item); + } + return py::cast(subtree); + } else if (py::isinstance(subtree)) { + auto d = py::cast(subtree); + for (auto item : d) { + d[item.first] = recurse(item.second); + } + return py::cast(d); + } else if (py::isinstance(subtree)) { + return visitor(subtree); + } else { + return py::cast(subtree); + } + }; + recurse(tree); +} + +// Fill a pytree (recursive dict or list of dict or list) +// in place with the given arrays +// Non dict or list nodes are ignored +void tree_fill(py::object& tree, const std::vector& values) { + size_t index = 0; + tree_visit_update( + tree, [&](py::handle node) { return py::cast(values[index++]); }); +} + +// Replace all the arrays from the src values with the dst values in the tree +void tree_replace( + py::object& tree, + const std::vector& src, + const std::vector& dst) { + std::unordered_map src_to_dst; + for (int i = 0; i < src.size(); ++i) { + src_to_dst.insert({src[i].id(), dst[i]}); + } + tree_visit_update(tree, [&](py::handle node) { + auto arr = py::cast(node); + if (auto it = src_to_dst.find(arr.id()); it != src_to_dst.end()) { + return py::cast(it->second); + } + return py::cast(arr); + }); +} + std::vector tree_flatten(py::object tree, bool strict = true) { std::vector flat_tree; @@ -495,9 +553,21 @@ std::unordered_map& tree_cache() { struct PyCompiledFun { py::function fun; size_t fun_id; + py::object captured_inputs; + py::object captured_outputs; + bool shapeless; + size_t num_outputs{0}; - PyCompiledFun(const py::function& fun) - : fun(fun), fun_id(reinterpret_cast(fun.ptr())) {} + PyCompiledFun( + const py::function& fun, + py::object inputs, + py::object outputs, + bool shapeless) + : fun(fun), + fun_id(reinterpret_cast(fun.ptr())), + captured_inputs(inputs), + captured_outputs(outputs), + shapeless(shapeless) {} PyCompiledFun(const PyCompiledFun&) = delete; PyCompiledFun& operator=(const PyCompiledFun&) = delete; @@ -505,23 +575,105 @@ struct PyCompiledFun { PyCompiledFun(PyCompiledFun&& other) : fun(std::move(other.fun)), fun_id(reinterpret_cast(fun.ptr())) { other.fun_id = 0; + captured_inputs = std::move(other.captured_inputs); + captured_outputs = std::move(other.captured_outputs); + shapeless = other.shapeless; + num_outputs = other.num_outputs; }; - py::object operator()(const py::args& args) { - auto compile_fun = [this, &args](const std::vector& a) { - // Call the python function and flatten the outputs - auto [outputs, py_outputs] = tree_flatten_with_structure( - std::move(this->fun(*tree_unflatten(args, a))), true); + py::object operator()(const py::args& args, const py::kwargs& kwargs) { + auto inputs = tree_flatten(args, false); - tree_cache().insert({this->fun_id, py_outputs}); + auto compile_fun = [this, &args, &kwargs, num_args = inputs.size()]( + const std::vector& a) { + // Put tracers into captured inputs + std::vector flat_in_captures; + std::vector trace_captures; + if (!py::isinstance(captured_inputs)) { + flat_in_captures = tree_flatten(captured_inputs, false); + trace_captures.insert( + trace_captures.end(), a.end() - flat_in_captures.size(), a.end()); + tree_fill(captured_inputs, trace_captures); + } + + auto tree_outputs = + fun(*tree_unflatten(args, a), **tree_unflatten(kwargs, a, num_args)); + auto [outputs, py_outputs] = + tree_flatten_with_structure(std::move(tree_outputs), false); + + tree_cache().insert({fun_id, py_outputs}); + + num_outputs = outputs.size(); + if (!py::isinstance(captured_outputs)) { + auto flat_out_captures = tree_flatten(captured_outputs, false); + outputs.insert( + outputs.end(), + std::make_move_iterator(flat_out_captures.begin()), + std::make_move_iterator(flat_out_captures.end())); + } + + // Replace tracers with originals in captured inputs + if (!py::isinstance(captured_inputs)) { + tree_replace(captured_inputs, trace_captures, flat_in_captures); + } return outputs; }; - // Inputs must be array or tree of arrays - auto inputs = tree_flatten(args, true); + { + auto flat_kwargs = tree_flatten(kwargs, false); + inputs.insert( + inputs.end(), + std::make_move_iterator(flat_kwargs.begin()), + std::make_move_iterator(flat_kwargs.end())); + } + + if (!py::isinstance(captured_inputs)) { + auto flat_in_captures = tree_flatten(captured_inputs, false); + inputs.insert( + inputs.end(), + std::make_move_iterator(flat_in_captures.begin()), + std::make_move_iterator(flat_in_captures.end())); + } + + // Collect the compilation constants + std::vector constants; + auto value_hash = [](py::handle o) -> std::optional { + // Consider expanding tuples to their contents including start and end + // ids + if (py::isinstance(o) || py::isinstance(o)) { + auto r = py::hash(o); + return *reinterpret_cast(&r); + } else if (py::isinstance(o)) { + auto r = o.cast(); + return *reinterpret_cast(&r); + } else if (py::isinstance(o)) { + auto r = o.cast(); + return *reinterpret_cast(&r); + } else { + return std::nullopt; + } + }; + for (int i = 0; i < args.size(); i++) { + if (auto h = value_hash(args[i]); h.has_value()) { + constants.push_back(*h); + } + } + for (auto& pair : kwargs) { + if (auto h = value_hash(pair.second); h.has_value()) { + constants.push_back(*value_hash(pair.first)); + constants.push_back(*h); + } + } // Compile and call - auto outputs = detail::compile(compile_fun, fun_id)(inputs); + auto outputs = + detail::compile(compile_fun, fun_id, shapeless, constants)(inputs); + if (!py::isinstance(captured_outputs)) { + std::vector captures( + std::make_move_iterator(outputs.begin() + num_outputs), + std::make_move_iterator(outputs.end())); + tree_fill(captured_outputs, captures); + } // Put the outputs back in the container py::object py_outputs = tree_cache().at(fun_id); @@ -534,6 +686,8 @@ struct PyCompiledFun { tree_cache().erase(fun_id); detail::compile_erase(fun_id); fun.release().dec_ref(); + captured_inputs.release().dec_ref(); + captured_outputs.release().dec_ref(); } }; @@ -601,7 +755,7 @@ void init_transforms(py::module_& m) { m.def( "eval", [](const py::args& args) { - std::vector arrays = tree_flatten(args); + std::vector arrays = tree_flatten(args, false); { py::gil_scoped_release nogil; eval(arrays); @@ -615,8 +769,8 @@ void init_transforms(py::module_& m) { Args: *args (arrays or trees of arrays): Each argument can be a single array or a tree of arrays. If a tree is given the nodes can be a Python - :class:`list`, :class:`tuple` or :class:`dict` but the leafs must all be - an :class:`array`. + :class:`list`, :class:`tuple` or :class:`dict`. Leaves which are not + arrays are ignored. )pbdoc"); m.def( "jvp", @@ -859,10 +1013,16 @@ void init_transforms(py::module_& m) { "file"_a); m.def( "compile", - [](const py::function& fun) { - return py::cpp_function(PyCompiledFun{fun}); + [](const py::function& fun, + const py::object& inputs, + const py::object& outputs, + bool shapeless) { + return py::cpp_function(PyCompiledFun{fun, inputs, outputs, shapeless}); }, "fun"_a, + "inputs"_a = std::nullopt, + "outputs"_a = std::nullopt, + "shapeless"_a = false, R"pbdoc( compile(fun: function) -> function @@ -872,6 +1032,22 @@ void init_transforms(py::module_& m) { fun (function): A function which takes a variable number of :class:`array` or trees of :class:`array` and returns a variable number of :class:`array` or trees of :class:`array`. + inputs (list or dict, optional): These inputs will be captured during + the function compilation along with the inputs to ``fun``. The ``inputs`` + can be a :obj:`list` or a :obj:`dict` containing arbitrarily nested + lists, dictionaries, or arrays. Leaf nodes that are not + :obj:`array` are ignored. Default: ``None`` + outputs (list or dict, optional): These outputs will be captured and + updated in a compiled function. The ``outputs`` can be a + :obj:`list` or a :obj:`dict` containing arbitrarily nested lists, + dictionaries, or arrays. Leaf nodes that are not :obj:`array` are ignored. + Default: ``None`` + shapeless (bool, optional): A function compiled with the ``shapeless`` + option enabled will not be recompiled when the input shape changes. Not all + functions can be compiled with ``shapeless`` enabled. Attempting to compile + such functions with shapeless enabled will throw. Note, changing the number + of dimensions or type of any input will result in a recompilation even with + ``shapeless`` set to ``True``. Default: ``False`` Returns: function: A compiled function which has the same input arguments @@ -890,7 +1066,7 @@ void init_transforms(py::module_& m) { "enable_compile", &enable_compile, R"pbdoc( - enable_compiler() -> None + enable_compile() -> None Globally enable compilation. This will override the environment variable ``MLX_DISABLE_COMPILE`` if set. diff --git a/python/src/utils.cpp b/python/src/utils.cpp new file mode 100644 index 000000000..c07016709 --- /dev/null +++ b/python/src/utils.cpp @@ -0,0 +1,81 @@ + +#include "mlx/utils.h" +#include +#include +#include + +namespace py = pybind11; +using namespace py::literals; +using namespace mlx::core; + +// Slightly different from the original, with python context on init we are not +// in the context yet. Only create the inner context on enter then delete on +// exit. +class PyStreamContext { + public: + PyStreamContext(StreamOrDevice s) : _inner(nullptr) { + if (std::holds_alternative(s)) { + throw std::runtime_error( + "[StreamContext] Invalid argument, please specify a stream or device."); + } + _s = s; + } + + void enter() { + _inner = new StreamContext(_s); + } + + void exit() { + if (_inner != nullptr) { + delete _inner; + _inner = nullptr; + } + } + + private: + StreamOrDevice _s; + StreamContext* _inner; +}; + +void init_utils(py::module_& m) { + py::class_(m, "StreamContext", R"pbdoc( + A context manager for setting the current device and stream. + + See :func:`stream` for usage. + + Args: + s: The stream or device to set as the default. + )pbdoc") + .def(py::init(), "s"_a) + .def("__enter__", [](PyStreamContext& scm) { scm.enter(); }) + .def( + "__exit__", + [](PyStreamContext& scm, + const std::optional& exc_type, + const std::optional& exc_value, + const std::optional& traceback) { scm.exit(); }); + m.def( + "stream", + [](StreamOrDevice s) { return PyStreamContext(s); }, + "s"_a, + R"pbdoc( + Create a context manager to set the default device and stream. + + Args: + s: The :obj:`Stream` or :obj:`Device` to set as the default. + + Returns: + A context manager that sets the default device and stream. + + Example: + + .. code-block::python + + import mlx.core as mx + + # Create a context manager for the default device and stream. + with mx.stream(mx.cpu): + # Operations here will use mx.cpu by default. + pass + )pbdoc"); +} diff --git a/python/tests/test_array.py b/python/tests/test_array.py index 40d0923f4..d30fd0de0 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -431,6 +431,14 @@ class TestArray(mlx_tests.MLXTestCase): x = mx.array(vals) self.assertEqual(x.tolist(), vals) + # Half types + vals = [1.0, 2.0, 3.0, 4.0, 5.0] + x = mx.array(vals, dtype=mx.float16) + self.assertEqual(x.tolist(), vals) + + x = mx.array(vals, dtype=mx.bfloat16) + self.assertEqual(x.tolist(), vals) + def test_array_np_conversion(self): # Shape test a = np.array([]) diff --git a/python/tests/test_compile.py b/python/tests/test_compile.py index 56dff8b3d..e53134482 100644 --- a/python/tests/test_compile.py +++ b/python/tests/test_compile.py @@ -2,6 +2,7 @@ import io import unittest +from functools import partial import mlx.core as mx import mlx_tests @@ -301,6 +302,243 @@ class TestCompile(mlx_tests.MLXTestCase): cdfdx = mx.grad(outer)(x) self.assertTrue(mx.allclose(dfdx, cdfdx)) + def test_compile_capture(self): + # Test update captured state outside compiled function + state = {"y": mx.array(2)} + + @partial(mx.compile, inputs=state) + def test_state(x): + x = x + state["y"] + return x + + test_state(mx.array(1)) + # Check the state is unchanged + self.assertEqual(state["y"], 2) + + # Check the udpated state is used + state["y"] = mx.array(3) + out = test_state(mx.array(1)) + self.assertEqual(out.item(), 4) + + # Capture list + state = [mx.array(2)] + + @partial(mx.compile, inputs=state) + def test_state(x): + x = x + state[0] + return x + + out = test_state(mx.array(1)) + self.assertEqual(out.item(), 3) + state[0] = mx.array(3) + out = test_state(mx.array(1)) + self.assertEqual(out.item(), 4) + + # Capture tuple of list + state = ([mx.array(2)],) + + @partial(mx.compile, inputs=state) + def test_state(x): + x = x + state[0][0] + return x + + out = test_state(mx.array(1)) + self.assertEqual(out.item(), 3) + state[0][0] = mx.array(3) + out = test_state(mx.array(1)) + self.assertEqual(out.item(), 4) + + # Test state updated inside compiled function + state = {} + + @partial(mx.compile, outputs=state) + def test_state(x): + state["y"] = x + 3 + return mx.abs(x) + + test_state(mx.array(-1)) + self.assertEqual(state["y"].item(), 2) + + # Test state changed inside compiled function + # triggers recompile + state = {} + + @partial(mx.compile, inputs=state, outputs=state) + def test_state(x): + y = state.get("y", mx.array(0)) + state["y"] = x + y + return x + 2 * y + + test_state(mx.array(1)) + self.assertEqual(state["y"].item(), 1) + test_state(mx.array(1)) + self.assertEqual(state["y"].item(), 2) + + def test_compile_rng(self): + @partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state) + def fun(): + return mx.random.uniform(shape=(10, 10)) + + self.assertFalse(mx.allclose(fun(), fun(), 1e-2, 1e-2)) + + def test_compile_kwargs(self): + + @mx.compile + def fun(x, y, z): + return x + y + z + + x = mx.array(1) + y = mx.array(2) + z = mx.array(3) + out = fun(x, y=y, z=z) + self.assertEqual(out.item(), 6) + + def test_shapeless_compile(self): + y = 1 + + @partial(mx.compile, shapeless=True) + def fun(x): + return x + y + + x = mx.array([1, 2]) + self.assertTrue(mx.array_equal(fun(x), mx.array([2, 3]))) + + # The function is not recompiled, so the change + # to y should not be reflected in the output + y = 2 + x = mx.array([1, 2, 3]) + self.assertTrue(mx.array_equal(fun(x), mx.array([2, 3, 4]))) + + # Type change recompiles + x = mx.array([1.0, 2.0, 3.0]) + self.assertTrue(mx.array_equal(fun(x), mx.array([3.0, 4.0, 5.0]))) + fun(x, y=y, z=z) + + def test_shapeless_compile(self): + y = 1 + + @partial(mx.compile, shapeless=True) + def fun(x): + return x + y + + x = mx.array([1, 2]) + self.assertTrue(mx.array_equal(fun(x), mx.array([2, 3]))) + + # The function is not recompiled, so the change + # to y should not be reflected in the output + y = 2 + x = mx.array([1, 2, 3]) + self.assertTrue(mx.array_equal(fun(x), mx.array([2, 3, 4]))) + + # Type change recompiles + x = mx.array([1.0, 2.0, 3.0]) + self.assertTrue(mx.array_equal(fun(x), mx.array([3.0, 4.0, 5.0]))) + + # Dim change recompiles + x = mx.array([[1, 2, 3]]) + self.assertTrue(mx.array_equal(fun(x), mx.array([[3, 4, 5]]))) + + def test_shapeless_compile_with_broadcasts(self): + x = mx.ones((2, 2)) + y = mx.array([2, 2]) + + def fun(x, y): + return x * y + + cfun = mx.compile(fun, shapeless=True) + self.assertTrue(mx.array_equal(cfun(x, y), fun(x, y))) + self.assertTrue(mx.array_equal(cfun(y, x), fun(y, x))) + y = mx.array([[3]]) + self.assertTrue(mx.array_equal(cfun(x, y), fun(x, y))) + self.assertTrue(mx.array_equal(cfun(y, x), fun(y, x))) + + def test_shapeless_compile_with_reduction(self): + # Test shapeless compile with a reduction + z = 1 + + @partial(mx.compile, shapeless=True) + def fun(x, y): + return x + y.sum(0, keepdims=True) + z + + x = mx.ones((2, 2), mx.int32) + y = mx.ones((2, 2), mx.int32) + self.assertTrue(mx.array_equal(fun(x, y), mx.full(shape=(2, 2), vals=4))) + x = mx.ones((3, 3), mx.int32) + y = mx.ones((3, 3), mx.int32) + z = 2 + self.assertTrue(mx.array_equal(fun(x, y), mx.full(shape=(3, 3), vals=5))) + + x1 = mx.array([[1, 2], [3, 4], [5, 6]]) + x2 = mx.array([[1, 2]]) + + def fun(x): + return x * x.sum(-1, keepdims=True) + + cfun = mx.compile(fun, shapeless=True) + mx.eval(cfun(x1)) + self.assertTrue(mx.array_equal(fun(x2), cfun(x2))) + + def test_compile_with_constant(self): + + # Test float + @partial(mx.compile) + def fun(x, y): + return x + y + + z = fun(mx.array(1.0), 1.0) + self.assertEqual(z.item(), 2.0) + + z = fun(mx.array(1.0), 2.0) + self.assertEqual(z.item(), 3.0) + + z = fun(mx.array(1.0), y=1.0) + self.assertEqual(z.item(), 2.0) + + z = fun(mx.array(1.0), y=3.0) + self.assertEqual(z.item(), 4.0) + + # Test tuple + @partial(mx.compile) + def fun(x, y=(1, 2)): + return x + y[0] + y[1] + + z = fun(mx.array(1)) + self.assertEqual(z.item(), 4) + + z = fun(mx.array(1), (2, 2)) + self.assertEqual(z.item(), 5) + + z = fun(mx.array(1), (2, 1)) + self.assertEqual(z.item(), 4) + + # Test bool + @partial(mx.compile) + def fun(x, y): + if y: + return x + 1 + else: + return x + 2 + + z = fun(mx.array(1), True) + self.assertEqual(z.item(), 2) + + z = fun(mx.array(1), False) + self.assertEqual(z.item(), 3) + + # Test string + @partial(mx.compile) + def fun(x, y): + if y == "one": + return x + 1 + else: + return x + 2 + + z = fun(mx.array(1), "one") + self.assertEqual(z.item(), 2) + + z = fun(mx.array(1), "two") + self.assertEqual(z.item(), 3) + if __name__ == "__main__": unittest.main() diff --git a/python/tests/test_device.py b/python/tests/test_device.py index 8aac105bc..53826cad7 100644 --- a/python/tests/test_device.py +++ b/python/tests/test_device.py @@ -38,6 +38,17 @@ class TestDevice(mlx_tests.MLXTestCase): # Restore device mx.set_default_device(device) + @unittest.skipIf(not mx.metal.is_available(), "Metal is not available") + def test_device_context(self): + default = mx.default_device() + diff = mx.cpu if default == mx.gpu else mx.gpu + self.assertNotEqual(default, diff) + with mx.stream(diff): + a = mx.add(mx.zeros((2, 2)), mx.ones((2, 2))) + mx.eval(a) + self.assertEqual(mx.default_device(), diff) + self.assertEqual(mx.default_device(), default) + def test_op_on_device(self): x = mx.array(1.0) y = mx.array(1.0) diff --git a/python/tests/test_eval.py b/python/tests/test_eval.py index 6619afa67..dc986a19a 100644 --- a/python/tests/test_eval.py +++ b/python/tests/test_eval.py @@ -24,6 +24,14 @@ class TestEval(mlx_tests.MLXTestCase): y = dfun_dx(mx.array(1.0)) self.assertEqual(y.item(), 6.0) + def test_eval_mixed(self): + x = mx.array(1) + 1 + 1 + y = 0 + z = "hello" + state = [x, y, z] + mx.eval(state) + self.assertEqual(x.item(), 3) + if __name__ == "__main__": unittest.main() diff --git a/python/tests/test_fast.py b/python/tests/test_fast.py new file mode 100644 index 000000000..1cb4ddcca --- /dev/null +++ b/python/tests/test_fast.py @@ -0,0 +1,158 @@ +# Copyright © 2023-2024 Apple Inc. + +import math +import unittest + +import mlx.core as mx +import mlx_tests + + +def rope_orig(x, dims, traditional, base, scale, offset): + N = x.shape[1] + offset + dtype = x.dtype + half_D = dims // 2 + positions = mx.arange(offset, N, dtype=dtype) * scale + freqs = mx.exp(-mx.arange(0.0, half_D, dtype=dtype) * (math.log(base) / half_D)) + theta = mx.reshape(positions, (-1, 1)) * mx.reshape(freqs, (1, -1)) + costheta, sintheta = mx.cos(theta), mx.sin(theta) + if traditional: + x1 = x[..., ::2] + x2 = x[..., 1::2] + rx1 = x1 * costheta - x2 * sintheta + rx2 = x1 * sintheta + x2 * costheta + rx = mx.concatenate([rx1[..., None], rx2[..., None]], axis=-1) + return mx.reshape(rx, x.shape) + else: + x1 = x[..., : dims // 2] + x2 = x[..., dims // 2 : dims] + rx1 = x1 * costheta - x2 * sintheta + rx2 = x1 * sintheta + x2 * costheta + if dims < x.shape[-1]: + rx = mx.concatenate([rx1, rx2, x[..., dims:]], axis=-1) + else: + rx = mx.concatenate([rx1, rx2], axis=-1) + return rx + + +class TestFast(mlx_tests.MLXTestCase): + def test_rope(self): + T = 4 + + # Defaults: dims, dtype, base, scale, offset, traditional + defaults = (8, mx.float32, 10000.0, 1.0, 0, False) + + # Per dtype absolute tolerance + tolerances = {mx.float32: 1e-6, mx.float16: 1e-3, mx.bfloat16: 1e-2} + + # Test cases: + dtypes = [mx.float32, mx.float16, mx.bfloat16] + bases = [10000.0, 1000000.0] + scales = [1.0, 2.0] + offsets = [0, 3] + traditional = [True, False] + + for traditional in [True, False]: + dims, dtype, _, scale, offset, _ = defaults + for base in bases: + x = mx.random.uniform(shape=(2, T, dims)).astype(dtype) + rx = rope_orig(x, dims, traditional, base, scale, offset) + rx_fast = mx.fast.rope( + x, + dims, + traditional=traditional, + base=base, + scale=scale, + offset=offset, + ) + self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype]) + + dims, _, base, scale, offset, _ = defaults + for dtype in dtypes: + x = mx.random.uniform(shape=(2, T, dims)).astype(dtype) + ry = rope_orig( + x.astype(mx.float32), dims, traditional, base, scale, offset + ) + rx = rope_orig(x, dims, traditional, base, scale, offset) + rx_fast = mx.fast.rope( + x, + dims, + traditional=traditional, + base=base, + scale=scale, + offset=offset, + ) + if dtype != mx.float32: + self.assertLessEqual( + mx.abs(ry - rx_fast).max(), mx.abs(ry - rx).max() + ) + self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype]) + + dims, dtype, base, scale, _, _ = defaults + for offset in offsets: + x = mx.random.uniform(shape=(2, T, dims)).astype(dtype) + rx = rope_orig(x, dims, traditional, base, scale, offset) + rx_fast = mx.fast.rope( + x, + dims, + traditional=traditional, + base=base, + scale=scale, + offset=offset, + ) + self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype]) + + dims, dtype, base, _, offset, _ = defaults + for scale in scales: + x = mx.random.uniform(shape=(2, T, dims)).astype(dtype) + rx = rope_orig(x, dims, traditional, base, scale, offset) + rx_fast = mx.fast.rope( + x, + dims, + traditional=traditional, + base=base, + scale=scale, + offset=offset, + ) + self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype]) + + def test_fast_transforms(self): + x = mx.random.uniform(shape=(2, 2, 8)) + + defaults = (8, False, 10000.0, 1.0, 0) + dims, traditional, base, scale, offset = defaults + + # VJP + _, vjp_out = mx.vjp(lambda x: rope_orig(x, *defaults), (x,), (mx.ones_like(x),)) + _, vjp_fast_out = mx.vjp( + lambda x: mx.fast.rope( + x, dims, traditional=traditional, base=base, scale=scale, offset=offset + ), + (x,), + (mx.ones_like(x),), + ) + self.assertTrue(mx.allclose(vjp_out[0], vjp_fast_out[0])) + + # JVP + _, jvp_out = mx.jvp(lambda x: rope_orig(x, *defaults), (x,), (mx.ones_like(x),)) + _, jvp_fast_out = mx.jvp( + lambda x: mx.fast.rope( + x, dims, traditional=traditional, base=base, scale=scale, offset=offset + ), + (x,), + (mx.ones_like(x),), + ) + self.assertTrue(mx.allclose(jvp_out[0], jvp_fast_out[0])) + + # VMAP + x = mx.random.uniform(shape=(2, 2, 2, 8)) + vmap_out = mx.vmap(lambda x: rope_orig(x, *defaults))(x) + vmap_fast_out = mx.vmap( + lambda x: mx.fast.rope( + x, dims, traditional=traditional, base=base, scale=scale, offset=offset + ) + )(x) + self.assertTrue(mx.allclose(vmap_out, vmap_fast_out)) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/tests/test_fft.py b/python/tests/test_fft.py index 4be12e21f..14473afa1 100644 --- a/python/tests/test_fft.py +++ b/python/tests/test_fft.py @@ -19,72 +19,73 @@ class TestFFT(mlx_tests.MLXTestCase): self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6)) def test_fft(self): - default = mx.default_device() - mx.set_default_device(mx.cpu) - def check_mx_np(op_mx, op_np, a_np, **kwargs): out_np = op_np(a_np, **kwargs) a_mx = mx.array(a_np) out_mx = op_mx(a_mx, **kwargs) self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6)) - r = np.random.rand(100).astype(np.float32) - i = np.random.rand(100).astype(np.float32) - a_np = r + 1j * i - check_mx_np(mx.fft.fft, np.fft.fft, a_np) + with mx.stream(mx.cpu): + r = np.random.rand(100).astype(np.float32) + i = np.random.rand(100).astype(np.float32) + a_np = r + 1j * i + check_mx_np(mx.fft.fft, np.fft.fft, a_np) - # Check with slicing and padding - r = np.random.rand(100).astype(np.float32) - i = np.random.rand(100).astype(np.float32) - a_np = r + 1j * i - check_mx_np(mx.fft.fft, np.fft.fft, a_np, n=80) - check_mx_np(mx.fft.fft, np.fft.fft, a_np, n=120) + # Check with slicing and padding + r = np.random.rand(100).astype(np.float32) + i = np.random.rand(100).astype(np.float32) + a_np = r + 1j * i + check_mx_np(mx.fft.fft, np.fft.fft, a_np, n=80) + check_mx_np(mx.fft.fft, np.fft.fft, a_np, n=120) - # Check different axes - r = np.random.rand(100, 100).astype(np.float32) - i = np.random.rand(100, 100).astype(np.float32) - a_np = r + 1j * i - check_mx_np(mx.fft.fft, np.fft.fft, a_np, axis=0) - check_mx_np(mx.fft.fft, np.fft.fft, a_np, axis=1) + # Check different axes + r = np.random.rand(100, 100).astype(np.float32) + i = np.random.rand(100, 100).astype(np.float32) + a_np = r + 1j * i + check_mx_np(mx.fft.fft, np.fft.fft, a_np, axis=0) + check_mx_np(mx.fft.fft, np.fft.fft, a_np, axis=1) - # Check real fft - a_np = np.random.rand(100).astype(np.float32) - check_mx_np(mx.fft.rfft, np.fft.rfft, a_np) - check_mx_np(mx.fft.rfft, np.fft.rfft, a_np, n=80) - check_mx_np(mx.fft.rfft, np.fft.rfft, a_np, n=120) + # Check real fft + a_np = np.random.rand(100).astype(np.float32) + check_mx_np(mx.fft.rfft, np.fft.rfft, a_np) + check_mx_np(mx.fft.rfft, np.fft.rfft, a_np, n=80) + check_mx_np(mx.fft.rfft, np.fft.rfft, a_np, n=120) - # Check real inverse - r = np.random.rand(100, 100).astype(np.float32) - i = np.random.rand(100, 100).astype(np.float32) - a_np = r + 1j * i - check_mx_np(mx.fft.ifft, np.fft.ifft, a_np) - check_mx_np(mx.fft.ifft, np.fft.ifft, a_np, n=80) - check_mx_np(mx.fft.ifft, np.fft.ifft, a_np, n=120) - check_mx_np(mx.fft.irfft, np.fft.irfft, a_np) - check_mx_np(mx.fft.irfft, np.fft.irfft, a_np, n=80) - check_mx_np(mx.fft.irfft, np.fft.irfft, a_np, n=120) - - mx.set_default_device(default) + # Check real inverse + r = np.random.rand(100, 100).astype(np.float32) + i = np.random.rand(100, 100).astype(np.float32) + a_np = r + 1j * i + check_mx_np(mx.fft.ifft, np.fft.ifft, a_np) + check_mx_np(mx.fft.ifft, np.fft.ifft, a_np, n=80) + check_mx_np(mx.fft.ifft, np.fft.ifft, a_np, n=120) + check_mx_np(mx.fft.irfft, np.fft.irfft, a_np) + check_mx_np(mx.fft.irfft, np.fft.irfft, a_np, n=80) + check_mx_np(mx.fft.irfft, np.fft.irfft, a_np, n=120) def test_fftn(self): - default = mx.default_device() - mx.set_default_device(mx.cpu) + with mx.stream(mx.cpu): + r = np.random.randn(8, 8, 8).astype(np.float32) + i = np.random.randn(8, 8, 8).astype(np.float32) + a = r + 1j * i - r = np.random.randn(8, 8, 8).astype(np.float32) - i = np.random.randn(8, 8, 8).astype(np.float32) - a = r + 1j * i + axes = [None, (1, 2), (2, 1), (0, 2)] + shapes = [None, (10, 5), (5, 10)] + ops = [ + "fft2", + "ifft2", + "rfft2", + "irfft2", + "fftn", + "ifftn", + "rfftn", + "irfftn", + ] - axes = [None, (1, 2), (2, 1), (0, 2)] - shapes = [None, (10, 5), (5, 10)] - ops = ["fft2", "ifft2", "rfft2", "irfft2", "fftn", "ifftn", "rfftn", "irfftn"] - - for op, ax, s in itertools.product(ops, axes, shapes): - x = a - if op in ["rfft2", "rfftn"]: - x = r - self.check_mx_np(op, x, axes=ax, s=s) - - mx.set_default_device(default) + for op, ax, s in itertools.product(ops, axes, shapes): + x = a + if op in ["rfft2", "rfftn"]: + x = r + self.check_mx_np(op, x, axes=ax, s=s) if __name__ == "__main__": diff --git a/python/tests/test_load.py b/python/tests/test_load.py index a37ba83a9..fdf06041a 100644 --- a/python/tests/test_load.py +++ b/python/tests/test_load.py @@ -66,6 +66,15 @@ class TestLoad(mlx_tests.MLXTestCase): def test_save_and_load_safetensors(self): if not os.path.isdir(self.test_dir): os.mkdir(self.test_dir) + with self.assertRaises(Exception): + mx.save_safetensors("test", {"a": mx.ones((4, 4))}, {"testing": 0}) + + mx.save_safetensors( + "test", {"test": mx.ones((2, 2))}, {"testing": "test", "format": "mlx"} + ) + res = mx.load("test.safetensors", return_metadata=True) + self.assertEqual(len(res), 2) + self.assertEqual(res[1], {"testing": "test", "format": "mlx"}) for dt in self.dtypes + ["bfloat16"]: with self.subTest(dtype=dt): @@ -75,9 +84,11 @@ class TestLoad(mlx_tests.MLXTestCase): self.test_dir, f"mlx_{dt}_{i}_fs.safetensors" ) save_dict = { - "test": mx.random.normal(shape=shape, dtype=getattr(mx, dt)) - if dt in ["float32", "float16", "bfloat16"] - else mx.ones(shape, dtype=getattr(mx, dt)) + "test": ( + mx.random.normal(shape=shape, dtype=getattr(mx, dt)) + if dt in ["float32", "float16", "bfloat16"] + else mx.ones(shape, dtype=getattr(mx, dt)) + ) } with open(save_file_mlx, "wb") as f: @@ -104,9 +115,11 @@ class TestLoad(mlx_tests.MLXTestCase): self.test_dir, f"mlx_{dt}_{i}_fs.gguf" ) save_dict = { - "test": mx.random.normal(shape=shape, dtype=getattr(mx, dt)) - if dt in ["float32", "float16", "bfloat16"] - else mx.ones(shape, dtype=getattr(mx, dt)) + "test": ( + mx.random.normal(shape=shape, dtype=getattr(mx, dt)) + if dt in ["float32", "float16", "bfloat16"] + else mx.ones(shape, dtype=getattr(mx, dt)) + ) } mx.save_gguf(save_file_mlx, save_dict) diff --git a/python/tests/test_losses.py b/python/tests/test_losses.py index 2160b0a6e..3a430be21 100644 --- a/python/tests/test_losses.py +++ b/python/tests/test_losses.py @@ -92,6 +92,14 @@ class TestLosses(mlx_tests.MLXTestCase): expected_sum = mx.sum(expected_none) self.assertEqual(losses_sum, expected_sum) + # With weights, no label smoothing + weights = mx.array([1.0, 2.0, 1.0, 2.0]) + expected = mx.array([0.747215, 1.62186, 0.262365, 0.672944]) + loss = nn.losses.binary_cross_entropy( + logits, targets, weights=weights, reduction="none" + ) + self.assertTrue(mx.allclose(loss, expected)) + def _test_probs_as_inputs(): probs = mx.array([0.5, 0.6, 0.7, 0.8]) targets = mx.array([0, 0, 1, 1]) diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index 7749e159a..eaaf3bb9c 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -71,7 +71,7 @@ class TestBase(mlx_tests.MLXTestCase): def test_save_safetensors_weights(self): def make_model(): - return nn.Sequential(nn.Linear(2, 2), nn.ReLU(), nn.Linear(2, 2)) + return nn.Sequential(nn.Linear(2, 2), nn.ReLU(), nn.Linear(2, 2), nn.ReLU()) m = make_model() tdir = tempfile.TemporaryDirectory() @@ -130,6 +130,11 @@ class TestBase(mlx_tests.MLXTestCase): ] ) + def test_module_state(self): + m = nn.Linear(10, 1) + m.state["hello"] = "world" + self.assertEqual(m.state["hello"], "world") + class TestLayers(mlx_tests.MLXTestCase): def test_identity(self): @@ -900,6 +905,347 @@ class TestLayers(mlx_tests.MLXTestCase): self.assertTrue(y.shape, x.shape) self.assertTrue(y.dtype, mx.float16) + def test_pooling(self): + # Test 1d pooling + x = mx.array( + [ + [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]], + [[12, 13, 14], [15, 16, 17], [18, 19, 20], [21, 22, 23]], + ] + ) + expected_max_pool_output_no_padding_stride_1 = [ + [[3.0, 4.0, 5.0], [6.0, 7.0, 8.0], [9.0, 10.0, 11.0]], + [[15.0, 16.0, 17.0], [18.0, 19.0, 20.0], [21.0, 22.0, 23.0]], + ] + expected_max_pool_output_no_padding_stride_2 = [ + [[3.0, 4.0, 5.0], [9.0, 10.0, 11.0]], + [[15.0, 16.0, 17.0], [21.0, 22.0, 23.0]], + ] + expected_max_pool_output_padding_1_stride_2 = [ + [[0.0, 1.0, 2.0], [6.0, 7.0, 8.0], [9.0, 10.0, 11.0]], + [[12.0, 13.0, 14.0], [18.0, 19.0, 20.0], [21.0, 22.0, 23.0]], + ] + expected_max_pool_output_padding_1_stride_2_kernel_3 = [ + [[3.0, 4.0, 5.0], [9.0, 10.0, 11.0]], + [[15.0, 16.0, 17.0], [21.0, 22.0, 23.0]], + ] + expected_avg_pool_output_no_padding_stride_1 = [ + [ + [1.5000, 2.5000, 3.5000], + [4.5000, 5.5000, 6.5000], + [7.5000, 8.5000, 9.5000], + ], + [ + [13.5000, 14.5000, 15.5000], + [16.5000, 17.5000, 18.5000], + [19.5000, 20.5000, 21.5000], + ], + ] + expected_avg_pool_output_no_padding_stride_2 = [ + [[1.5000, 2.5000, 3.5000], [7.5000, 8.5000, 9.5000]], + [[13.5000, 14.5000, 15.5000], [19.5000, 20.5000, 21.5000]], + ] + expected_avg_pool_output_padding_1_stride_2 = [ + [ + [0.0000, 0.5000, 1.0000], + [4.5000, 5.5000, 6.5000], + [4.5000, 5.0000, 5.5000], + ], + [ + [6.0000, 6.5000, 7.0000], + [16.5000, 17.5000, 18.5000], + [10.5000, 11.0000, 11.5000], + ], + ] + expected_avg_pool_output_padding_1_kernel_3 = [ + [[1, 1.66667, 2.33333], [6, 7, 8]], + [[9, 9.66667, 10.3333], [18, 19, 20]], + ] + self.assertTrue( + np.array_equal( + nn.MaxPool1d(kernel_size=2, stride=1, padding=0)(x), + expected_max_pool_output_no_padding_stride_1, + ) + ) + self.assertTrue( + np.array_equal( + nn.MaxPool1d(kernel_size=2, stride=2, padding=0)(x), + expected_max_pool_output_no_padding_stride_2, + ) + ) + self.assertTrue( + np.array_equal( + nn.MaxPool1d(kernel_size=2, stride=2, padding=1)(x), + expected_max_pool_output_padding_1_stride_2, + ) + ) + self.assertTrue( + np.array_equal( + nn.MaxPool1d(kernel_size=3, stride=2, padding=1)(x), + expected_max_pool_output_padding_1_stride_2_kernel_3, + ) + ) + self.assertTrue( + np.allclose( + nn.AvgPool1d(kernel_size=2, stride=1, padding=0)(x), + expected_avg_pool_output_no_padding_stride_1, + ) + ) + self.assertTrue( + np.allclose( + nn.AvgPool1d(kernel_size=2, stride=2, padding=0)(x), + expected_avg_pool_output_no_padding_stride_2, + ) + ) + self.assertTrue( + np.allclose( + nn.AvgPool1d(kernel_size=2, stride=2, padding=1)(x), + expected_avg_pool_output_padding_1_stride_2, + ) + ) + self.assertTrue( + np.allclose( + nn.AvgPool1d(kernel_size=3, stride=2, padding=1)(x), + expected_avg_pool_output_padding_1_kernel_3, + ) + ) + # Test 2d pooling + x = mx.array( + [ + [ + [[0, 16], [1, 17], [2, 18], [3, 19]], + [[4, 20], [5, 21], [6, 22], [7, 23]], + [[8, 24], [9, 25], [10, 26], [11, 27]], + [[12, 28], [13, 29], [14, 30], [15, 31]], + ] + ] + ) + expected_max_pool_output_no_padding_stride_1 = [ + [ + [[5, 21], [6, 22], [7, 23]], + [[9, 25], [10, 26], [11, 27]], + [[13, 29], [14, 30], [15, 31]], + ] + ] + expected_max_pool_output_no_padding_stride_2 = [ + [[[5, 21], [7, 23]], [[13, 29], [15, 31]]] + ] + expected_max_pool_output_padding_1 = [ + [ + [[0, 16], [2, 18], [3, 19]], + [[8, 24], [10, 26], [11, 27]], + [[12, 28], [14, 30], [15, 31]], + ] + ] + expected_mean_pool_output_no_padding_stride_1 = [ + [ + [[2.5000, 18.5000], [3.5000, 19.5000], [4.5000, 20.5000]], + [[6.5000, 22.5000], [7.5000, 23.5000], [8.5000, 24.5000]], + [[10.5000, 26.5000], [11.5000, 27.5000], [12.5000, 28.5000]], + ] + ] + expected_mean_pool_output_no_padding_stride_2 = [ + [ + [[2.5000, 18.5000], [4.5000, 20.5000]], + [[10.5000, 26.5000], [12.5000, 28.5000]], + ] + ] + expected_mean_pool_output_padding_1 = [ + [ + [[0.0000, 4.0000], [0.7500, 8.7500], [0.7500, 4.7500]], + [[3.0000, 11.0000], [7.5000, 23.5000], [4.5000, 12.5000]], + [[3.0000, 7.0000], [6.7500, 14.7500], [3.7500, 7.7500]], + ] + ] + self.assertTrue( + np.array_equal( + nn.MaxPool2d(kernel_size=2, stride=1, padding=0)(x), + expected_max_pool_output_no_padding_stride_1, + ) + ) + self.assertTrue( + np.array_equal( + nn.MaxPool2d(kernel_size=2, stride=2, padding=0)(x), + expected_max_pool_output_no_padding_stride_2, + ) + ) + self.assertTrue( + np.array_equal( + nn.MaxPool2d(kernel_size=2, stride=2, padding=1)(x), + expected_max_pool_output_padding_1, + ) + ) + # Average pooling + self.assertTrue( + np.allclose( + nn.AvgPool2d(kernel_size=2, stride=1, padding=0)(x), + expected_mean_pool_output_no_padding_stride_1, + ) + ) + self.assertTrue( + np.array_equal( + nn.AvgPool2d(kernel_size=2, stride=2, padding=0)(x), + expected_mean_pool_output_no_padding_stride_2, + ) + ) + self.assertTrue( + np.array_equal( + nn.AvgPool2d(kernel_size=2, stride=2, padding=1)(x), + expected_mean_pool_output_padding_1, + ) + ) + # Test multiple batches + x = mx.array( + [ + [ + [[0, 1], [2, 3], [4, 5], [6, 7]], + [[8, 9], [10, 11], [12, 13], [14, 15]], + [[16, 17], [18, 19], [20, 21], [22, 23]], + [[24, 25], [26, 27], [28, 29], [30, 31]], + ], + [ + [[32, 33], [34, 35], [36, 37], [38, 39]], + [[40, 41], [42, 43], [44, 45], [46, 47]], + [[48, 49], [50, 51], [52, 53], [54, 55]], + [[56, 57], [58, 59], [60, 61], [62, 63]], + ], + ] + ) + expected_max_pool_output = [ + [[[10.0, 11.0], [14.0, 15.0]], [[26.0, 27.0], [30.0, 31.0]]], + [[[42.0, 43.0], [46.0, 47.0]], [[58.0, 59.0], [62.0, 63.0]]], + ] + expected_avg_pool_output = [ + [[[2.22222, 2.66667], [5.33333, 6]], [[11.3333, 12], [20, 21]]], + [[[16.4444, 16.8889], [26.6667, 27.3333]], [[32.6667, 33.3333], [52, 53]]], + ] + self.assertTrue( + np.array_equal( + nn.MaxPool2d(kernel_size=3, stride=2, padding=1)(x), + expected_max_pool_output, + ) + ) + self.assertTrue( + np.allclose( + nn.AvgPool2d(kernel_size=3, stride=2, padding=1)(x), + expected_avg_pool_output, + ) + ) + # Test irregular kernel (2, 4), stride (3, 1) and padding (1, 2) + x = mx.array( + [ + [ + [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]], + [[12, 13, 14], [15, 16, 17], [18, 19, 20], [21, 22, 23]], + [[24, 25, 26], [27, 28, 29], [30, 31, 32], [33, 34, 35]], + [[36, 37, 38], [39, 40, 41], [42, 43, 44], [45, 46, 47]], + ], + [ + [[48, 49, 50], [51, 52, 53], [54, 55, 56], [57, 58, 59]], + [[60, 61, 62], [63, 64, 65], [66, 67, 68], [69, 70, 71]], + [[72, 73, 74], [75, 76, 77], [78, 79, 80], [81, 82, 83]], + [[84, 85, 86], [87, 88, 89], [90, 91, 92], [93, 94, 95]], + ], + ] + ) + expected_irregular_max_pool_output = [ + [ + [ + [3.0, 4.0, 5.0], + [6.0, 7.0, 8.0], + [9.0, 10.0, 11.0], + [9.0, 10.0, 11.0], + [9.0, 10.0, 11.0], + ], + [ + [39.0, 40.0, 41.0], + [42.0, 43.0, 44.0], + [45.0, 46.0, 47.0], + [45.0, 46.0, 47.0], + [45.0, 46.0, 47.0], + ], + ], + [ + [ + [51.0, 52.0, 53.0], + [54.0, 55.0, 56.0], + [57.0, 58.0, 59.0], + [57.0, 58.0, 59.0], + [57.0, 58.0, 59.0], + ], + [ + [87.0, 88.0, 89.0], + [90.0, 91.0, 92.0], + [93.0, 94.0, 95.0], + [93.0, 94.0, 95.0], + [93.0, 94.0, 95.0], + ], + ], + ] + expected_irregular_average_pool_output = [ + [ + [ + [0.3750, 0.6250, 0.8750], + [1.1250, 1.5000, 1.8750], + [2.2500, 2.7500, 3.2500], + [2.2500, 2.6250, 3.0000], + [1.8750, 2.1250, 2.3750], + ], + [ + [15.7500, 16.2500, 16.7500], + [24.7500, 25.5000, 26.2500], + [34.5000, 35.5000, 36.5000], + [27.0000, 27.7500, 28.5000], + [18.7500, 19.2500, 19.7500], + ], + ], + [ + [ + [12.3750, 12.6250, 12.8750], + [19.1250, 19.5000, 19.8750], + [26.2500, 26.7500, 27.2500], + [20.2500, 20.6250, 21.0000], + [13.8750, 14.1250, 14.3750], + ], + [ + [39.7500, 40.2500, 40.7500], + [60.7500, 61.5000, 62.2500], + [82.5000, 83.5000, 84.5000], + [63.0000, 63.7500, 64.5000], + [42.7500, 43.2500, 43.7500], + ], + ], + ] + self.assertTrue( + np.array_equal( + nn.MaxPool2d(kernel_size=(2, 4), stride=(3, 1), padding=(1, 2))(x), + expected_irregular_max_pool_output, + ) + ) + self.assertTrue( + np.allclose( + nn.AvgPool2d(kernel_size=(2, 4), stride=(3, 1), padding=(1, 2))(x), + expected_irregular_average_pool_output, + ) + ) + # Test repr + self.assertEqual( + str(nn.MaxPool1d(kernel_size=3, padding=2)), + "MaxPool1d(kernel_size=(3,), stride=(3,), padding=(2,))", + ) + self.assertEqual( + str(nn.AvgPool1d(kernel_size=2, stride=3)), + "AvgPool1d(kernel_size=(2,), stride=(3,), padding=(0,))", + ) + self.assertEqual( + str(nn.MaxPool2d(kernel_size=3, stride=2, padding=1)), + "MaxPool2d(kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))", + ) + self.assertEqual( + str(nn.AvgPool2d(kernel_size=(1, 2), stride=2, padding=(1, 2))), + "AvgPool2d(kernel_size=(1, 2), stride=(2, 2), padding=(1, 2))", + ) + if __name__ == "__main__": unittest.main() diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 6ac46779d..3401338f8 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -1,6 +1,7 @@ # Copyright © 2023-2024 Apple Inc. import math +import os import unittest from itertools import permutations @@ -274,6 +275,20 @@ class TestOps(mlx_tests.MLXTestCase): self.assertEqual(z.dtype, dt) self.assertEqual(z.item(), 1) + z = -1 % x + self.assertEqual(z.dtype, dt) + self.assertEqual(z.item(), 1) + + z = -1 % -x + self.assertEqual(z.dtype, dt) + self.assertEqual(z.item(), -1) + + x = mx.arange(10).astype(dt) - 5 + y = x % 5 + z = x % -5 + self.assertEqual(y.tolist(), [0, 1, 2, 3, 4, 0, 1, 2, 3, 4]) + self.assertEqual(z.tolist(), [0, -4, -3, -2, -1, 0, -4, -3, -2, -1]) + def test_comparisons(self): a = mx.array([0.0, 1.0, 5.0]) b = mx.array([-1.0, 2.0, 5.0]) @@ -1012,6 +1027,9 @@ class TestOps(mlx_tests.MLXTestCase): self.assertEqual(y.tolist(), [[3, 4]]) self.assertEqual(z.tolist(), [[5, 6]]) + with self.assertRaises(ValueError): + mx.split(a, 3, axis=2) + a = mx.arange(8) x, y, z = mx.split(a, [1, 5]) self.assertEqual(x.tolist(), [0]) @@ -1318,9 +1336,7 @@ class TestOps(mlx_tests.MLXTestCase): for d in dims: anp = np.random.randint(-20, 20, (size**d,)).reshape([size] * d) for n_bsx in range(d): - bnp = np.random.randint(-20, 20, (size**n_bsx,)).reshape( - [size] * n_bsx - ) + bnp = np.random.randint(-20, 20, (size**n_bsx,)).reshape([size] * n_bsx) for _ in range(trial_mul * d): amlx = mx.array(anp) bmlx = mx.array(bnp) @@ -1371,6 +1387,11 @@ class TestOps(mlx_tests.MLXTestCase): self.assertTrue((a[:-1] < 1e-9).all()) self.assertEqual(a[-1], 1) + # Sliced inputs + y = mx.random.uniform(shape=(8, 4)) + out = mx.softmax(y[:, 0:2], axis=-1) + self.assertAlmostEqual(out.sum().item(), 8.0) + def test_concatenate(self): a_npy = np.random.randn(32, 32, 32) b_npy = np.random.randn(32, 32, 32) @@ -1566,6 +1587,10 @@ class TestOps(mlx_tests.MLXTestCase): d_np = np.take(b_mx, np.arange(kth), axis=axis) self.assertTrue(np.all(d_np <= c_mx)) + @unittest.skipIf( + os.getenv("LOW_MEMORY", None) is not None, + "This test requires a lot of memory", + ) def test_large_binary(self): a = mx.ones([1000, 2147484], mx.int8) b = mx.ones([2147484], mx.int8) @@ -1677,6 +1702,8 @@ class TestOps(mlx_tests.MLXTestCase): def test_repeat(self): # Setup data for the tests data = mx.array([[[13, 3], [16, 6]], [[14, 4], [15, 5]], [[11, 1], [12, 2]]]) + # Test repeat 0 times + self.assertCmpNumpy([data, 0], mx.repeat, np.repeat) # Test repeat along axis 0 self.assertCmpNumpy([data, 2], mx.repeat, np.repeat, axis=0) # Test repeat along axis 1 @@ -1856,6 +1883,96 @@ class TestOps(mlx_tests.MLXTestCase): expected = mx.array(np.diag(x, k=-1)) self.assertTrue(mx.array_equal(result, expected)) + def test_atleast_1d(self): + def compare_nested_lists(x, y): + if isinstance(x, list) and isinstance(y, list): + if len(x) != len(y): + return False + for i in range(len(x)): + if not compare_nested_lists(x[i], y[i]): + return False + return True + else: + return x == y + + # Test 1D input + arrays = [ + [1], + [1, 2, 3], + [1, 2, 3, 4], + [[1], [2], [3]], + [[1, 2], [3, 4]], + [[1, 2, 3], [4, 5, 6]], + [[[[1]], [[2]], [[3]]]], + ] + + for array in arrays: + mx_res = mx.atleast_1d(mx.array(array)) + np_res = np.atleast_1d(np.array(array)) + self.assertTrue(compare_nested_lists(mx_res.tolist(), np_res.tolist())) + self.assertEqual(mx_res.shape, np_res.shape) + self.assertEqual(mx_res.ndim, np_res.ndim) + + def test_atleast_2d(self): + def compare_nested_lists(x, y): + if isinstance(x, list) and isinstance(y, list): + if len(x) != len(y): + return False + for i in range(len(x)): + if not compare_nested_lists(x[i], y[i]): + return False + return True + else: + return x == y + + # Test 1D input + arrays = [ + [1], + [1, 2, 3], + [1, 2, 3, 4], + [[1], [2], [3]], + [[1, 2], [3, 4]], + [[1, 2, 3], [4, 5, 6]], + [[[[1]], [[2]], [[3]]]], + ] + + for array in arrays: + mx_res = mx.atleast_2d(mx.array(array)) + np_res = np.atleast_2d(np.array(array)) + self.assertTrue(compare_nested_lists(mx_res.tolist(), np_res.tolist())) + self.assertEqual(mx_res.shape, np_res.shape) + self.assertEqual(mx_res.ndim, np_res.ndim) + + def test_atleast_3d(self): + def compare_nested_lists(x, y): + if isinstance(x, list) and isinstance(y, list): + if len(x) != len(y): + return False + for i in range(len(x)): + if not compare_nested_lists(x[i], y[i]): + return False + return True + else: + return x == y + + # Test 1D input + arrays = [ + [1], + [1, 2, 3], + [1, 2, 3, 4], + [[1], [2], [3]], + [[1, 2], [3, 4]], + [[1, 2, 3], [4, 5, 6]], + [[[[1]], [[2]], [[3]]]], + ] + + for array in arrays: + mx_res = mx.atleast_3d(mx.array(array)) + np_res = np.atleast_3d(np.array(array)) + self.assertTrue(compare_nested_lists(mx_res.tolist(), np_res.tolist())) + self.assertEqual(mx_res.shape, np_res.shape) + self.assertEqual(mx_res.ndim, np_res.ndim) + if __name__ == "__main__": unittest.main() diff --git a/python/tests/test_optimizers.py b/python/tests/test_optimizers.py index 59046184f..f978943de 100644 --- a/python/tests/test_optimizers.py +++ b/python/tests/test_optimizers.py @@ -1,50 +1,215 @@ # Copyright © 2023 Apple Inc. import inspect +import math import unittest +from functools import partial import mlx.core as mx +import mlx.nn as nn import mlx.optimizers as opt import mlx.utils import mlx_tests +from mlx.utils import tree_flatten, tree_map def get_all_optimizers(): classes = dict() for name, obj in inspect.getmembers(opt): - if inspect.isclass(obj): - if obj.__name__ not in ["OptimizerState", "Optimizer"]: - classes[name] = obj + if ( + inspect.isclass(obj) + and issubclass(obj, opt.Optimizer) + and obj != opt.Optimizer + ): + classes[name] = obj return classes +def tree_equal(fn, *args): + return all(v for _, v in tree_flatten(tree_map(fn, *args))) + + optimizers_dict = get_all_optimizers() class TestOptimizers(mlx_tests.MLXTestCase): + def test_optimizer_state(self): + optim = opt.SGD(0.1) + optim.state["hello"] = "world" + self.assertEqual(optim.state["hello"], "world") + + optim.state = {0: 1} + self.assertEqual(optim.state, {0: 1}) + def test_optimizers(self): params = { "first": [mx.zeros((10,)), mx.zeros((1,))], "second": mx.zeros((1,)), } - grads = mlx.utils.tree_map(lambda x: mx.ones_like(x), params) + grads = tree_map(lambda x: mx.ones_like(x), params) for optim_class in optimizers_dict.values(): optim = optim_class(0.1) update = optim.apply_gradients(grads, params) mx.eval(update) - equal_shape = mlx.utils.tree_map( - lambda x, y: x.shape == y.shape, params, update - ) + equal_shape = tree_map(lambda x, y: x.shape == y.shape, params, update) all_equal = all(v for _, v in mlx.utils.tree_flatten(equal_shape)) self.assertTrue(all_equal) + def test_types_conserved(self): + params = {"w": mx.ones((5, 5), mx.float16)} + grads = tree_map(lambda x: mx.ones_like(x), params) + for optim_class in optimizers_dict.values(): + optim = optim_class(0.1) + update = optim.apply_gradients(grads, params) + self.assertEqual(update["w"].dtype, mx.float16) + + def test_sgd(self): + params = { + "first": [mx.zeros((10,)), mx.zeros((1,))], + "second": mx.zeros((1,)), + } + grads = tree_map(lambda x: mx.ones_like(x), params) + + # Explicit init + optim = opt.SGD(learning_rate=1e-2, momentum=0.9) + optim.init(params) + self.assertTrue( + tree_equal( + lambda p, s: mx.array_equal(s["v"], mx.zeros_like(p)), + params, + optim.state, + ) + ) + + # Implicit init + optim = opt.SGD(learning_rate=1e-2, momentum=0.9) + optim.apply_gradients(grads, params) + self.assertTrue( + tree_equal(lambda g, s: mx.array_equal(s["v"], g), grads, optim.state) + ) + + def test_rmsprop(self): + params = { + "first": [mx.zeros((10,)), mx.zeros((1,))], + "second": mx.zeros((1,)), + } + grads = tree_map(lambda x: mx.ones_like(x), params) + + # Explicit init + optim = opt.RMSprop(learning_rate=1e-2) + optim.init(params) + self.assertTrue( + tree_equal( + lambda p, s: mx.array_equal(s["v"], mx.zeros_like(p)), + params, + optim.state, + ) + ) + + # Implicit init + alpha = 0.99 + optim = opt.RMSprop(learning_rate=1e-2, alpha=alpha) + optim.apply_gradients(grads, params) + self.assertTrue( + tree_equal( + lambda g, s: mx.allclose(s["v"], (1 - alpha) * g), grads, optim.state + ) + ) + + def test_adagrad(self): + params = { + "first": [mx.zeros((10,)), mx.zeros((1,))], + "second": mx.zeros((1,)), + } + grads = tree_map(lambda x: mx.ones_like(x), params) + + # Explicit init + optim = opt.Adagrad(learning_rate=1e-2) + optim.init(params) + self.assertTrue( + tree_equal( + lambda p, s: mx.array_equal(s["v"], mx.zeros_like(p)), + params, + optim.state, + ) + ) + + def test_adadelta(self): + params = { + "first": [mx.zeros((10,)), mx.zeros((1,))], + "second": mx.zeros((1,)), + } + grads = tree_map(lambda x: mx.ones_like(x), params) + + # Explicit init + optim = opt.AdaDelta(learning_rate=1e-2) + optim.init(params) + self.assertTrue( + tree_equal( + lambda p, s: mx.array_equal(s["v"], mx.zeros_like(p)), + params, + optim.state, + ) + ) + self.assertTrue( + tree_equal( + lambda p, s: mx.array_equal(s["u"], mx.zeros_like(p)), + params, + optim.state, + ) + ) + + def test_adam(self): + params = { + "first": [mx.zeros((10,)), mx.zeros((1,))], + "second": mx.zeros((1,)), + } + grads = tree_map(lambda x: mx.ones_like(x), params) + + # Explicit init + for optimizer in [opt.Adam, opt.AdamW, opt.Adamax]: + optim = optimizer(learning_rate=1e-2) + optim.init(params) + self.assertTrue( + tree_equal( + lambda p, s: mx.array_equal(s["v"], mx.zeros_like(p)), + params, + optim.state, + ) + ) + self.assertTrue( + tree_equal( + lambda p, s: mx.array_equal(s["m"], mx.zeros_like(p)), + params, + optim.state, + ) + ) + + def test_lion(self): + params = { + "first": [mx.zeros((10,)), mx.zeros((1,))], + "second": mx.zeros((1,)), + } + grads = tree_map(lambda x: mx.ones_like(x), params) + + # Explicit init + optim = opt.Lion(learning_rate=1e-2) + optim.init(params) + self.assertTrue( + tree_equal( + lambda p, s: mx.array_equal(s["m"], mx.zeros_like(p)), + params, + optim.state, + ) + ) + def test_adafactor(self): x = mx.zeros((5, 5)) grad = mx.ones_like(x) optimizer = opt.Adafactor() for _ in range(2): - xp = optimizer.apply_single(grad, x, optimizer.state) + xp = optimizer.apply_gradients(grad, x) self.assertEqual(xp.dtype, x.dtype) self.assertEqual(xp.shape, x.shape) @@ -52,11 +217,129 @@ class TestOptimizers(mlx_tests.MLXTestCase): grad = mx.ones_like(x) optimizer = opt.Adafactor() for _ in range(2): - xp = optimizer.apply_single(grad, x, optimizer.state) + xp = optimizer.apply_gradients(grad, x) self.assertEqual(xp.dtype, x.dtype) self.assertEqual(xp.shape, x.shape) self.assertEqual(optimizer.state["step"], 2) + def test_compiled_optimizer(self): + model = nn.Linear(10, 10) + x = mx.random.uniform(shape=(2, 10)) + optim = opt.SGD(learning_rate=1e-2, momentum=0.9) + + orig_params = model.parameters() + + def loss(model, x): + return model(x).sum() + + # Uncompiled version + def step(x): + _, grad = nn.value_and_grad(model, loss)(model, x) + optim.update(model, grad) + + step(x) + uncompiled_params = model.parameters() + + # Pure version + def loss(params, x): + model.update(params) + return model(x).sum() + + model.update(orig_params) + optim = opt.SGD(learning_rate=1e-2, momentum=0.9) + + @mx.compile + def step(params, opt_state, x): + grad = mx.grad(loss)(params, x) + optim.state = opt_state + params = optim.apply_gradients(grad, params) + return params, optim.state + + optim.init(model.parameters()) + pure_params, _ = step(model.parameters(), optim.state, x) + self.assertTrue(mx.allclose(pure_params["weight"], uncompiled_params["weight"])) + self.assertTrue(mx.allclose(pure_params["bias"], uncompiled_params["bias"])) + + # Impure version + def loss(model, x): + return model(x).sum() + + model.update(orig_params) + optim = opt.SGD(learning_rate=1e-2, momentum=0.9) + state = [model.state, optim.state] + + @partial(mx.compile, inputs=state, outputs=state) + def step(x): + _, grad = nn.value_and_grad(model, loss)(model, x) + optim.update(model, grad) + + step(x) + impure_params = model.parameters() + self.assertTrue( + mx.allclose(impure_params["weight"], uncompiled_params["weight"]) + ) + self.assertTrue(mx.allclose(impure_params["bias"], uncompiled_params["bias"])) + + def test_update_lr_compiled(self): + params = {"w": mx.ones((5, 5))} + grads = tree_map(lambda x: mx.ones_like(x), params) + optim = opt.SGD(-1.0) + + @partial(mx.compile, inputs=optim.state) + def update(grads): + return optim.apply_gradients(grads, params) + + result = update(grads) + self.assertTrue(mx.allclose(result["w"], mx.full((5, 5), 2.0))) + optim.learning_rate = -2.0 + result = update(grads) + self.assertTrue(mx.allclose(result["w"], mx.full((5, 5), 3.0))) + + +class TestSchedulers(unittest.TestCase): + def test_decay_lr(self): + for optim_class in optimizers_dict.values(): + lr_schedule = opt.step_decay(1e-1, 0.9, 1000) + optimizer = optim_class(learning_rate=lr_schedule) + + params = {"w": mx.ones((5, 5))} + grads = tree_map(lambda x: mx.ones_like(x), params) + + for it in range(10): + expected_lr = 0.1 * (0.9**it) + self.assertAlmostEqual(optimizer.learning_rate, expected_lr, delta=1e-7) + return optimizer.apply_gradients(grads, params) + + def test_step_decay(self): + lr_schedule = opt.step_decay(1e-1, 0.9, 1000) + lr = lr_schedule(2500) + expected_lr = 0.1 * (0.9**2) + self.assertAlmostEqual(lr, expected_lr, delta=1e-7) + + def test_exponential_decay(self): + lr_schedule = opt.exponential_decay(1e-1, 0.99) + lr = lr_schedule(10) + expected_lr = 0.1 * (0.99**10) + self.assertAlmostEqual(lr, expected_lr, delta=1e-7) + + def test_cosine_decay(self): + lr_schedule = opt.cosine_decay(0.1, 10) + lr = lr_schedule(4) + expected_lr = 0.1 * 0.5 * (1.0 + math.cos(math.pi * 4 / 10)) + self.assertAlmostEqual(lr, expected_lr, delta=1e-7) + + def test_compile_with_schedule(self): + lr_schedule = opt.exponential_decay(1e-1, 0.9) + optimizer = opt.SGD(learning_rate=lr_schedule) + + @partial(mx.compile, inputs=optimizer.state, outputs=optimizer.state) + def update(): + optimizer.update({}, {}) + + for step in range(5): + update() + self.assertAlmostEqual(lr_schedule(step), optimizer.learning_rate.item()) + if __name__ == "__main__": unittest.main() diff --git a/python/tests/test_quantized.py b/python/tests/test_quantized.py index b068aa6ee..fad2ba51c 100644 --- a/python/tests/test_quantized.py +++ b/python/tests/test_quantized.py @@ -165,6 +165,70 @@ class TestQuantized(mlx_tests.MLXTestCase): self.assertEqual(y_q.shape, y_hat.shape) self.assertLess((y_q - y_hat).abs().max(), 1e-3) + def test_non_multiples(self): + w = mx.random.normal(shape=(33, 256)) + w_q, scales, biases = mx.quantize(w) + w_hat = mx.dequantize(w_q, scales, biases) + + # Test qmv + x = mx.random.normal(shape=(1, 256)) + y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True) + y_hat = x @ w_hat.T + self.assertLess((y_q - y_hat).abs().max(), 1e-3) + + # Test qmm_t + x = mx.random.normal(shape=(10, 256)) + y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True) + y_hat = x @ w_hat.T + self.assertEqual(y_q.shape, y_hat.shape) + self.assertLess((y_q - y_hat).abs().max(), 1e-3) + + # Test qvm + x = mx.random.normal(shape=(1, 33)) + y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=False) + y_hat = x @ w_hat + self.assertEqual(y_q.shape, y_hat.shape) + self.assertLess((y_q - y_hat).abs().max(), 1e-3) + + # Test qmm + x = mx.random.normal(shape=(10, 33)) + y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=False) + y_hat = x @ w_hat + self.assertEqual(y_q.shape, y_hat.shape) + self.assertLess((y_q - y_hat).abs().max(), 1e-3) + + # Smaller than 8 + w = mx.random.normal(shape=(3, 256)) + w_q, scales, biases = mx.quantize(w) + w_hat = mx.dequantize(w_q, scales, biases) + + # Test qmv + x = mx.random.normal(shape=(1, 256)) + y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True) + y_hat = x @ w_hat.T + self.assertLess((y_q - y_hat).abs().max(), 1e-3) + + # Test qmm_t + x = mx.random.normal(shape=(10, 256)) + y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True) + y_hat = x @ w_hat.T + self.assertEqual(y_q.shape, y_hat.shape) + self.assertLess((y_q - y_hat).abs().max(), 1e-3) + + # Test qvm + x = mx.random.normal(shape=(1, 3)) + y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=False) + y_hat = x @ w_hat + self.assertEqual(y_q.shape, y_hat.shape) + self.assertLess((y_q - y_hat).abs().max(), 1e-3) + + # Test qmm + x = mx.random.normal(shape=(10, 3)) + y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=False) + y_hat = x @ w_hat + self.assertEqual(y_q.shape, y_hat.shape) + self.assertLess((y_q - y_hat).abs().max(), 1e-3) + if __name__ == "__main__": unittest.main() diff --git a/python/tests/test_random.py b/python/tests/test_random.py index 0a06c3496..892db37df 100644 --- a/python/tests/test_random.py +++ b/python/tests/test_random.py @@ -80,6 +80,20 @@ class TestRandom(mlx_tests.MLXTestCase): a = mx.random.normal(dtype=t) self.assertEqual(a.dtype, t) + # Generate with a given mean and standard deviation + loc = 1.0 + scale = 2.0 + + a = mx.random.normal(shape=(3, 2), loc=loc, scale=scale, key=key) + b = scale * mx.random.normal(shape=(3, 2), key=key) + loc + self.assertTrue(mx.allclose(a, b)) + + a = mx.random.normal( + shape=(3, 2), loc=loc, scale=scale, dtype=mx.float16, key=key + ) + b = scale * mx.random.normal(shape=(3, 2), dtype=mx.float16, key=key) + loc + self.assertTrue(mx.allclose(a, b)) + self.assertEqual(mx.random.normal().dtype, mx.random.normal(dtype=None).dtype) def test_randint(self): diff --git a/setup.py b/setup.py index 44cccf62d..961655419 100644 --- a/setup.py +++ b/setup.py @@ -152,7 +152,7 @@ if __name__ == "__main__": setup( name="mlx", - version=get_version("0.1.0"), + version=get_version("0.3.0"), author="MLX Contributors", author_email="mlx@group.apple.com", description="A framework for machine learning on Apple silicon.", diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index be9570f7e..34b69f233 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -14,6 +14,8 @@ if (MLX_BUILD_METAL) ) endif() +include(${doctest_SOURCE_DIR}/scripts/cmake/doctest.cmake) + target_sources(tests PRIVATE allocator_tests.cpp array_tests.cpp @@ -37,4 +39,5 @@ target_sources(tests PRIVATE ) target_link_libraries(tests PRIVATE mlx doctest) +doctest_discover_tests(tests) add_test(NAME tests COMMAND tests) diff --git a/tests/array_tests.cpp b/tests/array_tests.cpp index 080d53daa..62341c5c7 100644 --- a/tests/array_tests.cpp +++ b/tests/array_tests.cpp @@ -591,3 +591,21 @@ TEST_CASE("test array shared buffer") { eval(a + b); } + +TEST_CASE("test make empty array") { + auto a = array({}); + CHECK_EQ(a.size(), 0); + CHECK_EQ(a.dtype(), float32); + + a = array({}, int32); + CHECK_EQ(a.size(), 0); + CHECK_EQ(a.dtype(), int32); + + a = array({}, float32); + CHECK_EQ(a.size(), 0); + CHECK_EQ(a.dtype(), float32); + + a = array({}, bool_); + CHECK_EQ(a.size(), 0); + CHECK_EQ(a.dtype(), bool_); +} diff --git a/tests/compile_tests.cpp b/tests/compile_tests.cpp index 935a881a8..569ab0913 100644 --- a/tests/compile_tests.cpp +++ b/tests/compile_tests.cpp @@ -44,8 +44,8 @@ TEST_CASE("test compile with grad") { auto y = array(1.0f); auto grads_expected = grad_fun({x, y}); auto grads_compile = compile(grad_fun)({x, y}); - CHECK_EQ(grads_compile[0].item(), grads_expected[0].item()); - CHECK_EQ(grads_compile[1].item(), grads_expected[1].item()); + CHECK(allclose(grads_compile[0], grads_expected[0]).item()); + CHECK(allclose(grads_compile[1], grads_expected[1]).item()); } TEST_CASE("test compile inputs with primitive") { @@ -272,7 +272,7 @@ TEST_CASE("test compile unary fused") { CHECK_EQ(out.inputs()[0].id(), x.id()); auto expected_out = unary_fused_1({array(2.0)})[0]; - CHECK_EQ(out.item(), expected_out.item()); + CHECK(allclose(out, expected_out).item()); } { @@ -623,3 +623,83 @@ TEST_CASE("test transform compiled function") { CHECK(!outs[0].inputs()[0].has_primitive()); CHECK(!outs[0].inputs()[1].has_primitive()); } + +TEST_CASE("test fusion kernel reuse") { + auto cfun = compile(gelu_1); + auto x = array({2.0f, -2.0f}); + auto y = cfun({x})[0]; + auto p = std::dynamic_pointer_cast(y.primitive_ptr()); + eval(y); + + std::string lib_name = p->lib_name(); + CHECK(!lib_name.empty()); + + x = astype(reshape(arange(10), {2, 5}), float32); + auto z = cfun({x})[0]; + auto pz = std::dynamic_pointer_cast(z.primitive_ptr()); + eval(z); + + std::string lib_name_z = pz->lib_name(); + CHECK(!lib_name_z.empty()); + + CHECK_EQ(lib_name, lib_name_z); +} + +auto add3(const std::vector& xs) { + return std::vector{xs[0] + xs[0] + xs[0]}; +} + +TEST_CASE("test fusion types") { + auto cfun = compile(add3); + auto x = array({2.0f, -2.0f}); + auto y = cfun({x})[0]; + auto p = std::dynamic_pointer_cast(y.primitive_ptr()); + eval(y); + + std::string lib_name = p->lib_name(); + CHECK(!lib_name.empty()); + + x = array({2, -2}, int32); + auto z = cfun({x})[0]; + auto pz = std::dynamic_pointer_cast(z.primitive_ptr()); + eval(z); + + std::string lib_name_z = pz->lib_name(); + CHECK(!lib_name_z.empty()); +} + +auto compile_shapeless_not_ok(const std::vector& inputs) { + auto x = reshape(inputs[0], {2, 2}); + return std::vector{x}; +} + +auto compile_shapeless_ok(const std::vector& inputs) { + auto x = inputs[0] + array({2}); + return std::vector{x}; +} + +TEST_CASE("test shapeless compile") { + { + auto cfun = compile(compile_shapeless_not_ok, /* shapeless */ true); + CHECK_THROWS(cfun({array({1, 2, 3, 4})})); + } + + { + auto cfun = compile(compile_shapeless_ok, /* shapeless */ true); + auto out = cfun({array({1, 2})})[0]; + auto out2 = cfun({array({1, 2, 3, 4})})[0]; + + // Not making a new constant array since no recompile, + // hence the ids should be the same + CHECK_EQ(out.inputs()[1].id(), out2.inputs()[1].id()); + CHECK(array_equal(out2, array({3, 4, 5, 6})).item()); + + // Recompile since type changes + out2 = cfun({array({1.0, 2.0})})[0]; + CHECK_NE(out.inputs()[1].id(), out2.inputs()[1].id()); + + // Recompile since ndim changes + out2 = cfun({array({1.0, 2.0}, {1, 2})})[0]; + CHECK_NE(out.inputs()[1].id(), out2.inputs()[1].id()); + } +} diff --git a/tests/load_tests.cpp b/tests/load_tests.cpp index 51d1659f3..3a7556b57 100644 --- a/tests/load_tests.cpp +++ b/tests/load_tests.cpp @@ -19,8 +19,14 @@ TEST_CASE("test save_safetensors") { auto map = std::unordered_map(); map.insert({"test", array({1.0, 2.0, 3.0, 4.0})}); map.insert({"test2", ones({2, 2})}); - save_safetensors(file_path, map); - auto dict = load_safetensors(file_path); + auto _metadata = std::unordered_map(); + _metadata.insert({"test", "test"}); + _metadata.insert({"test2", "test2"}); + save_safetensors(file_path, map, _metadata); + auto [dict, metadata] = load_safetensors(file_path); + + CHECK_EQ(metadata, _metadata); + CHECK_EQ(dict.size(), 2); CHECK_EQ(dict.count("test"), 1); CHECK_EQ(dict.count("test2"), 1); @@ -55,7 +61,7 @@ TEST_CASE("test gguf") { } // Test saving and loading string metadata - std::unordered_map original_metadata; + std::unordered_map original_metadata; original_metadata.insert({"test_str", "my string"}); save_gguf(file_path, original_weights, original_metadata); @@ -97,7 +103,7 @@ TEST_CASE("test gguf metadata") { // Scalar array { - std::unordered_map original_metadata; + std::unordered_map original_metadata; original_metadata.insert({"test_arr", array(1.0)}); save_gguf(file_path, original_weights, original_metadata); @@ -111,7 +117,7 @@ TEST_CASE("test gguf metadata") { // 1D Array { - std::unordered_map original_metadata; + std::unordered_map original_metadata; auto arr = array({1.0, 2.0}); original_metadata.insert({"test_arr", arr}); save_gguf(file_path, original_weights, original_metadata); @@ -138,21 +144,21 @@ TEST_CASE("test gguf metadata") { // > 1D array throws { - std::unordered_map original_metadata; + std::unordered_map original_metadata; original_metadata.insert({"test_arr", array({1.0}, {1, 1})}); CHECK_THROWS(save_gguf(file_path, original_weights, original_metadata)); } // empty array throws { - std::unordered_map original_metadata; + std::unordered_map original_metadata; original_metadata.insert({"test_arr", array({})}); CHECK_THROWS(save_gguf(file_path, original_weights, original_metadata)); } // vector of string { - std::unordered_map original_metadata; + std::unordered_map original_metadata; std::vector data = {"data1", "data2", "data1234"}; original_metadata.insert({"meta", data}); save_gguf(file_path, original_weights, original_metadata); @@ -169,7 +175,7 @@ TEST_CASE("test gguf metadata") { // vector of string, string, scalar, and array { - std::unordered_map original_metadata; + std::unordered_map original_metadata; std::vector data = {"data1", "data2", "data1234"}; original_metadata.insert({"meta1", data}); original_metadata.insert({"meta2", array(2.5)}); diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index e52c1294f..ba4ab552f 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -1002,7 +1002,7 @@ TEST_CASE("test arithmetic unary ops") { CHECK_EQ(exp(x).item(), 1.0); x = array(2.0); - CHECK_EQ(exp(x).item(), std::exp(2.0f)); + CHECK_EQ(exp(x).item(), doctest::Approx(std::exp(2.0f))); CHECK(array_equal(exp(array({})), array({})).item()); @@ -1012,7 +1012,7 @@ TEST_CASE("test arithmetic unary ops") { // Integer input type x = array(2); CHECK_EQ(x.dtype(), int32); - CHECK_EQ(exp(x).item(), std::exp(2.0f)); + CHECK_EQ(exp(x).item(), doctest::Approx(std::exp(2.0f))); // Input is irregularly strided x = broadcast_to(array(1.0f), {2, 2, 2}); @@ -1020,7 +1020,7 @@ TEST_CASE("test arithmetic unary ops") { x = split(array({0.0f, 1.0f, 2.0f, 3.0f}, {2, 2}), 2, 1)[0]; auto expected = array({std::exp(0.0f), std::exp(2.0f)}, {2, 1}); - CHECK(array_equal(exp(x), expected).item()); + CHECK(allclose(exp(x), expected).item()); } // Test sine @@ -2716,3 +2716,54 @@ TEST_CASE("test diag") { out = diag(x, -1); CHECK(array_equal(out, array({3, 7}, {2})).item()); } + +TEST_CASE("test atleast_1d") { + auto x = array(1); + auto out = atleast_1d(x); + CHECK_EQ(out.ndim(), 1); + CHECK_EQ(out.shape(), std::vector{1}); + + x = array({1, 2, 3}, {3}); + out = atleast_1d(x); + CHECK_EQ(out.ndim(), 1); + CHECK_EQ(out.shape(), std::vector{3}); + + x = array({1, 2, 3}, {3, 1}); + out = atleast_1d(x); + CHECK_EQ(out.ndim(), 2); + CHECK_EQ(out.shape(), std::vector{3, 1}); +} + +TEST_CASE("test atleast_2d") { + auto x = array(1); + auto out = atleast_2d(x); + CHECK_EQ(out.ndim(), 2); + CHECK_EQ(out.shape(), std::vector{1, 1}); + + x = array({1, 2, 3}, {3}); + out = atleast_2d(x); + CHECK_EQ(out.ndim(), 2); + CHECK_EQ(out.shape(), std::vector{1, 3}); + + x = array({1, 2, 3}, {3, 1}); + out = atleast_2d(x); + CHECK_EQ(out.ndim(), 2); + CHECK_EQ(out.shape(), std::vector{3, 1}); +} + +TEST_CASE("test atleast_3d") { + auto x = array(1); + auto out = atleast_3d(x); + CHECK_EQ(out.ndim(), 3); + CHECK_EQ(out.shape(), std::vector{1, 1, 1}); + + x = array({1, 2, 3}, {3}); + out = atleast_3d(x); + CHECK_EQ(out.ndim(), 3); + CHECK_EQ(out.shape(), std::vector{1, 3, 1}); + + x = array({1, 2, 3}, {3, 1}); + out = atleast_3d(x); + CHECK_EQ(out.ndim(), 3); + CHECK_EQ(out.shape(), std::vector{3, 1, 1}); +} \ No newline at end of file