mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-06 16:51:24 +08:00
Compare commits
65 Commits
v0.24.0
...
trellis-qu
Author | SHA1 | Date | |
---|---|---|---|
![]() |
998404ada4 | ||
![]() |
e3d275bc49 | ||
![]() |
d7acf59fd0 | ||
![]() |
e9e268336b | ||
![]() |
7275ac7523 | ||
![]() |
c4189a38e4 | ||
![]() |
68d1b3256b | ||
![]() |
9c6953bda7 | ||
![]() |
ef7ece9851 | ||
![]() |
ddaa4b7dcb | ||
![]() |
dfae2c6989 | ||
![]() |
515f104926 | ||
![]() |
9ecefd56db | ||
![]() |
e5d35aa187 | ||
![]() |
00794c42bc | ||
![]() |
08a1bf3f10 | ||
![]() |
60c4154346 | ||
![]() |
f2c85308c1 | ||
![]() |
1a28b69ee2 | ||
![]() |
ba09f01ce8 | ||
![]() |
6cf48872b7 | ||
![]() |
7b3b8fa000 | ||
![]() |
ec5e2aae61 | ||
![]() |
86389bf970 | ||
![]() |
3290bfa690 | ||
![]() |
8777fd104f | ||
![]() |
c41f7565ed | ||
![]() |
9ba81e3da4 | ||
![]() |
c23888acd7 | ||
![]() |
f98ce25ab9 | ||
![]() |
de5f38fd48 | ||
![]() |
ec2854b13a | ||
![]() |
90823d2938 | ||
![]() |
5f5770e3a2 | ||
![]() |
28f39e9038 | ||
![]() |
b2d2b37888 | ||
![]() |
fe597e141c | ||
![]() |
72ca1539e0 | ||
![]() |
13b26775f1 | ||
![]() |
05d7118561 | ||
![]() |
98b901ad66 | ||
![]() |
5580b47291 | ||
![]() |
bc62932984 | ||
![]() |
a6b5d6e759 | ||
![]() |
a8931306e1 | ||
![]() |
fecdb8717e | ||
![]() |
916fd273ea | ||
![]() |
0da8506552 | ||
![]() |
eda7a7b43e | ||
![]() |
022eabb734 | ||
![]() |
aba899cef8 | ||
![]() |
6a40e1c176 | ||
![]() |
9307b2ab8b | ||
![]() |
522d8d3917 | ||
![]() |
a84cc0123f | ||
![]() |
f018e248cd | ||
![]() |
cfd7237a80 | ||
![]() |
4eef8102c9 | ||
![]() |
69e4dd506b | ||
![]() |
25814a9458 | ||
![]() |
2a980a76ce | ||
![]() |
d343782c8b | ||
![]() |
4e1994e9d7 | ||
![]() |
65a38c452b | ||
![]() |
7b7e2352cd |
@@ -24,8 +24,8 @@ jobs:
|
||||
type: boolean
|
||||
default: false
|
||||
macos:
|
||||
xcode: "15.2.0"
|
||||
resource_class: macos.m1.medium.gen1
|
||||
xcode: "16.2.0"
|
||||
resource_class: m2pro.medium
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
@@ -89,15 +89,14 @@ jobs:
|
||||
pip install numpy
|
||||
sudo apt-get update
|
||||
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
||||
sudo apt-get install openmpi-bin openmpi-common libopenmpi-dev
|
||||
- run:
|
||||
name: Install Python package
|
||||
command: |
|
||||
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF
|
||||
CMAKE_COMPILE_WARNING_AS_ERROR=ON" \
|
||||
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" \
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
||||
python3 setup.py build_ext --inplace
|
||||
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF \
|
||||
CMAKE_COMPILE_WARNING_AS_ERROR=ON" \
|
||||
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" \
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
||||
python3 setup.py develop
|
||||
- run:
|
||||
@@ -110,6 +109,8 @@ jobs:
|
||||
name: Run Python tests
|
||||
command: |
|
||||
python3 -m unittest discover python/tests -v
|
||||
mpirun --bind-to none -host localhost:8 -np 8 python python/tests/mpi_test_distributed.py
|
||||
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py
|
||||
- run:
|
||||
name: Build CPP only
|
||||
command: |
|
||||
@@ -124,10 +125,15 @@ jobs:
|
||||
parameters:
|
||||
xcode_version:
|
||||
type: string
|
||||
default: "15.2.0"
|
||||
default: "16.2.0"
|
||||
macosx_deployment_target:
|
||||
type: string
|
||||
default: ""
|
||||
macos:
|
||||
xcode: << parameters.xcode_version >>
|
||||
resource_class: macos.m1.medium.gen1
|
||||
environment:
|
||||
MACOSX_DEPLOYMENT_TARGET: << parameters.macosx_deployment_target >>
|
||||
resource_class: m2pro.medium
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
@@ -149,7 +155,7 @@ jobs:
|
||||
command: |
|
||||
source env/bin/activate
|
||||
DEBUG=1 CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
|
||||
CMAKE_ARGS="CMAKE_COMPILE_WARNING_AS_ERROR=ON" \
|
||||
CMAKE_ARGS="-DCMAKE_COMPILE_WARNING_AS_ERROR=ON" \
|
||||
pip install -e . -v
|
||||
- run:
|
||||
name: Generate package stubs
|
||||
@@ -213,13 +219,18 @@ jobs:
|
||||
default: "3.9"
|
||||
xcode_version:
|
||||
type: string
|
||||
default: "15.2.0"
|
||||
default: "16.2.0"
|
||||
build_env:
|
||||
type: string
|
||||
default: ""
|
||||
macosx_deployment_target:
|
||||
type: string
|
||||
default: ""
|
||||
macos:
|
||||
xcode: << parameters.xcode_version >>
|
||||
resource_class: macos.m1.medium.gen1
|
||||
resource_class: m2pro.medium
|
||||
environment:
|
||||
MACOSX_DEPLOYMENT_TARGET: << parameters.macosx_deployment_target >>
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
@@ -240,7 +251,7 @@ jobs:
|
||||
name: Install Python package
|
||||
command: |
|
||||
source env/bin/activate
|
||||
DEV_RELEASE=1 \
|
||||
env -u MACOSX_DEPLOYMENT_TARGET DEV_RELEASE=1 \
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
|
||||
pip install . -v
|
||||
- run:
|
||||
@@ -335,7 +346,7 @@ workflows:
|
||||
- mac_build_and_test:
|
||||
matrix:
|
||||
parameters:
|
||||
xcode_version: ["15.0.0", "15.2.0", "16.0.0"]
|
||||
macosx_deployment_target: ["13.5", "14.0"]
|
||||
- linux_build_and_test
|
||||
- build_documentation
|
||||
|
||||
@@ -355,8 +366,70 @@ workflows:
|
||||
matrix:
|
||||
parameters:
|
||||
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||
xcode_version: ["15.0.0", "15.2.0"]
|
||||
macosx_deployment_target: ["13.5", "14.0", "15.0"]
|
||||
build_env: ["PYPI_RELEASE=1"]
|
||||
xcode_version: ["16.2.0", "15.0.0"]
|
||||
exclude:
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.9"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.10"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.11"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.12"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.13"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.9"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.10"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.11"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.12"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.13"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.9"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.10"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.11"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.12"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.13"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- build_documentation:
|
||||
filters:
|
||||
tags:
|
||||
@@ -379,7 +452,7 @@ workflows:
|
||||
requires: [ hold ]
|
||||
matrix:
|
||||
parameters:
|
||||
xcode_version: ["15.0.0", "15.2.0", "16.0.0"]
|
||||
macosx_deployment_target: ["13.5", "14.0"]
|
||||
- linux_build_and_test:
|
||||
requires: [ hold ]
|
||||
nightly_build:
|
||||
@@ -392,7 +465,54 @@ workflows:
|
||||
matrix:
|
||||
parameters:
|
||||
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||
xcode_version: ["15.0.0", "15.2.0"]
|
||||
macosx_deployment_target: ["13.5", "14.0", "15.0"]
|
||||
xcode_version: ["16.2.0", "15.0.0"]
|
||||
exclude:
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.9"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.10"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.11"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.12"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.13"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.9"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.10"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.11"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.12"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.13"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.9"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.10"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.11"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.12"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.13"
|
||||
weekly_build:
|
||||
when:
|
||||
and:
|
||||
@@ -403,8 +523,70 @@ workflows:
|
||||
matrix:
|
||||
parameters:
|
||||
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||
xcode_version: ["15.0.0", "15.2.0", "16.0.0"]
|
||||
macosx_deployment_target: ["13.5", "14.0", "15.0"]
|
||||
build_env: ["DEV_RELEASE=1"]
|
||||
xcode_version: ["16.2.0", "15.0.0"]
|
||||
exclude:
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.9"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.10"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.11"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.12"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.13"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.9"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.10"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.11"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.12"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.13"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.9"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.10"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.11"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.12"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.13"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
linux_test_release:
|
||||
when:
|
||||
and:
|
||||
|
@@ -212,24 +212,6 @@ else()
|
||||
set(MLX_BUILD_ACCELERATE OFF)
|
||||
endif()
|
||||
|
||||
find_package(MPI)
|
||||
if(MPI_FOUND)
|
||||
execute_process(
|
||||
COMMAND zsh "-c" "mpirun --version"
|
||||
OUTPUT_VARIABLE MPI_VERSION
|
||||
ERROR_QUIET)
|
||||
if(${MPI_VERSION} MATCHES ".*Open MPI.*")
|
||||
target_include_directories(mlx PRIVATE ${MPI_INCLUDE_PATH})
|
||||
elseif(MPI_VERSION STREQUAL "")
|
||||
set(MPI_FOUND FALSE)
|
||||
message(
|
||||
WARNING "MPI found but mpirun is not available. Building without MPI.")
|
||||
else()
|
||||
set(MPI_FOUND FALSE)
|
||||
message(WARNING "MPI which is not OpenMPI found. Building without MPI.")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
message(STATUS "Downloading json")
|
||||
FetchContent_Declare(
|
||||
json
|
||||
|
@@ -5,26 +5,26 @@ possible.
|
||||
|
||||
## Pull Requests
|
||||
|
||||
1. Fork and submit pull requests to the repo.
|
||||
1. Fork and submit pull requests to the repo.
|
||||
2. If you've added code that should be tested, add tests.
|
||||
3. If a change is likely to impact efficiency, run some of the benchmarks before
|
||||
and after the change. Examples of benchmarks can be found in `benchmarks/python/`.
|
||||
4. If you've changed APIs, update the documentation.
|
||||
5. Every PR should have passing tests and at least one review.
|
||||
5. Every PR should have passing tests and at least one review.
|
||||
6. For code formatting install `pre-commit` using something like `pip install pre-commit` and run `pre-commit install`.
|
||||
This should install hooks for running `black` and `clang-format` to ensure
|
||||
consistent style for C++ and python code.
|
||||
|
||||
|
||||
You can also run the formatters manually as follows:
|
||||
|
||||
```
|
||||
clang-format -i file.cpp
|
||||
```
|
||||
|
||||
```
|
||||
black file.py
|
||||
```
|
||||
|
||||
|
||||
```shell
|
||||
clang-format -i file.cpp
|
||||
```
|
||||
|
||||
```shell
|
||||
black file.py
|
||||
```
|
||||
|
||||
or run `pre-commit run --all-files` to check all files in the repo.
|
||||
|
||||
## Issues
|
||||
|
@@ -13,7 +13,7 @@ EXCLUDE_PATTERNS = */private/*
|
||||
CREATE_SUBDIRS = NO
|
||||
FULL_PATH_NAMES = YES
|
||||
RECURSIVE = YES
|
||||
GENERATE_HTML = YES
|
||||
GENERATE_HTML = NO
|
||||
GENERATE_LATEX = NO
|
||||
GENERATE_XML = YES
|
||||
XML_PROGRAMLISTING = YES
|
||||
|
@@ -93,9 +93,9 @@ Primitives
|
||||
^^^^^^^^^^^
|
||||
|
||||
A :class:`Primitive` is part of the computation graph of an :class:`array`. It
|
||||
defines how to create outputs arrays given a input arrays. Further, a
|
||||
defines how to create output arrays given input arrays. Further, a
|
||||
:class:`Primitive` has methods to run on the CPU or GPU and for function
|
||||
transformations such as ``vjp`` and ``jvp``. Lets go back to our example to be
|
||||
transformations such as ``vjp`` and ``jvp``. Let's go back to our example to be
|
||||
more concrete:
|
||||
|
||||
.. code-block:: C++
|
||||
@@ -128,7 +128,7 @@ more concrete:
|
||||
/** The vector-Jacobian product. */
|
||||
std::vector<array> vjp(
|
||||
const std::vector<array>& primals,
|
||||
const array& cotan,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>& outputs) override;
|
||||
|
||||
@@ -247,9 +247,7 @@ point-wise. This is captured in the templated function :meth:`axpby_impl`.
|
||||
float alpha_,
|
||||
float beta_,
|
||||
mx::Stream stream) {
|
||||
// Allocate the output with `malloc_or_wait` which synchronously allocates
|
||||
// memory, potentially waiting if the system is under memory pressure
|
||||
out.set_data(mx::allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(mx::allocator::malloc(out.nbytes()));
|
||||
|
||||
// Get the CPU command encoder and register input and output arrays
|
||||
auto& encoder = mx::cpu::get_command_encoder(stream);
|
||||
@@ -393,7 +391,7 @@ below.
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
// Allocate output memory
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
// Resolve name of kernel
|
||||
std::ostringstream kname;
|
||||
@@ -471,7 +469,7 @@ one we just defined:
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>& argnums) {
|
||||
// Forward mode diff that pushes along the tangents
|
||||
// The jvp transform on the primitive can built with ops
|
||||
// The jvp transform on the primitive can be built with ops
|
||||
// that are scheduled on the same stream as the primitive
|
||||
|
||||
// If argnums = {0}, we only push along x in which case the
|
||||
@@ -483,7 +481,7 @@ one we just defined:
|
||||
auto scale_arr = array(scale, tangents[0].dtype());
|
||||
return {multiply(scale_arr, tangents[0], stream())};
|
||||
}
|
||||
// If, argnums = {0, 1}, we take contributions from both
|
||||
// If argnums = {0, 1}, we take contributions from both
|
||||
// which gives us jvp = tangent_x * alpha + tangent_y * beta
|
||||
else {
|
||||
return {axpby(tangents[0], tangents[1], alpha_, beta_, stream())};
|
||||
@@ -737,7 +735,7 @@ Let's look at a simple script and its results:
|
||||
|
||||
print(f"c shape: {c.shape}")
|
||||
print(f"c dtype: {c.dtype}")
|
||||
print(f"c correct: {mx.all(c == 6.0).item()}")
|
||||
print(f"c is correct: {mx.all(c == 6.0).item()}")
|
||||
|
||||
Output:
|
||||
|
||||
@@ -745,7 +743,7 @@ Output:
|
||||
|
||||
c shape: [3, 4]
|
||||
c dtype: float32
|
||||
c correctness: True
|
||||
c is correct: True
|
||||
|
||||
Results
|
||||
^^^^^^^
|
||||
|
@@ -70,6 +70,7 @@ are the CPU and GPU.
|
||||
python/fft
|
||||
python/linalg
|
||||
python/metal
|
||||
python/memory_management
|
||||
python/nn
|
||||
python/optimizers
|
||||
python/distributed
|
||||
|
@@ -38,6 +38,7 @@ Array
|
||||
array.log10
|
||||
array.log1p
|
||||
array.log2
|
||||
array.logcumsumexp
|
||||
array.logsumexp
|
||||
array.max
|
||||
array.mean
|
||||
|
@@ -20,5 +20,6 @@ Linear Algebra
|
||||
eigh
|
||||
lu
|
||||
lu_factor
|
||||
pinv
|
||||
solve
|
||||
solve_triangular
|
||||
|
16
docs/src/python/memory_management.rst
Normal file
16
docs/src/python/memory_management.rst
Normal file
@@ -0,0 +1,16 @@
|
||||
Memory Management
|
||||
=================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
get_active_memory
|
||||
get_peak_memory
|
||||
reset_peak_memory
|
||||
get_cache_memory
|
||||
set_memory_limit
|
||||
set_cache_limit
|
||||
set_wired_limit
|
||||
clear_cache
|
@@ -8,13 +8,5 @@ Metal
|
||||
|
||||
is_available
|
||||
device_info
|
||||
get_active_memory
|
||||
get_peak_memory
|
||||
reset_peak_memory
|
||||
get_cache_memory
|
||||
set_memory_limit
|
||||
set_cache_limit
|
||||
set_wired_limit
|
||||
clear_cache
|
||||
start_capture
|
||||
stop_capture
|
||||
|
@@ -36,10 +36,12 @@ Operations
|
||||
bitwise_or
|
||||
bitwise_xor
|
||||
block_masked_mm
|
||||
broadcast_arrays
|
||||
broadcast_to
|
||||
ceil
|
||||
clip
|
||||
concatenate
|
||||
contiguous
|
||||
conj
|
||||
conjugate
|
||||
convolve
|
||||
@@ -101,6 +103,7 @@ Operations
|
||||
log10
|
||||
log1p
|
||||
logaddexp
|
||||
logcumsumexp
|
||||
logical_not
|
||||
logical_and
|
||||
logical_or
|
||||
|
@@ -18,3 +18,4 @@ Common Optimizers
|
||||
AdamW
|
||||
Adamax
|
||||
Lion
|
||||
MultiOptimizer
|
||||
|
@@ -9,6 +9,7 @@ Transforms
|
||||
:toctree: _autosummary
|
||||
|
||||
eval
|
||||
async_eval
|
||||
compile
|
||||
custom_function
|
||||
disable_compile
|
||||
|
@@ -72,9 +72,7 @@ void axpby_impl(
|
||||
float alpha_,
|
||||
float beta_,
|
||||
mx::Stream stream) {
|
||||
// Allocate the output with `malloc_or_wait` which synchronously allocates
|
||||
// memory, potentially waiting if the system is under memory pressure
|
||||
out.set_data(mx::allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(mx::allocator::malloc(out.nbytes()));
|
||||
|
||||
// Get the CPU command encoder and register input and output arrays
|
||||
auto& encoder = mx::cpu::get_command_encoder(stream);
|
||||
@@ -160,12 +158,12 @@ void Axpby::eval_gpu(
|
||||
// Allocate output memory with strides based on specialization
|
||||
if (contiguous_kernel) {
|
||||
out.set_data(
|
||||
mx::allocator::malloc_or_wait(x.data_size() * out.itemsize()),
|
||||
mx::allocator::malloc(x.data_size() * out.itemsize()),
|
||||
x.data_size(),
|
||||
x.strides(),
|
||||
x.flags());
|
||||
} else {
|
||||
out.set_data(mx::allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(mx::allocator::malloc(out.nbytes()));
|
||||
}
|
||||
|
||||
// Resolve name of kernel (corresponds to axpby.metal)
|
||||
|
@@ -4,12 +4,11 @@
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/scheduler.h"
|
||||
|
||||
namespace mlx::core::allocator {
|
||||
|
||||
Buffer malloc(size_t size) {
|
||||
auto buffer = allocator().malloc(size, /* allow_swap */ true);
|
||||
auto buffer = allocator().malloc(size);
|
||||
if (size && !buffer.ptr()) {
|
||||
std::ostringstream msg;
|
||||
msg << "[malloc] Unable to allocate " << size << " bytes.";
|
||||
@@ -22,45 +21,4 @@ void free(Buffer buffer) {
|
||||
allocator().free(buffer);
|
||||
}
|
||||
|
||||
Buffer CommonAllocator::malloc(size_t size, bool) {
|
||||
void* ptr = std::malloc(size + sizeof(size_t));
|
||||
if (ptr != nullptr) {
|
||||
*static_cast<size_t*>(ptr) = size;
|
||||
}
|
||||
return Buffer{ptr};
|
||||
}
|
||||
|
||||
void CommonAllocator::free(Buffer buffer) {
|
||||
std::free(buffer.ptr());
|
||||
}
|
||||
|
||||
size_t CommonAllocator::size(Buffer buffer) const {
|
||||
if (buffer.ptr() == nullptr) {
|
||||
return 0;
|
||||
}
|
||||
return *static_cast<size_t*>(buffer.ptr());
|
||||
}
|
||||
|
||||
Buffer malloc_or_wait(size_t size) {
|
||||
auto buffer = allocator().malloc(size);
|
||||
|
||||
while (size && !buffer.ptr() && scheduler::n_active_tasks() > 0) {
|
||||
scheduler::wait_for_one();
|
||||
buffer = allocator().malloc(size);
|
||||
}
|
||||
|
||||
// Try swapping if needed
|
||||
if (size && !buffer.ptr()) {
|
||||
buffer = allocator().malloc(size, /* allow_swap = */ true);
|
||||
}
|
||||
|
||||
if (size && !buffer.ptr()) {
|
||||
std::ostringstream msg;
|
||||
msg << "[malloc_or_wait] Unable to allocate " << size << " bytes.";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
return buffer;
|
||||
}
|
||||
|
||||
} // namespace mlx::core::allocator
|
||||
|
@@ -32,14 +32,10 @@ Buffer malloc(size_t size);
|
||||
|
||||
void free(Buffer buffer);
|
||||
|
||||
// Wait for running tasks to finish and free up memory
|
||||
// if allocation fails
|
||||
Buffer malloc_or_wait(size_t size);
|
||||
|
||||
class Allocator {
|
||||
/** Abstract base class for a memory allocator. */
|
||||
public:
|
||||
virtual Buffer malloc(size_t size, bool allow_swap = false) = 0;
|
||||
virtual Buffer malloc(size_t size) = 0;
|
||||
virtual void free(Buffer buffer) = 0;
|
||||
virtual size_t size(Buffer buffer) const = 0;
|
||||
|
||||
@@ -53,16 +49,4 @@ class Allocator {
|
||||
|
||||
Allocator& allocator();
|
||||
|
||||
class CommonAllocator : public Allocator {
|
||||
/** A general CPU allocator. */
|
||||
public:
|
||||
virtual Buffer malloc(size_t size, bool allow_swap = false) override;
|
||||
virtual void free(Buffer buffer) override;
|
||||
virtual size_t size(Buffer buffer) const override;
|
||||
|
||||
private:
|
||||
CommonAllocator() = default;
|
||||
friend Allocator& allocator();
|
||||
};
|
||||
|
||||
} // namespace mlx::core::allocator
|
||||
|
@@ -44,14 +44,14 @@ inline void set_binary_op_output_data(
|
||||
switch (bopt) {
|
||||
case BinaryOpType::ScalarScalar:
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(out.itemsize()), 1, a.strides(), a.flags());
|
||||
allocator::malloc(out.itemsize()), 1, a.strides(), a.flags());
|
||||
break;
|
||||
case BinaryOpType::ScalarVector:
|
||||
if (b_donatable) {
|
||||
out.copy_shared_buffer(b);
|
||||
} else {
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(b.data_size() * out.itemsize()),
|
||||
allocator::malloc(b.data_size() * out.itemsize()),
|
||||
b.data_size(),
|
||||
b.strides(),
|
||||
b.flags());
|
||||
@@ -62,7 +62,7 @@ inline void set_binary_op_output_data(
|
||||
out.copy_shared_buffer(a);
|
||||
} else {
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(a.data_size() * out.itemsize()),
|
||||
allocator::malloc(a.data_size() * out.itemsize()),
|
||||
a.data_size(),
|
||||
a.strides(),
|
||||
a.flags());
|
||||
@@ -75,7 +75,7 @@ inline void set_binary_op_output_data(
|
||||
out.copy_shared_buffer(b);
|
||||
} else {
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(a.data_size() * out.itemsize()),
|
||||
allocator::malloc(a.data_size() * out.itemsize()),
|
||||
a.data_size(),
|
||||
a.strides(),
|
||||
a.flags());
|
||||
@@ -88,7 +88,7 @@ inline void set_binary_op_output_data(
|
||||
b_donatable && b.flags().row_contiguous && b.size() == out.size()) {
|
||||
out.copy_shared_buffer(b);
|
||||
} else {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
@@ -103,7 +103,7 @@ void ExpandDims::eval(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
void NumberOfElements::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
double numel = 1;
|
||||
for (auto ax : axes_) {
|
||||
|
@@ -188,7 +188,7 @@ void compiled_allocate_outputs(
|
||||
}
|
||||
for (; o < outputs.size(); ++o) {
|
||||
outputs[o].set_data(
|
||||
allocator::malloc_or_wait(data_size * outputs[o].itemsize()),
|
||||
allocator::malloc(data_size * outputs[o].itemsize()),
|
||||
data_size,
|
||||
strides,
|
||||
flags);
|
||||
@@ -211,7 +211,7 @@ void compiled_allocate_outputs(
|
||||
}
|
||||
}
|
||||
for (; o < outputs.size(); ++o) {
|
||||
outputs[o].set_data(allocator::malloc_or_wait(outputs[o].nbytes()));
|
||||
outputs[o].set_data(allocator::malloc(outputs[o].nbytes()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@@ -31,14 +31,14 @@ inline bool set_copy_output_data(const array& in, array& out, CopyType ctype) {
|
||||
return true;
|
||||
} else {
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(in.data_size() * out.itemsize()),
|
||||
allocator::malloc(in.data_size() * out.itemsize()),
|
||||
in.data_size(),
|
||||
in.strides(),
|
||||
in.flags());
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
@@ -28,7 +28,7 @@ void swap_endianness(uint8_t* data_bytes, size_t N) {
|
||||
namespace mlx::core {
|
||||
|
||||
void Load::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
auto read_task = [out_ptr = out.data<char>(),
|
||||
size = out.size(),
|
||||
itemsize = out.itemsize(),
|
||||
|
@@ -48,12 +48,12 @@ inline void set_ternary_op_output_data(
|
||||
switch (topt) {
|
||||
case TernaryOpType::ScalarScalarScalar:
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(out.itemsize()), 1, b.strides(), b.flags());
|
||||
allocator::malloc(out.itemsize()), 1, b.strides(), b.flags());
|
||||
break;
|
||||
case TernaryOpType::VectorVectorVector:
|
||||
if (!(maybe_donate(a) || maybe_donate(b) || maybe_donate(c))) {
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(out.itemsize() * b.data_size()),
|
||||
allocator::malloc(out.itemsize() * b.data_size()),
|
||||
b.data_size(),
|
||||
b.strides(),
|
||||
b.flags());
|
||||
@@ -64,7 +64,7 @@ inline void set_ternary_op_output_data(
|
||||
if (!((a.flags().row_contiguous && maybe_donate(a)) ||
|
||||
(b.flags().row_contiguous && maybe_donate(b)) ||
|
||||
(c.flags().row_contiguous && maybe_donate(c)))) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
@@ -58,6 +58,7 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/select.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/threefry.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
||||
@@ -73,8 +74,8 @@ target_sources(
|
||||
if(MLX_BUILD_ACCELERATE)
|
||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/bnns.cpp)
|
||||
else()
|
||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/no_fp16.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/gemms/no_bf16.cpp)
|
||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/simd_fp16.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/gemms/simd_bf16.cpp)
|
||||
endif()
|
||||
|
||||
if(IOS)
|
||||
|
@@ -68,7 +68,7 @@ void arg_reduce_dispatch(
|
||||
void ArgReduce::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(out);
|
||||
|
@@ -921,7 +921,7 @@ void explicit_gemm_conv_1D_cpu(
|
||||
|
||||
if (out.dtype() != float32) {
|
||||
gemm_out = array(out.shape(), float32, nullptr, {});
|
||||
gemm_out.set_data(allocator::malloc_or_wait(gemm_out.nbytes()));
|
||||
gemm_out.set_data(allocator::malloc(gemm_out.nbytes()));
|
||||
temps.push_back(gemm_out);
|
||||
}
|
||||
|
||||
@@ -1048,7 +1048,7 @@ void explicit_gemm_conv_2D_cpu(
|
||||
|
||||
if (out.dtype() != float32) {
|
||||
gemm_out = array(out.shape(), float32, nullptr, {});
|
||||
gemm_out.set_data(allocator::malloc_or_wait(gemm_out.nbytes()));
|
||||
gemm_out.set_data(allocator::malloc(gemm_out.nbytes()));
|
||||
temps.push_back(gemm_out);
|
||||
}
|
||||
|
||||
@@ -1214,7 +1214,7 @@ void explicit_gemm_conv_ND_cpu(
|
||||
|
||||
if (out.dtype() != float32) {
|
||||
gemm_out = array(out.shape(), float32, nullptr, {});
|
||||
gemm_out.set_data(allocator::malloc_or_wait(gemm_out.nbytes()));
|
||||
gemm_out.set_data(allocator::malloc(gemm_out.nbytes()));
|
||||
temps.push_back(gemm_out);
|
||||
}
|
||||
|
||||
@@ -1327,7 +1327,7 @@ void conv_3D_cpu(
|
||||
} // namespace
|
||||
|
||||
void Convolution::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
auto& in = inputs[0];
|
||||
auto& wt = inputs[1];
|
||||
|
@@ -30,7 +30,7 @@ void AllReduce::eval_cpu(
|
||||
if (in.is_donatable()) {
|
||||
out.copy_shared_buffer(in);
|
||||
} else {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
}
|
||||
return in;
|
||||
} else {
|
||||
@@ -46,8 +46,15 @@ void AllReduce::eval_cpu(
|
||||
case Sum:
|
||||
distributed::detail::all_sum(group(), in, outputs[0], stream());
|
||||
break;
|
||||
case Max:
|
||||
distributed::detail::all_max(group(), in, outputs[0], stream());
|
||||
break;
|
||||
case Min:
|
||||
distributed::detail::all_min(group(), in, outputs[0], stream());
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error("Only all reduce sum is supported for now");
|
||||
throw std::runtime_error(
|
||||
"Only all reduce sum, min and max are supported for now");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -58,7 +65,7 @@ void AllGather::eval_cpu(
|
||||
assert(outputs.size() == 1);
|
||||
|
||||
auto [in, copied] = ensure_row_contiguous(inputs[0], stream());
|
||||
outputs[0].set_data(allocator::malloc_or_wait(outputs[0].nbytes()));
|
||||
outputs[0].set_data(allocator::malloc(outputs[0].nbytes()));
|
||||
distributed::detail::all_gather(group(), in, outputs[0], stream());
|
||||
if (copied) {
|
||||
auto& enc = cpu::get_command_encoder(stream());
|
||||
@@ -87,7 +94,7 @@ void Recv::eval_cpu(
|
||||
assert(inputs.size() == 0);
|
||||
assert(outputs.size() == 1);
|
||||
|
||||
outputs[0].set_data(allocator::malloc_or_wait(outputs[0].nbytes()));
|
||||
outputs[0].set_data(allocator::malloc(outputs[0].nbytes()));
|
||||
distributed::detail::recv(group(), outputs[0], src_, stream());
|
||||
}
|
||||
|
||||
|
@@ -55,9 +55,8 @@ void eigh_impl(
|
||||
liwork = iwork;
|
||||
}
|
||||
|
||||
auto work_buf = array::Data{allocator::malloc_or_wait(sizeof(T) * lwork)};
|
||||
auto iwork_buf =
|
||||
array::Data{allocator::malloc_or_wait(sizeof(int) * liwork)};
|
||||
auto work_buf = array::Data{allocator::malloc(sizeof(T) * lwork)};
|
||||
auto iwork_buf = array::Data{allocator::malloc(sizeof(int) * liwork)};
|
||||
for (size_t i = 0; i < size / (N * N); ++i) {
|
||||
syevd<T>(
|
||||
&jobz,
|
||||
@@ -98,7 +97,7 @@ void Eigh::eval_cpu(
|
||||
? outputs[1]
|
||||
: array(a.shape(), a.dtype(), nullptr, {});
|
||||
|
||||
values.set_data(allocator::malloc_or_wait(values.nbytes()));
|
||||
values.set_data(allocator::malloc(values.nbytes()));
|
||||
|
||||
copy(
|
||||
a,
|
||||
|
@@ -22,7 +22,7 @@ void FFT::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
s *= out.itemsize();
|
||||
}
|
||||
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
std::vector<size_t> shape;
|
||||
if (out.dtype() == float32) {
|
||||
|
@@ -1,27 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cpu/gemm.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
template <>
|
||||
void matmul<bfloat16_t>(
|
||||
const bfloat16_t*,
|
||||
const bfloat16_t*,
|
||||
bfloat16_t*,
|
||||
bool,
|
||||
bool,
|
||||
size_t,
|
||||
size_t,
|
||||
size_t,
|
||||
float,
|
||||
float,
|
||||
size_t,
|
||||
const Shape&,
|
||||
const Strides&,
|
||||
const Shape&,
|
||||
const Strides&) {
|
||||
throw std::runtime_error("[Matmul::eval_cpu] bfloat16 not supported.");
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@@ -1,27 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cpu/gemm.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
template <>
|
||||
void matmul<float16_t>(
|
||||
const float16_t*,
|
||||
const float16_t*,
|
||||
float16_t*,
|
||||
bool,
|
||||
bool,
|
||||
size_t,
|
||||
size_t,
|
||||
size_t,
|
||||
float,
|
||||
float,
|
||||
size_t,
|
||||
const Shape&,
|
||||
const Strides&,
|
||||
const Shape&,
|
||||
const Strides&) {
|
||||
throw std::runtime_error("[Matmul::eval_cpu] float16 not supported.");
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
45
mlx/backend/cpu/gemms/simd_bf16.cpp
Normal file
45
mlx/backend/cpu/gemms/simd_bf16.cpp
Normal file
@@ -0,0 +1,45 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/cpu/gemm.h"
|
||||
#include "mlx/backend/cpu/gemms/simd_gemm.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
template <>
|
||||
void matmul<bfloat16_t>(
|
||||
const bfloat16_t* a,
|
||||
const bfloat16_t* b,
|
||||
bfloat16_t* out,
|
||||
bool a_transposed,
|
||||
bool b_transposed,
|
||||
size_t lda,
|
||||
size_t ldb,
|
||||
size_t ldc,
|
||||
float alpha,
|
||||
float beta,
|
||||
size_t batch_size,
|
||||
const Shape& a_shape,
|
||||
const Strides& a_strides,
|
||||
const Shape& b_shape,
|
||||
const Strides& b_strides) {
|
||||
auto ndim = a_shape.size();
|
||||
size_t M = a_shape[ndim - 2];
|
||||
size_t N = b_shape[ndim - 1];
|
||||
size_t K = a_shape[ndim - 1];
|
||||
for (int i = 0; i < batch_size; ++i) {
|
||||
simd_gemm<bfloat16_t, float>(
|
||||
a + elem_to_loc(M * K * i, a_shape, a_strides),
|
||||
b + elem_to_loc(K * N * i, b_shape, b_strides),
|
||||
out + M * N * i,
|
||||
a_transposed,
|
||||
b_transposed,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
alpha,
|
||||
beta);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
45
mlx/backend/cpu/gemms/simd_fp16.cpp
Normal file
45
mlx/backend/cpu/gemms/simd_fp16.cpp
Normal file
@@ -0,0 +1,45 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/cpu/gemm.h"
|
||||
#include "mlx/backend/cpu/gemms/simd_gemm.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
template <>
|
||||
void matmul<float16_t>(
|
||||
const float16_t* a,
|
||||
const float16_t* b,
|
||||
float16_t* out,
|
||||
bool a_transposed,
|
||||
bool b_transposed,
|
||||
size_t lda,
|
||||
size_t ldb,
|
||||
size_t ldc,
|
||||
float alpha,
|
||||
float beta,
|
||||
size_t batch_size,
|
||||
const Shape& a_shape,
|
||||
const Strides& a_strides,
|
||||
const Shape& b_shape,
|
||||
const Strides& b_strides) {
|
||||
auto ndim = a_shape.size();
|
||||
size_t M = a_shape[ndim - 2];
|
||||
size_t N = b_shape[ndim - 1];
|
||||
size_t K = a_shape[ndim - 1];
|
||||
for (int i = 0; i < batch_size; ++i) {
|
||||
simd_gemm<float16_t, float>(
|
||||
a + elem_to_loc(M * K * i, a_shape, a_strides),
|
||||
b + elem_to_loc(K * N * i, b_shape, b_strides),
|
||||
out + M * N * i,
|
||||
a_transposed,
|
||||
b_transposed,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
alpha,
|
||||
beta);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
139
mlx/backend/cpu/gemms/simd_gemm.h
Normal file
139
mlx/backend/cpu/gemms/simd_gemm.h
Normal file
@@ -0,0 +1,139 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/cpu/simd/simd.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
inline int ceildiv(int a, int b) {
|
||||
return (a + b - 1) / b;
|
||||
}
|
||||
|
||||
template <int block_size, typename T, typename AccT>
|
||||
void load_block(
|
||||
const T* in,
|
||||
AccT* out,
|
||||
int M,
|
||||
int N,
|
||||
int i,
|
||||
int j,
|
||||
bool transpose) {
|
||||
if (transpose) {
|
||||
for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) {
|
||||
for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) {
|
||||
out[jj * block_size + ii] =
|
||||
in[(i * block_size + ii) * N + j * block_size + jj];
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) {
|
||||
for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) {
|
||||
out[ii * block_size + jj] =
|
||||
in[(i * block_size + ii) * N + j * block_size + jj];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename AccT>
|
||||
void simd_gemm(
|
||||
const T* a,
|
||||
const T* b,
|
||||
T* c,
|
||||
bool a_trans,
|
||||
bool b_trans,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
float alpha,
|
||||
float beta) {
|
||||
constexpr int block_size = 16;
|
||||
constexpr int simd_size = simd::max_size<AccT>;
|
||||
static_assert(
|
||||
(block_size % simd_size) == 0,
|
||||
"Block size must be divisible by SIMD size");
|
||||
|
||||
int last_k_block_size = K - block_size * (K / block_size);
|
||||
int last_k_simd_block = (last_k_block_size / simd_size) * simd_size;
|
||||
for (int i = 0; i < ceildiv(M, block_size); i++) {
|
||||
for (int j = 0; j < ceildiv(N, block_size); j++) {
|
||||
AccT c_block[block_size * block_size] = {0.0};
|
||||
AccT a_block[block_size * block_size];
|
||||
AccT b_block[block_size * block_size];
|
||||
|
||||
int k = 0;
|
||||
for (; k < K / block_size; k++) {
|
||||
// Load a and b blocks
|
||||
if (a_trans) {
|
||||
load_block<block_size>(a, a_block, K, M, k, i, true);
|
||||
} else {
|
||||
load_block<block_size>(a, a_block, M, K, i, k, false);
|
||||
}
|
||||
if (b_trans) {
|
||||
load_block<block_size>(b, b_block, N, K, j, k, false);
|
||||
} else {
|
||||
load_block<block_size>(b, b_block, K, N, k, j, true);
|
||||
}
|
||||
|
||||
// Multiply and accumulate
|
||||
for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) {
|
||||
for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) {
|
||||
for (int kk = 0; kk < block_size; kk += simd_size) {
|
||||
auto av =
|
||||
simd::load<AccT, simd_size>(a_block + ii * block_size + kk);
|
||||
auto bv =
|
||||
simd::load<AccT, simd_size>(b_block + jj * block_size + kk);
|
||||
c_block[ii * block_size + jj] += simd::sum(av * bv);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (last_k_block_size) {
|
||||
// Load a and b blocks
|
||||
if (a_trans) {
|
||||
load_block<block_size>(a, a_block, K, M, k, i, true);
|
||||
} else {
|
||||
load_block<block_size>(a, a_block, M, K, i, k, false);
|
||||
}
|
||||
if (b_trans) {
|
||||
load_block<block_size>(b, b_block, N, K, j, k, false);
|
||||
} else {
|
||||
load_block<block_size>(b, b_block, K, N, k, j, true);
|
||||
}
|
||||
|
||||
// Multiply and accumulate
|
||||
for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) {
|
||||
for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) {
|
||||
int kk = 0;
|
||||
for (; kk < last_k_simd_block; kk += simd_size) {
|
||||
auto av =
|
||||
simd::load<AccT, simd_size>(a_block + ii * block_size + kk);
|
||||
auto bv =
|
||||
simd::load<AccT, simd_size>(b_block + jj * block_size + kk);
|
||||
c_block[ii * block_size + jj] += simd::sum(av * bv);
|
||||
}
|
||||
for (; kk < last_k_block_size; ++kk) {
|
||||
c_block[ii * block_size + jj] +=
|
||||
a_block[ii * block_size + kk] * b_block[jj * block_size + kk];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Store
|
||||
for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) {
|
||||
for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) {
|
||||
auto c_idx = (i * block_size + ii) * N + j * block_size + jj;
|
||||
if (beta != 0) {
|
||||
c[c_idx] = static_cast<T>(
|
||||
alpha * c_block[ii * block_size + jj] + beta * c[c_idx]);
|
||||
} else {
|
||||
c[c_idx] = static_cast<T>(alpha * c_block[ii * block_size + jj]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@@ -197,7 +197,7 @@ void dispatch_gather(
|
||||
}
|
||||
|
||||
void Gather::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
auto& src = inputs[0];
|
||||
std::vector<array> inds;
|
||||
@@ -354,7 +354,7 @@ void dispatch_gather_axis(
|
||||
}
|
||||
|
||||
void GatherAxis::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
auto& src = inputs[0];
|
||||
auto& inds = inputs[1];
|
||||
|
@@ -11,7 +11,7 @@ namespace mlx::core {
|
||||
template <typename T>
|
||||
void general_inv(T* inv, int N) {
|
||||
int info;
|
||||
auto ipiv = array::Data{allocator::malloc_or_wait(sizeof(int) * N)};
|
||||
auto ipiv = array::Data{allocator::malloc(sizeof(int) * N)};
|
||||
// Compute LU factorization.
|
||||
getrf<T>(
|
||||
/* m = */ &N,
|
||||
@@ -49,7 +49,7 @@ void general_inv(T* inv, int N) {
|
||||
}
|
||||
|
||||
const int lwork = workspace_size;
|
||||
auto scratch = array::Data{allocator::malloc_or_wait(sizeof(T) * lwork)};
|
||||
auto scratch = array::Data{allocator::malloc(sizeof(T) * lwork)};
|
||||
|
||||
// Compute inverse.
|
||||
getri<T>(
|
||||
|
140
mlx/backend/cpu/logsumexp.cpp
Normal file
140
mlx/backend/cpu/logsumexp.cpp
Normal file
@@ -0,0 +1,140 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <cassert>
|
||||
#include <cmath>
|
||||
|
||||
#include "mlx/backend/cpu/copy.h"
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
#include "mlx/backend/cpu/simd/simd.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/types/limits.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
using namespace mlx::core::simd;
|
||||
|
||||
template <typename T, typename AccT>
|
||||
void logsumexp(const array& in, array& out, Stream stream) {
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(out);
|
||||
|
||||
const T* in_ptr = in.data<T>();
|
||||
T* out_ptr = out.data<T>();
|
||||
|
||||
int M = in.shape().back();
|
||||
int L = in.data_size() / M;
|
||||
|
||||
encoder.dispatch([in_ptr, out_ptr, M, L]() mutable {
|
||||
constexpr int N = std::min(max_size<AccT>, max_size<T>);
|
||||
|
||||
const T* current_in_ptr;
|
||||
|
||||
for (int i = 0; i < L; i++, in_ptr += M, out_ptr += 1) {
|
||||
// Find the maximum
|
||||
current_in_ptr = in_ptr;
|
||||
Simd<AccT, N> vmaximum(-numeric_limits<AccT>::infinity());
|
||||
size_t s = M;
|
||||
while (s >= N) {
|
||||
Simd<AccT, N> vals = load<T, N>(current_in_ptr);
|
||||
vmaximum = maximum(vals, vmaximum);
|
||||
current_in_ptr += N;
|
||||
s -= N;
|
||||
}
|
||||
|
||||
AccT maximum = max(vmaximum);
|
||||
while (s-- > 0) {
|
||||
maximum = std::max(maximum, static_cast<AccT>(*current_in_ptr));
|
||||
current_in_ptr++;
|
||||
}
|
||||
|
||||
// Compute the normalizer and the exponentials
|
||||
Simd<AccT, N> vnormalizer(0.0);
|
||||
current_in_ptr = in_ptr;
|
||||
s = M;
|
||||
while (s >= N) {
|
||||
Simd<AccT, N> vexp = load<T, N>(current_in_ptr);
|
||||
vexp = exp(vexp - maximum);
|
||||
vnormalizer = vnormalizer + vexp;
|
||||
current_in_ptr += N;
|
||||
s -= N;
|
||||
}
|
||||
AccT normalizer = sum(vnormalizer);
|
||||
while (s-- > 0) {
|
||||
AccT _exp = std::exp(*current_in_ptr - maximum);
|
||||
normalizer += _exp;
|
||||
current_in_ptr++;
|
||||
}
|
||||
// Normalize
|
||||
*out_ptr = std::isinf(maximum)
|
||||
? static_cast<T>(maximum)
|
||||
: static_cast<T>(std::log(normalizer) + maximum);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void LogSumExp::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
|
||||
// Make sure that the last dimension is contiguous
|
||||
auto s = stream();
|
||||
auto& encoder = cpu::get_command_encoder(s);
|
||||
auto ensure_contiguous = [&s, &encoder](const array& x) {
|
||||
if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
|
||||
return x;
|
||||
} else {
|
||||
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
|
||||
copy(x, x_copy, CopyType::General, s);
|
||||
encoder.add_temporary(x_copy);
|
||||
return x_copy;
|
||||
}
|
||||
};
|
||||
|
||||
auto in = ensure_contiguous(inputs[0]);
|
||||
if (in.flags().row_contiguous) {
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
} else {
|
||||
auto n = in.shape(-1);
|
||||
auto flags = in.flags();
|
||||
auto strides = in.strides();
|
||||
for (auto& s : strides) {
|
||||
s /= n;
|
||||
}
|
||||
bool col_contig = strides[0] == 1;
|
||||
for (int i = 1; col_contig && i < strides.size(); ++i) {
|
||||
col_contig &=
|
||||
(out.shape(i) == 1 || strides[i - 1] == out.shape(i) * strides[i]);
|
||||
}
|
||||
flags.col_contiguous = col_contig;
|
||||
out.set_data(
|
||||
allocator::malloc(in.nbytes() / n),
|
||||
in.data_size() / n,
|
||||
std::move(strides),
|
||||
flags);
|
||||
}
|
||||
|
||||
switch (in.dtype()) {
|
||||
case float32:
|
||||
logsumexp<float, float>(in, out, stream());
|
||||
break;
|
||||
case float16:
|
||||
logsumexp<float16_t, float>(in, out, stream());
|
||||
break;
|
||||
case bfloat16:
|
||||
logsumexp<bfloat16_t, float>(in, out, stream());
|
||||
break;
|
||||
case float64:
|
||||
logsumexp<double, double>(in, out, stream());
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"[logsumexp] only supports floating point types");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@@ -30,8 +30,7 @@ void luf_impl(
|
||||
auto strides = lu.strides();
|
||||
strides[ndim - 1] = M;
|
||||
strides[ndim - 2] = 1;
|
||||
lu.set_data(
|
||||
allocator::malloc_or_wait(lu.nbytes()), lu.nbytes(), strides, flags);
|
||||
lu.set_data(allocator::malloc(lu.nbytes()), lu.nbytes(), strides, flags);
|
||||
copy_inplace(
|
||||
a,
|
||||
lu,
|
||||
@@ -44,8 +43,8 @@ void luf_impl(
|
||||
stream);
|
||||
|
||||
auto a_ptr = lu.data<T>();
|
||||
pivots.set_data(allocator::malloc_or_wait(pivots.nbytes()));
|
||||
row_indices.set_data(allocator::malloc_or_wait(row_indices.nbytes()));
|
||||
pivots.set_data(allocator::malloc(pivots.nbytes()));
|
||||
row_indices.set_data(allocator::malloc(row_indices.nbytes()));
|
||||
auto pivots_ptr = pivots.data<uint32_t>();
|
||||
auto row_indices_ptr = row_indices.data<uint32_t>();
|
||||
size_t num_matrices = a.size() / (M * N);
|
||||
|
@@ -59,7 +59,7 @@ void BlockMaskedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
throw std::runtime_error(
|
||||
"[BlockMaskedMM::eval] Currently only supports float32.");
|
||||
}
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
auto& a_pre = inputs[0];
|
||||
auto& b_pre = inputs[1];
|
||||
@@ -318,7 +318,7 @@ void GatherMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
throw std::runtime_error(
|
||||
"[GatherMM::eval] Currently only supports float32.");
|
||||
}
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
auto& a_pre = inputs[0];
|
||||
auto& b_pre = inputs[1];
|
||||
|
@@ -115,7 +115,7 @@ void matmul_general(
|
||||
}
|
||||
|
||||
void Matmul::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
if (inputs[0].shape(-1) == 0) {
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
encoder.set_output_array(out);
|
||||
|
@@ -21,7 +21,7 @@ namespace mlx::core {
|
||||
void reshape(const array& in, array& out) {
|
||||
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
|
||||
if (copy_necessary) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
copy_inplace(in, out, CopyType::General, out.primitive().stream());
|
||||
} else {
|
||||
shared_buffer_reshape(in, out_strides, out);
|
||||
@@ -39,7 +39,7 @@ static std::pair<array, bool> compute_dynamic_offset(
|
||||
if (donate) {
|
||||
offset.copy_shared_buffer(indices);
|
||||
} else {
|
||||
offset.set_data(allocator::malloc_or_wait(offset.itemsize()));
|
||||
offset.set_data(allocator::malloc(offset.itemsize()));
|
||||
}
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
@@ -124,7 +124,7 @@ void Transpose::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
void Arange::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 0);
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
switch (out.dtype()) {
|
||||
case bool_:
|
||||
throw std::runtime_error("Bool type unsupported for arange.");
|
||||
@@ -186,7 +186,7 @@ void Concatenate::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
std::partial_sum(sizes.cbegin(), sizes.cend(), sizes.begin());
|
||||
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
auto strides = out.strides();
|
||||
auto flags = out.flags();
|
||||
@@ -205,8 +205,10 @@ void Concatenate::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
void Contiguous::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
if (in.flags().row_contiguous ||
|
||||
(allow_col_major_ && in.flags().col_contiguous)) {
|
||||
constexpr size_t extra_bytes = 16384;
|
||||
if (in.buffer_size() <= out.nbytes() + extra_bytes &&
|
||||
(in.flags().row_contiguous ||
|
||||
(allow_col_major_ && in.flags().col_contiguous))) {
|
||||
out.copy_shared_buffer(in);
|
||||
} else {
|
||||
copy(in, out, CopyType::General, stream());
|
||||
@@ -276,7 +278,7 @@ void RandomBits::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
size_t elems_per_key = out.size() / num_keys;
|
||||
size_t bytes_per_key = out.itemsize() * elems_per_key;
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
auto kptr = inputs[0].data<uint32_t>();
|
||||
auto cptr = out.data<char>();
|
||||
@@ -335,7 +337,7 @@ void DynamicSlice::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
return;
|
||||
}
|
||||
auto& in = inputs[0];
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
auto [in_offset, donated] =
|
||||
compute_dynamic_offset(inputs[1], in.strides(), axes_, stream());
|
||||
copy_inplace(
|
||||
@@ -450,7 +452,7 @@ void View::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
} else {
|
||||
auto tmp = array(
|
||||
in.shape(), in.dtype() == bool_ ? uint8 : in.dtype(), nullptr, {});
|
||||
tmp.set_data(allocator::malloc_or_wait(tmp.nbytes()));
|
||||
tmp.set_data(allocator::malloc(tmp.nbytes()));
|
||||
if (in.dtype() == bool_) {
|
||||
auto in_tmp = array(in.shape(), uint8, nullptr, {});
|
||||
in_tmp.copy_shared_buffer(in);
|
||||
|
@@ -25,12 +25,11 @@ void qrf_impl(const array& a, array& q, array& r, Stream stream) {
|
||||
auto strides = in.strides();
|
||||
strides[in.ndim() - 2] = 1;
|
||||
strides[in.ndim() - 1] = M;
|
||||
in.set_data(
|
||||
allocator::malloc_or_wait(in.nbytes()), in.nbytes(), strides, flags);
|
||||
in.set_data(allocator::malloc(in.nbytes()), in.nbytes(), strides, flags);
|
||||
copy_inplace(a, in, CopyType::GeneralGeneral, stream);
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
q.set_data(allocator::malloc_or_wait(q.nbytes()));
|
||||
r.set_data(allocator::malloc_or_wait(r.nbytes()));
|
||||
q.set_data(allocator::malloc(q.nbytes()));
|
||||
r.set_data(allocator::malloc(r.nbytes()));
|
||||
|
||||
auto in_ptr = in.data<T>();
|
||||
auto r_ptr = r.data<T>();
|
||||
@@ -41,8 +40,7 @@ void qrf_impl(const array& a, array& q, array& r, Stream stream) {
|
||||
encoder.set_output_array(r);
|
||||
encoder.dispatch([in_ptr, q_ptr, r_ptr, M, N, lda, num_matrices]() {
|
||||
int num_reflectors = std::min(M, N);
|
||||
auto tau =
|
||||
allocator::malloc_or_wait(sizeof(T) * num_matrices * num_reflectors);
|
||||
auto tau = allocator::malloc(sizeof(T) * num_matrices * num_reflectors);
|
||||
|
||||
T optimal_work;
|
||||
int lwork = -1;
|
||||
@@ -53,7 +51,7 @@ void qrf_impl(const array& a, array& q, array& r, Stream stream) {
|
||||
|
||||
// Update workspace size
|
||||
lwork = optimal_work;
|
||||
auto work = allocator::malloc_or_wait(sizeof(T) * lwork);
|
||||
auto work = allocator::malloc(sizeof(T) * lwork);
|
||||
|
||||
// Loop over matrices
|
||||
for (int i = 0; i < num_matrices; ++i) {
|
||||
@@ -96,7 +94,7 @@ void qrf_impl(const array& a, array& q, array& r, Stream stream) {
|
||||
&lwork,
|
||||
&info);
|
||||
lwork = optimal_work;
|
||||
work = allocator::malloc_or_wait(sizeof(T) * lwork);
|
||||
work = allocator::malloc(sizeof(T) * lwork);
|
||||
|
||||
// Loop over matrices
|
||||
for (int i = 0; i < num_matrices; ++i) {
|
||||
|
@@ -515,7 +515,7 @@ void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
auto scales = ensure_row_contiguous(scales_pre);
|
||||
auto biases = ensure_row_contiguous(biases_pre);
|
||||
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
encoder.add_temporaries(std::move(temps));
|
||||
@@ -565,7 +565,7 @@ void GatherQMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
auto scales = ensure_row_contiguous_last_dims(scales_pre);
|
||||
auto biases = ensure_row_contiguous_last_dims(biases_pre);
|
||||
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
encoder.add_temporaries(std::move(temps));
|
||||
@@ -691,12 +691,12 @@ void fast::AffineQuantize::eval_cpu(
|
||||
|
||||
auto [w, copied] = ensure_row_contiguous(inputs[0]);
|
||||
auto& out = outputs[0];
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
auto& scales = outputs[1];
|
||||
auto& biases = outputs[2];
|
||||
scales.set_data(allocator::malloc_or_wait(scales.nbytes()));
|
||||
biases.set_data(allocator::malloc_or_wait(biases.nbytes()));
|
||||
scales.set_data(allocator::malloc(scales.nbytes()));
|
||||
biases.set_data(allocator::malloc(biases.nbytes()));
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
if (copied) {
|
||||
encoder.add_temporary(w);
|
||||
|
@@ -433,7 +433,7 @@ void reduce_dispatch_min_max(
|
||||
void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(out);
|
||||
|
@@ -3,6 +3,7 @@
|
||||
#include <cassert>
|
||||
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/cpu/binary_ops.h"
|
||||
#include "mlx/backend/cpu/copy.h"
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
#include "mlx/backend/cpu/simd/simd.h"
|
||||
@@ -226,6 +227,16 @@ void scan_dispatch(
|
||||
scan_op<T, U>(in, out, axis, reverse, inclusive, op, init);
|
||||
break;
|
||||
}
|
||||
case Scan::LogAddExp: {
|
||||
auto op = [](U a, T b) {
|
||||
return detail::LogAddExp{}(a, static_cast<U>(b));
|
||||
};
|
||||
auto init = (issubdtype(in.dtype(), floating))
|
||||
? static_cast<U>(-std::numeric_limits<float>::infinity())
|
||||
: std::numeric_limits<U>::min();
|
||||
scan_op<T, U>(in, out, axis, reverse, inclusive, op, init);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -244,7 +255,7 @@ void Scan::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
in = arr_copy;
|
||||
encoder.add_temporary(arr_copy);
|
||||
}
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(out);
|
||||
|
@@ -17,7 +17,7 @@ struct ScalarT<float16_t, N> {
|
||||
#endif
|
||||
|
||||
template <>
|
||||
static constexpr int max_size<float16_t> = N;
|
||||
inline constexpr int max_size<float16_t> = N;
|
||||
|
||||
#define SIMD_FP16_DEFAULT_UNARY(op) \
|
||||
template <> \
|
||||
|
@@ -83,25 +83,25 @@ struct Simd {
|
||||
// Values chosen based on benchmarks on M3 Max
|
||||
// TODO: consider choosing these more optimally
|
||||
template <>
|
||||
static constexpr int max_size<int8_t> = 16;
|
||||
inline constexpr int max_size<int8_t> = 16;
|
||||
template <>
|
||||
static constexpr int max_size<int16_t> = 16;
|
||||
inline constexpr int max_size<int16_t> = 16;
|
||||
template <>
|
||||
static constexpr int max_size<int> = 8;
|
||||
inline constexpr int max_size<int> = 8;
|
||||
template <>
|
||||
static constexpr int max_size<int64_t> = 4;
|
||||
inline constexpr int max_size<int64_t> = 4;
|
||||
template <>
|
||||
static constexpr int max_size<uint8_t> = 16;
|
||||
inline constexpr int max_size<uint8_t> = 16;
|
||||
template <>
|
||||
static constexpr int max_size<uint16_t> = 16;
|
||||
inline constexpr int max_size<uint16_t> = 16;
|
||||
template <>
|
||||
static constexpr int max_size<uint32_t> = 8;
|
||||
inline constexpr int max_size<uint32_t> = 8;
|
||||
template <>
|
||||
static constexpr int max_size<uint64_t> = 4;
|
||||
inline constexpr int max_size<uint64_t> = 4;
|
||||
template <>
|
||||
static constexpr int max_size<float> = 8;
|
||||
inline constexpr int max_size<float> = 8;
|
||||
template <>
|
||||
static constexpr int max_size<double> = 4;
|
||||
inline constexpr int max_size<double> = 4;
|
||||
|
||||
#define SIMD_DEFAULT_UNARY(name, op) \
|
||||
template <typename T, int N> \
|
||||
|
@@ -87,7 +87,6 @@ DEFAULT_UNARY(cosh, std::cosh)
|
||||
DEFAULT_UNARY(expm1, std::expm1)
|
||||
DEFAULT_UNARY(floor, std::floor)
|
||||
DEFAULT_UNARY(log, std::log)
|
||||
DEFAULT_UNARY(log2, std::log2)
|
||||
DEFAULT_UNARY(log10, std::log10)
|
||||
DEFAULT_UNARY(log1p, std::log1p)
|
||||
DEFAULT_UNARY(sinh, std::sinh)
|
||||
@@ -95,6 +94,17 @@ DEFAULT_UNARY(sqrt, std::sqrt)
|
||||
DEFAULT_UNARY(tan, std::tan)
|
||||
DEFAULT_UNARY(tanh, std::tanh)
|
||||
|
||||
template <typename T>
|
||||
Simd<T, 1> log2(Simd<T, 1> in) {
|
||||
if constexpr (is_complex<T>) {
|
||||
auto out = std::log(in.value);
|
||||
auto scale = decltype(out.real())(M_LN2);
|
||||
return Simd<T, 1>{T{out.real() / scale, out.imag() / scale}};
|
||||
} else {
|
||||
return Simd<T, 1>{std::log2(in.value)};
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Simd<T, 1> operator~(Simd<T, 1> in) {
|
||||
return ~in.value;
|
||||
|
@@ -119,17 +119,12 @@ void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
// Make sure that the last dimension is contiguous
|
||||
auto set_output = [s = stream(), &out](const array& x) {
|
||||
bool no_copy = x.strides()[x.ndim() - 1] == 1;
|
||||
if (x.ndim() > 1) {
|
||||
auto s = x.strides()[x.ndim() - 2];
|
||||
no_copy &= (s == 0 || s == x.shape().back());
|
||||
}
|
||||
if (no_copy) {
|
||||
if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
|
||||
if (x.is_donatable()) {
|
||||
out.copy_shared_buffer(x);
|
||||
} else {
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(x.data_size() * x.itemsize()),
|
||||
allocator::malloc(x.data_size() * x.itemsize()),
|
||||
x.data_size(),
|
||||
x.strides(),
|
||||
x.flags());
|
||||
@@ -146,18 +141,6 @@ void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
auto in = set_output(inputs[0]);
|
||||
|
||||
switch (in.dtype()) {
|
||||
case bool_:
|
||||
case uint8:
|
||||
case uint16:
|
||||
case uint32:
|
||||
case uint64:
|
||||
case int8:
|
||||
case int16:
|
||||
case int32:
|
||||
case int64:
|
||||
throw std::runtime_error(
|
||||
"Softmax is defined only for floating point types");
|
||||
break;
|
||||
case float32:
|
||||
softmax<float, float>(in, out, stream());
|
||||
break;
|
||||
@@ -178,9 +161,9 @@ void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
case float64:
|
||||
softmax<double, double>(in, out, stream());
|
||||
break;
|
||||
case complex64:
|
||||
throw std::invalid_argument(
|
||||
"[Softmax] Not yet implemented for complex64");
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"[softmax] Only defined for floating point types.");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
@@ -288,7 +288,7 @@ void ArgSort::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& in = inputs[0];
|
||||
|
||||
// Allocate output
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
encoder.set_input_array(in);
|
||||
@@ -379,7 +379,7 @@ void ArgPartition::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& in = inputs[0];
|
||||
|
||||
// Allocate output
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
encoder.set_input_array(in);
|
||||
|
@@ -50,9 +50,9 @@ void svd_impl(
|
||||
array& s = outputs[1];
|
||||
array& vt = outputs[2];
|
||||
|
||||
u.set_data(allocator::malloc_or_wait(u.nbytes()));
|
||||
s.set_data(allocator::malloc_or_wait(s.nbytes()));
|
||||
vt.set_data(allocator::malloc_or_wait(vt.nbytes()));
|
||||
u.set_data(allocator::malloc(u.nbytes()));
|
||||
s.set_data(allocator::malloc(s.nbytes()));
|
||||
vt.set_data(allocator::malloc(vt.nbytes()));
|
||||
|
||||
encoder.set_output_array(u);
|
||||
encoder.set_output_array(s);
|
||||
@@ -64,7 +64,7 @@ void svd_impl(
|
||||
} else {
|
||||
array& s = outputs[0];
|
||||
|
||||
s.set_data(allocator::malloc_or_wait(s.nbytes()));
|
||||
s.set_data(allocator::malloc(s.nbytes()));
|
||||
|
||||
encoder.set_output_array(s);
|
||||
|
||||
@@ -91,7 +91,7 @@ void svd_impl(
|
||||
|
||||
// Will contain the indices of eigenvectors that failed to converge (not
|
||||
// used here but required by lapack).
|
||||
auto iwork = array::Data{allocator::malloc_or_wait(sizeof(int) * 12 * K)};
|
||||
auto iwork = array::Data{allocator::malloc(sizeof(int) * 12 * K)};
|
||||
|
||||
static const int lwork_query = -1;
|
||||
|
||||
@@ -132,7 +132,7 @@ void svd_impl(
|
||||
}
|
||||
|
||||
const int lwork = workspace_dimension;
|
||||
auto scratch = array::Data{allocator::malloc_or_wait(sizeof(T) * lwork)};
|
||||
auto scratch = array::Data{allocator::malloc(sizeof(T) * lwork)};
|
||||
|
||||
// Loop over matrices.
|
||||
for (int i = 0; i < num_matrices; i++) {
|
||||
|
@@ -1,5 +1,8 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
// Required for using M_LN2 in MSVC.
|
||||
#define _USE_MATH_DEFINES
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include "mlx/backend/cpu/unary.h"
|
||||
|
@@ -18,13 +18,13 @@ void set_unary_output_data(const array& in, array& out) {
|
||||
} else {
|
||||
auto size = in.data_size();
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(size * out.itemsize()),
|
||||
allocator::malloc(size * out.itemsize()),
|
||||
size,
|
||||
in.strides(),
|
||||
in.flags());
|
||||
}
|
||||
} else {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -86,13 +86,14 @@ struct Sign {
|
||||
template <int N, typename T>
|
||||
Simd<T, N> operator()(Simd<T, N> x) {
|
||||
auto z = Simd<T, N>{0};
|
||||
auto o = Simd<T, N>{1};
|
||||
auto m = Simd<T, N>{-1};
|
||||
if constexpr (std::is_unsigned_v<T>) {
|
||||
return x != z;
|
||||
return simd::select(x == z, z, o);
|
||||
} else if constexpr (std::is_same_v<T, complex64_t>) {
|
||||
return simd::select(x == z, x, Simd<T, N>(x / simd::abs(x)));
|
||||
} else {
|
||||
return simd::select(
|
||||
x < z, Simd<T, N>{-1}, simd::select(x > z, Simd<T, N>{1}, z));
|
||||
return simd::select(x < z, m, simd::select(x > z, o, z));
|
||||
}
|
||||
}
|
||||
SINGLE()
|
||||
|
@@ -47,6 +47,7 @@ if(MLX_METAL_JIT)
|
||||
make_jit_source(binary)
|
||||
make_jit_source(binary_two)
|
||||
make_jit_source(fft kernels/fft/radix.h kernels/fft/readwrite.h)
|
||||
make_jit_source(logsumexp)
|
||||
make_jit_source(ternary)
|
||||
make_jit_source(softmax)
|
||||
make_jit_source(scan)
|
||||
@@ -95,6 +96,7 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp
|
||||
|
@@ -3,6 +3,7 @@
|
||||
#include "mlx/backend/metal/metal.h"
|
||||
#include "mlx/backend/metal/metal_impl.h"
|
||||
#include "mlx/backend/metal/resident.h"
|
||||
#include "mlx/memory.h"
|
||||
|
||||
#include <mach/vm_page_size.h>
|
||||
#include <unistd.h>
|
||||
@@ -32,8 +33,11 @@ namespace metal {
|
||||
|
||||
namespace {
|
||||
|
||||
BufferCache::BufferCache(MTL::Device* device)
|
||||
: device_(device), head_(nullptr), tail_(nullptr), pool_size_(0) {}
|
||||
BufferCache::BufferCache(ResidencySet& residency_set)
|
||||
: head_(nullptr),
|
||||
tail_(nullptr),
|
||||
pool_size_(0),
|
||||
residency_set_(residency_set) {}
|
||||
|
||||
BufferCache::~BufferCache() {
|
||||
auto pool = metal::new_scoped_memory_pool();
|
||||
@@ -44,6 +48,9 @@ int BufferCache::clear() {
|
||||
int n_release = 0;
|
||||
for (auto& [size, holder] : buffer_pool_) {
|
||||
if (holder->buf) {
|
||||
if (!holder->buf->heap()) {
|
||||
residency_set_.erase(holder->buf);
|
||||
}
|
||||
holder->buf->release();
|
||||
n_release++;
|
||||
}
|
||||
@@ -101,6 +108,9 @@ int BufferCache::release_cached_buffers(size_t min_bytes_to_free) {
|
||||
while (tail_ && (total_bytes_freed < min_bytes_to_free)) {
|
||||
if (tail_->buf) {
|
||||
total_bytes_freed += tail_->buf->length();
|
||||
if (!tail_->buf->heap()) {
|
||||
residency_set_.erase(tail_->buf);
|
||||
}
|
||||
tail_->buf->release();
|
||||
tail_->buf = nullptr;
|
||||
n_release++;
|
||||
@@ -155,7 +165,7 @@ void BufferCache::remove_from_list(BufferCache::BufferHolder* to_remove) {
|
||||
MetalAllocator::MetalAllocator()
|
||||
: device_(device(mlx::core::Device::gpu).mtl_device()),
|
||||
residency_set_(device_),
|
||||
buffer_cache_(device_) {
|
||||
buffer_cache_(residency_set_) {
|
||||
auto pool = metal::new_scoped_memory_pool();
|
||||
auto memsize = std::get<size_t>(device_info().at("memory_size"));
|
||||
auto max_rec_size =
|
||||
@@ -192,16 +202,19 @@ size_t MetalAllocator::set_cache_limit(size_t limit) {
|
||||
return limit;
|
||||
};
|
||||
|
||||
size_t MetalAllocator::set_memory_limit(size_t limit, bool relaxed) {
|
||||
size_t MetalAllocator::set_memory_limit(size_t limit) {
|
||||
std::unique_lock lk(mutex_);
|
||||
std::swap(limit, block_limit_);
|
||||
relaxed_ = relaxed;
|
||||
gc_limit_ = std::min(
|
||||
block_limit_,
|
||||
static_cast<size_t>(0.95 * device_->recommendedMaxWorkingSetSize()));
|
||||
return limit;
|
||||
};
|
||||
|
||||
size_t MetalAllocator::get_memory_limit() {
|
||||
return block_limit_;
|
||||
}
|
||||
|
||||
size_t MetalAllocator::set_wired_limit(size_t limit) {
|
||||
std::unique_lock lk(mutex_);
|
||||
std::swap(limit, wired_limit_);
|
||||
@@ -209,7 +222,7 @@ size_t MetalAllocator::set_wired_limit(size_t limit) {
|
||||
return limit;
|
||||
};
|
||||
|
||||
Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
|
||||
Buffer MetalAllocator::malloc(size_t size) {
|
||||
// Metal doesn't like empty buffers
|
||||
if (size == 0) {
|
||||
return Buffer{nullptr};
|
||||
@@ -236,11 +249,6 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
|
||||
if (!buf) {
|
||||
size_t mem_required = get_active_memory() + get_cache_memory() + size;
|
||||
|
||||
// If there is too much memory pressure, fail (likely causes a wait).
|
||||
if (!(allow_swap && relaxed_) && mem_required >= block_limit_) {
|
||||
return Buffer{nullptr};
|
||||
}
|
||||
|
||||
auto pool = metal::new_scoped_memory_pool();
|
||||
|
||||
// If we have a lot of memory pressure or are over the maximum cache size,
|
||||
@@ -264,9 +272,13 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
|
||||
if (!buf) {
|
||||
buf = device_->newBuffer(size, resource_options);
|
||||
}
|
||||
if (!buf) {
|
||||
return Buffer{nullptr};
|
||||
}
|
||||
lk.lock();
|
||||
if (buf) {
|
||||
num_resources_++;
|
||||
num_resources_++;
|
||||
if (!buf->heap()) {
|
||||
residency_set_.insert(buf);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -280,10 +292,6 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
|
||||
get_cache_memory() - max_pool_size_);
|
||||
}
|
||||
|
||||
if (!buf->heap()) {
|
||||
residency_set_.insert(buf);
|
||||
}
|
||||
|
||||
return Buffer{static_cast<void*>(buf)};
|
||||
}
|
||||
|
||||
@@ -299,14 +307,14 @@ void MetalAllocator::free(Buffer buffer) {
|
||||
return;
|
||||
}
|
||||
std::unique_lock lk(mutex_);
|
||||
if (!buf->heap()) {
|
||||
residency_set_.erase(buf);
|
||||
}
|
||||
active_memory_ -= buf->length();
|
||||
if (get_cache_memory() < max_pool_size_) {
|
||||
buffer_cache_.recycle_to_cache(buf);
|
||||
} else {
|
||||
num_resources_--;
|
||||
if (!buf->heap()) {
|
||||
residency_set_.erase(buf);
|
||||
}
|
||||
lk.unlock();
|
||||
auto pool = metal::new_scoped_memory_pool();
|
||||
buf->release();
|
||||
@@ -325,37 +333,40 @@ MetalAllocator& allocator() {
|
||||
return *allocator_;
|
||||
}
|
||||
|
||||
} // namespace metal
|
||||
|
||||
size_t set_cache_limit(size_t limit) {
|
||||
return allocator().set_cache_limit(limit);
|
||||
return metal::allocator().set_cache_limit(limit);
|
||||
}
|
||||
size_t set_memory_limit(size_t limit, bool relaxed /* = true */) {
|
||||
return allocator().set_memory_limit(limit, relaxed);
|
||||
size_t set_memory_limit(size_t limit) {
|
||||
return metal::allocator().set_memory_limit(limit);
|
||||
}
|
||||
size_t get_memory_limit() {
|
||||
return metal::allocator().get_memory_limit();
|
||||
}
|
||||
size_t set_wired_limit(size_t limit) {
|
||||
if (limit >
|
||||
std::get<size_t>(device_info().at("max_recommended_working_set_size"))) {
|
||||
if (limit > std::get<size_t>(metal::device_info().at(
|
||||
"max_recommended_working_set_size"))) {
|
||||
throw std::invalid_argument(
|
||||
"[metal::set_wired_limit] Setting a wired limit larger than "
|
||||
"the maximum working set size is not allowed.");
|
||||
}
|
||||
return allocator().set_wired_limit(limit);
|
||||
return metal::allocator().set_wired_limit(limit);
|
||||
}
|
||||
size_t get_active_memory() {
|
||||
return allocator().get_active_memory();
|
||||
return metal::allocator().get_active_memory();
|
||||
}
|
||||
size_t get_peak_memory() {
|
||||
return allocator().get_peak_memory();
|
||||
return metal::allocator().get_peak_memory();
|
||||
}
|
||||
void reset_peak_memory() {
|
||||
allocator().reset_peak_memory();
|
||||
metal::allocator().reset_peak_memory();
|
||||
}
|
||||
size_t get_cache_memory() {
|
||||
return allocator().get_cache_memory();
|
||||
return metal::allocator().get_cache_memory();
|
||||
}
|
||||
void clear_cache() {
|
||||
return allocator().clear_cache();
|
||||
return metal::allocator().clear_cache();
|
||||
}
|
||||
|
||||
} // namespace metal
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -18,7 +18,7 @@ namespace {
|
||||
|
||||
class BufferCache {
|
||||
public:
|
||||
BufferCache(MTL::Device* device);
|
||||
BufferCache(ResidencySet& residency_set);
|
||||
~BufferCache();
|
||||
|
||||
MTL::Buffer* reuse_from_cache(size_t size);
|
||||
@@ -42,13 +42,11 @@ class BufferCache {
|
||||
void add_at_head(BufferHolder* to_add);
|
||||
void remove_from_list(BufferHolder* to_remove);
|
||||
|
||||
MTL::Device* device_;
|
||||
MTL::Heap* heap_{nullptr};
|
||||
|
||||
std::multimap<size_t, BufferHolder*> buffer_pool_;
|
||||
BufferHolder* head_;
|
||||
BufferHolder* tail_;
|
||||
size_t pool_size_;
|
||||
ResidencySet& residency_set_;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
@@ -56,7 +54,7 @@ class BufferCache {
|
||||
class MetalAllocator : public allocator::Allocator {
|
||||
/** Allocator for Metal GPUs. */
|
||||
public:
|
||||
virtual Buffer malloc(size_t size, bool allow_swap = false) override;
|
||||
virtual Buffer malloc(size_t size) override;
|
||||
virtual void free(Buffer buffer) override;
|
||||
virtual size_t size(Buffer buffer) const override;
|
||||
size_t get_active_memory() {
|
||||
@@ -73,7 +71,8 @@ class MetalAllocator : public allocator::Allocator {
|
||||
return buffer_cache_.cache_size();
|
||||
};
|
||||
size_t set_cache_limit(size_t limit);
|
||||
size_t set_memory_limit(size_t limit, bool relaxed);
|
||||
size_t set_memory_limit(size_t limit);
|
||||
size_t get_memory_limit();
|
||||
size_t set_wired_limit(size_t limit);
|
||||
void clear_cache();
|
||||
|
||||
|
@@ -37,7 +37,7 @@ void explicit_gemm_conv_ND_gpu(
|
||||
Shape unfolded_shape{implicit_M, implicit_K};
|
||||
array in_unfolded(unfolded_shape, in.dtype(), nullptr, {});
|
||||
|
||||
in_unfolded.set_data(allocator::malloc_or_wait(in_unfolded.nbytes()));
|
||||
in_unfolded.set_data(allocator::malloc(in_unfolded.nbytes()));
|
||||
|
||||
// Prepare unfolding kernel
|
||||
std::ostringstream kname;
|
||||
@@ -115,7 +115,7 @@ void explicit_gemm_conv_group_ND_gpu(
|
||||
// Prepare unfolding array
|
||||
Shape unfolded_shape{implicit_M, implicit_K * groups};
|
||||
array in_unfolded(unfolded_shape, in.dtype(), nullptr, {});
|
||||
in_unfolded.set_data(allocator::malloc_or_wait(in_unfolded.nbytes()));
|
||||
in_unfolded.set_data(allocator::malloc(in_unfolded.nbytes()));
|
||||
|
||||
// Prepare unfolding kernel
|
||||
std::ostringstream kname;
|
||||
@@ -613,7 +613,7 @@ void winograd_conv_2D_gpu(
|
||||
// Do filter transform
|
||||
Shape filt_wg_shape = {8 * 8, conv_params.C, conv_params.O};
|
||||
array filt_wg(std::move(filt_wg_shape), wt.dtype(), nullptr, {});
|
||||
filt_wg.set_data(allocator::malloc_or_wait(filt_wg.nbytes()));
|
||||
filt_wg.set_data(allocator::malloc(filt_wg.nbytes()));
|
||||
copies_w.push_back(filt_wg);
|
||||
{
|
||||
int bc = 32;
|
||||
@@ -640,7 +640,7 @@ void winograd_conv_2D_gpu(
|
||||
// Do input transform
|
||||
Shape inp_wg_shape = {8 * 8, N_tiles, conv_params.C};
|
||||
array inp_wg(std::move(inp_wg_shape), in.dtype(), nullptr, {});
|
||||
inp_wg.set_data(allocator::malloc_or_wait(inp_wg.nbytes()));
|
||||
inp_wg.set_data(allocator::malloc(inp_wg.nbytes()));
|
||||
copies_w.push_back(inp_wg);
|
||||
{
|
||||
int bc = 32;
|
||||
@@ -667,7 +667,7 @@ void winograd_conv_2D_gpu(
|
||||
// Do batched gemm
|
||||
Shape out_wg_shape = {8 * 8, N_tiles, conv_params.O};
|
||||
array out_wg(std::move(out_wg_shape), in.dtype(), nullptr, {});
|
||||
out_wg.set_data(allocator::malloc_or_wait(out_wg.nbytes()));
|
||||
out_wg.set_data(allocator::malloc(out_wg.nbytes()));
|
||||
copies_w.push_back(out_wg);
|
||||
{
|
||||
std::vector<array> empty_copies;
|
||||
@@ -712,6 +712,65 @@ void winograd_conv_2D_gpu(
|
||||
}
|
||||
}
|
||||
|
||||
void depthwise_conv_2D_gpu(
|
||||
const Stream& s,
|
||||
metal::Device& d,
|
||||
const array& in,
|
||||
const array& wt,
|
||||
array out,
|
||||
const MLXConvParams<2>& conv_params) {
|
||||
std::ostringstream kname;
|
||||
kname << "depthwise_conv_2d_" << type_to_name(out);
|
||||
std::string base_name = kname.str();
|
||||
|
||||
const int N = conv_params.N;
|
||||
const int ker_h = conv_params.wS[0];
|
||||
const int ker_w = conv_params.wS[1];
|
||||
const int str_h = conv_params.str[0];
|
||||
const int str_w = conv_params.str[1];
|
||||
const int tc = 8;
|
||||
const int tw = 8;
|
||||
const int th = 4;
|
||||
const bool do_flip = conv_params.flip;
|
||||
|
||||
metal::MTLFCList func_consts = {
|
||||
{&ker_h, MTL::DataType::DataTypeInt, 00},
|
||||
{&ker_w, MTL::DataType::DataTypeInt, 01},
|
||||
{&str_h, MTL::DataType::DataTypeInt, 10},
|
||||
{&str_w, MTL::DataType::DataTypeInt, 11},
|
||||
{&th, MTL::DataType::DataTypeInt, 100},
|
||||
{&tw, MTL::DataType::DataTypeInt, 101},
|
||||
{&do_flip, MTL::DataType::DataTypeBool, 200},
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
kname << "_ker_h_" << ker_h
|
||||
<< "_ker_w_" << ker_w
|
||||
<< "_str_h_" << str_h
|
||||
<< "_str_w_" << str_w
|
||||
<< "_tgp_h_" << th
|
||||
<< "_tgp_w_" << tw
|
||||
<< "_do_flip_" << (do_flip ? 't' : 'n'); // clang-format on
|
||||
|
||||
std::string hash_name = kname.str();
|
||||
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(base_name, "mlx", hash_name, func_consts);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_input_array(wt, 1);
|
||||
compute_encoder.set_output_array(out, 2);
|
||||
|
||||
compute_encoder.set_bytes(conv_params, 3);
|
||||
|
||||
MTL::Size group_dims = MTL::Size(tc, tw, th);
|
||||
MTL::Size grid_dims = MTL::Size(
|
||||
conv_params.C / tc, conv_params.oS[1] / tw, (conv_params.oS[0] / th) * N);
|
||||
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void conv_2D_gpu(
|
||||
const Stream& s,
|
||||
metal::Device& d,
|
||||
@@ -754,11 +813,20 @@ void conv_2D_gpu(
|
||||
bool is_kdil_one = conv_params.kdil[0] == 1 && conv_params.kdil[1] == 1;
|
||||
bool is_idil_one = conv_params.idil[0] == 1 && conv_params.idil[1] == 1;
|
||||
|
||||
if (groups > 1) {
|
||||
if (is_idil_one && groups > 1) {
|
||||
const int C_per_group = conv_params.C / groups;
|
||||
const int O_per_group = conv_params.O / groups;
|
||||
|
||||
if (is_idil_one && (C_per_group <= 4 || C_per_group % 16 == 0) &&
|
||||
if (C_per_group == 1 && O_per_group == 1 && is_kdil_one &&
|
||||
conv_params.wS[0] <= 7 && conv_params.wS[1] <= 7 &&
|
||||
conv_params.str[0] <= 2 && conv_params.str[1] <= 2 &&
|
||||
conv_params.oS[0] % 8 == 0 && conv_params.oS[1] % 8 == 0 &&
|
||||
conv_params.wt_strides[1] == conv_params.wS[1] &&
|
||||
conv_params.C % 16 == 0 && conv_params.C == conv_params.O) {
|
||||
return depthwise_conv_2D_gpu(s, d, in, wt, out, conv_params);
|
||||
}
|
||||
|
||||
if ((C_per_group <= 4 || C_per_group % 16 == 0) &&
|
||||
(O_per_group <= 16 || O_per_group % 16 == 0)) {
|
||||
return implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params);
|
||||
} else {
|
||||
@@ -855,7 +923,7 @@ void conv_3D_gpu(
|
||||
} // namespace
|
||||
|
||||
void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
|
@@ -202,7 +202,7 @@ void fill_gpu(const array& val, array& out, const Stream& s) {
|
||||
if (out.size() == 0) {
|
||||
return;
|
||||
}
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
bool large = out.data_size() > UINT32_MAX;
|
||||
auto& d = metal::device(s.device);
|
||||
std::string kernel_name = std::string(large ? "s2" : "s") + "_copy" +
|
||||
|
@@ -19,7 +19,7 @@ void CustomKernel::eval_gpu(
|
||||
copies.emplace_back(init_value_.value(), out.dtype());
|
||||
fill_gpu(copies.back(), out, s);
|
||||
} else {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -55,7 +55,10 @@ std::pair<MTL::Library*, NS::Error*> load_library_from_path(
|
||||
}
|
||||
|
||||
#ifdef SWIFTPM_BUNDLE
|
||||
MTL::Library* try_load_bundle(MTL::Device* device, NS::URL* url) {
|
||||
MTL::Library* try_load_bundle(
|
||||
MTL::Device* device,
|
||||
NS::URL* url,
|
||||
const std::string& lib_name) {
|
||||
std::string bundle_path = std::string(url->fileSystemRepresentation()) + "/" +
|
||||
SWIFTPM_BUNDLE + ".bundle";
|
||||
auto bundle = NS::Bundle::alloc()->init(
|
||||
@@ -63,8 +66,8 @@ MTL::Library* try_load_bundle(MTL::Device* device, NS::URL* url) {
|
||||
if (bundle != nullptr) {
|
||||
std::string resource_path =
|
||||
std::string(bundle->resourceURL()->fileSystemRepresentation()) + "/" +
|
||||
"default.metallib";
|
||||
auto [lib, error] = load_library_from_path(device, resource_path.c_str());
|
||||
lib_name + ".metallib" auto [lib, error] =
|
||||
load_library_from_path(device, resource_path.c_str());
|
||||
if (lib) {
|
||||
return lib;
|
||||
}
|
||||
@@ -73,51 +76,124 @@ MTL::Library* try_load_bundle(MTL::Device* device, NS::URL* url) {
|
||||
}
|
||||
#endif
|
||||
|
||||
// Firstly, search for the metallib in the same path as this binary
|
||||
std::pair<MTL::Library*, NS::Error*> load_colocated_library(
|
||||
MTL::Device* device,
|
||||
const std::string& lib_name) {
|
||||
std::string lib_path = get_colocated_mtllib_path(lib_name);
|
||||
if (lib_path.size() != 0) {
|
||||
return load_library_from_path(device, lib_path.c_str());
|
||||
}
|
||||
return {nullptr, nullptr};
|
||||
}
|
||||
|
||||
std::pair<MTL::Library*, NS::Error*> load_swiftpm_library(
|
||||
MTL::Device* device,
|
||||
const std::string& lib_name) {
|
||||
#ifdef SWIFTPM_BUNDLE
|
||||
MTL::Library* library =
|
||||
try_load_bundle(device, NS::Bundle::mainBundle()->bundleURL(), lib_name);
|
||||
if (library != nullptr) {
|
||||
return {library, nullptr};
|
||||
}
|
||||
auto bundles = NS::Bundle::allBundles();
|
||||
for (int i = 0, c = (int)bundles->count(); i < c; i++) {
|
||||
auto bundle = reinterpret_cast<NS::Bundle*>(bundles->object(i));
|
||||
library = try_load_bundle(device, bundle->resourceURL());
|
||||
if (library != nullptr) {
|
||||
return {library, nullptr};
|
||||
}
|
||||
}
|
||||
#endif
|
||||
return {nullptr, nullptr};
|
||||
}
|
||||
|
||||
MTL::Library* load_default_library(MTL::Device* device) {
|
||||
NS::Error *error1, *error2, *error3;
|
||||
MTL::Library* lib;
|
||||
// First try the colocated mlx.metallib
|
||||
std::tie(lib, error1) = load_colocated_library(device, "mlx");
|
||||
if (lib) {
|
||||
return lib;
|
||||
}
|
||||
|
||||
// Then try default.metallib in a SwiftPM bundle if we have one
|
||||
std::tie(lib, error2) = load_swiftpm_library(device, "default");
|
||||
if (lib) {
|
||||
return lib;
|
||||
}
|
||||
|
||||
// Finally try default_mtllib_path
|
||||
std::tie(lib, error3) = load_library_from_path(device, default_mtllib_path);
|
||||
if (!lib) {
|
||||
std::ostringstream msg;
|
||||
msg << "Failed to load the default metallib. ";
|
||||
if (error1 != nullptr) {
|
||||
msg << error1->localizedDescription()->utf8String() << " ";
|
||||
}
|
||||
if (error2 != nullptr) {
|
||||
msg << error2->localizedDescription()->utf8String() << " ";
|
||||
}
|
||||
if (error3 != nullptr) {
|
||||
msg << error3->localizedDescription()->utf8String() << " ";
|
||||
}
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
return lib;
|
||||
}
|
||||
|
||||
MTL::Library* load_library(
|
||||
MTL::Device* device,
|
||||
const std::string& lib_name = "mlx",
|
||||
const char* lib_path = default_mtllib_path) {
|
||||
// Firstly, search for the metallib in the same path as this binary
|
||||
std::string first_path = get_colocated_mtllib_path(lib_name);
|
||||
if (first_path.size() != 0) {
|
||||
auto [lib, error] = load_library_from_path(device, first_path.c_str());
|
||||
const std::string& lib_name,
|
||||
const std::string& lib_path) {
|
||||
// We have been given a path that ends in metallib so try to load it
|
||||
if (lib_path.size() > 9 &&
|
||||
std::equal(lib_path.end() - 9, lib_path.end(), ".metallib")) {
|
||||
auto [lib, error] = load_library_from_path(device, lib_path.c_str());
|
||||
if (!lib) {
|
||||
std::ostringstream msg;
|
||||
msg << "Failed to load the metallib from <" << lib_path << "> with error "
|
||||
<< error->localizedDescription()->utf8String();
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
}
|
||||
|
||||
// We have been given a path so try to load from lib_path / lib_name.metallib
|
||||
if (lib_path.size() > 0) {
|
||||
std::string full_path = lib_path + "/" + lib_name + ".metallib";
|
||||
auto [lib, error] = load_library_from_path(device, full_path.c_str());
|
||||
if (!lib) {
|
||||
std::ostringstream msg;
|
||||
msg << "Failed to load the metallib from <" << full_path
|
||||
<< "> with error " << error->localizedDescription()->utf8String();
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
}
|
||||
|
||||
// Try to load the colocated library
|
||||
{
|
||||
auto [lib, error] = load_colocated_library(device, lib_name);
|
||||
if (lib) {
|
||||
return lib;
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef SWIFTPM_BUNDLE
|
||||
// try to load from a swiftpm resource bundle -- scan the available bundles to
|
||||
// find one that contains the named bundle
|
||||
// Try to load the library from swiftpm
|
||||
{
|
||||
MTL::Library* library =
|
||||
try_load_bundle(device, NS::Bundle::mainBundle()->bundleURL());
|
||||
if (library != nullptr) {
|
||||
return library;
|
||||
}
|
||||
auto bundles = NS::Bundle::allBundles();
|
||||
for (int i = 0, c = (int)bundles->count(); i < c; i++) {
|
||||
auto bundle = reinterpret_cast<NS::Bundle*>(bundles->object(i));
|
||||
library = try_load_bundle(device, bundle->resourceURL());
|
||||
if (library != nullptr) {
|
||||
return library;
|
||||
}
|
||||
auto [lib, error] = load_swiftpm_library(device, lib_name);
|
||||
if (lib) {
|
||||
return lib;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
// Couldn't find it so let's load it from default_mtllib_path
|
||||
{
|
||||
auto [lib, error] = load_library_from_path(device, lib_path);
|
||||
if (!lib) {
|
||||
std::ostringstream msg;
|
||||
msg << error->localizedDescription()->utf8String() << "\n"
|
||||
<< "Failed to load device library from <" << lib_path << ">"
|
||||
<< " or <" << first_path << ">.";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
return lib;
|
||||
}
|
||||
std::ostringstream msg;
|
||||
msg << "Failed to load the metallib " << lib_name << ".metallib. "
|
||||
<< "We attempted to load it from <" << get_colocated_mtllib_path(lib_name)
|
||||
<< ">";
|
||||
#ifdef SWIFTPM_BUNDLE
|
||||
msg << " and from the Swift PM bundle.";
|
||||
#endif
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
@@ -210,7 +286,7 @@ void CommandEncoder::barrier() {
|
||||
Device::Device() {
|
||||
auto pool = new_scoped_memory_pool();
|
||||
device_ = load_device();
|
||||
library_map_ = {{"mlx", load_library(device_)}};
|
||||
library_map_ = {{"mlx", load_default_library(device_)}};
|
||||
arch_ = std::string(device_->architecture()->name()->utf8String());
|
||||
auto arch = arch_.back();
|
||||
switch (arch) {
|
||||
|
@@ -189,15 +189,7 @@ class Device {
|
||||
|
||||
void register_library(
|
||||
const std::string& lib_name,
|
||||
const std::string& lib_path);
|
||||
|
||||
// Note, this should remain in the header so that it is not dynamically
|
||||
// linked
|
||||
void register_library(const std::string& lib_name) {
|
||||
if (auto it = library_map_.find(lib_name); it == library_map_.end()) {
|
||||
register_library(lib_name, get_colocated_mtllib_path(lib_name));
|
||||
}
|
||||
}
|
||||
const std::string& lib_path = "");
|
||||
|
||||
MTL::Library* get_library(
|
||||
const std::string& name,
|
||||
|
@@ -24,10 +24,6 @@ void Event::wait() {
|
||||
}
|
||||
}
|
||||
|
||||
void Event::signal() {
|
||||
static_cast<MTL::SharedEvent*>(event_.get())->setSignaledValue(value());
|
||||
}
|
||||
|
||||
void Event::wait(Stream stream) {
|
||||
if (stream.device == Device::cpu) {
|
||||
scheduler::enqueue(stream, [*this]() mutable { wait(); });
|
||||
@@ -42,7 +38,9 @@ void Event::wait(Stream stream) {
|
||||
|
||||
void Event::signal(Stream stream) {
|
||||
if (stream.device == Device::cpu) {
|
||||
scheduler::enqueue(stream, [*this]() mutable { signal(); });
|
||||
scheduler::enqueue(stream, [*this]() mutable {
|
||||
static_cast<MTL::SharedEvent*>(event_.get())->setSignaledValue(value());
|
||||
});
|
||||
} else {
|
||||
auto& d = metal::device(stream.device);
|
||||
d.end_encoding(stream.index);
|
||||
|
@@ -20,7 +20,7 @@ struct FenceImpl {
|
||||
auto p = metal::new_scoped_memory_pool();
|
||||
fence = static_cast<void*>(d->newSharedEvent());
|
||||
} else {
|
||||
auto buf = allocator::malloc_or_wait(sizeof(uint32_t)).ptr();
|
||||
auto buf = allocator::malloc(sizeof(uint32_t)).ptr();
|
||||
fence = static_cast<void*>(buf);
|
||||
cpu_value()[0] = 0;
|
||||
}
|
||||
|
@@ -281,7 +281,7 @@ std::tuple<array, array, array> compute_raders_constants(
|
||||
}
|
||||
|
||||
array b_q_fft({rader_n - 1}, complex64, nullptr, {});
|
||||
b_q_fft.set_data(allocator::malloc_or_wait(b_q_fft.nbytes()));
|
||||
b_q_fft.set_data(allocator::malloc(b_q_fft.nbytes()));
|
||||
auto b_q_fft_ptr =
|
||||
reinterpret_cast<std::complex<float>*>(b_q_fft.data<complex64_t>());
|
||||
std::ptrdiff_t item_size = b_q_fft.itemsize();
|
||||
@@ -327,11 +327,11 @@ std::pair<array, array> compute_bluestein_constants(int n, int bluestein_n) {
|
||||
}
|
||||
|
||||
array w_k({n}, complex64, nullptr, {});
|
||||
w_k.set_data(allocator::malloc_or_wait(w_k.nbytes()));
|
||||
w_k.set_data(allocator::malloc(w_k.nbytes()));
|
||||
std::copy(w_k_vec.begin(), w_k_vec.end(), w_k.data<complex64_t>());
|
||||
|
||||
array w_q({bluestein_n}, complex64, nullptr, {});
|
||||
w_q.set_data(allocator::malloc_or_wait(w_q.nbytes()));
|
||||
w_q.set_data(allocator::malloc(w_q.nbytes()));
|
||||
auto w_q_ptr =
|
||||
reinterpret_cast<std::complex<float>*>(w_q.data<complex64_t>());
|
||||
|
||||
@@ -356,20 +356,14 @@ void multi_upload_bluestein_fft(
|
||||
bool inverse,
|
||||
bool real,
|
||||
FFTPlan& plan,
|
||||
std::vector<array> copies,
|
||||
std::vector<array>& copies,
|
||||
const Stream& s) {
|
||||
// TODO(alexbarron) Implement fused kernels for mutli upload bluestein's
|
||||
// algorithm
|
||||
int n = inverse ? out.shape(axis) : in.shape(axis);
|
||||
auto [w_k, w_q] = compute_bluestein_constants(n, plan.bluestein_n);
|
||||
|
||||
// Broadcast w_q and w_k to the batch size
|
||||
Strides b_strides(in.ndim(), 0);
|
||||
b_strides[axis] = 1;
|
||||
array w_k_broadcast({}, complex64, nullptr, {});
|
||||
array w_q_broadcast({}, complex64, nullptr, {});
|
||||
w_k_broadcast.copy_shared_buffer(w_k, b_strides, {}, w_k.data_size());
|
||||
w_q_broadcast.copy_shared_buffer(w_q, b_strides, {}, w_q.data_size());
|
||||
copies.push_back(w_k);
|
||||
copies.push_back(w_q);
|
||||
|
||||
auto temp_shape = inverse ? out.shape() : in.shape();
|
||||
array temp(temp_shape, complex64, nullptr, {});
|
||||
@@ -378,13 +372,13 @@ void multi_upload_bluestein_fft(
|
||||
if (real && !inverse) {
|
||||
// Convert float32->complex64
|
||||
copy_gpu(in, temp, CopyType::General, s);
|
||||
copies.push_back(temp);
|
||||
} else if (real && inverse) {
|
||||
int back_offset = n % 2 == 0 ? 2 : 1;
|
||||
auto slice_shape = in.shape();
|
||||
slice_shape[axis] -= back_offset;
|
||||
array slice_temp(slice_shape, complex64, nullptr, {});
|
||||
array conj_temp(in.shape(), complex64, nullptr, {});
|
||||
copies.push_back(slice_temp);
|
||||
copies.push_back(conj_temp);
|
||||
|
||||
Shape rstarts(in.ndim(), 0);
|
||||
@@ -394,19 +388,28 @@ void multi_upload_bluestein_fft(
|
||||
unary_op_gpu({in}, conj_temp, "Conjugate", s);
|
||||
slice_gpu(in, slice_temp, rstarts, rstrides, s);
|
||||
concatenate_gpu({conj_temp, slice_temp}, temp, (int)axis, s);
|
||||
copies.push_back(temp);
|
||||
} else if (inverse) {
|
||||
unary_op_gpu({in}, temp, "Conjugate", s);
|
||||
copies.push_back(temp);
|
||||
} else {
|
||||
temp.copy_shared_buffer(in);
|
||||
}
|
||||
|
||||
Strides b_strides(in.ndim(), 0);
|
||||
b_strides[axis] = 1;
|
||||
array w_k_broadcast(temp.shape(), complex64, nullptr, {});
|
||||
w_k_broadcast.copy_shared_buffer(w_k, b_strides, {}, w_k.data_size());
|
||||
binary_op_gpu({temp, w_k_broadcast}, temp1, "Multiply", s);
|
||||
|
||||
std::vector<std::pair<int, int>> pads;
|
||||
auto padded_shape = out.shape();
|
||||
padded_shape[axis] = plan.bluestein_n;
|
||||
array pad_temp(padded_shape, complex64, nullptr, {});
|
||||
pad_gpu(temp1, array(complex64_t{0.0f, 0.0f}), pad_temp, {(int)axis}, {0}, s);
|
||||
auto zero = array(complex64_t{0.0f, 0.0f});
|
||||
copies.push_back(zero);
|
||||
pad_gpu(temp1, zero, pad_temp, {(int)axis}, {0}, s);
|
||||
copies.push_back(pad_temp);
|
||||
|
||||
array pad_temp1(padded_shape, complex64, nullptr, {});
|
||||
fft_op(
|
||||
@@ -418,7 +421,10 @@ void multi_upload_bluestein_fft(
|
||||
FourStepParams(),
|
||||
/*inplace=*/false,
|
||||
s);
|
||||
copies.push_back(pad_temp1);
|
||||
|
||||
array w_q_broadcast(pad_temp1.shape(), complex64, nullptr, {});
|
||||
w_q_broadcast.copy_shared_buffer(w_q, b_strides, {}, w_q.data_size());
|
||||
binary_op_gpu_inplace({pad_temp1, w_q_broadcast}, pad_temp, "Multiply", s);
|
||||
|
||||
fft_op(
|
||||
@@ -435,9 +441,11 @@ void multi_upload_bluestein_fft(
|
||||
Shape starts(in.ndim(), 0);
|
||||
Shape strides(in.ndim(), 1);
|
||||
starts[axis] = plan.bluestein_n - offset - n;
|
||||
slice_gpu(pad_temp1, temp, starts, strides, s);
|
||||
|
||||
binary_op_gpu_inplace({temp, w_k_broadcast}, temp1, "Multiply", s);
|
||||
array temp2(temp_shape, complex64, nullptr, {});
|
||||
slice_gpu(pad_temp1, temp2, starts, strides, s);
|
||||
|
||||
binary_op_gpu_inplace({temp2, w_k_broadcast}, temp1, "Multiply", s);
|
||||
|
||||
if (real && !inverse) {
|
||||
Shape rstarts(in.ndim(), 0);
|
||||
@@ -449,26 +457,21 @@ void multi_upload_bluestein_fft(
|
||||
array temp_float(out.shape(), out.dtype(), nullptr, {});
|
||||
copies.push_back(temp_float);
|
||||
copies.push_back(inv_n);
|
||||
copies.push_back(temp1);
|
||||
|
||||
copy_gpu(temp1, temp_float, CopyType::General, s);
|
||||
binary_op_gpu({temp_float, inv_n}, out, "Multiply", s);
|
||||
} else if (inverse) {
|
||||
auto inv_n = array({1.0f / n}, {1}, complex64);
|
||||
unary_op_gpu({temp1}, temp, "Conjugate", s);
|
||||
binary_op_gpu({temp, inv_n}, out, "Multiply", s);
|
||||
array temp3(temp_shape, complex64, nullptr, {});
|
||||
unary_op_gpu({temp1}, temp3, "Conjugate", s);
|
||||
binary_op_gpu({temp3, inv_n}, out, "Multiply", s);
|
||||
copies.push_back(inv_n);
|
||||
copies.push_back(temp1);
|
||||
copies.push_back(temp3);
|
||||
} else {
|
||||
out.copy_shared_buffer(temp1);
|
||||
}
|
||||
|
||||
copies.push_back(w_k);
|
||||
copies.push_back(w_q);
|
||||
copies.push_back(w_k_broadcast);
|
||||
copies.push_back(w_q_broadcast);
|
||||
copies.push_back(temp);
|
||||
copies.push_back(temp1);
|
||||
copies.push_back(pad_temp);
|
||||
copies.push_back(pad_temp1);
|
||||
}
|
||||
|
||||
void four_step_fft(
|
||||
@@ -478,8 +481,9 @@ void four_step_fft(
|
||||
bool inverse,
|
||||
bool real,
|
||||
FFTPlan& plan,
|
||||
std::vector<array> copies,
|
||||
const Stream& s) {
|
||||
std::vector<array>& copies,
|
||||
const Stream& s,
|
||||
bool in_place) {
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
if (plan.bluestein_n == -1) {
|
||||
@@ -492,7 +496,14 @@ void four_step_fft(
|
||||
in, temp, axis, inverse, real, four_step_params, /*inplace=*/false, s);
|
||||
four_step_params.first_step = false;
|
||||
fft_op(
|
||||
temp, out, axis, inverse, real, four_step_params, /*inplace=*/false, s);
|
||||
temp,
|
||||
out,
|
||||
axis,
|
||||
inverse,
|
||||
real,
|
||||
four_step_params,
|
||||
/*inplace=*/in_place,
|
||||
s);
|
||||
copies.push_back(temp);
|
||||
} else {
|
||||
multi_upload_bluestein_fft(in, out, axis, inverse, real, plan, copies, s);
|
||||
@@ -551,8 +562,7 @@ void fft_op(
|
||||
flags.row_contiguous = is_row_contiguous;
|
||||
flags.contiguous = data_size == x_copy.size();
|
||||
|
||||
x_copy.set_data(
|
||||
allocator::malloc_or_wait(x.nbytes()), data_size, strides, flags);
|
||||
x_copy.set_data(allocator::malloc(x.nbytes()), data_size, strides, flags);
|
||||
copy_gpu_inplace(x, x_copy, CopyType::GeneralGeneral, s);
|
||||
copies.push_back(x_copy);
|
||||
return x_copy;
|
||||
@@ -575,7 +585,7 @@ void fft_op(
|
||||
|
||||
auto plan = plan_fft(n);
|
||||
if (plan.four_step) {
|
||||
four_step_fft(in, out, axis, inverse, real, plan, copies, s);
|
||||
four_step_fft(in, out, axis, inverse, real, plan, copies, s, inplace);
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
return;
|
||||
}
|
||||
@@ -583,7 +593,7 @@ void fft_op(
|
||||
// TODO: allow donation here
|
||||
if (!inplace) {
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(out.nbytes()),
|
||||
allocator::malloc(out.nbytes()),
|
||||
out_data_size,
|
||||
out_strides,
|
||||
in_contiguous.flags());
|
||||
|
@@ -84,7 +84,7 @@ void Hadamard::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
if (in_contiguous.is_donatable()) {
|
||||
out.copy_shared_buffer(in_contiguous);
|
||||
} else {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
}
|
||||
|
||||
int n, m;
|
||||
@@ -161,7 +161,7 @@ void Hadamard::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// Upload 2:
|
||||
// y = h12 @ tmp
|
||||
array temp(in.shape(), in.dtype(), nullptr, {});
|
||||
temp.set_data(allocator::malloc_or_wait(temp.nbytes()));
|
||||
temp.set_data(allocator::malloc(temp.nbytes()));
|
||||
copies.push_back(temp);
|
||||
|
||||
launch_hadamard(in_contiguous, temp, "n" + kernel_name, 1.0);
|
||||
|
@@ -43,7 +43,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
if (out.size() == 0) {
|
||||
return;
|
||||
}
|
||||
@@ -393,7 +393,7 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& src = inputs[0];
|
||||
auto& idx = inputs[1];
|
||||
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
if (out.size() == 0) {
|
||||
return;
|
||||
}
|
||||
|
@@ -1,9 +0,0 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
constexpr std::string_view arange_kernels = R"(
|
||||
template [[host_name("{0}")]] [[kernel]] void arange<{1}>(
|
||||
constant const {1}& start,
|
||||
constant const {1}& step,
|
||||
device {1}* out,
|
||||
uint index [[thread_position_in_grid]]);
|
||||
)";
|
@@ -20,6 +20,7 @@ const char* copy();
|
||||
const char* fft();
|
||||
const char* gather_axis();
|
||||
const char* hadamard();
|
||||
const char* logsumexp();
|
||||
const char* quantized();
|
||||
const char* ternary();
|
||||
const char* scan();
|
||||
|
@@ -1,23 +0,0 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
constexpr std::string_view softmax_kernels = R"(
|
||||
template [[host_name("block_{0}")]] [[kernel]] void
|
||||
softmax_single_row<{1}, {2}>(
|
||||
const device {1}* in,
|
||||
device {1}* out,
|
||||
constant int& axis_size,
|
||||
uint gid [[thread_position_in_grid]],
|
||||
uint _lid [[thread_position_in_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
template [[host_name("looped_{0}")]] [[kernel]] void
|
||||
softmax_looped<{1}, {2}>(
|
||||
const device {1}* in,
|
||||
device {1}* out,
|
||||
constant int& axis_size,
|
||||
uint gid [[threadgroup_position_in_grid]],
|
||||
uint lid [[thread_position_in_threadgroup]],
|
||||
uint lsize [[threads_per_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
)";
|
@@ -1,8 +1,6 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
#include "mlx/backend/metal/jit/arange.h"
|
||||
#include "mlx/backend/metal/jit/includes.h"
|
||||
#include "mlx/backend/metal/jit/softmax.h"
|
||||
#include "mlx/backend/metal/kernels.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
|
||||
@@ -21,13 +19,11 @@ MTL::ComputePipelineState* get_arange_kernel(
|
||||
const std::string& kernel_name,
|
||||
const array& out) {
|
||||
auto lib = d.get_library(kernel_name, [&]() {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::arange()
|
||||
<< fmt::format(
|
||||
arange_kernels,
|
||||
kernel_name,
|
||||
get_type_string(out.dtype()));
|
||||
return kernel_source.str();
|
||||
std::string kernel_source = metal::utils();
|
||||
kernel_source += metal::arange();
|
||||
kernel_source += get_template_definition(
|
||||
kernel_name, "arange", get_type_string(out.dtype()));
|
||||
return kernel_source;
|
||||
});
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
@@ -259,14 +255,34 @@ MTL::ComputePipelineState* get_softmax_kernel(
|
||||
const array& out) {
|
||||
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
||||
auto lib = d.get_library(lib_name, [&] {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::softmax()
|
||||
<< fmt::format(
|
||||
softmax_kernels,
|
||||
lib_name,
|
||||
get_type_string(out.dtype()),
|
||||
get_type_string(precise ? float32 : out.dtype()));
|
||||
return kernel_source.str();
|
||||
std::string kernel_source = metal::utils();
|
||||
auto in_type = get_type_string(out.dtype());
|
||||
auto acc_type = get_type_string(precise ? float32 : out.dtype());
|
||||
kernel_source += metal::softmax();
|
||||
kernel_source += get_template_definition(
|
||||
"block_" + lib_name, "softmax_single_row", in_type, acc_type);
|
||||
kernel_source += get_template_definition(
|
||||
"looped_" + lib_name, "softmax_looped", in_type, acc_type);
|
||||
return kernel_source;
|
||||
});
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_logsumexp_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& out) {
|
||||
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
||||
auto lib = d.get_library(lib_name, [&] {
|
||||
auto t_str = get_type_string(out.dtype());
|
||||
std::string kernel_source;
|
||||
kernel_source = metal::utils();
|
||||
kernel_source += metal::logsumexp();
|
||||
kernel_source +=
|
||||
get_template_definition("block_" + lib_name, "logsumexp", t_str);
|
||||
kernel_source += get_template_definition(
|
||||
"looped_" + lib_name, "logsumexp_looped", t_str);
|
||||
return kernel_source;
|
||||
});
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
@@ -59,6 +59,11 @@ MTL::ComputePipelineState* get_softmax_kernel(
|
||||
bool precise,
|
||||
const array& out);
|
||||
|
||||
MTL::ComputePipelineState* get_logsumexp_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& out);
|
||||
|
||||
MTL::ComputePipelineState* get_scan_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
|
@@ -13,6 +13,10 @@ function(build_kernel_base TARGET SRCFILE DEPS)
|
||||
if(MLX_METAL_DEBUG)
|
||||
set(METAL_FLAGS ${METAL_FLAGS} -gline-tables-only -frecord-sources)
|
||||
endif()
|
||||
if(NOT CMAKE_OSX_DEPLOYMENT_TARGET STREQUAL "")
|
||||
set(METAL_FLAGS ${METAL_FLAGS}
|
||||
"-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}")
|
||||
endif()
|
||||
if(MLX_METAL_VERSION GREATER_EQUAL 310)
|
||||
set(VERSION_INCLUDES
|
||||
${PROJECT_SOURCE_DIR}/mlx/backend/metal/kernels/metal_3_1)
|
||||
@@ -105,6 +109,7 @@ if(NOT MLX_METAL_JIT)
|
||||
build_kernel(quantized quantized.h ${STEEL_HEADERS})
|
||||
build_kernel(scan scan.h)
|
||||
build_kernel(softmax softmax.h)
|
||||
build_kernel(logsumexp logsumexp.h)
|
||||
build_kernel(sort sort.h)
|
||||
build_kernel(ternary ternary.h ternary_ops.h)
|
||||
build_kernel(unary unary.h unary_ops.h)
|
||||
|
@@ -5,11 +5,7 @@
|
||||
#include "mlx/backend/metal/kernels/arange.h"
|
||||
|
||||
#define instantiate_arange(tname, type) \
|
||||
template [[host_name("arange" #tname)]] [[kernel]] void arange<type>( \
|
||||
constant const type& start, \
|
||||
constant const type& step, \
|
||||
device type* out, \
|
||||
uint index [[thread_position_in_grid]]);
|
||||
instantiate_kernel("arange" #tname, arange, type)
|
||||
|
||||
instantiate_arange(uint8, uint8_t)
|
||||
instantiate_arange(uint16, uint16_t)
|
||||
|
@@ -275,6 +275,128 @@ instantiate_naive_conv_2d_blocks(float32, float);
|
||||
instantiate_naive_conv_2d_blocks(float16, half);
|
||||
instantiate_naive_conv_2d_blocks(bfloat16, bfloat16_t);
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
/// Depthwise convolution kernels
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
constant int ker_h [[function_constant(00)]];
|
||||
constant int ker_w [[function_constant(01)]];
|
||||
constant int str_h [[function_constant(10)]];
|
||||
constant int str_w [[function_constant(11)]];
|
||||
constant int tgp_h [[function_constant(100)]];
|
||||
constant int tgp_w [[function_constant(101)]];
|
||||
constant bool do_flip [[function_constant(200)]];
|
||||
|
||||
constant int span_h = tgp_h * str_h + ker_h - 1;
|
||||
constant int span_w = tgp_w * str_w + ker_w - 1;
|
||||
constant int span_hw = span_h * span_w;
|
||||
|
||||
template <typename T>
|
||||
[[kernel]] void depthwise_conv_2d(
|
||||
const device T* in [[buffer(0)]],
|
||||
const device T* wt [[buffer(1)]],
|
||||
device T* out [[buffer(2)]],
|
||||
const constant MLXConvParams<2>& params [[buffer(3)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint3 gid [[thread_position_in_grid]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
constexpr int tc = 8;
|
||||
constexpr int tw = 8;
|
||||
constexpr int th = 4;
|
||||
|
||||
constexpr int c_per_thr = 8;
|
||||
|
||||
constexpr int TGH = th * 2 + 6;
|
||||
constexpr int TGW = tw * 2 + 6;
|
||||
constexpr int TGC = tc;
|
||||
|
||||
threadgroup T ins[TGH * TGW * TGC];
|
||||
|
||||
const int n_tgblocks_h = params.oS[0] / th;
|
||||
const int n = tid.z / n_tgblocks_h;
|
||||
const int tghid = tid.z % n_tgblocks_h;
|
||||
const int oh = tghid * th + lid.z;
|
||||
const int ow = gid.y;
|
||||
const int c = gid.x;
|
||||
|
||||
in += n * params.in_strides[0];
|
||||
|
||||
// Load in
|
||||
{
|
||||
constexpr int n_threads = th * tw * tc;
|
||||
const int tg_oh = (tghid * th) * str_h - params.pad[0];
|
||||
const int tg_ow = (tid.y * tw) * str_w - params.pad[1];
|
||||
const int tg_c = tid.x * tc;
|
||||
|
||||
const int thread_idx = simd_gid * 32 + simd_lid;
|
||||
constexpr int thr_per_hw = tc / c_per_thr;
|
||||
constexpr int hw_per_group = n_threads / thr_per_hw;
|
||||
|
||||
const int thr_c = thread_idx % thr_per_hw;
|
||||
const int thr_hw = thread_idx / thr_per_hw;
|
||||
|
||||
for (int hw = thr_hw; hw < span_hw; hw += hw_per_group) {
|
||||
const int h = hw / span_w;
|
||||
const int w = hw % span_w;
|
||||
|
||||
const int ih = tg_oh + h;
|
||||
const int iw = tg_ow + w;
|
||||
|
||||
const int in_s_offset = h * span_w * TGC + w * TGC;
|
||||
|
||||
if (ih >= 0 && ih < params.iS[0] && iw >= 0 && iw < params.iS[1]) {
|
||||
const auto in_load =
|
||||
in + ih * params.in_strides[1] + iw * params.in_strides[2] + tg_c;
|
||||
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int cc = 0; cc < c_per_thr; ++cc) {
|
||||
ins[in_s_offset + c_per_thr * thr_c + cc] =
|
||||
in_load[c_per_thr * thr_c + cc];
|
||||
}
|
||||
} else {
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int cc = 0; cc < c_per_thr; ++cc) {
|
||||
ins[in_s_offset + c_per_thr * thr_c + cc] = T(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
wt += c * params.wt_strides[0];
|
||||
|
||||
const auto ins_ptr =
|
||||
&ins[lid.z * str_h * span_w * TGC + lid.y * str_w * TGC + lid.x];
|
||||
float o = 0.;
|
||||
for (int h = 0; h < ker_h; ++h) {
|
||||
for (int w = 0; w < ker_w; ++w) {
|
||||
int wt_h = h;
|
||||
int wt_w = w;
|
||||
if (do_flip) {
|
||||
wt_h = ker_h - h - 1;
|
||||
wt_w = ker_w - w - 1;
|
||||
}
|
||||
auto inv = ins_ptr[h * span_w * TGC + w * TGC];
|
||||
auto wtv = wt[wt_h * ker_w + wt_w];
|
||||
o += inv * wtv;
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
out += n * params.out_strides[0] + oh * params.out_strides[1] +
|
||||
ow * params.out_strides[2];
|
||||
out[c] = static_cast<T>(o);
|
||||
}
|
||||
|
||||
#define instantiate_depthconv2d(iname, itype) \
|
||||
instantiate_kernel("depthwise_conv_2d_" #iname, depthwise_conv_2d, itype)
|
||||
|
||||
instantiate_depthconv2d(float32, float);
|
||||
instantiate_depthconv2d(float16, half);
|
||||
instantiate_depthconv2d(bfloat16, bfloat16_t);
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
/// Winograd kernels
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
@@ -483,4 +483,4 @@ template <
|
||||
perform_fft(fft_idx, &p, m, n, buf);
|
||||
|
||||
read_writer.write_strided(stride, overall_n);
|
||||
}
|
||||
}
|
||||
|
@@ -341,7 +341,7 @@ struct GEMVTKernel {
|
||||
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tm = 0; tm < TM; tm++) {
|
||||
auto vc = float(v_coeff[tm]);
|
||||
auto vc = static_cast<AccT>(v_coeff[tm]);
|
||||
for (int tn = 0; tn < TN; tn++) {
|
||||
inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn];
|
||||
}
|
||||
|
@@ -493,71 +493,11 @@ template <typename T, int N_READS = RMS_N_READS>
|
||||
}
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_layer_norm_single_row(name, itype) \
|
||||
template [[host_name("layer_norm" #name)]] [[kernel]] void \
|
||||
layer_norm_single_row<itype>( \
|
||||
const device itype* x, \
|
||||
const device itype* w, \
|
||||
const device itype* b, \
|
||||
device itype* out, \
|
||||
constant float& eps, \
|
||||
constant uint& axis_size, \
|
||||
constant uint& w_stride, \
|
||||
constant uint& b_stride, \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]); \
|
||||
template [[host_name("vjp_layer_norm" #name)]] [[kernel]] void \
|
||||
vjp_layer_norm_single_row<itype>( \
|
||||
const device itype* x, \
|
||||
const device itype* w, \
|
||||
const device itype* g, \
|
||||
device itype* gx, \
|
||||
device itype* gw, \
|
||||
constant float& eps, \
|
||||
constant uint& axis_size, \
|
||||
constant uint& w_stride, \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
#define instantiate_layer_norm_looped(name, itype) \
|
||||
template [[host_name("layer_norm_looped" #name)]] [[kernel]] void \
|
||||
layer_norm_looped<itype>( \
|
||||
const device itype* x, \
|
||||
const device itype* w, \
|
||||
const device itype* b, \
|
||||
device itype* out, \
|
||||
constant float& eps, \
|
||||
constant uint& axis_size, \
|
||||
constant uint& w_stride, \
|
||||
constant uint& b_stride, \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint lsize [[threads_per_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]); \
|
||||
template [[host_name("vjp_layer_norm_looped" #name)]] [[kernel]] void \
|
||||
vjp_layer_norm_looped<itype>( \
|
||||
const device itype* x, \
|
||||
const device itype* w, \
|
||||
const device itype* g, \
|
||||
device itype* gx, \
|
||||
device itype* gb, \
|
||||
constant float& eps, \
|
||||
constant uint& axis_size, \
|
||||
constant uint& w_stride, \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint lsize [[threads_per_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
#define instantiate_layer_norm(name, itype) \
|
||||
instantiate_layer_norm_single_row(name, itype) \
|
||||
instantiate_layer_norm_looped(name, itype)
|
||||
#define instantiate_layer_norm(name, itype) \
|
||||
instantiate_kernel("layer_norm" #name, layer_norm_single_row, itype) \
|
||||
instantiate_kernel("vjp_layer_norm" #name, vjp_layer_norm_single_row, itype) \
|
||||
instantiate_kernel("layer_norm_looped" #name, layer_norm_looped, itype) \
|
||||
instantiate_kernel("vjp_layer_norm_looped" #name, vjp_layer_norm_looped, itype)
|
||||
|
||||
instantiate_layer_norm(float32, float)
|
||||
instantiate_layer_norm(float16, half)
|
||||
|
143
mlx/backend/metal/kernels/logsumexp.h
Normal file
143
mlx/backend/metal/kernels/logsumexp.h
Normal file
@@ -0,0 +1,143 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
template <typename T, typename AccT = float, int N_READS = 4>
|
||||
[[kernel]] void logsumexp(
|
||||
const device T* in,
|
||||
device T* out,
|
||||
constant int& axis_size,
|
||||
uint gid [[threadgroup_position_in_grid]],
|
||||
uint _lid [[thread_position_in_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
int lid = _lid;
|
||||
|
||||
constexpr int SIMD_SIZE = 32;
|
||||
|
||||
threadgroup AccT local_max[SIMD_SIZE];
|
||||
threadgroup AccT local_normalizer[SIMD_SIZE];
|
||||
|
||||
AccT ld[N_READS];
|
||||
|
||||
in += gid * size_t(axis_size) + lid * N_READS;
|
||||
if (lid * N_READS + N_READS <= axis_size) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
ld[i] = AccT(in[i]);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
ld[i] =
|
||||
((lid * N_READS + i) < axis_size) ? AccT(in[i]) : Limits<AccT>::min;
|
||||
}
|
||||
}
|
||||
if (simd_group_id == 0) {
|
||||
local_max[simd_lane_id] = Limits<AccT>::min;
|
||||
local_normalizer[simd_lane_id] = 0;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Get the max
|
||||
AccT maxval = Limits<AccT>::finite_min;
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
maxval = (maxval < ld[i]) ? ld[i] : maxval;
|
||||
}
|
||||
maxval = simd_max(maxval);
|
||||
if (simd_lane_id == 0) {
|
||||
local_max[simd_group_id] = maxval;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
if (simd_group_id == 0) {
|
||||
maxval = simd_max(local_max[simd_lane_id]);
|
||||
if (simd_lane_id == 0) {
|
||||
local_max[0] = maxval;
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
maxval = local_max[0];
|
||||
|
||||
// Compute exp(x_i - maxval) and store the partial sums in local_normalizer
|
||||
AccT normalizer = 0;
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
normalizer += fast::exp(ld[i] - maxval);
|
||||
}
|
||||
normalizer = simd_sum(normalizer);
|
||||
if (simd_lane_id == 0) {
|
||||
local_normalizer[simd_group_id] = normalizer;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
if (simd_group_id == 0) {
|
||||
normalizer = simd_sum(local_normalizer[simd_lane_id]);
|
||||
if (simd_lane_id == 0) {
|
||||
out[gid] = isinf(maxval) ? T(maxval) : T(log(normalizer) + maxval);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename AccT = float, int N_READS = 4>
|
||||
[[kernel]] void logsumexp_looped(
|
||||
const device T* in,
|
||||
device T* out,
|
||||
constant int& axis_size,
|
||||
uint gid [[threadgroup_position_in_grid]],
|
||||
uint lid [[thread_position_in_threadgroup]],
|
||||
uint lsize [[threads_per_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
in += gid * size_t(axis_size);
|
||||
|
||||
constexpr int SIMD_SIZE = 32;
|
||||
|
||||
threadgroup AccT local_max[SIMD_SIZE];
|
||||
threadgroup AccT local_normalizer[SIMD_SIZE];
|
||||
|
||||
// Get the max and the normalizer in one go
|
||||
AccT prevmax;
|
||||
AccT maxval = Limits<AccT>::finite_min;
|
||||
AccT normalizer = 0;
|
||||
for (int r = 0; r < static_cast<int>(ceildiv(axis_size, N_READS * lsize));
|
||||
r++) {
|
||||
int offset = r * lsize * N_READS + lid * N_READS;
|
||||
AccT vals[N_READS];
|
||||
if (offset + N_READS <= axis_size) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
vals[i] = AccT(in[offset + i]);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
vals[i] = (offset + i < axis_size) ? AccT(in[offset + i])
|
||||
: Limits<AccT>::finite_min;
|
||||
}
|
||||
}
|
||||
prevmax = maxval;
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
maxval = (maxval < vals[i]) ? vals[i] : maxval;
|
||||
}
|
||||
normalizer *= fast::exp(prevmax - maxval);
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
normalizer += fast::exp(vals[i] - maxval);
|
||||
}
|
||||
}
|
||||
prevmax = maxval;
|
||||
maxval = simd_max(maxval);
|
||||
normalizer *= fast::exp(prevmax - maxval);
|
||||
normalizer = simd_sum(normalizer);
|
||||
|
||||
prevmax = maxval;
|
||||
if (simd_lane_id == 0) {
|
||||
local_max[simd_group_id] = maxval;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
maxval = simd_max(local_max[simd_lane_id]);
|
||||
normalizer *= fast::exp(prevmax - maxval);
|
||||
if (simd_lane_id == 0) {
|
||||
local_normalizer[simd_group_id] = normalizer;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
normalizer = simd_sum(local_normalizer[simd_lane_id]);
|
||||
|
||||
if (simd_group_id == 0) {
|
||||
normalizer = simd_sum(local_normalizer[simd_lane_id]);
|
||||
if (simd_lane_id == 0) {
|
||||
out[gid] = isinf(maxval) ? T(maxval) : T(log(normalizer) + maxval);
|
||||
}
|
||||
}
|
||||
}
|
18
mlx/backend/metal/kernels/logsumexp.metal
Normal file
18
mlx/backend/metal/kernels/logsumexp.metal
Normal file
@@ -0,0 +1,18 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <metal_common>
|
||||
#include <metal_simdgroup>
|
||||
|
||||
using namespace metal;
|
||||
|
||||
// clang-format off
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
#include "mlx/backend/metal/kernels/logsumexp.h"
|
||||
|
||||
#define instantiate_logsumexp(name, itype) \
|
||||
instantiate_kernel("block_logsumexp_" #name, logsumexp, itype) \
|
||||
instantiate_kernel("looped_logsumexp_" #name, logsumexp_looped, itype) \
|
||||
|
||||
instantiate_logsumexp(float32, float)
|
||||
instantiate_logsumexp(float16, half)
|
||||
instantiate_logsumexp(bfloat16, bfloat16_t) // clang-format on
|
@@ -586,13 +586,13 @@ METAL_FUNC void qmv_quad_impl(
|
||||
// Adjust positions
|
||||
const int in_vec_size_w = in_vec_size / pack_factor;
|
||||
const int in_vec_size_g = in_vec_size / group_size;
|
||||
const int out_row = tid.x * quads_per_simd * results_per_quadgroup + quad_gid;
|
||||
const int out_row = tid.y * quads_per_simd * results_per_quadgroup + quad_gid;
|
||||
|
||||
w += out_row * in_vec_size_w + quad_lid * packs_per_thread;
|
||||
scales += out_row * in_vec_size_g + quad_lid / scale_step_per_thread;
|
||||
biases += out_row * in_vec_size_g + quad_lid / scale_step_per_thread;
|
||||
x += tid.y * in_vec_size + quad_lid * values_per_thread;
|
||||
y += tid.y * out_vec_size + out_row;
|
||||
x += tid.x * in_vec_size + quad_lid * values_per_thread;
|
||||
y += tid.x * out_vec_size + out_row;
|
||||
|
||||
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
|
||||
|
||||
@@ -684,6 +684,115 @@ METAL_FUNC void qmv_fast_impl(
|
||||
}
|
||||
}
|
||||
|
||||
template <uint32_t a = 89226354, uint32_t b = 64248484, uint32_t m = 996162400>
|
||||
float inst3(uint16_t xi) {
|
||||
uint32_t x = xi;
|
||||
x = a * x + b;
|
||||
x = (x & 0b10001111111111111000111111111111) ^ m;
|
||||
auto xf = reinterpret_cast<thread float16_t*>(&x);
|
||||
return xf[0] + xf[1];
|
||||
}
|
||||
|
||||
template <typename T, int bits>
|
||||
METAL_FUNC void qmv_trellis_impl(
|
||||
const device uint32_t* w,
|
||||
const device T* scales,
|
||||
const device T* biases,
|
||||
const device T* x,
|
||||
device T* y,
|
||||
const constant int& in_vec_size,
|
||||
const constant int& out_vec_size,
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
|
||||
constexpr int packs_per_thread = 2;
|
||||
constexpr int num_simdgroups = 2;
|
||||
constexpr int results_per_simdgroup = 4;
|
||||
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
|
||||
constexpr int bytes_per_pack = power_of_2_bits ? 4 : 3;
|
||||
constexpr int values_per_thread = pack_factor * packs_per_thread;
|
||||
constexpr int block_size = values_per_thread * SIMD_SIZE;
|
||||
constexpr int reads_per = 16 / bits;
|
||||
constexpr int local_w_size =
|
||||
results_per_simdgroup * values_per_thread / reads_per;
|
||||
|
||||
const device uint8_t* ws = (const device uint8_t*)w;
|
||||
|
||||
typedef float U;
|
||||
|
||||
thread U x_thread[values_per_thread];
|
||||
thread uint16_t w_thread[local_w_size];
|
||||
thread U result[results_per_simdgroup] = {0};
|
||||
|
||||
// Adjust positions
|
||||
const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor;
|
||||
const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) +
|
||||
simd_gid * results_per_simdgroup;
|
||||
|
||||
ws += out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack;
|
||||
x += tid.x * in_vec_size + simd_lid * values_per_thread;
|
||||
y += tid.x * out_vec_size + out_row;
|
||||
|
||||
T scale = scales[0];
|
||||
|
||||
for (int k = 0; k < in_vec_size; k += block_size) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (int i = 0; i < values_per_thread; i++) {
|
||||
x_thread[i] = x[i];
|
||||
}
|
||||
|
||||
#pragma clang loop unroll(full)
|
||||
for (int row = 0; row < results_per_simdgroup; row++) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (int i = 0; i < values_per_thread / reads_per; i++) {
|
||||
auto wl = (const device uint16_t*)(ws + row * in_vec_size_w);
|
||||
w_thread[row * values_per_thread / reads_per + i] = wl[i];
|
||||
}
|
||||
}
|
||||
|
||||
#pragma clang loop unroll(full)
|
||||
for (int row = 0; row < results_per_simdgroup; row++) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (int i = 0; i < values_per_thread / reads_per; i++) {
|
||||
int index = row * values_per_thread / reads_per + i;
|
||||
uint16_t w0 = w_thread[index];
|
||||
uint16_t w1 = w_thread[(index + 1) % local_w_size];
|
||||
|
||||
uint16_t wx = w0 ^ w1;
|
||||
uint16_t wx1 = wx ^ 1;
|
||||
uint16_t wf = w0 ^ (1 << bits);
|
||||
|
||||
if (bits == 2) {
|
||||
result[row] += x_thread[8 * i] * inst3(w0);
|
||||
result[row] += x_thread[8 * i + 1] * inst3(wf ^ (wx1 & 0x3));
|
||||
result[row] += x_thread[8 * i + 2] * inst3(w0 ^ (wx & 0xf));
|
||||
result[row] += x_thread[8 * i + 3] * inst3(w0 ^ (wx1 & 0x3f));
|
||||
result[row] += x_thread[8 * i + 4] * inst3(w0 ^ (wx & 0xff));
|
||||
result[row] += x_thread[8 * i + 5] * inst3(w0 ^ (wx1 & 0x3ff));
|
||||
result[row] += x_thread[8 * i + 6] * inst3(w0 ^ (wx & 0xfff));
|
||||
result[row] += x_thread[8 * i + 7] * inst3(w0 ^ (wx1 & 0x3fff));
|
||||
} else if (bits == 4) {
|
||||
result[row] += x_thread[4 * i] * inst3(w0);
|
||||
result[row] += x_thread[4 * i + 1] * inst3(wf ^ (wx1 & 0xf));
|
||||
result[row] += x_thread[4 * i + 2] * inst3(w0 ^ (wx & 0xff));
|
||||
result[row] += x_thread[4 * i + 3] * inst3(w0 ^ (wx1 & 0xfff));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ws += block_size * bytes_per_pack / pack_factor;
|
||||
x += block_size;
|
||||
}
|
||||
|
||||
for (int row = 0; row < results_per_simdgroup; row++) {
|
||||
result[row] = simd_sum(result[row]);
|
||||
if (simd_lid == 0) {
|
||||
y[row] = static_cast<T>(scale * result[row]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int group_size, int bits>
|
||||
METAL_FUNC void qmv_impl(
|
||||
const device uint32_t* w,
|
||||
@@ -1302,7 +1411,13 @@ METAL_FUNC void adjust_matrix_offsets(
|
||||
y += tid.z * output_stride;
|
||||
}
|
||||
|
||||
template <typename T, int group_size, int bits, int D, bool batched>
|
||||
template <
|
||||
typename T,
|
||||
int group_size,
|
||||
int bits,
|
||||
int D,
|
||||
bool batched,
|
||||
bool trellis = false>
|
||||
[[kernel]] void qmv_quad(
|
||||
const device uint32_t* w [[buffer(0)]],
|
||||
const device T* scales [[buffer(1)]],
|
||||
@@ -1354,7 +1469,12 @@ template <typename T, int group_size, int bits, int D, bool batched>
|
||||
quad_lid);
|
||||
}
|
||||
|
||||
template <typename T, int group_size, int bits, bool batched>
|
||||
template <
|
||||
typename T,
|
||||
int group_size,
|
||||
int bits,
|
||||
bool batched,
|
||||
bool trellis = false>
|
||||
[[kernel]] void qmv_fast(
|
||||
const device uint32_t* w [[buffer(0)]],
|
||||
const device T* scales [[buffer(1)]],
|
||||
@@ -1393,20 +1513,39 @@ template <typename T, int group_size, int bits, bool batched>
|
||||
b_strides,
|
||||
tid);
|
||||
}
|
||||
qmv_fast_impl<T, group_size, bits>(
|
||||
w,
|
||||
scales,
|
||||
biases,
|
||||
x,
|
||||
y,
|
||||
in_vec_size,
|
||||
out_vec_size,
|
||||
tid,
|
||||
simd_gid,
|
||||
simd_lid);
|
||||
if (trellis) {
|
||||
qmv_trellis_impl<T, bits>(
|
||||
w,
|
||||
scales,
|
||||
biases,
|
||||
x,
|
||||
y,
|
||||
in_vec_size,
|
||||
out_vec_size,
|
||||
tid,
|
||||
simd_gid,
|
||||
simd_lid);
|
||||
} else {
|
||||
qmv_fast_impl<T, group_size, bits>(
|
||||
w,
|
||||
scales,
|
||||
biases,
|
||||
x,
|
||||
y,
|
||||
in_vec_size,
|
||||
out_vec_size,
|
||||
tid,
|
||||
simd_gid,
|
||||
simd_lid);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, const int group_size, const int bits, bool batched>
|
||||
template <
|
||||
typename T,
|
||||
const int group_size,
|
||||
const int bits,
|
||||
bool batched,
|
||||
bool trellis = false>
|
||||
[[kernel]] void qmv(
|
||||
const device uint32_t* w [[buffer(0)]],
|
||||
const device T* scales [[buffer(1)]],
|
||||
@@ -1458,7 +1597,12 @@ template <typename T, const int group_size, const int bits, bool batched>
|
||||
simd_lid);
|
||||
}
|
||||
|
||||
template <typename T, const int group_size, const int bits, bool batched>
|
||||
template <
|
||||
typename T,
|
||||
const int group_size,
|
||||
const int bits,
|
||||
bool batched,
|
||||
bool trellis = false>
|
||||
[[kernel]] void qvm(
|
||||
const device uint32_t* w [[buffer(0)]],
|
||||
const device T* scales [[buffer(1)]],
|
||||
@@ -1572,6 +1716,7 @@ template <
|
||||
const int bits,
|
||||
const bool aligned_N,
|
||||
const bool batched,
|
||||
bool trellis = false,
|
||||
const int BM = 32,
|
||||
const int BK = 32,
|
||||
const int BN = 32>
|
||||
@@ -1630,6 +1775,7 @@ template <
|
||||
const int group_size,
|
||||
const int bits,
|
||||
const bool batched,
|
||||
bool trellis = false,
|
||||
const int BM = 32,
|
||||
const int BK = 32,
|
||||
const int BN = 32>
|
||||
@@ -1685,7 +1831,7 @@ template <
|
||||
w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
|
||||
}
|
||||
|
||||
template <typename T, int group_size, int bits>
|
||||
template <typename T, int group_size, int bits, bool trellis = false>
|
||||
[[kernel]] void bs_qmv_fast(
|
||||
const device uint32_t* w [[buffer(0)]],
|
||||
const device T* scales [[buffer(1)]],
|
||||
@@ -1734,20 +1880,34 @@ template <typename T, int group_size, int bits>
|
||||
s_strides,
|
||||
b_strides,
|
||||
tid);
|
||||
qmv_fast_impl<T, group_size, bits>(
|
||||
w,
|
||||
scales,
|
||||
biases,
|
||||
x,
|
||||
y,
|
||||
in_vec_size,
|
||||
out_vec_size,
|
||||
tid,
|
||||
simd_gid,
|
||||
simd_lid);
|
||||
if (trellis) {
|
||||
qmv_trellis_impl<T, bits>(
|
||||
w,
|
||||
scales,
|
||||
biases,
|
||||
x,
|
||||
y,
|
||||
in_vec_size,
|
||||
out_vec_size,
|
||||
tid,
|
||||
simd_gid,
|
||||
simd_lid);
|
||||
} else {
|
||||
qmv_fast_impl<T, group_size, bits>(
|
||||
w,
|
||||
scales,
|
||||
biases,
|
||||
x,
|
||||
y,
|
||||
in_vec_size,
|
||||
out_vec_size,
|
||||
tid,
|
||||
simd_gid,
|
||||
simd_lid);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int group_size, int bits>
|
||||
template <typename T, int group_size, int bits, bool trellis = false>
|
||||
[[kernel]] void bs_qmv(
|
||||
const device uint32_t* w [[buffer(0)]],
|
||||
const device T* scales [[buffer(1)]],
|
||||
@@ -1809,7 +1969,7 @@ template <typename T, int group_size, int bits>
|
||||
simd_lid);
|
||||
}
|
||||
|
||||
template <typename T, int group_size, int bits>
|
||||
template <typename T, int group_size, int bits, bool trellis = false>
|
||||
[[kernel]] void bs_qvm(
|
||||
const device uint32_t* w [[buffer(0)]],
|
||||
const device T* scales [[buffer(1)]],
|
||||
@@ -1876,6 +2036,7 @@ template <
|
||||
const int group_size,
|
||||
const int bits,
|
||||
const bool aligned_N,
|
||||
bool trellis = false,
|
||||
const int BM = 32,
|
||||
const int BK = 32,
|
||||
const int BN = 32>
|
||||
@@ -1943,6 +2104,7 @@ template <
|
||||
typename T,
|
||||
const int group_size,
|
||||
const int bits,
|
||||
bool trellis = false,
|
||||
const int BM = 32,
|
||||
const int BK = 32,
|
||||
const int BN = 32>
|
||||
@@ -2157,3 +2319,211 @@ template <typename T, const int group_size, const int bits>
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
const bool use_overlap,
|
||||
const int bits = 2,
|
||||
const int timesteps = 128>
|
||||
[[kernel]] void trellis_viterbi(
|
||||
const device T* w [[buffer(0)]],
|
||||
device float16_t* score [[buffer(1)]],
|
||||
device uint8_t* pointers [[buffer(2)]],
|
||||
const device uint16_t* overlap [[buffer(3)]],
|
||||
uint3 tid [[thread_position_in_grid]]) {
|
||||
constexpr uint16_t L = 16;
|
||||
constexpr uint L2 = 1 << L;
|
||||
|
||||
uint16_t idx = tid.y * 16;
|
||||
|
||||
threadgroup float16_t swap_V[16384];
|
||||
|
||||
thread float16_t min_V[16] = {0};
|
||||
|
||||
for (uint16_t t = 0; t < timesteps; t++) {
|
||||
uint16_t tt = t % 8 == 0 ? L / bits : t % 8;
|
||||
uint16_t shift = ((tt - 1) % (L / bits)) * bits;
|
||||
uint16_t flip = (t == 0 || (t > 1 && t % 8 == 1)) ? (1 << bits) + 1 : t % 2;
|
||||
|
||||
uint16_t s000 = 1 << (shift - 6);
|
||||
uint16_t s0 = 1 << (shift - 2);
|
||||
uint16_t s1 = 1 << (shift);
|
||||
uint16_t s2 = 1 << (shift + 2);
|
||||
uint16_t s4 = 1 << (shift + 4);
|
||||
|
||||
if (t > 1) {
|
||||
uint16_t i = 0;
|
||||
uint16_t loff = 1 << (metal::clamp((shift + 14) % 16, 2, 12));
|
||||
uint16_t hoff = shift > 4 ? 4 : shift == 4 ? 16 : 1;
|
||||
uint16_t ind = idx;
|
||||
|
||||
if (shift == 0) {
|
||||
ind >>= 2;
|
||||
} else if (shift == 14) {
|
||||
ind = (ind & 0xfff) + (ind >> 12);
|
||||
} else if (shift == 2) {
|
||||
} else if (shift == 4) {
|
||||
ind = ((ind >> 4) & 0x3) + (ind & ~0x3f);
|
||||
} else if (shift == 6) {
|
||||
ind = ((ind / s0) % 4) * s1 + ((ind / s1) % 4) + (ind / s2) * s2;
|
||||
} else {
|
||||
ind = ((ind / 16) % s000) * 16 + ((ind / s0) % 4) * s1 +
|
||||
((ind / s1) % 4) + (ind / s2) * s2;
|
||||
}
|
||||
|
||||
for (uint16_t high = 0; high < 4; high++) {
|
||||
uint16_t sub_ind = ind;
|
||||
for (uint16_t low = 0; low < 4; low++) {
|
||||
swap_V[sub_ind] = min_V[i];
|
||||
i++;
|
||||
sub_ind += loff;
|
||||
}
|
||||
ind += hoff;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
for (uint16_t i = 0; i < 16; i++) {
|
||||
min_V[i] = swap_V[idx + i];
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
}
|
||||
|
||||
uint16_t rolled_t = use_overlap ? t : (t + 64) % 128;
|
||||
T w_t = w[tid.x * timesteps + rolled_t];
|
||||
|
||||
for (uint16_t i = 0; i < 4; i++) {
|
||||
thread float16_t min_val[4] = {INFINITY, INFINITY, INFINITY, INFINITY};
|
||||
thread uint16_t min_idx[4] = {0};
|
||||
|
||||
uint16_t ii = idx * 4 + i * 16;
|
||||
uint16_t big_idx = ii;
|
||||
if (shift > 0 && shift < 14) {
|
||||
big_idx = ((ii / s2) % 4) + (ii / s4 * s4);
|
||||
if (shift > 2) {
|
||||
big_idx += ((ii / 16) % s0) * 4;
|
||||
}
|
||||
} else if (shift == 14 && t > 0) {
|
||||
big_idx >>= 2;
|
||||
}
|
||||
|
||||
uint16_t loff = t == 0 ? 4 : s1;
|
||||
uint16_t hoff = (t == 0 || shift == 14) ? 1 : s2;
|
||||
|
||||
for (uint16_t high = 0; high < 4; high++) {
|
||||
uint16_t sub_ind = big_idx;
|
||||
for (uint16_t low = 0; low < 4; low++) {
|
||||
float mse = inst3(sub_ind ^ flip) - w_t;
|
||||
mse *= mse;
|
||||
|
||||
float16_t new_val = min_V[i * 4 + high] + mse;
|
||||
if (new_val < min_val[low]) {
|
||||
min_val[low] = new_val;
|
||||
min_idx[low] = high;
|
||||
}
|
||||
sub_ind += loff;
|
||||
}
|
||||
big_idx += hoff;
|
||||
}
|
||||
|
||||
for (uint16_t j = 0; j < 4; j++) {
|
||||
min_V[i * 4 + j] = min_val[j];
|
||||
pointers[tid.x * L2 / 4 * timesteps + t * L2 / 4 + idx + i * 4 + j] =
|
||||
min_idx[j];
|
||||
}
|
||||
}
|
||||
if (t == 0 && use_overlap) {
|
||||
uint16_t over = overlap[tid.x * 128 + 64];
|
||||
over = over & ((1 << 14) - 1);
|
||||
for (uint16_t i = 0; i < 16; i++) {
|
||||
uint16_t rs = (over >> 2) ^ 1;
|
||||
uint16_t ls = (idx + i) & ((1 << 12) - 1);
|
||||
min_V[i] = rs == ls ? min_V[i] : INFINITY;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (use_overlap) {
|
||||
uint16_t over = overlap[tid.x * 128 + 64];
|
||||
over = over & ((1 << 14) - 1);
|
||||
uint16_t node =
|
||||
(over % 4) * 4096 + ((over / 4) % 1024) * 4 + (over / 4096) % 4;
|
||||
for (uint16_t i = 0; i < 16; i++) {
|
||||
min_V[i] = (idx + i) == node ? min_V[i] : INFINITY;
|
||||
}
|
||||
}
|
||||
for (uint16_t i = 0; i < 16; i++) {
|
||||
score[tid.x * L2 / 4 + idx + i] = min_V[i];
|
||||
}
|
||||
}
|
||||
|
||||
uint16_t remove_bits(uint16_t i, uint16_t shift) {
|
||||
uint16_t lower = i & ((1 << shift) - 1);
|
||||
uint16_t upper = i & ~((1 << (shift + 2)) - 1);
|
||||
return lower + (upper >> 2);
|
||||
}
|
||||
|
||||
uint16_t swap_bits(uint16_t i, uint16_t shift) {
|
||||
uint16_t diff = ((i >> shift) ^ i) & 0x3;
|
||||
i = i ^ diff;
|
||||
i ^= diff << shift;
|
||||
return i;
|
||||
}
|
||||
|
||||
template <const bool use_overlap, const int bits = 2, const int timesteps = 128>
|
||||
[[kernel]] void trellis_backtrack(
|
||||
const device uint32_t* start [[buffer(0)]],
|
||||
const device uint8_t* pointers [[buffer(1)]],
|
||||
device uint16_t* out [[buffer(2)]],
|
||||
const device uint16_t* overlap [[buffer(3)]],
|
||||
uint3 tid [[thread_position_in_grid]]) {
|
||||
constexpr uint16_t L = 16;
|
||||
|
||||
uint16_t node = start[tid.x];
|
||||
|
||||
uint16_t dir =
|
||||
pointers[tid.x * timesteps * 16384 + (timesteps - 1) * 16384 + node];
|
||||
|
||||
node = (node % 4) * 4096 + ((node / 4) % 1024) * 4 + (node / 4096) % 4;
|
||||
node ^= 1;
|
||||
node += dir * 16384;
|
||||
|
||||
out[tid.x * timesteps + timesteps - 1] = node;
|
||||
|
||||
for (int t = timesteps - 2; t >= 0; t--) {
|
||||
uint16_t shift = (t % (L / bits)) * bits;
|
||||
uint16_t mask = ((1 << L) - 1) ^ (((1 << bits) - 1) << shift);
|
||||
uint16_t flip = t % (L / bits) == 0 ? 1 << bits : 1;
|
||||
uint16_t i = (node & mask) ^ flip;
|
||||
|
||||
if (shift > 0) {
|
||||
i = remove_bits(i, shift);
|
||||
}
|
||||
|
||||
if (t == 0) {
|
||||
i >>= 2;
|
||||
}
|
||||
|
||||
if (t % 2 == 1 || t == 0) {
|
||||
i ^= 1;
|
||||
}
|
||||
|
||||
shift = shift == 0 ? L : shift;
|
||||
|
||||
if (t > 0) {
|
||||
i = swap_bits(i, shift - 2);
|
||||
}
|
||||
|
||||
shift = shift == L ? 0 : shift;
|
||||
|
||||
uint16_t last_p = pointers[tid.x * timesteps * 16384 + t * 16384 + i];
|
||||
if ((t % 8 == 1 && t > 1) || t == 0) {
|
||||
last_p ^= 1;
|
||||
}
|
||||
|
||||
node = ((node & mask) ^ flip) | (last_p << shift);
|
||||
if (t == 0 && use_overlap) {
|
||||
uint16_t over = overlap[tid.x * 128 + 64];
|
||||
over = over & ((1 << 14) - 1);
|
||||
node = (node & 0xfffc) + (over & 0x3);
|
||||
}
|
||||
out[tid.x * timesteps + t] = node;
|
||||
}
|
||||
}
|
@@ -120,4 +120,34 @@
|
||||
instantiate_quantized_groups(6) \
|
||||
instantiate_quantized_groups(8)
|
||||
|
||||
instantiate_kernel(
|
||||
"trellis_viterbi_float16_t_overlap_0",
|
||||
trellis_viterbi,
|
||||
float16_t,
|
||||
false)
|
||||
instantiate_kernel(
|
||||
"trellis_viterbi_float16_t_overlap_1",
|
||||
trellis_viterbi,
|
||||
float16_t,
|
||||
true)
|
||||
|
||||
instantiate_kernel(
|
||||
"trellis_backtrack_overlap_0",
|
||||
trellis_backtrack,
|
||||
false)
|
||||
instantiate_kernel(
|
||||
"trellis_backtrack_overlap_1",
|
||||
trellis_backtrack,
|
||||
true)
|
||||
|
||||
instantiate_kernel(
|
||||
"qmv_fast_float16_t_gs_64_b_2_batch_0_mode_1",
|
||||
qmv_fast,
|
||||
float16_t,
|
||||
64,
|
||||
2,
|
||||
false,
|
||||
true)
|
||||
|
||||
|
||||
instantiate_quantized_all() // clang-format on
|
||||
|
@@ -380,69 +380,11 @@ template <typename T, int N_READS = RMS_N_READS>
|
||||
}
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_rms_single_row(name, itype) \
|
||||
template [[host_name("rms" #name)]] [[kernel]] void \
|
||||
rms_single_row<itype>( \
|
||||
const device itype* x, \
|
||||
const device itype* w, \
|
||||
device itype* out, \
|
||||
constant float& eps, \
|
||||
constant uint& axis_size, \
|
||||
constant uint& w_stride, \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]); \
|
||||
\
|
||||
template [[host_name("vjp_rms" #name)]] [[kernel]] void \
|
||||
vjp_rms_single_row<itype>( \
|
||||
const device itype* x, \
|
||||
const device itype* w, \
|
||||
const device itype* g, \
|
||||
device itype* gx, \
|
||||
device itype* gw, \
|
||||
constant float& eps, \
|
||||
constant uint& axis_size, \
|
||||
constant uint& w_stride, \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
#define instantiate_rms_looped(name, itype) \
|
||||
template [[host_name("rms_looped" #name)]] [[kernel]] void \
|
||||
rms_looped<itype>( \
|
||||
const device itype* x, \
|
||||
const device itype* w, \
|
||||
device itype* out, \
|
||||
constant float& eps, \
|
||||
constant uint& axis_size, \
|
||||
constant uint& w_stride, \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint lsize [[threads_per_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]); \
|
||||
\
|
||||
template [[host_name("vjp_rms_looped" #name)]] [[kernel]] void \
|
||||
vjp_rms_looped<itype>( \
|
||||
const device itype* x, \
|
||||
const device itype* w, \
|
||||
const device itype* g, \
|
||||
device itype* gx, \
|
||||
device itype* gw, \
|
||||
constant float& eps, \
|
||||
constant uint& axis_size, \
|
||||
constant uint& w_stride, \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint lsize [[threads_per_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
#define instantiate_rms(name, itype) \
|
||||
instantiate_rms_single_row(name, itype) \
|
||||
instantiate_rms_looped(name, itype)
|
||||
#define instantiate_rms(name, itype) \
|
||||
instantiate_kernel("rms" #name, rms_single_row, itype) \
|
||||
instantiate_kernel("vjp_rms" #name, vjp_rms_single_row, itype) \
|
||||
instantiate_kernel("rms_looped" #name, rms_looped, itype) \
|
||||
instantiate_kernel("vjp_rms_looped" #name, vjp_rms_looped, itype)
|
||||
|
||||
instantiate_rms(float32, float)
|
||||
instantiate_rms(float16, half)
|
||||
|
@@ -1,11 +1,11 @@
|
||||
#include <metal_stdlib>
|
||||
|
||||
#include "mlx/backend/metal/kernels/sdpa_vector.h"
|
||||
// clang-format off
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
#include "mlx/backend/metal/kernels/sdpa_vector.h"
|
||||
|
||||
using namespace metal;
|
||||
|
||||
// clang-format off
|
||||
// SDPA vector instantiations
|
||||
#define instantiate_sdpa_vector_aggregation(type, value_dim) \
|
||||
instantiate_kernel( \
|
||||
@@ -32,9 +32,11 @@ using namespace metal;
|
||||
instantiate_sdpa_vector(type, 64, 64) \
|
||||
instantiate_sdpa_vector(type, 96, 96) \
|
||||
instantiate_sdpa_vector(type, 128, 128) \
|
||||
instantiate_sdpa_vector(type, 256, 256) \
|
||||
instantiate_sdpa_vector_aggregation(type, 64) \
|
||||
instantiate_sdpa_vector_aggregation(type, 96) \
|
||||
instantiate_sdpa_vector_aggregation(type, 128)
|
||||
instantiate_sdpa_vector_aggregation(type, 128) \
|
||||
instantiate_sdpa_vector_aggregation(type, 256)
|
||||
|
||||
instantiate_sdpa_vector_heads(float)
|
||||
instantiate_sdpa_vector_heads(bfloat16_t)
|
||||
|
@@ -2,6 +2,8 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/metal/kernels/binary_ops.h"
|
||||
|
||||
#define DEFINE_SIMD_SCAN() \
|
||||
template <typename T, metal::enable_if_t<sizeof(T) < 8, bool> = true> \
|
||||
T simd_scan(T val) { \
|
||||
@@ -139,6 +141,29 @@ struct CumMin {
|
||||
}
|
||||
};
|
||||
|
||||
template <typename U>
|
||||
struct CumLogaddexp {
|
||||
static constexpr constant U init = Limits<U>::min;
|
||||
|
||||
template <typename T>
|
||||
U operator()(U a, T b) {
|
||||
return LogAddExp{}(a, static_cast<U>(b));
|
||||
}
|
||||
|
||||
U simd_scan(U x) {
|
||||
for (int i = 1; i <= 16; i *= 2) {
|
||||
U other = simd_shuffle_and_fill_up(x, init, i);
|
||||
x = LogAddExp{}(x, other);
|
||||
}
|
||||
return x;
|
||||
}
|
||||
|
||||
U simd_exclusive_scan(U x) {
|
||||
x = simd_scan(x);
|
||||
return simd_shuffle_and_fill_up(x, init, 1);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename U, int N_READS, bool reverse>
|
||||
inline void load_unsafe(U values[N_READS], const device T* input) {
|
||||
if (reverse) {
|
||||
|
@@ -101,4 +101,7 @@ instantiate_scan_helper(min_int64_int64, int64_t, int64_t, CumMi
|
||||
instantiate_scan_helper(min_float16_float16, half, half, CumMin, 4)
|
||||
instantiate_scan_helper(min_float32_float32, float, float, CumMin, 4)
|
||||
instantiate_scan_helper(min_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumMin, 4)
|
||||
instantiate_scan_helper(min_complex64_complex64, complex64_t, complex64_t, CumMin, 2) // clang-format on
|
||||
instantiate_scan_helper(min_complex64_complex64, complex64_t, complex64_t, CumMin, 2)
|
||||
instantiate_scan_helper(logaddexp_float16_float16, half, half, CumLogaddexp, 4)
|
||||
instantiate_scan_helper(logaddexp_float32_float32, float, float, CumLogaddexp, 4)
|
||||
instantiate_scan_helper(logaddexp_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumLogaddexp, 4) // clang-format on
|
||||
|
@@ -6,6 +6,9 @@ using namespace metal;
|
||||
|
||||
constant bool has_mask [[function_constant(20)]];
|
||||
constant bool query_transposed [[function_constant(21)]];
|
||||
constant bool do_causal [[function_constant(22)]];
|
||||
constant bool bool_mask [[function_constant(23)]];
|
||||
constant bool float_mask [[function_constant(24)]];
|
||||
|
||||
template <typename T, int D, int V = D>
|
||||
[[kernel]] void sdpa_vector(
|
||||
@@ -13,17 +16,21 @@ template <typename T, int D, int V = D>
|
||||
const device T* keys [[buffer(1)]],
|
||||
const device T* values [[buffer(2)]],
|
||||
device T* out [[buffer(3)]],
|
||||
const constant int& gqa_factor,
|
||||
const constant int& N,
|
||||
const constant size_t& k_head_stride,
|
||||
const constant size_t& k_seq_stride,
|
||||
const constant size_t& v_head_stride,
|
||||
const constant size_t& v_seq_stride,
|
||||
const constant float& scale,
|
||||
const device bool* mask [[function_constant(has_mask)]],
|
||||
const constant int& mask_kv_seq_stride [[function_constant(has_mask)]],
|
||||
const constant int& mask_q_seq_stride [[function_constant(has_mask)]],
|
||||
const constant int& mask_head_stride [[function_constant(has_mask)]],
|
||||
const constant int& gqa_factor [[buffer(4)]],
|
||||
const constant int& N [[buffer(5)]],
|
||||
const constant size_t& k_head_stride [[buffer(6)]],
|
||||
const constant size_t& k_seq_stride [[buffer(7)]],
|
||||
const constant size_t& v_head_stride [[buffer(8)]],
|
||||
const constant size_t& v_seq_stride [[buffer(9)]],
|
||||
const constant float& scale [[buffer(10)]],
|
||||
const device bool* bmask [[buffer(11), function_constant(bool_mask)]],
|
||||
const device T* fmask [[buffer(12), function_constant(float_mask)]],
|
||||
const constant int& mask_kv_seq_stride
|
||||
[[buffer(13), function_constant(has_mask)]],
|
||||
const constant int& mask_q_seq_stride
|
||||
[[buffer(14), function_constant(has_mask)]],
|
||||
const constant int& mask_head_stride
|
||||
[[buffer(15), function_constant(has_mask)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 tpg [[threadgroups_per_grid]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
@@ -57,8 +64,12 @@ template <typename T, int D, int V = D>
|
||||
simd_lid * qk_per_thread;
|
||||
values += kv_head_idx * v_head_stride + simd_gid * v_seq_stride +
|
||||
simd_lid * v_per_thread;
|
||||
if (has_mask) {
|
||||
mask += head_idx * mask_head_stride + simd_gid * mask_kv_seq_stride +
|
||||
if (bool_mask) {
|
||||
bmask += head_idx * mask_head_stride + simd_gid * mask_kv_seq_stride +
|
||||
q_seq_idx * mask_q_seq_stride;
|
||||
}
|
||||
if (float_mask) {
|
||||
fmask += head_idx * mask_head_stride + simd_gid * mask_kv_seq_stride +
|
||||
q_seq_idx * mask_q_seq_stride;
|
||||
}
|
||||
|
||||
@@ -77,7 +88,13 @@ template <typename T, int D, int V = D>
|
||||
|
||||
// For each key
|
||||
for (int i = simd_gid; i < N; i += BN) {
|
||||
if (!has_mask || mask[0]) {
|
||||
bool use_key = true;
|
||||
if (do_causal) {
|
||||
use_key = i <= (N - int(tpg.y) + int(q_seq_idx));
|
||||
} else if (bool_mask) {
|
||||
use_key = bmask[0];
|
||||
}
|
||||
if (use_key) {
|
||||
// Read the key
|
||||
for (int j = 0; j < qk_per_thread; j++) {
|
||||
k[j] = keys[j];
|
||||
@@ -89,6 +106,9 @@ template <typename T, int D, int V = D>
|
||||
score += q[j] * k[j];
|
||||
}
|
||||
score = simd_sum(score);
|
||||
if (float_mask) {
|
||||
score += max(Limits<U>::finite_min, static_cast<U>(fmask[0]));
|
||||
}
|
||||
|
||||
// Update the accumulators
|
||||
U new_max = max(max_score, score);
|
||||
@@ -107,8 +127,11 @@ template <typename T, int D, int V = D>
|
||||
// Move the pointers to the next kv
|
||||
keys += inner_k_stride;
|
||||
values += inner_v_stride;
|
||||
if (has_mask) {
|
||||
mask += BN * mask_kv_seq_stride;
|
||||
if (bool_mask) {
|
||||
bmask += BN * mask_kv_seq_stride;
|
||||
}
|
||||
if (float_mask) {
|
||||
fmask += BN * mask_kv_seq_stride;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -149,17 +172,21 @@ template <typename T, int D, int V = D>
|
||||
device float* out [[buffer(3)]],
|
||||
device float* sums [[buffer(4)]],
|
||||
device float* maxs [[buffer(5)]],
|
||||
const constant int& gqa_factor,
|
||||
const constant int& N,
|
||||
const constant size_t& k_head_stride,
|
||||
const constant size_t& k_seq_stride,
|
||||
const constant size_t& v_head_stride,
|
||||
const constant size_t& v_seq_stride,
|
||||
const constant float& scale,
|
||||
const device bool* mask [[function_constant(has_mask)]],
|
||||
const constant int& mask_kv_seq_stride [[function_constant(has_mask)]],
|
||||
const constant int& mask_q_seq_stride [[function_constant(has_mask)]],
|
||||
const constant int& mask_head_stride [[function_constant(has_mask)]],
|
||||
const constant int& gqa_factor [[buffer(6)]],
|
||||
const constant int& N [[buffer(7)]],
|
||||
const constant size_t& k_head_stride [[buffer(8)]],
|
||||
const constant size_t& k_seq_stride [[buffer(9)]],
|
||||
const constant size_t& v_head_stride [[buffer(10)]],
|
||||
const constant size_t& v_seq_stride [[buffer(11)]],
|
||||
const constant float& scale [[buffer(12)]],
|
||||
const device bool* bmask [[buffer(13), function_constant(bool_mask)]],
|
||||
const device T* fmask [[buffer(14), function_constant(float_mask)]],
|
||||
const constant int& mask_kv_seq_stride
|
||||
[[buffer(15), function_constant(has_mask)]],
|
||||
const constant int& mask_q_seq_stride
|
||||
[[buffer(16), function_constant(has_mask)]],
|
||||
const constant int& mask_head_stride
|
||||
[[buffer(17), function_constant(has_mask)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 tpg [[threadgroups_per_grid]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
@@ -197,8 +224,13 @@ template <typename T, int D, int V = D>
|
||||
values += kv_head_idx * v_head_stride +
|
||||
(block_idx * BN + simd_gid) * v_seq_stride + simd_lid * v_per_thread;
|
||||
out += o_offset * blocks * V + block_idx * V + simd_lid * v_per_thread;
|
||||
if (has_mask) {
|
||||
mask += head_idx * mask_head_stride +
|
||||
if (bool_mask) {
|
||||
bmask += head_idx * mask_head_stride +
|
||||
(block_idx * BN + simd_gid) * mask_kv_seq_stride +
|
||||
q_seq_idx * mask_q_seq_stride;
|
||||
}
|
||||
if (float_mask) {
|
||||
fmask += head_idx * mask_head_stride +
|
||||
(block_idx * BN + simd_gid) * mask_kv_seq_stride +
|
||||
q_seq_idx * mask_q_seq_stride;
|
||||
}
|
||||
@@ -218,7 +250,13 @@ template <typename T, int D, int V = D>
|
||||
|
||||
// For each key
|
||||
for (int i = block_idx * BN + simd_gid; i < N; i += blocks * BN) {
|
||||
if (!has_mask || mask[0]) {
|
||||
bool use_key = true;
|
||||
if (do_causal) {
|
||||
use_key = i <= (N - int(tpg.y) + int(q_seq_idx));
|
||||
} else if (bool_mask) {
|
||||
use_key = bmask[0];
|
||||
}
|
||||
if (use_key) {
|
||||
// Read the key
|
||||
for (int i = 0; i < qk_per_thread; i++) {
|
||||
k[i] = keys[i];
|
||||
@@ -230,6 +268,9 @@ template <typename T, int D, int V = D>
|
||||
score += q[i] * k[i];
|
||||
}
|
||||
score = simd_sum(score);
|
||||
if (float_mask) {
|
||||
score += fmask[0];
|
||||
}
|
||||
|
||||
// Update the accumulators
|
||||
U new_max = max(max_score, score);
|
||||
@@ -248,8 +289,11 @@ template <typename T, int D, int V = D>
|
||||
// Move the pointers to the next kv
|
||||
keys += blocks * inner_k_stride;
|
||||
values += blocks * inner_v_stride;
|
||||
if (has_mask) {
|
||||
mask += BN * blocks * mask_kv_seq_stride;
|
||||
if (bool_mask) {
|
||||
bmask += BN * blocks * mask_kv_seq_stride;
|
||||
}
|
||||
if (float_mask) {
|
||||
fmask += BN * blocks * mask_kv_seq_stride;
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -9,47 +9,13 @@ using namespace metal;
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
#include "mlx/backend/metal/kernels/softmax.h"
|
||||
|
||||
#define instantiate_softmax(name, itype) \
|
||||
template [[host_name("block_softmax_" #name)]] [[kernel]] void \
|
||||
softmax_single_row<itype>( \
|
||||
const device itype* in, \
|
||||
device itype* out, \
|
||||
constant int& axis_size, \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint _lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]); \
|
||||
template [[host_name("looped_softmax_" #name)]] [[kernel]] void \
|
||||
softmax_looped<itype>( \
|
||||
const device itype* in, \
|
||||
device itype* out, \
|
||||
constant int& axis_size, \
|
||||
uint gid [[threadgroup_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint lsize [[threads_per_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
#define instantiate_softmax(name, itype) \
|
||||
instantiate_kernel("block_softmax_" #name, softmax_single_row, itype) \
|
||||
instantiate_kernel("looped_softmax_" #name, softmax_looped, itype)
|
||||
|
||||
#define instantiate_softmax_precise(name, itype) \
|
||||
template [[host_name("block_softmax_precise_" #name)]] [[kernel]] void \
|
||||
softmax_single_row<itype, float>( \
|
||||
const device itype* in, \
|
||||
device itype* out, \
|
||||
constant int& axis_size, \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint _lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]); \
|
||||
template [[host_name("looped_softmax_precise_" #name)]] [[kernel]] void \
|
||||
softmax_looped<itype, float>( \
|
||||
const device itype* in, \
|
||||
device itype* out, \
|
||||
constant int& axis_size, \
|
||||
uint gid [[threadgroup_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint lsize [[threads_per_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
#define instantiate_softmax_precise(name, itype) \
|
||||
instantiate_kernel("block_softmax_precise_" #name, softmax_single_row, itype, float) \
|
||||
instantiate_kernel("looped_softmax_precise_" #name, softmax_looped, itype, float)
|
||||
|
||||
instantiate_softmax(float32, float)
|
||||
instantiate_softmax(float16, half)
|
||||
|
@@ -229,7 +229,7 @@ template <
|
||||
// Init to -Inf
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < kRowsPT; ++i) {
|
||||
max_score[i] = Limits<AccumType>::min;
|
||||
max_score[i] = Limits<AccumType>::finite_min;
|
||||
}
|
||||
|
||||
int kb_lim = params->NK;
|
||||
@@ -237,6 +237,7 @@ template <
|
||||
if (do_causal) {
|
||||
int q_max = (tid.x + 1) * BQ + params->qL_off;
|
||||
kb_lim = (q_max + BK - 1) / BK;
|
||||
kb_lim = min(params->NK, kb_lim);
|
||||
}
|
||||
|
||||
// Loop over KV seq length
|
||||
@@ -272,7 +273,7 @@ template <
|
||||
if (!align_K && kb == (params->NK_aligned)) {
|
||||
using stile_t = decltype(Stile);
|
||||
using selem_t = typename stile_t::elem_type;
|
||||
constexpr auto neg_inf = -metal::numeric_limits<selem_t>::infinity();
|
||||
constexpr auto neg_inf = Limits<selem_t>::finite_min;
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < stile_t::kTileRows; i++) {
|
||||
@@ -290,10 +291,10 @@ template <
|
||||
}
|
||||
|
||||
// Mask out if causal
|
||||
if (do_causal && kb >= (kb_lim - (BQ + BK - 1) / BK - int(!align_K))) {
|
||||
if (do_causal && kb >= (kb_lim - ((BQ + BK - 1) / BK) - int(!align_K))) {
|
||||
using stile_t = decltype(Stile);
|
||||
using selem_t = typename stile_t::elem_type;
|
||||
constexpr auto neg_inf = -metal::numeric_limits<selem_t>::infinity();
|
||||
constexpr auto neg_inf = Limits<selem_t>::finite_min;
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < stile_t::kTileRows; i++) {
|
||||
@@ -316,7 +317,7 @@ template <
|
||||
if (has_mask) {
|
||||
using stile_t = decltype(Stile);
|
||||
using selem_t = typename stile_t::elem_type;
|
||||
constexpr auto neg_inf = -metal::numeric_limits<selem_t>::infinity();
|
||||
constexpr auto neg_inf = Limits<selem_t>::finite_min;
|
||||
|
||||
constexpr bool is_bool = is_same_v<MaskType, bool>;
|
||||
using melem_t = typename metal::conditional_t<is_bool, bool, selem_t>;
|
||||
|
@@ -73,6 +73,9 @@ instantiate_unary_all_same(Conjugate, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Cos, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Cosh, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Exp, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Log, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Log2, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Log10, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Negative, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Sign, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Sin, complex64, complex64_t)
|
||||
|
@@ -257,6 +257,13 @@ struct Log {
|
||||
T operator()(T x) {
|
||||
return metal::precise::log(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x) {
|
||||
auto r = metal::precise::log(Abs{}(x).real);
|
||||
auto i = metal::precise::atan2(x.imag, x.real);
|
||||
return {r, i};
|
||||
};
|
||||
};
|
||||
|
||||
struct Log2 {
|
||||
@@ -264,6 +271,12 @@ struct Log2 {
|
||||
T operator()(T x) {
|
||||
return metal::precise::log2(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x) {
|
||||
auto y = Log{}(x);
|
||||
return {y.real / M_LN2_F, y.imag / M_LN2_F};
|
||||
};
|
||||
};
|
||||
|
||||
struct Log10 {
|
||||
@@ -271,6 +284,12 @@ struct Log10 {
|
||||
T operator()(T x) {
|
||||
return metal::precise::log10(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x) {
|
||||
auto y = Log{}(x);
|
||||
return {y.real / M_LN10_F, y.imag / M_LN10_F};
|
||||
};
|
||||
};
|
||||
|
||||
struct Log1p {
|
||||
|
96
mlx/backend/metal/logsumexp.cpp
Normal file
96
mlx/backend/metal/logsumexp.cpp
Normal file
@@ -0,0 +1,96 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
#include <algorithm>
|
||||
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/kernels.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
constexpr int LOGSUMEXP_LOOPED_LIMIT = 4096;
|
||||
|
||||
void LogSumExp::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
if (!issubdtype(out.dtype(), floating)) {
|
||||
throw std::runtime_error(
|
||||
"[logsumexp] Does not support non-floating point types.");
|
||||
}
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
// Make sure that the last dimension is contiguous
|
||||
auto ensure_contiguous = [&s, &d](const array& x) {
|
||||
if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
|
||||
return x;
|
||||
} else {
|
||||
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
|
||||
copy_gpu(x, x_copy, CopyType::General, s);
|
||||
d.add_temporary(x_copy, s.index);
|
||||
return x_copy;
|
||||
}
|
||||
};
|
||||
|
||||
auto in = ensure_contiguous(inputs[0]);
|
||||
if (in.flags().row_contiguous) {
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
} else {
|
||||
auto n = in.shape(-1);
|
||||
auto flags = in.flags();
|
||||
auto strides = in.strides();
|
||||
for (auto& s : strides) {
|
||||
s /= n;
|
||||
}
|
||||
bool col_contig = strides[0] == 1;
|
||||
for (int i = 1; col_contig && i < strides.size(); ++i) {
|
||||
col_contig &=
|
||||
(out.shape(i) == 1 || strides[i - 1] == out.shape(i) * strides[i]);
|
||||
}
|
||||
flags.col_contiguous = col_contig;
|
||||
out.set_data(
|
||||
allocator::malloc(in.nbytes() / n),
|
||||
in.data_size() / n,
|
||||
std::move(strides),
|
||||
flags);
|
||||
}
|
||||
|
||||
int axis_size = in.shape().back();
|
||||
int n_rows = in.data_size() / axis_size;
|
||||
|
||||
const int simd_size = 32;
|
||||
const int n_reads = 4;
|
||||
const int looped_limit = LOGSUMEXP_LOOPED_LIMIT;
|
||||
|
||||
std::string kernel_name = (axis_size > looped_limit) ? "looped_" : "block_";
|
||||
kernel_name += "logsumexp_";
|
||||
kernel_name += type_to_name(out);
|
||||
|
||||
auto kernel = get_logsumexp_kernel(d, kernel_name, out);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
{
|
||||
MTL::Size grid_dims, group_dims;
|
||||
if (axis_size <= looped_limit) {
|
||||
size_t threadgroup_needed = (axis_size + n_reads - 1) / n_reads;
|
||||
size_t simds_needed = (threadgroup_needed + simd_size - 1) / simd_size;
|
||||
size_t threadgroup_size = simd_size * simds_needed;
|
||||
assert(threadgroup_size <= kernel->maxTotalThreadsPerThreadgroup());
|
||||
size_t n_threads = n_rows * threadgroup_size;
|
||||
grid_dims = MTL::Size(n_threads, 1, 1);
|
||||
group_dims = MTL::Size(threadgroup_size, 1, 1);
|
||||
} else {
|
||||
size_t threadgroup_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
size_t n_threads = n_rows * threadgroup_size;
|
||||
grid_dims = MTL::Size(n_threads, 1, 1);
|
||||
group_dims = MTL::Size(threadgroup_size, 1, 1);
|
||||
}
|
||||
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
compute_encoder.set_bytes(axis_size, 2);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@@ -382,7 +382,7 @@ void steel_matmul(
|
||||
int split_k_partition_size = gemm_k_iterations * bk;
|
||||
|
||||
array C_split({split_k_partitions, M, N}, float32, nullptr, {});
|
||||
C_split.set_data(allocator::malloc_or_wait(C_split.nbytes()));
|
||||
C_split.set_data(allocator::malloc(C_split.nbytes()));
|
||||
copies.push_back(C_split);
|
||||
|
||||
bool mn_aligned = M % bm == 0 && N % bn == 0;
|
||||
@@ -513,7 +513,7 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
return;
|
||||
}
|
||||
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Init checks and prep
|
||||
@@ -677,7 +677,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
throw std::runtime_error(
|
||||
"[matmul] Does not yet support non-floating point types.");
|
||||
}
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
@@ -860,7 +860,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
int split_k_partition_size = gemm_k_iterations * bk;
|
||||
|
||||
array C_split({split_k_partitions, M, N}, float32, nullptr, {});
|
||||
C_split.set_data(allocator::malloc_or_wait(C_split.nbytes()));
|
||||
C_split.set_data(allocator::malloc(C_split.nbytes()));
|
||||
copies.push_back(C_split);
|
||||
|
||||
bool mn_aligned = M % bm == 0 && N % bn == 0;
|
||||
@@ -1096,7 +1096,7 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
return;
|
||||
}
|
||||
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Init checks and prep
|
||||
@@ -1484,7 +1484,7 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
return;
|
||||
}
|
||||
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Init checks and prep
|
||||
|
@@ -12,71 +12,6 @@ namespace mlx::core::metal {
|
||||
/* Check if the Metal backend is available. */
|
||||
bool is_available();
|
||||
|
||||
/* Get the actively used memory in bytes.
|
||||
*
|
||||
* Note, this will not always match memory use reported by the system because
|
||||
* it does not include cached memory buffers.
|
||||
* */
|
||||
size_t get_active_memory();
|
||||
|
||||
/* Get the peak amount of used memory in bytes.
|
||||
*
|
||||
* The maximum memory used recorded from the beginning of the program
|
||||
* execution or since the last call to reset_peak_memory.
|
||||
* */
|
||||
size_t get_peak_memory();
|
||||
|
||||
/* Reset the peak memory to zero.
|
||||
* */
|
||||
void reset_peak_memory();
|
||||
|
||||
/* Get the cache size in bytes.
|
||||
*
|
||||
* The cache includes memory not currently used that has not been returned
|
||||
* to the system allocator.
|
||||
* */
|
||||
size_t get_cache_memory();
|
||||
|
||||
/* Set the memory limit.
|
||||
* Calls to malloc will wait on scheduled tasks if the limit is exceeded. If
|
||||
* there are no more scheduled tasks an error will be raised if relaxed
|
||||
* is false or memory will be allocated (including the potential for
|
||||
* swap) if relaxed is true.
|
||||
*
|
||||
* The memory limit defaults to 1.5 times the maximum recommended working set
|
||||
* size reported by the device.
|
||||
*
|
||||
* Returns the previous memory limit.
|
||||
* */
|
||||
size_t set_memory_limit(size_t limit, bool relaxed = true);
|
||||
|
||||
/* Set the free cache limit.
|
||||
* If using more than the given limit, free memory will be reclaimed
|
||||
* from the cache on the next allocation. To disable the cache,
|
||||
* set the limit to 0.
|
||||
*
|
||||
* The cache limit defaults to the memory limit.
|
||||
*
|
||||
* Returns the previous cache limit.
|
||||
* */
|
||||
size_t set_cache_limit(size_t limit);
|
||||
|
||||
/* Clear the memory cache. */
|
||||
void clear_cache();
|
||||
|
||||
/* Set the wired size limit.
|
||||
*
|
||||
* Note, this function is only useful for macOS 15.0 or higher.
|
||||
*
|
||||
* The wired limit is the total size in bytes of memory that will be kept
|
||||
* resident. The default value is ``0``.
|
||||
*
|
||||
* Setting a wired limit larger than system wired limit is an error.
|
||||
*
|
||||
* Returns the previous wired limit.
|
||||
* */
|
||||
size_t set_wired_limit(size_t limit);
|
||||
|
||||
/** Capture a GPU trace, saving it to an absolute file `path` */
|
||||
void start_capture(std::string path = "");
|
||||
void stop_capture();
|
||||
|
@@ -72,6 +72,13 @@ MTL::ComputePipelineState* get_softmax_kernel(
|
||||
return d.get_kernel(kernel_name);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_logsumexp_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array&) {
|
||||
return d.get_kernel(kernel_name);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_scan_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
|
@@ -29,7 +29,7 @@ void RMSNorm::eval_gpu(
|
||||
out.copy_shared_buffer(x);
|
||||
} else {
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(x.data_size() * x.itemsize()),
|
||||
allocator::malloc(x.data_size() * x.itemsize()),
|
||||
x.data_size(),
|
||||
x.strides(),
|
||||
x.flags());
|
||||
@@ -129,7 +129,7 @@ void RMSNormVJP::eval_gpu(
|
||||
gx.copy_shared_buffer(g);
|
||||
g_in_gx = true;
|
||||
} else {
|
||||
gx.set_data(allocator::malloc_or_wait(gx.nbytes()));
|
||||
gx.set_data(allocator::malloc(gx.nbytes()));
|
||||
}
|
||||
if (g_copied && !g_in_gx) {
|
||||
d.add_temporary(g, s.index);
|
||||
@@ -146,11 +146,11 @@ void RMSNormVJP::eval_gpu(
|
||||
if (!g_in_gx && donate_g) {
|
||||
gw_temp.copy_shared_buffer(g);
|
||||
} else {
|
||||
gw_temp.set_data(allocator::malloc_or_wait(gw_temp.nbytes()));
|
||||
gw_temp.set_data(allocator::malloc(gw_temp.nbytes()));
|
||||
d.add_temporary(gw_temp, s.index);
|
||||
}
|
||||
}
|
||||
gw.set_data(allocator::malloc_or_wait(gw.nbytes()));
|
||||
gw.set_data(allocator::malloc(gw.nbytes()));
|
||||
|
||||
const int simd_size = 32;
|
||||
const int n_reads = RMS_N_READS;
|
||||
@@ -226,7 +226,7 @@ void LayerNorm::eval_gpu(
|
||||
out.copy_shared_buffer(x);
|
||||
} else {
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(x.data_size() * x.itemsize()),
|
||||
allocator::malloc(x.data_size() * x.itemsize()),
|
||||
x.data_size(),
|
||||
x.strides(),
|
||||
x.flags());
|
||||
@@ -331,7 +331,7 @@ void LayerNormVJP::eval_gpu(
|
||||
gx.copy_shared_buffer(g);
|
||||
g_in_gx = true;
|
||||
} else {
|
||||
gx.set_data(allocator::malloc_or_wait(gx.nbytes()));
|
||||
gx.set_data(allocator::malloc(gx.nbytes()));
|
||||
}
|
||||
if (g_copied && !g_in_gx) {
|
||||
d.add_temporary(g, s.index);
|
||||
@@ -348,12 +348,12 @@ void LayerNormVJP::eval_gpu(
|
||||
if (!g_in_gx && donate_g) {
|
||||
gw_temp.copy_shared_buffer(g);
|
||||
} else {
|
||||
gw_temp.set_data(allocator::malloc_or_wait(gw_temp.nbytes()));
|
||||
gw_temp.set_data(allocator::malloc(gw_temp.nbytes()));
|
||||
d.add_temporary(gw_temp, s.index);
|
||||
}
|
||||
}
|
||||
gw.set_data(allocator::malloc_or_wait(gw.nbytes()));
|
||||
gb.set_data(allocator::malloc_or_wait(gb.nbytes()));
|
||||
gw.set_data(allocator::malloc(gw.nbytes()));
|
||||
gb.set_data(allocator::malloc(gb.nbytes()));
|
||||
|
||||
// Finish with the gradient for b in case we had a b
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
|
@@ -16,6 +16,8 @@
|
||||
#include "mlx/scheduler.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
#include <iostream>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
template <typename T>
|
||||
@@ -28,7 +30,7 @@ void arange_set_scalars(T start, T next, metal::CommandEncoder& enc) {
|
||||
void reshape(const array& in, array& out, Stream s) {
|
||||
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
|
||||
if (copy_necessary) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
copy_gpu_inplace(
|
||||
in,
|
||||
out,
|
||||
@@ -58,7 +60,7 @@ static array compute_dynamic_offset(
|
||||
if (donate) {
|
||||
offset.copy_shared_buffer(indices);
|
||||
} else {
|
||||
offset.set_data(allocator::malloc_or_wait(offset.itemsize()));
|
||||
offset.set_data(allocator::malloc(offset.itemsize()));
|
||||
}
|
||||
d.add_temporary(offset, s.index);
|
||||
|
||||
@@ -100,7 +102,7 @@ static array compute_dynamic_offset(
|
||||
|
||||
void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 0);
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
if (out.size() == 0) {
|
||||
return;
|
||||
}
|
||||
@@ -158,33 +160,25 @@ void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
auto& s = stream();
|
||||
void arg_reduce_dispatch(
|
||||
const array& in,
|
||||
array& out,
|
||||
int axis,
|
||||
std::string op_name,
|
||||
const Stream& s) {
|
||||
auto& d = metal::device(s.device);
|
||||
std::string op_name;
|
||||
switch (reduce_type_) {
|
||||
case ArgReduce::ArgMin:
|
||||
op_name = "argmin_";
|
||||
break;
|
||||
case ArgReduce::ArgMax:
|
||||
op_name = "argmax_";
|
||||
break;
|
||||
}
|
||||
|
||||
// Prepare the shapes, strides and axis arguments.
|
||||
auto in_strides = in.strides();
|
||||
auto shape = in.shape();
|
||||
auto out_strides = out.strides();
|
||||
auto axis_stride = in_strides[axis_];
|
||||
size_t axis_size = shape[axis_];
|
||||
auto axis_stride = in_strides[axis];
|
||||
size_t axis_size = shape[axis];
|
||||
if (out_strides.size() == in_strides.size()) {
|
||||
out_strides.erase(out_strides.begin() + axis_);
|
||||
out_strides.erase(out_strides.begin() + axis);
|
||||
}
|
||||
in_strides.erase(in_strides.begin() + axis_);
|
||||
shape.erase(shape.begin() + axis_);
|
||||
in_strides.erase(in_strides.begin() + axis);
|
||||
shape.erase(shape.begin() + axis);
|
||||
size_t ndim = shape.size();
|
||||
|
||||
// ArgReduce
|
||||
@@ -192,7 +186,7 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
int n_reads = 4;
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
{
|
||||
auto kernel = d.get_kernel(op_name + type_to_name(in));
|
||||
auto kernel = d.get_kernel(op_name + "_" + type_to_name(in));
|
||||
NS::UInteger thread_group_size = std::min(
|
||||
(axis_size + n_reads - 1) / n_reads,
|
||||
kernel->maxTotalThreadsPerThreadgroup());
|
||||
@@ -226,6 +220,23 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
std::string op_name;
|
||||
switch (reduce_type_) {
|
||||
case ArgReduce::ArgMin:
|
||||
op_name = "argmin";
|
||||
break;
|
||||
case ArgReduce::ArgMax:
|
||||
op_name = "argmax";
|
||||
break;
|
||||
}
|
||||
auto& in = inputs[0];
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
auto& s = stream();
|
||||
arg_reduce_dispatch(in, out, axis_, op_name, s);
|
||||
}
|
||||
|
||||
void AsType::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
CopyType ctype =
|
||||
inputs[0].flags().contiguous ? CopyType::Vector : CopyType::General;
|
||||
@@ -251,8 +262,10 @@ void Concatenate::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
void Contiguous::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
if (in.flags().row_contiguous ||
|
||||
(allow_col_major_ && in.flags().col_contiguous)) {
|
||||
constexpr size_t extra_bytes = 16384;
|
||||
if (in.buffer_size() <= out.nbytes() + extra_bytes &&
|
||||
(in.flags().row_contiguous ||
|
||||
(allow_col_major_ && in.flags().col_contiguous))) {
|
||||
out.copy_shared_buffer(in);
|
||||
} else {
|
||||
copy_gpu(in, out, CopyType::General);
|
||||
@@ -333,7 +346,7 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
size_t elems_per_key = out.size() / num_keys;
|
||||
size_t bytes_per_key = out.itemsize() * elems_per_key;
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
if (out.size() == 0) {
|
||||
return;
|
||||
}
|
||||
@@ -397,7 +410,7 @@ void DynamicSlice::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
auto& in = inputs[0];
|
||||
auto& start = inputs[1];
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
auto s = stream();
|
||||
auto in_offset = compute_dynamic_offset(start, in.strides(), axes_, s);
|
||||
copy_gpu_inplace(
|
||||
@@ -554,7 +567,7 @@ void View::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
in, strides, in.flags(), in.data_size() * ibytes / obytes);
|
||||
} else {
|
||||
auto tmp = array(in.shape(), in.dtype(), nullptr, {});
|
||||
tmp.set_data(allocator::malloc_or_wait(tmp.nbytes()));
|
||||
tmp.set_data(allocator::malloc(tmp.nbytes()));
|
||||
copy_gpu_inplace(in, tmp, CopyType::General, stream());
|
||||
|
||||
auto flags = out.flags();
|
||||
|
@@ -7,11 +7,14 @@
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/kernels.h"
|
||||
#include "mlx/backend/metal/reduce.h"
|
||||
#include "mlx/backend/metal/slicing.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
#include <iostream>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void launch_qmm(
|
||||
@@ -31,6 +34,7 @@ void launch_qmm(
|
||||
bool gather,
|
||||
bool aligned,
|
||||
bool quad,
|
||||
const std::string& mode,
|
||||
const Stream& s) {
|
||||
auto& x_pre = inputs[0];
|
||||
auto& w_pre = inputs[1];
|
||||
@@ -54,8 +58,12 @@ void launch_qmm(
|
||||
};
|
||||
auto x = ensure_row_contiguous_last_dims(x_pre);
|
||||
auto w = ensure_row_contiguous_last_dims(w_pre);
|
||||
auto scales = ensure_row_contiguous_last_dims(scales_pre);
|
||||
auto biases = ensure_row_contiguous_last_dims(biases_pre);
|
||||
auto scales = scales_pre;
|
||||
auto biases = biases_pre;
|
||||
if (mode == "affine") {
|
||||
scales = ensure_row_contiguous_last_dims(scales_pre);
|
||||
biases = ensure_row_contiguous_last_dims(biases_pre);
|
||||
}
|
||||
|
||||
int x_batch_ndims = x.ndim() - 2;
|
||||
auto& x_shape = x.shape();
|
||||
@@ -68,6 +76,8 @@ void launch_qmm(
|
||||
|
||||
std::string aligned_n = (O % 32) == 0 ? "true" : "false";
|
||||
|
||||
bool is_trellis = (mode == "trellis");
|
||||
|
||||
std::ostringstream kname;
|
||||
auto type_string = get_type_string(x.dtype());
|
||||
kname << name << "_" << type_string << "_gs_" << group_size << "_b_" << bits;
|
||||
@@ -80,24 +90,47 @@ void launch_qmm(
|
||||
if (!gather) {
|
||||
kname << "_batch_" << batched;
|
||||
}
|
||||
if (mode == "trellis") {
|
||||
kname << "_mode_" << is_trellis;
|
||||
}
|
||||
|
||||
// Encode and dispatch kernel
|
||||
std::string template_def;
|
||||
if (quad) {
|
||||
template_def = get_template_definition(
|
||||
kname.str(), name, type_string, group_size, bits, D, batched);
|
||||
kname.str(),
|
||||
name,
|
||||
type_string,
|
||||
group_size,
|
||||
bits,
|
||||
D,
|
||||
batched,
|
||||
is_trellis);
|
||||
} else if (aligned && !gather) {
|
||||
template_def = get_template_definition(
|
||||
kname.str(), name, type_string, group_size, bits, aligned_n, batched);
|
||||
kname.str(),
|
||||
name,
|
||||
type_string,
|
||||
group_size,
|
||||
bits,
|
||||
aligned_n,
|
||||
batched,
|
||||
is_trellis);
|
||||
} else if (!gather && !aligned) {
|
||||
template_def = get_template_definition(
|
||||
kname.str(), name, type_string, group_size, bits, batched);
|
||||
kname.str(), name, type_string, group_size, bits, batched, is_trellis);
|
||||
} else if (aligned && gather) {
|
||||
template_def = get_template_definition(
|
||||
kname.str(), name, type_string, group_size, bits, aligned_n);
|
||||
kname.str(),
|
||||
name,
|
||||
type_string,
|
||||
group_size,
|
||||
bits,
|
||||
aligned_n,
|
||||
is_trellis);
|
||||
} else {
|
||||
template_def = get_template_definition(
|
||||
kname.str(), name, type_string, group_size, bits);
|
||||
kname.str(), name, type_string, group_size, bits, is_trellis);
|
||||
}
|
||||
auto& d = metal::device(s.device);
|
||||
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
|
||||
@@ -224,7 +257,7 @@ void qvm_split_k(
|
||||
auto temp_shape = out.shape();
|
||||
temp_shape.insert(temp_shape.end() - 2, split_k);
|
||||
array intermediate(temp_shape, x.dtype(), nullptr, {});
|
||||
intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes()));
|
||||
intermediate.set_data(allocator::malloc(intermediate.nbytes()));
|
||||
d.add_temporary(intermediate, s.index);
|
||||
|
||||
std::ostringstream kname;
|
||||
@@ -276,8 +309,9 @@ void qmm_op(
|
||||
int group_size,
|
||||
int bits,
|
||||
bool gather,
|
||||
const std::string& mode,
|
||||
const Stream& s) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
MTL::Size group_dims;
|
||||
MTL::Size grid_dims;
|
||||
@@ -298,23 +332,69 @@ void qmm_op(
|
||||
bool aligned = false;
|
||||
bool quad = false;
|
||||
|
||||
auto get_qmv_batch_limit = [s](int D, int O) {
|
||||
auto arch = metal::device(s.device).get_architecture();
|
||||
auto arch_size = arch.back();
|
||||
auto arch_gen = arch.substr(arch.size() - 3, 2);
|
||||
if (arch_gen == "13" || arch_gen == "14") {
|
||||
switch (arch_size) {
|
||||
case 'd':
|
||||
if (D <= 2048 && O <= 2048) {
|
||||
return 32;
|
||||
} else if (D <= 4096 && O <= 4096) {
|
||||
return 18;
|
||||
} else {
|
||||
return 12;
|
||||
}
|
||||
default:
|
||||
if (D <= 2048 && O <= 2048) {
|
||||
return 14;
|
||||
} else if (D <= 4096 && O <= 4096) {
|
||||
return 10;
|
||||
} else {
|
||||
return 6;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
switch (arch_size) {
|
||||
case 'd':
|
||||
if (D <= 2048 && O <= 2048) {
|
||||
return 32;
|
||||
} else if (D <= 4096 && O <= 4096) {
|
||||
return 18;
|
||||
} else {
|
||||
return 12;
|
||||
}
|
||||
default:
|
||||
if (D <= 2048 && O <= 2048) {
|
||||
return 18;
|
||||
} else if (D <= 4096 && O <= 4096) {
|
||||
return 12;
|
||||
} else {
|
||||
return 10;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
if (transpose) {
|
||||
if (B < 6 && (D == 128 || D == 64) && is_power_of_2(bits)) {
|
||||
auto qmv_batch_limit = get_qmv_batch_limit(D, O);
|
||||
if (B < qmv_batch_limit && (D == 128 || D == 64) && is_power_of_2(bits)) {
|
||||
name += "qmv_quad";
|
||||
constexpr int quads_per_simd = 8;
|
||||
constexpr int results_per_quadgroup = 8;
|
||||
int bo = quads_per_simd * results_per_quadgroup;
|
||||
int simdgroup_size = 32;
|
||||
group_dims = MTL::Size(simdgroup_size, 1, 1);
|
||||
grid_dims = MTL::Size((O + bo - 1) / bo, B, N);
|
||||
grid_dims = MTL::Size(B, (O + bo - 1) / bo, N);
|
||||
quad = true;
|
||||
} else if (B < 6 && O % 8 == 0 && D % 512 == 0 && D >= 512) {
|
||||
} else if (B < 10000 && O % 8 == 0 && D % 512 == 0 && D >= 512) {
|
||||
name += "qmv_fast";
|
||||
int bo = 8;
|
||||
int bd = 32;
|
||||
group_dims = MTL::Size(bd, 2, 1);
|
||||
grid_dims = MTL::Size(B, O / bo, N);
|
||||
} else if (B < 6) {
|
||||
} else if (B < qmv_batch_limit) {
|
||||
name += "qmv";
|
||||
int bo = 8;
|
||||
int bd = 32;
|
||||
@@ -374,19 +454,34 @@ void qmm_op(
|
||||
gather,
|
||||
aligned,
|
||||
quad,
|
||||
mode,
|
||||
s);
|
||||
}
|
||||
|
||||
void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 4);
|
||||
qmm_op(
|
||||
inputs, out, transpose_, group_size_, bits_, /*gather=*/false, stream());
|
||||
inputs,
|
||||
out,
|
||||
transpose_,
|
||||
group_size_,
|
||||
bits_,
|
||||
/*gather=*/false,
|
||||
mode_,
|
||||
stream());
|
||||
}
|
||||
|
||||
void GatherQMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 6);
|
||||
qmm_op(
|
||||
inputs, out, transpose_, group_size_, bits_, /*gather=*/true, stream());
|
||||
inputs,
|
||||
out,
|
||||
transpose_,
|
||||
group_size_,
|
||||
bits_,
|
||||
/*gather=*/true,
|
||||
mode_,
|
||||
stream());
|
||||
}
|
||||
|
||||
void fast::AffineQuantize::eval_gpu(
|
||||
@@ -394,7 +489,7 @@ void fast::AffineQuantize::eval_gpu(
|
||||
std::vector<array>& outputs) {
|
||||
auto& w_pre = inputs[0];
|
||||
auto& out = outputs[0];
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
@@ -425,8 +520,8 @@ void fast::AffineQuantize::eval_gpu(
|
||||
} else {
|
||||
auto& scales = outputs[1];
|
||||
auto& biases = outputs[2];
|
||||
scales.set_data(allocator::malloc_or_wait(scales.nbytes()));
|
||||
biases.set_data(allocator::malloc_or_wait(biases.nbytes()));
|
||||
scales.set_data(allocator::malloc(scales.nbytes()));
|
||||
biases.set_data(allocator::malloc(biases.nbytes()));
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
compute_encoder.set_output_array(scales, 2);
|
||||
compute_encoder.set_output_array(biases, 3);
|
||||
@@ -470,4 +565,123 @@ void fast::AffineQuantize::eval_gpu(
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
}
|
||||
|
||||
void viterbi(
|
||||
array& w,
|
||||
array& scores,
|
||||
array& pointers,
|
||||
array& start,
|
||||
array& overlap,
|
||||
bool use_overlap,
|
||||
const Stream& s) {
|
||||
int B = scores.shape(0);
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder.set_input_array(w, 0);
|
||||
compute_encoder.set_output_array(scores, 1);
|
||||
compute_encoder.set_output_array(pointers, 2);
|
||||
if (use_overlap) {
|
||||
compute_encoder.set_input_array(overlap, 3);
|
||||
}
|
||||
|
||||
std::ostringstream kname;
|
||||
auto type_string = get_type_string(w.dtype());
|
||||
kname << "trellis_viterbi_" << type_string << "_overlap_" << use_overlap;
|
||||
auto template_def = get_template_definition(
|
||||
kname.str(), "trellis_viterbi", type_string, use_overlap);
|
||||
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
auto group_dims = MTL::Size(1, 1024, 1);
|
||||
auto grid_dims = MTL::Size(B, 1024, 1);
|
||||
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
arg_reduce_dispatch(scores, start, 1, "argmin", s);
|
||||
}
|
||||
|
||||
void viterbi_backtrack(
|
||||
array& start,
|
||||
array& pointers,
|
||||
array& out,
|
||||
array& overlap,
|
||||
bool use_overlap,
|
||||
const Stream& s) {
|
||||
int B = start.shape(0);
|
||||
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder.set_input_array(start, 0);
|
||||
compute_encoder.set_input_array(pointers, 1);
|
||||
compute_encoder.set_output_array(out, 2);
|
||||
if (use_overlap) {
|
||||
compute_encoder.set_input_array(overlap, 3);
|
||||
}
|
||||
|
||||
std::ostringstream kname;
|
||||
kname << "trellis_backtrack" << "_overlap_" << use_overlap;
|
||||
auto template_def =
|
||||
get_template_definition(kname.str(), "trellis_backtrack", use_overlap);
|
||||
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
auto group_dims = MTL::Size(256, 1, 1);
|
||||
auto grid_dims = MTL::Size(B, 1, 1);
|
||||
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void fast::TrellisQuantize::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
auto& w_pre = inputs[0];
|
||||
auto& out = outputs[0];
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
std::vector<array> copies;
|
||||
auto ensure_row_contiguous = [&copies, &s](const array& arr) {
|
||||
if (arr.flags().row_contiguous) {
|
||||
return arr;
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy_gpu(arr, arr_copy, CopyType::General, s);
|
||||
copies.push_back(arr_copy);
|
||||
return arr_copy;
|
||||
}
|
||||
};
|
||||
auto w = ensure_row_contiguous(w_pre);
|
||||
|
||||
int B = w.shape(0);
|
||||
int T = w.shape(1);
|
||||
|
||||
constexpr int num_states = 1 << 14;
|
||||
|
||||
array scores({B, num_states}, float16, nullptr, {});
|
||||
scores.set_data(allocator::malloc(scores.nbytes()));
|
||||
copies.push_back(scores);
|
||||
|
||||
array pointers({B, T, num_states}, uint8, nullptr, {});
|
||||
pointers.set_data(allocator::malloc(pointers.nbytes()));
|
||||
copies.push_back(pointers);
|
||||
|
||||
array start({B}, uint32, nullptr, {});
|
||||
start.set_data(allocator::malloc(start.nbytes()));
|
||||
copies.push_back(start);
|
||||
|
||||
array rolled({B, T}, uint16, nullptr, {});
|
||||
rolled.set_data(allocator::malloc(rolled.nbytes()));
|
||||
copies.push_back(rolled);
|
||||
|
||||
viterbi(w, scores, pointers, start, out, false, s);
|
||||
viterbi_backtrack(start, pointers, rolled, out, false, s);
|
||||
|
||||
viterbi(w, scores, pointers, start, rolled, true, s);
|
||||
viterbi_backtrack(start, pointers, out, rolled, true, s);
|
||||
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -347,7 +347,7 @@ void all_reduce_dispatch(
|
||||
|
||||
// Allocate an intermediate tensor to hold results if needed
|
||||
array intermediate({n_rows}, out_type, nullptr, {});
|
||||
intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes()));
|
||||
intermediate.set_data(allocator::malloc(intermediate.nbytes()));
|
||||
d.add_temporary(intermediate, s.index);
|
||||
|
||||
// 1st pass
|
||||
@@ -641,7 +641,7 @@ void strided_reduce_longcolumn(
|
||||
intermediate_shape.insert(
|
||||
intermediate_shape.end(), out.shape().begin(), out.shape().end());
|
||||
array intermediate(std::move(intermediate_shape), out_type, nullptr, {});
|
||||
intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes()));
|
||||
intermediate.set_data(allocator::malloc(intermediate.nbytes()));
|
||||
d.add_temporary(intermediate, s.index);
|
||||
|
||||
// Prepare the arguments for the kernel
|
||||
@@ -812,7 +812,7 @@ void strided_reduce_2pass(
|
||||
intermediate_shape.insert(
|
||||
intermediate_shape.end(), out.shape().begin(), out.shape().end());
|
||||
array intermediate(std::move(intermediate_shape), out_type, nullptr, {});
|
||||
intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes()));
|
||||
intermediate.set_data(allocator::malloc(intermediate.nbytes()));
|
||||
d.add_temporary(intermediate, s.index);
|
||||
|
||||
// Prepare the arguments for the kernel
|
||||
@@ -950,7 +950,7 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// Minimum of 4 bytes since we use size 4 structs for all reduce
|
||||
// and metal will complain o/w
|
||||
size_t min_bytes = std::max(out.nbytes(), 4ul);
|
||||
out.set_data(allocator::malloc_or_wait(min_bytes));
|
||||
out.set_data(allocator::malloc(min_bytes));
|
||||
std::string op_name;
|
||||
switch (reduce_type_) {
|
||||
case Reduce::And:
|
||||
|
@@ -38,4 +38,11 @@ void strided_reduce_general_dispatch(
|
||||
metal::Device& d,
|
||||
const Stream& s);
|
||||
|
||||
void arg_reduce_dispatch(
|
||||
const array& in,
|
||||
array& out,
|
||||
int axis,
|
||||
std::string op_name,
|
||||
const Stream& s);
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -22,6 +22,7 @@ ResidencySet::ResidencySet(MTL::Device* d) {
|
||||
}
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
wired_set_->requestResidency();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -32,7 +33,6 @@ void ResidencySet::insert(MTL::Allocation* buf) {
|
||||
if (wired_set_->allocatedSize() + buf->allocatedSize() <= capacity_) {
|
||||
wired_set_->addAllocation(buf);
|
||||
wired_set_->commit();
|
||||
wired_set_->requestResidency();
|
||||
} else {
|
||||
unwired_set_.insert(buf);
|
||||
}
|
||||
@@ -76,7 +76,6 @@ void ResidencySet::resize(size_t size) {
|
||||
}
|
||||
}
|
||||
wired_set_->commit();
|
||||
wired_set_->requestResidency();
|
||||
} else if (current_size > size) {
|
||||
auto pool = new_scoped_memory_pool();
|
||||
// Remove wired allocations until under capacity
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user