Compare commits

...

41 Commits

Author SHA1 Message Date
Angelos Katharopoulos
1600092e92 Patch bump (#1376) 2024-08-29 16:54:30 -07:00
Awni Hannun
dba2bd1105 Even Even Faster IO (#1374)
* even more faster io

* make reader pool static

* make python reader thread safe

* one more optimization
2024-08-29 16:05:40 -07:00
Alex Barron
28be4de7c2 Fix JIT reductions (#1373) 2024-08-28 16:39:11 -07:00
Awni Hannun
a6c3b38fba Async load (#1372)
* async load

* async load
2024-08-28 14:21:55 -07:00
Awni Hannun
fcb65a3897 Even Faster I/O (#1369)
* try multithreading for faster IO

* smaller batch size

* Account for pread returning less than size

* nit

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-08-28 11:49:07 -07:00
Saanidhya
4e22a1dffe In continuation to PR1243 to solve issue #1240 (#1365)
* Solves issue #1240

* Correction

* Update python/mlx/utils.py

* Update python/mlx/utils.py

---------

Co-authored-by: Awni Hannun <awni@apple.com>
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2024-08-28 11:40:41 -07:00
Awni Hannun
291cf40aca Some fixes to typing (#1371)
* some fixes to typing

* fix module reference

* comment
2024-08-28 11:16:19 -07:00
Jeethu Rao
bd47e1f066 Fix neon_fast_exp and add more softmax tests (#1367) 2024-08-27 23:42:42 -07:00
Aditya Dhulipala
e6b223df5f Pinv (#875) 2024-08-27 23:06:12 -07:00
Angelos Katharopoulos
e64349bbdd Make eval just wait if all arrays are scheduled (#1368) 2024-08-27 17:01:22 -07:00
Angelos Katharopoulos
cdb59faea6 Adds send/recv ops in distributed (#1366) 2024-08-26 23:01:37 -07:00
Alex Barron
1d94ac3f90 Add optional headers to `mx.fast.metal_kernel` (#1358) 2024-08-26 21:45:45 -07:00
Awni Hannun
5f7d19d1f5 MPI ops in GPU stream for faster comms (#1356) 2024-08-26 15:12:50 -07:00
Awni Hannun
2fdf9eb535 Fix ternary for large arrays (#1359)
* fix ternary for large arrays

* fix
2024-08-26 11:22:27 -07:00
Awni Hannun
860d3a50d7 fix extension metal library finding (#1361) 2024-08-26 09:18:50 -07:00
Alex Barron
d1183821a7 int() and float() for mx.array (#1360) 2024-08-25 20:41:44 -07:00
Angelos Katharopoulos
8081df79be Fix boolean all reduce bug (#1355) 2024-08-24 10:09:32 -07:00
Nripesh Niketan
64bec4fad7 Chore: update pre-commit hooks (#1353)
* Chore: update pre-commit refs

* run pre-commit
2024-08-24 06:46:36 -07:00
Alex Barron
b96e105244 Add grid_sample example to metal_kernel docs (#1352)
* Add `zero_outputs` and `atomic_outputs` options to `metal_kernel`

* add grid sample to docs

* zero_outputs -> init_value

* add missing header for linux
2024-08-23 18:24:16 -07:00
Awni Hannun
3b4d5484c7 Bump extension MLX version (#1350)
* Bump extension MLX version

* fix some docs nits
2024-08-23 12:38:34 -07:00
Alex Barron
684e11c664 patch (#1347) 2024-08-23 10:42:02 -07:00
Angelos Katharopoulos
b57a52813b Further reduction tuning (#1349)
* More reduction tuning
* Forgotten pdb
* Small column long row specialization
2024-08-23 10:35:25 -07:00
Alex Barron
da8deb2b62 fix bug with multiple attributes (#1348)
Co-authored-by: Alex Barron <abarron22@apple.com>
2024-08-23 10:06:15 -07:00
Awni Hannun
98b6ce3460 Refactor reductions and fix scatter atomics for large sizes (#1300)
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-08-22 16:03:31 -07:00
Awni Hannun
f9e00efe31 fix nanobind and stub gen in circle (#1346) 2024-08-22 14:07:27 -07:00
Alex Barron
0fd2a1f4b0 Custom Metal Kernels from Python (#1325)
* start

* simple kernels working

* restructure

* inverse example working

* docs + fixes

* missing file

* fix imports

* address comments

* add docs + fix test

* Review comments + refactor to a single function

* update docs

* remove hashing

* fix contig bug in test

* back to a class

* trailing whitespace

* fix tests

* match c++ and python apis

* add link + make args kw_only
2024-08-22 13:46:29 -07:00
Awni Hannun
df3233454d 2d gather specialization (#1339) 2024-08-22 10:48:24 -07:00
Awni Hannun
82db84b899 bump nanobind + fix extension (#1344) 2024-08-21 16:05:07 -07:00
Awni Hannun
8ae751d3da fix io (#1343)
* fix io

* fix io

* comment
2024-08-21 13:14:46 -07:00
Awni Hannun
d40e76809f Fix rope (#1340)
* add test

* fix rope

* fix test
2024-08-20 17:37:52 -07:00
Awni Hannun
bb1b76d9dc RoPE with frequencies as optional input (#1337)
* start rope with freq input

* rope with frequencies

* nits

* fix bug

* fix bug + test

* cleanup

* optional base
2024-08-19 18:30:50 -07:00
Angelos Katharopoulos
9d26441224 Fix contiguity check (#1336)
Co-authored-by: Alex Barron <abarron22@apple.com>
2024-08-19 16:05:06 -07:00
Awni Hannun
f12f24a77c fix compiling with space in paths (#1332) 2024-08-15 16:39:24 -07:00
Awni Hannun
ae5b5cabfd Fix optimizer reloading from checkpoint (#1329)
* fix optimizer reloading from checkpoint

* comment
2024-08-15 07:33:23 -07:00
Awni Hannun
d0630ffe8c Read arrays from files faster (#1330)
* read faster

* faster write as well

* set default permission for linux

* comment
2024-08-14 20:09:56 -07:00
Alex Barron
99bb7d3a58 GPU mx.sign for complex64 (#1326) 2024-08-14 07:54:53 -07:00
Awni Hannun
63ae767232 fix transformer (#1327) 2024-08-13 16:04:26 -07:00
Awni Hannun
eaaea02010 Add isfinite (#1318)
* isfinite

* remove reduce test since fix is not complete
2024-08-13 14:49:28 -07:00
Bhargav Yagnik
a098bc92e0 Fix: Preserve input dtype in Dropout layer output (#1323)
* Fix: Preserve input dtype in Dropout layer output

- Modified Dropout implementation to ensure that the output dtype matches the input dtype.
- This resolves the issue #1321

* Update test cases in test_nn.py

- Revised test cases to align with updated dropout code
- Fixed assertion method: replaced self.assertTrue with self.assertEqual for accurate comparisons in test_nn.py -> test_rope, test_alibi and test_dropout,

* updated dropout.py
2024-08-13 11:54:21 -07:00
Awni Hannun
1086dc4db0 patch (#1320) 2024-08-12 16:13:33 -07:00
Brian Keene
19fb69e2ed Add memory_efficient_threshold kwarg to sdpa kernel (#1319)
Allows opt-in to memory efficient GPU shader at proscribed sequence
length.  Otherwise, utilizes aggregate MLX primitives for best latency.
2024-08-12 12:57:09 -07:00
116 changed files with 4742 additions and 1901 deletions

View File

@@ -31,7 +31,7 @@ jobs:
name: Install dependencies
command: |
pip install --upgrade cmake
pip install git+https://github.com/wjakob/nanobind.git@2f04eac452a6d9142dedb957701bdb20125561e4
pip install nanobind==2.1.0
pip install numpy
sudo apt-get update
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
@@ -44,6 +44,7 @@ jobs:
name: Generate package stubs
command: |
echo "stubs"
pip install typing_extensions
python setup.py generate_stubs
- run:
name: Run Python tests
@@ -76,7 +77,7 @@ jobs:
source env/bin/activate
pip install --upgrade pip
pip install --upgrade cmake
pip install git+https://github.com/wjakob/nanobind.git@2f04eac452a6d9142dedb957701bdb20125561e4
pip install nanobind==2.1.0
pip install numpy
pip install torch
pip install tensorflow
@@ -90,6 +91,7 @@ jobs:
name: Generate package stubs
command: |
source env/bin/activate
pip install typing_extensions
python setup.py generate_stubs
- run:
name: Run Python tests
@@ -123,6 +125,12 @@ jobs:
cd build/
cmake .. -DCMAKE_BUILD_TYPE=MinSizeRel -DBUILD_SHARED_LIBS=ON -DMLX_BUILD_CPU=OFF -DMLX_BUILD_SAFETENSORS=OFF -DMLX_BUILD_GGUF=OFF -DMLX_METAL_JIT=ON
make -j
- run:
name: Run Python tests with JIT
command: |
source env/bin/activate
CMAKE_BUILD_PARALLEL_LEVEL="" CMAKE_ARGS="-DMLX_METAL_JIT=ON" pip install -e . -v
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu_jit
build_release:
parameters:
@@ -149,7 +157,7 @@ jobs:
source env/bin/activate
pip install --upgrade pip
pip install --upgrade cmake
pip install git+https://github.com/wjakob/nanobind.git@2f04eac452a6d9142dedb957701bdb20125561e4
pip install nanobind==2.1.0
pip install --upgrade setuptools
pip install numpy
pip install twine
@@ -165,6 +173,7 @@ jobs:
name: Generate package stubs
command: |
source env/bin/activate
pip install typing_extensions
python setup.py generate_stubs
- run:
name: Build Python package
@@ -213,7 +222,7 @@ jobs:
source env/bin/activate
pip install --upgrade pip
pip install --upgrade cmake
pip install git+https://github.com/wjakob/nanobind.git@2f04eac452a6d9142dedb957701bdb20125561e4
pip install nanobind==2.1.0
pip install --upgrade setuptools
pip install numpy
pip install auditwheel
@@ -222,6 +231,7 @@ jobs:
<< parameters.extra_env >> \
CMAKE_BUILD_PARALLEL_LEVEL="" \
pip install . -v
pip install typing_extensions
python setup.py generate_stubs
<< parameters.extra_env >> \
CMAKE_BUILD_PARALLEL_LEVEL="" \

View File

@@ -1,11 +1,11 @@
repos:
- repo: https://github.com/pre-commit/mirrors-clang-format
rev: v18.1.4
rev: v18.1.8
hooks:
- id: clang-format
# Using this mirror lets us use mypyc-compiled black, which is about 2x faster
- repo: https://github.com/psf/black-pre-commit-mirror
rev: 24.4.2
rev: 24.8.0
hooks:
- id: black
- repo: https://github.com/pycqa/isort

View File

@@ -24,7 +24,7 @@ option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF)
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
if(NOT MLX_VERSION)
set(MLX_VERSION 0.16.2)
set(MLX_VERSION 0.17.2)
endif()
# --------------------- Processor tests -------------------------

View File

@@ -0,0 +1,66 @@
# Copyright © 2024 Apple Inc.
"""
Run with:
mpirun -n 2 python /path/to/distributed_bench.py
"""
import time
import mlx.core as mx
def time_fn(fn, *args, **kwargs):
msg = kwargs.pop("msg", None)
world = mx.distributed.init()
if world.rank() == 0:
if msg:
print(f"Timing {msg} ...", end=" ")
else:
print(f"Timing {fn.__name__} ...", end=" ")
# warmup
for _ in range(5):
mx.eval(fn(*args, **kwargs))
num_iters = 100
tic = time.perf_counter()
for _ in range(num_iters):
x = mx.eval(fn(*args, **kwargs))
toc = time.perf_counter()
msec = 1e3 * (toc - tic) / num_iters
if world.rank() == 0:
print(f"{msec:.5f} msec")
def time_all_sum():
shape = (4096,)
x = mx.random.uniform(shape=shape)
mx.eval(x)
def sine(x):
for _ in range(20):
x = mx.sin(x)
return x
time_fn(sine, x)
def all_sum_plain(x):
for _ in range(20):
x = mx.distributed.all_sum(x)
return x
time_fn(all_sum_plain, x)
def all_sum_with_sine(x):
for _ in range(20):
x = mx.sin(x)
x = mx.distributed.all_sum(x)
return x
time_fn(all_sum_with_sine, x)
if __name__ == "__main__":
time_all_sum()

View File

@@ -0,0 +1,413 @@
Custom Metal Kernels
====================
MLX supports writing custom Metal kernels through the Python and C++ APIs.
Simple Example
--------------
Let's write a custom kernel that computes ``exp`` elementwise:
.. code-block:: python
def exp_elementwise(a: mx.array):
source = """
uint elem = thread_position_in_grid.x;
T tmp = inp[elem];
out[elem] = metal::exp(tmp);
"""
kernel = mx.fast.metal_kernel(
name="myexp",
source=source,
)
outputs = kernel(
inputs={"inp": a},
template={"T": mx.float32},
grid=(a.size, 1, 1),
threadgroup=(256, 1, 1),
output_shapes={"out": a.shape},
output_dtypes={"out": a.dtype},
)
return outputs["out"]
a = mx.random.normal(shape=(4, 16)).astype(mx.float16)
b = exp_elementwise(a)
assert mx.allclose(b, mx.exp(a))
.. note::
We are only required to pass the body of the Metal kernel in ``source``.
The full function signature will be generated using:
* The keys and shapes/dtypes of ``inputs``
In the above, ``a`` is an ``mx.array`` of type ``mx.float16`` and we pass it with the key ``inp``
so we will add ``const device float16_t* inp`` to the signature.
``inp_shape``, ``inp_strides`` and ``inp_ndim`` are also added for convenience if they are present
in ``source``.
* The keys and values of ``output_shapes`` and ``output_dtypes``
In the above, ``out`` is an ``mx.array`` of type ``mx.float16``
so we add ``device float16_t* out``.
* Template parameters passed using ``template``
In the above, ``template={"T": mx.float32}`` adds a template of ``template <typename T>`` to the function
and instantiates the template with ``custom_kernel_myexp_float<float>``.
Template parameters can be ``mx.core.Dtype``, ``int`` or ``bool``.
* Metal attributes used in ``source`` such as ``[[thread_position_in_grid]]``
These will be added as function arguments.
All the attributes defined in Table 5.8 of the `Metal Shading Language Specification <https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf>`_ are supported.
Putting this all together, the generated function signature for ``myexp`` is as follows:
.. code-block:: cpp
template <typename T>
[[kernel]] void custom_kernel_myexp_float(
const device float16_t* inp [[buffer(0)]],
device float16_t* out [[buffer(1)]],
uint3 thread_position_in_grid [[thread_position_in_grid]]) {
uint elem = thread_position_in_grid.x;
T tmp = inp[elem];
out[elem] = metal::exp(tmp);
}
template [[host_name("custom_kernel_myexp_float")]] [[kernel]] decltype(custom_kernel_myexp_float<float>) custom_kernel_myexp_float<float>;
Passing ``verbose=True`` to ``mx.fast.metal_kernel.__call__`` will print the generated code for debugging purposes.
Using Shape/Strides
-------------------
``mx.fast.metal_kernel`` supports an argument ``ensure_row_contiguous`` which is ``True`` by default.
This will copy the ``mx.array`` inputs if needed before the kernel is launched to ensure that the memory layout is row contiguous.
Generally this makes writing the kernel easier, since we don't have to worry about gaps or the ordering of the dims
when indexing.
If we want to avoid this copy, ``metal_kernel`` automatically passes ``a_shape``, ``a_strides`` and ``a_ndim`` for each
input array ``a`` if any are present in ``source``.
We can then use MLX's built in indexing utils to fetch the right elements for each thread.
Let's convert ``myexp`` above to support arbitrarily strided arrays without relying on a copy from ``ensure_row_contiguous``:
.. code-block:: python
def exp_elementwise(a: mx.array):
source = """
uint elem = thread_position_in_grid.x;
// Utils from `mlx/backend/metal/kernels/utils.h` are automatically included
uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim);
T tmp = inp[loc];
// Output arrays are always row contiguous
out[elem] = metal::exp(tmp);
"""
kernel = mx.fast.metal_kernel(
name="myexp_strided",
source=source
)
outputs = kernel(
inputs={"inp": a},
template={"T": mx.float32},
grid=(a.size, 1, 1),
threadgroup=(256, 1, 1),
output_shapes={"out": a.shape},
output_dtypes={"out": a.dtype},
ensure_row_contiguous=False,
)
return outputs["out"]
a = mx.random.normal(shape=(4, 16)).astype(mx.float16)
# make non-contiguous
a = a[::2]
b = exp_elementwise(a)
assert mx.allclose(b, mx.exp(a))
Complex Example
-----------------------------
Let's implement a more complex example: ``grid_sample`` in ``"bilinear"`` mode.
We'll start with the following MLX implementation using standard ops:
.. code-block:: python
def grid_sample_ref(x, grid):
N, H_in, W_in, _ = x.shape
ix = ((grid[..., 0] + 1) * W_in - 1) / 2
iy = ((grid[..., 1] + 1) * H_in - 1) / 2
ix_nw = mx.floor(ix).astype(mx.int32)
iy_nw = mx.floor(iy).astype(mx.int32)
ix_ne = ix_nw + 1
iy_ne = iy_nw
ix_sw = ix_nw
iy_sw = iy_nw + 1
ix_se = ix_nw + 1
iy_se = iy_nw + 1
nw = (ix_se - ix) * (iy_se - iy)
ne = (ix - ix_sw) * (iy_sw - iy)
sw = (ix_ne - ix) * (iy - iy_ne)
se = (ix - ix_nw) * (iy - iy_nw)
I_nw = x[mx.arange(N)[:, None, None], iy_nw, ix_nw, :]
I_ne = x[mx.arange(N)[:, None, None], iy_ne, ix_ne, :]
I_sw = x[mx.arange(N)[:, None, None], iy_sw, ix_sw, :]
I_se = x[mx.arange(N)[:, None, None], iy_se, ix_se, :]
mask_nw = (iy_nw >= 0) & (iy_nw <= H_in - 1) & (ix_nw >= 0) & (ix_nw <= W_in - 1)
mask_ne = (iy_ne >= 0) & (iy_ne <= H_in - 1) & (ix_ne >= 0) & (ix_ne <= W_in - 1)
mask_sw = (iy_sw >= 0) & (iy_sw <= H_in - 1) & (ix_sw >= 0) & (ix_sw <= W_in - 1)
mask_se = (iy_se >= 0) & (iy_se <= H_in - 1) & (ix_se >= 0) & (ix_se <= W_in - 1)
I_nw *= mask_nw[..., None]
I_ne *= mask_ne[..., None]
I_sw *= mask_sw[..., None]
I_se *= mask_se[..., None]
output = nw[..., None] * I_nw + ne[..., None] * I_ne + sw[..., None] * I_sw + se[..., None] * I_se
return output
Now let's use ``mx.custom_function`` together with ``mx.fast.metal_kernel``
to write a fast GPU kernel for both the forward and backward passes.
First we'll implement the forward pass as a fused kernel:
.. code-block:: python
@mx.custom_function
def grid_sample(x, grid):
assert x.ndim == 4, "`x` must be 4D."
assert grid.ndim == 4, "`grid` must be 4D."
B, _, _, C = x.shape
_, gN, gM, D = grid.shape
out_shape = (B, gN, gM, C)
assert D == 2, "Last dim of `grid` must be size 2."
source = """
uint elem = thread_position_in_grid.x;
int H = x_shape[1];
int W = x_shape[2];
int C = x_shape[3];
int gH = grid_shape[1];
int gW = grid_shape[2];
int w_stride = C;
int h_stride = W * w_stride;
int b_stride = H * h_stride;
uint grid_idx = elem / C * 2;
float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
int ix_nw = floor(ix);
int iy_nw = floor(iy);
int ix_ne = ix_nw + 1;
int iy_ne = iy_nw;
int ix_sw = ix_nw;
int iy_sw = iy_nw + 1;
int ix_se = ix_nw + 1;
int iy_se = iy_nw + 1;
T nw = (ix_se - ix) * (iy_se - iy);
T ne = (ix - ix_sw) * (iy_sw - iy);
T sw = (ix_ne - ix) * (iy - iy_ne);
T se = (ix - ix_nw) * (iy - iy_nw);
int batch_idx = elem / C / gH / gW * b_stride;
int channel_idx = elem % C;
int base_idx = batch_idx + channel_idx;
T I_nw = x[base_idx + iy_nw * h_stride + ix_nw * w_stride];
T I_ne = x[base_idx + iy_ne * h_stride + ix_ne * w_stride];
T I_sw = x[base_idx + iy_sw * h_stride + ix_sw * w_stride];
T I_se = x[base_idx + iy_se * h_stride + ix_se * w_stride];
I_nw = iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1 ? I_nw : 0;
I_ne = iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1 ? I_ne : 0;
I_sw = iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1 ? I_sw : 0;
I_se = iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1 ? I_se : 0;
out[elem] = nw * I_nw + ne * I_ne + sw * I_sw + se * I_se;
"""
kernel = mx.fast.metal_kernel(
name="grid_sample",
source=source,
)
outputs = kernel(
inputs={"x": x, "grid": grid},
template={"T": x.dtype},
output_shapes={"out": out_shape},
output_dtypes={"out": x.dtype},
grid=(np.prod(out_shape), 1, 1),
threadgroup=(256, 1, 1),
)
return outputs["out"]
For a reasonably sized input such as:
.. code-block:: python
x.shape = (8, 1024, 1024, 64)
grid.shape = (8, 256, 256, 2)
On an M1 Max, we see a big performance improvement:
``55.7ms -> 6.7ms => 8x speed up``
Grid Sample VJP
---------------
Since we decorated ``grid_sample`` with ``mx.custom_function``, we can now define
its custom vjp transform so MLX can differentiate it.
The backwards pass requires atomically updating ``x_grad``/``grid_grad`` and so
requires a few extra ``mx.fast.metal_kernel`` features:
* ``init_value=0``
Initialize all of the kernel's outputs to this value before it runs. This allows us to update only part of the output arrays with the kernel.
* ``atomic_outputs=True``
Designate all of the kernel outputs as ``atomic`` in the function signature.
This means we can use Metal's ``atomic`` features to simultaneously update the ``x_grad`` and ``grid_grad`` arrays from multiple threadgroups.
See section 6.15 of the `Metal Shading Language Specification <https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf>`_ for more details.
We can then implement the backwards pass as follows:
.. code-block:: python
@grid_sample.vjp
def grid_sample_vjp(primals, cotangent, _):
x, grid = primals
B, _, _, C = x.shape
_, gN, gM, D = grid.shape
assert D == 2, "Last dim of `grid` must be size 2."
source = """
uint elem = thread_position_in_grid.x;
int H = x_shape[1];
int W = x_shape[2];
int C = x_shape[3];
// Pad C to the nearest larger simdgroup size multiple
int C_padded = ceildiv(C, threads_per_simdgroup) * threads_per_simdgroup;
int gH = grid_shape[1];
int gW = grid_shape[2];
int w_stride = C;
int h_stride = W * w_stride;
int b_stride = H * h_stride;
uint grid_idx = elem / C_padded * 2;
float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
int ix_nw = floor(ix);
int iy_nw = floor(iy);
int ix_ne = ix_nw + 1;
int iy_ne = iy_nw;
int ix_sw = ix_nw;
int iy_sw = iy_nw + 1;
int ix_se = ix_nw + 1;
int iy_se = iy_nw + 1;
T nw = (ix_se - ix) * (iy_se - iy);
T ne = (ix - ix_sw) * (iy_sw - iy);
T sw = (ix_ne - ix) * (iy - iy_ne);
T se = (ix - ix_nw) * (iy - iy_nw);
int batch_idx = elem / C_padded / gH / gW * b_stride;
int channel_idx = elem % C_padded;
int base_idx = batch_idx + channel_idx;
T gix = T(0);
T giy = T(0);
if (channel_idx < C) {
int cot_index = elem / C_padded * C + channel_idx;
T cot = cotangent[cot_index];
if (iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1) {
int offset = base_idx + iy_nw * h_stride + ix_nw * w_stride;
atomic_fetch_add_explicit(&x_grad[offset], nw * cot, memory_order_relaxed);
T I_nw = x[offset];
gix -= I_nw * (iy_se - iy) * cot;
giy -= I_nw * (ix_se - ix) * cot;
}
if (iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1) {
int offset = base_idx + iy_ne * h_stride + ix_ne * w_stride;
atomic_fetch_add_explicit(&x_grad[offset], ne * cot, memory_order_relaxed);
T I_ne = x[offset];
gix += I_ne * (iy_sw - iy) * cot;
giy -= I_ne * (ix - ix_sw) * cot;
}
if (iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1) {
int offset = base_idx + iy_sw * h_stride + ix_sw * w_stride;
atomic_fetch_add_explicit(&x_grad[offset], sw * cot, memory_order_relaxed);
T I_sw = x[offset];
gix -= I_sw * (iy - iy_ne) * cot;
giy += I_sw * (ix_ne - ix) * cot;
}
if (iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1) {
int offset = base_idx + iy_se * h_stride + ix_se * w_stride;
atomic_fetch_add_explicit(&x_grad[offset], se * cot, memory_order_relaxed);
T I_se = x[offset];
gix += I_se * (iy - iy_nw) * cot;
giy += I_se * (ix - ix_nw) * cot;
}
}
T gix_mult = W / 2;
T giy_mult = H / 2;
// Reduce across each simdgroup first.
// This is much faster than relying purely on atomics.
gix = simd_sum(gix);
giy = simd_sum(giy);
if (thread_index_in_simdgroup == 0) {
atomic_fetch_add_explicit(&grid_grad[grid_idx], gix * gix_mult, memory_order_relaxed);
atomic_fetch_add_explicit(&grid_grad[grid_idx + 1], giy * giy_mult, memory_order_relaxed);
}
"""
kernel = mx.fast.metal_kernel(
name="grid_sample_grad",
source=source,
atomic_outputs=True,
)
# pad the output channels to simd group size
# so that our `simd_sum`s don't overlap.
simdgroup_size = 32
C_padded = (C + simdgroup_size - 1) // simdgroup_size * simdgroup_size
grid_size = B * gN * gM * C_padded
outputs = kernel(
inputs={"x": x, "grid": grid, "cotangent": cotangent},
template={"T": x.dtype},
output_shapes={"x_grad": x.shape, "grid_grad": grid.shape},
output_dtypes={"x_grad": x.dtype, "grid_grad": x.dtype},
grid=(grid_size, 1, 1),
threadgroup=(256, 1, 1),
init_value=0,
)
return outputs["x_grad"], outputs["grid_grad"]
There's an even larger speed up for the vjp:
``676.4ms -> 16.7ms => 40x speed up``

View File

@@ -85,3 +85,4 @@ are the CPU and GPU.
dev/extensions
dev/metal_debugger
dev/custom_metal_kernels

View File

@@ -17,3 +17,6 @@ made available.
init
all_sum
all_gather
send
recv
recv_like

View File

@@ -12,3 +12,5 @@ Fast
layer_norm
rope
scaled_dot_product_attention
affine_quantize
metal_kernel

View File

@@ -44,6 +44,7 @@ Operations
convolve
conv1d
conv2d
conv3d
conv_general
cos
cosh

View File

@@ -2,7 +2,7 @@
requires = [
"setuptools>=42",
"cmake>=3.24",
"mlx>=0.9.0",
"nanobind@git+https://github.com/wjakob/nanobind.git@2f04eac452a6d9142dedb957701bdb20125561e4",
"mlx>=0.17.0",
"nanobind==2.1.0",
]
build-backend = "setuptools.build_meta"

View File

@@ -1,4 +1,4 @@
setuptools>=42
cmake>=3.24
mlx>=0.16.2
nanobind==2.0
mlx>=0.17.0
nanobind==2.1.0

View File

@@ -13,7 +13,6 @@ if __name__ == "__main__":
cmdclass={"build_ext": extension.CMakeBuild},
packages=["mlx_sample_extensions"],
package_data={"mlx_sample_extensions": ["*.so", "*.dylib", "*.metallib"]},
extras_require={"dev": []},
zip_safe=False,
python_requires=">=3.8",
)

View File

@@ -70,7 +70,6 @@ inline float16x8_t neon_fast_exp(float16x8_t x) {
x = vdupq_n_f16(float16_t(1.535336188319500e-4f));
x = vfmaq_f16(vdupq_n_f16(float16_t(1.339887440266574e-3f)), x, fpart);
x = vfmaq_f16(vdupq_n_f16(float16_t(1.339887440266574e-3f)), x, fpart);
x = vfmaq_f16(vdupq_n_f16(float16_t(9.618437357674640e-3f)), x, fpart);
x = vfmaq_f16(vdupq_n_f16(float16_t(5.550332471162809e-2f)), x, fpart);
x = vfmaq_f16(vdupq_n_f16(float16_t(2.402264791363012e-1f)), x, fpart);

View File

@@ -2,6 +2,7 @@
#include <dlfcn.h>
#include <filesystem>
#include <fstream>
#include <list>
#include "mlx/backend/common/compiled.h"

View File

@@ -5,11 +5,9 @@
#include <utility>
#include "mlx/allocator.h"
#include "mlx/io/load.h"
#include "mlx/backend/common/load.h"
#include "mlx/primitives.h"
namespace mlx::core {
namespace {
template <const uint8_t scalar_size>
@@ -29,12 +27,14 @@ void swap_endianness(uint8_t* data_bytes, size_t N) {
} // namespace
void Load::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 0);
out.set_data(allocator::malloc_or_wait(out.nbytes()));
namespace mlx::core {
reader_->seek(offset_, std::ios_base::beg);
reader_->read(out.data<char>(), out.nbytes());
void load(
array& out,
size_t offset,
const std::shared_ptr<io::Reader>& reader,
bool swap_endianness_) {
reader->read(out.data<char>(), out.nbytes(), offset);
if (swap_endianness_) {
switch (out.itemsize()) {
@@ -51,4 +51,11 @@ void Load::eval(const std::vector<array>& inputs, array& out) {
}
}
void Load::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 0);
out.set_data(allocator::malloc_or_wait(out.nbytes()));
load(out, offset_, reader_, swap_endianness_);
}
} // namespace mlx::core

14
mlx/backend/common/load.h Normal file
View File

@@ -0,0 +1,14 @@
// Copyright © 2024 Apple Inc.
#include "mlx/array.h"
#include "mlx/io/load.h"
namespace mlx::core {
void load(
array& out,
size_t offset,
const std::shared_ptr<io::Reader>& reader,
bool swap_endianess);
} // namespace mlx::core

View File

@@ -21,7 +21,7 @@ EOM
fi
CONTENT=$($GCC -I $SRCDIR -E $SRCDIR/mlx/backend/common/compiled_preamble.h 2>/dev/null)
CONTENT=$($GCC -I "$SRCDIR" -E "$SRCDIR/mlx/backend/common/compiled_preamble.h" 2>/dev/null)
cat << EOF > "$OUTPUT_FILE"
const char* get_kernel_preamble() {

View File

@@ -373,6 +373,10 @@ struct Sign {
uint64_t operator()(uint64_t x) {
return x != 0;
}
complex64_t operator()(complex64_t x) {
return x == complex64_t(0) ? x : x / std::abs(x);
}
};
struct Sin {

View File

@@ -87,6 +87,38 @@ struct OrReduce {
}
};
struct MaxReduce {
template <typename T>
std::enable_if_t<std::is_integral_v<T>> operator()(T* y, T x) {
(*y) = (*y > x) ? *y : x;
};
template <typename T>
std::enable_if_t<!std::is_integral_v<T>> operator()(T* y, T x) {
if (std::isnan(x)) {
*y = x;
} else {
(*y) = (*y > x) ? *y : x;
}
};
};
struct MinReduce {
template <typename T>
std::enable_if_t<std::is_integral_v<T>> operator()(T* y, T x) {
(*y) = (*y < x) ? *y : x;
};
template <typename T>
std::enable_if_t<!std::is_integral_v<T>> operator()(T* y, T x) {
if (std::isnan(x)) {
*y = x;
} else {
(*y) = (*y < x) ? *y : x;
}
};
};
template <typename InT>
void reduce_dispatch_out(
const array& in,
@@ -118,15 +150,13 @@ void reduce_dispatch_out(
break;
}
case Reduce::Max: {
auto op = [](auto y, auto x) { (*y) = (*y > x) ? *y : x; };
auto init = Limits<InT>::min;
reduction_op<InT, InT>(in, out, axes, init, op);
reduction_op<InT, InT>(in, out, axes, init, MaxReduce());
break;
}
case Reduce::Min: {
auto op = [](auto y, auto x) { (*y) = (*y < x) ? *y : x; };
auto init = Limits<InT>::max;
reduction_op<InT, InT>(in, out, axes, init, op);
reduction_op<InT, InT>(in, out, axes, init, MinReduce());
break;
}
}

View File

@@ -49,7 +49,7 @@ struct ReductionPlan {
ReductionPlan(ReductionOpType type_) : type(type_) {}
};
ReductionPlan get_reduction_plan(const array& x, const std::vector<int> axes);
ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes);
// Helper for the ndimensional strided loop
// Should this be in utils?

View File

@@ -19,7 +19,7 @@ std::pair<std::vector<int>, std::vector<size_t>> shapes_without_reduction_axes(
return std::make_pair(shape, strides);
}
ReductionPlan get_reduction_plan(const array& x, const std::vector<int> axes) {
ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes) {
// The data is all there and we are reducing over everything
if (x.size() == x.data_size() && axes.size() == x.ndim() &&
x.flags().contiguous) {
@@ -41,6 +41,14 @@ ReductionPlan get_reduction_plan(const array& x, const std::vector<int> axes) {
}
}
// Remove singleton axes from the plan
for (int i = shape.size() - 1; i >= 0; i--) {
if (shape[i] == 1) {
shape.erase(shape.begin() + i);
strides.erase(strides.begin() + i);
}
}
if (strides.back() == 1) {
return ReductionPlan(ContiguousReduce, shape, strides);
} else if (strides.back() > 1) {
@@ -63,10 +71,14 @@ ReductionPlan get_reduction_plan(const array& x, const std::vector<int> axes) {
// have a contiguous reduction.
std::vector<std::pair<int, size_t>> reductions;
for (auto a : axes) {
reductions.push_back(std::make_pair(x.shape(a), x.strides()[a]));
if (x.shape(a) > 1) {
reductions.push_back(std::make_pair(x.shape(a), x.strides()[a]));
}
}
std::sort(reductions.begin(), reductions.end(), [](auto a, auto b) {
return a.second > b.second;
bool a_is_zero = a.second == 0;
bool b_is_zero = b.second == 0;
return (a_is_zero != b_is_zero) ? a.second < b.second : a.second > b.second;
});
// Extract the two smallest and try to merge them in case the contiguous
// reduction can be bigger than just the last axis.
@@ -98,16 +110,33 @@ ReductionPlan get_reduction_plan(const array& x, const std::vector<int> axes) {
// strides.back() are contiguous.
if (strides.back() > 1) {
int size = 1;
bool have_expand = false;
for (int i = x.ndim() - 1; i >= 0; i--) {
if (axes.back() == i) {
continue;
}
if (x.strides()[i] != size) {
size_t stride_i = x.strides()[i];
int shape_i = x.shape(i);
if (stride_i == 0) {
if (shape_i == 1) {
continue;
}
have_expand = true;
break;
}
size *= x.shape(i);
if (stride_i != size && shape_i != 1) {
break;
}
size *= shape_i;
}
if (size >= strides.back()) {
// In the case of an expanded dimension we are being conservative and
// require the smallest reduction stride to be smaller than the maximum row
// contiguous size. The reason is that we can't easily know if the reduced
// axis is before or after an expanded dimension.
if (size > strides.back() || (size == strides.back() && !have_expand)) {
return ReductionPlan(GeneralStridedReduce, shape, strides);
}
}

View File

@@ -12,6 +12,7 @@ namespace {
// TODO: Add support for more combinations of input types.
enum class TernaryOpType {
ScalarScalarScalar,
VectorVectorVector,
General,
};
@@ -20,6 +21,12 @@ get_ternary_op_type(const array& a, const array& b, const array& c) {
TernaryOpType topt;
if (a.data_size() == 1 && b.data_size() == 1 && c.data_size() == 1) {
topt = TernaryOpType::ScalarScalarScalar;
} else if (
(a.flags().row_contiguous && b.flags().row_contiguous &&
c.flags().row_contiguous) ||
(a.flags().col_contiguous && b.flags().col_contiguous &&
c.flags().col_contiguous)) {
topt = TernaryOpType::VectorVectorVector;
} else {
topt = TernaryOpType::General;
}
@@ -33,11 +40,32 @@ void set_ternary_op_output_data(
array& out,
TernaryOpType topt,
bool donate_with_move = false) {
auto maybe_donate = [&out, donate_with_move](const array& x) {
if (x.is_donatable() && x.itemsize() == out.itemsize()) {
if (donate_with_move) {
out.move_shared_buffer(x);
} else {
out.copy_shared_buffer(x);
}
return true;
}
return false;
};
switch (topt) {
case TernaryOpType::ScalarScalarScalar:
out.set_data(
allocator::malloc_or_wait(out.itemsize()), 1, b.strides(), b.flags());
break;
case TernaryOpType::VectorVectorVector:
if (!(maybe_donate(a) || maybe_donate(b) || maybe_donate(c))) {
out.set_data(
allocator::malloc_or_wait(out.itemsize() * b.data_size()),
b.data_size(),
b.strides(),
b.flags());
}
break;
case TernaryOpType::General:
out.set_data(allocator::malloc_or_wait(out.nbytes()));
break;

View File

@@ -104,6 +104,33 @@ inline auto collapse_contiguous_dims(Arrays&&... xs) {
std::vector<array>{std::forward<Arrays>(xs)...});
}
// The single array version of the above.
inline std::tuple<std::vector<int>, std::vector<size_t>>
collapse_contiguous_dims(
const std::vector<int>& shape,
const std::vector<size_t>& strides) {
std::vector<int> collapsed_shape;
std::vector<size_t> collapsed_strides;
if (shape.size() > 0) {
collapsed_shape.push_back(shape[0]);
collapsed_strides.push_back(strides[0]);
for (int i = 1; i < shape.size(); i++) {
if (strides[i] * shape[i] != collapsed_strides.back() ||
collapsed_shape.back() * static_cast<size_t>(shape[i]) >
std::numeric_limits<int>::max()) {
collapsed_shape.push_back(shape[i]);
collapsed_strides.push_back(strides[i]);
} else {
collapsed_shape.back() *= shape[i];
collapsed_strides.back() = strides[i];
}
}
}
return std::make_tuple(collapsed_shape, collapsed_strides);
}
template <typename stride_t>
inline auto check_contiguity(
const std::vector<int>& shape,
@@ -115,8 +142,8 @@ inline auto check_contiguity(
bool is_col_contiguous = true;
for (int i = 0, ri = shape.size() - 1; ri >= 0; i++, ri--) {
is_row_contiguous &= strides[i] == f_stride || shape[i] == 1;
is_col_contiguous &= strides[ri] == b_stride || shape[ri] == 1;
is_col_contiguous &= strides[i] == f_stride || shape[i] == 1;
is_row_contiguous &= strides[ri] == b_stride || shape[ri] == 1;
f_stride *= shape[i];
b_stride *= shape[ri];
if (strides[i] > 0) {

View File

@@ -79,6 +79,7 @@ if (MLX_METAL_JIT)
kernels/reduction/reduce_all.h
kernels/reduction/reduce_col.h
kernels/reduction/reduce_row.h
kernels/reduction/reduce_init.h
)
make_jit_source(
steel/gemm/gemm
@@ -131,6 +132,8 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
${CMAKE_CURRENT_SOURCE_DIR}/custom_kernel.cpp
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/event.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp

View File

@@ -0,0 +1,89 @@
// Copyright © 2024 Apple Inc.
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/jit/includes.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/fast_primitives.h"
namespace mlx::core::fast {
void CustomKernel::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
auto& s = stream();
std::vector<array> copies;
for (auto& out : outputs) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
if (init_value_) {
array init = array(init_value_.value(), out.dtype());
copy_gpu(init, out, CopyType::Scalar, s);
copies.push_back(init);
}
}
auto check_input = [&copies, &s, this](const array& x) -> const array {
bool no_copy = x.flags().row_contiguous;
if (!ensure_row_contiguous_ || no_copy) {
return x;
} else {
copies.push_back(array(x.shape(), x.dtype(), nullptr, {}));
copy_gpu(x, copies.back(), CopyType::General, s);
return copies.back();
}
};
std::vector<const array> checked_inputs;
for (const array& in : inputs) {
checked_inputs.push_back(check_input(in));
}
auto& d = metal::device(s.device);
const auto& lib_name = name_;
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
lib = d.get_library(lib_name, metal::utils() + source_);
}
auto kernel = d.get_kernel(name_, lib);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
int index = 0;
for (int i = 0; i < checked_inputs.size(); i++) {
const array& in = checked_inputs[i];
auto shape_info = shape_infos_[i];
compute_encoder.set_input_array(in, index);
index++;
if (in.ndim() > 0) {
int ndim = in.ndim();
if (shape_info.shape) {
set_vector_bytes(compute_encoder, in.shape(), ndim, index);
index++;
}
if (shape_info.strides) {
set_vector_bytes(compute_encoder, in.strides(), ndim, index);
index++;
}
if (shape_info.ndim) {
compute_encoder->setBytes(&ndim, sizeof(int), index);
index++;
}
}
}
for (array out : outputs) {
compute_encoder.set_output_array(out, index);
index++;
}
const auto [tx, ty, tz] = threadgroup_;
MTL::Size group_dims = MTL::Size(tx, ty, tz);
const auto [gx, gy, gz] = grid_;
MTL::Size grid_dims = MTL::Size(gx, gy, gz);
compute_encoder->dispatchThreads(grid_dims, group_dims);
if (!copies.empty()) {
d.get_command_buffer(s.index)->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
}
}
} // namespace mlx::core::fast

View File

@@ -1,8 +1,6 @@
// Copyright © 2023-2024 Apple Inc.
#include <dlfcn.h>
#include <cstdlib>
#include <filesystem>
#include <sstream>
#include <sys/sysctl.h>
@@ -16,8 +14,6 @@
#include "mlx/backend/metal/metal_impl.h"
#include "mlx/backend/metal/utils.h"
namespace fs = std::filesystem;
namespace mlx::core::metal {
namespace {
@@ -38,20 +34,6 @@ constexpr auto get_metal_version() {
#endif
}
std::string get_colocated_mtllib_path(const std::string& lib_name) {
Dl_info info;
std::string mtllib_path;
std::string lib_ext = lib_name + ".metallib";
int success = dladdr((void*)get_colocated_mtllib_path, &info);
if (success) {
auto mtllib = fs::path(info.dli_fname).remove_filename() / lib_ext;
mtllib_path = mtllib.c_str();
}
return mtllib_path;
}
auto load_device() {
auto devices = MTL::CopyAllDevices();
auto device = static_cast<MTL::Device*>(devices->object(0))
@@ -311,12 +293,6 @@ void Device::register_library(
}
}
void Device::register_library(const std::string& lib_name) {
if (auto it = library_map_.find(lib_name); it == library_map_.end()) {
register_library(lib_name, get_colocated_mtllib_path(lib_name));
}
}
MTL::Library* Device::get_library_cache_(const std::string& lib_name) {
// Search for cached metal lib
MTL::Library* mtl_lib;

View File

@@ -3,6 +3,8 @@
#pragma once
#include <Metal/Metal.hpp>
#include <dlfcn.h>
#include <filesystem>
#include <functional>
#include <mutex>
#include <string>
@@ -12,8 +14,26 @@
#include "mlx/array.h"
#include "mlx/device.h"
namespace fs = std::filesystem;
namespace mlx::core::metal {
// Note, this function must be left inline in a header so that it is not
// dynamically linked.
inline std::string get_colocated_mtllib_path(const std::string& lib_name) {
Dl_info info;
std::string mtllib_path;
std::string lib_ext = lib_name + ".metallib";
int success = dladdr((void*)get_colocated_mtllib_path, &info);
if (success) {
auto mtllib = fs::path(info.dli_fname).remove_filename() / lib_ext;
mtllib_path = mtllib.c_str();
}
return mtllib_path;
}
using MTLFCList =
std::vector<std::tuple<const void*, MTL::DataType, NS::UInteger>>;
@@ -86,7 +106,13 @@ class Device {
const std::string& lib_name,
const std::string& lib_path);
void register_library(const std::string& lib_name);
// Note, this should remain in the header so that it is not dynamically
// linked
void register_library(const std::string& lib_name) {
if (auto it = library_map_.find(lib_name); it == library_map_.end()) {
register_library(lib_name, get_colocated_mtllib_path(lib_name));
}
}
MTL::Library* get_library(const std::string& name);

View File

@@ -0,0 +1,142 @@
// Copyright © 2024 Apple Inc.
#include <cassert>
#include "mlx/allocator.h"
#include "mlx/backend/metal/device.h"
#include "mlx/distributed/ops.h"
#include "mlx/distributed/primitives.h"
#include "mlx/scheduler.h"
namespace mlx::core::distributed {
void signal_and_wait(const array& in, const array& out, const Stream& s) {
auto& d = metal::device(s.device);
d.end_encoding(s.index);
auto command_buffer = d.get_command_buffer(s.index);
if (in.event().valid()) {
command_buffer->encodeSignalEvent(
static_cast<MTL::Event*>(in.event().raw_event().get()),
in.event().value());
}
command_buffer->encodeWait(
static_cast<MTL::Event*>(out.event().raw_event().get()),
out.event().value());
}
void AllReduce::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() == 1);
assert(outputs.size() == 1);
auto& in = inputs[0];
auto& out = outputs[0];
if (in.is_donatable()) {
out.move_shared_buffer(in);
} else {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
}
auto task = [in = in,
out = out,
reduce_type = reduce_type_,
group = group()]() mutable {
if (in.event().valid()) {
in.event().wait();
}
switch (reduce_type) {
case Sum:
distributed::detail::all_sum(
group, in.data_shared_ptr() == nullptr ? out : in, out);
break;
default:
throw std::runtime_error("Only all reduce sum is supported for now");
}
out.event().signal();
};
scheduler::enqueue(detail::communication_stream(), std::move(task));
signal_and_wait(in, out, stream());
}
void AllGather::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() == 1);
assert(outputs.size() == 1);
auto& in = inputs[0];
auto& out = outputs[0];
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto task = [in = in, out = out, group = group()]() mutable {
if (in.event().valid()) {
in.event().wait();
}
distributed::detail::all_gather(group, in, out);
out.event().signal();
};
scheduler::enqueue(detail::communication_stream(), std::move(task));
signal_and_wait(in, out, stream());
}
void Send::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() == 1);
assert(outputs.size() == 1);
auto& in = inputs[0];
auto& out = outputs[0];
// Schedule an async send on the comm stream
auto task = [in = in, out = out, group = group(), dst = dst_]() mutable {
if (in.event().valid()) {
in.event().wait();
}
distributed::detail::send(group, in, dst);
out.event().signal();
};
scheduler::enqueue(detail::communication_stream(), std::move(task));
// Encode a signal event for the input but not a wait since we don't need to
// wait on the output.
auto& s = stream();
auto& d = metal::device(s.device);
d.end_encoding(s.index);
auto command_buffer = d.get_command_buffer(s.index);
if (in.event().valid()) {
command_buffer->encodeSignalEvent(
static_cast<MTL::Event*>(in.event().raw_event().get()),
in.event().value());
}
}
void Recv::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() == 0);
assert(outputs.size() == 1);
auto& out = outputs[0];
out.set_data(allocator::malloc_or_wait(out.nbytes()));
// Schedule an async recv on the comm stream
auto task = [out = out, group = group(), src = src_]() mutable {
distributed::detail::recv(group, out, src);
out.event().signal();
};
scheduler::enqueue(detail::communication_stream(), std::move(task));
// Encode a wait event as there is no input for the recv to encode a signal.
auto& s = stream();
auto& d = metal::device(s.device);
auto command_buffer = d.get_command_buffer(s.index);
command_buffer->encodeWait(
static_cast<MTL::Event*>(out.event().raw_event().get()),
out.event().value());
}
} // namespace mlx::core::distributed

View File

@@ -546,8 +546,8 @@ void fft_op(
auto [data_size, is_row_contiguous, is_col_contiguous] =
check_contiguity(x.shape(), strides);
flags.col_contiguous = is_row_contiguous;
flags.row_contiguous = is_col_contiguous;
flags.col_contiguous = is_col_contiguous;
flags.row_contiguous = is_row_contiguous;
flags.contiguous = data_size == x_copy.size();
x_copy.set_data(

View File

@@ -95,11 +95,21 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
slice_size *= s;
}
// Launch 2D grid of threads: indices x slice
size_t dim0 = out.size() / slice_size;
size_t dim1 = slice_size;
auto group_dims = get_block_dims(dim0, dim1, 1);
MTL::Size grid_dims = MTL::Size(dim0, dim1, 1);
// Launch 3D grid of threads
// First two dimensions for the indices, the last one for the slice
size_t dim0 = 1;
size_t dim1 = 1;
if (nidx) {
if (inputs[1].ndim() >= 1) {
dim0 = inputs[1].shape(0);
}
if (inputs[1].ndim() >= 2) {
dim1 = inputs[1].size() / dim0;
}
}
size_t dim2 = slice_size;
auto group_dims = get_block_dims(dim0, dim1, dim2);
MTL::Size grid_dims = MTL::Size(dim0, dim1, dim2);
// Collect all idx shapes and strides into one place
std::vector<int> idx_shapes;

View File

@@ -13,8 +13,8 @@ constexpr std::string_view gather_kernels = R"(
const constant size_t* idx_strides [[buffer(8)]],
const constant int& idx_ndim [[buffer(9)]],
{4}
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {{
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {{
Indices<{2}, {3}> idxs{{
{{ {5} }}, idx_shapes, idx_strides, idx_ndim}};

View File

@@ -1,168 +0,0 @@
// Copyright © 2024 Apple Inc.
constexpr std::string_view reduce_init_kernels = R"(
[[kernel]] void {0}(
device {1}* out [[buffer(0)]],
uint tid [[thread_position_in_grid]]) {{
out[tid] = {2}<{1}>::init;
}}
)";
constexpr std::string_view reduce_kernels = R"(
template [[host_name("all_{0}")]] [[kernel]] void
all_reduce<{1}, {2}, {3}<{2}>>(
const device {1}* in [[buffer(0)]],
device mlx_atomic<{2}>* out [[buffer(1)]],
const device size_t& in_size [[buffer(2)]],
uint gid [[thread_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
uint grid_size [[threads_per_grid]],
uint simd_per_group [[simdgroups_per_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
template [[host_name("colGeneral_{0}")]] [[kernel]] void
col_reduce_general<{1}, {2}, {3}<{2}>>(
const device {1}* in [[buffer(0)]],
device mlx_atomic<{2}>* out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& reduction_stride [[buffer(3)]],
const constant size_t& out_size [[buffer(4)]],
const constant int* shape [[buffer(5)]],
const constant size_t* strides [[buffer(6)]],
const constant int& ndim [[buffer(7)]],
threadgroup {2}* local_data [[threadgroup(0)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint3 lsize [[threads_per_threadgroup]]);
template [[host_name("colSmall_{0}")]] [[kernel]] void
col_reduce_small<{1}, {2}, {3}<{2}>>(
const device {1}* in [[buffer(0)]],
device {2}* out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& reduction_stride [[buffer(3)]],
const constant size_t& out_size [[buffer(4)]],
const constant int* shape [[buffer(5)]],
const constant size_t* strides [[buffer(6)]],
const constant int& ndim [[buffer(7)]],
const constant size_t& non_col_reductions [[buffer(8)]],
const constant int* non_col_shapes [[buffer(9)]],
const constant size_t* non_col_strides [[buffer(10)]],
const constant int& non_col_ndim [[buffer(11)]],
uint tid [[thread_position_in_grid]]);
template [[host_name("rowGeneralSmall_{0}")]] [[kernel]] void
row_reduce_general_small<{1}, {2}, {3}<{2}>>(
const device {1}* in [[buffer(0)]],
device {2}* out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& out_size [[buffer(3)]],
const constant size_t& non_row_reductions [[buffer(4)]],
const constant int* shape [[buffer(5)]],
const constant size_t* strides [[buffer(6)]],
const constant int& ndim [[buffer(7)]],
uint lid [[thread_position_in_grid]]);
template [[host_name("rowGeneralMed_{0}")]] [[kernel]] void
row_reduce_general_med<{1}, {2}, {3}<{2}>>(
const device {1}* in [[buffer(0)]],
device {2}* out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& out_size [[buffer(3)]],
const constant size_t& non_row_reductions [[buffer(4)]],
const constant int* shape [[buffer(5)]],
const constant size_t* strides [[buffer(6)]],
const constant int& ndim [[buffer(7)]],
uint tid [[threadgroup_position_in_grid]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_per_group [[dispatch_simdgroups_per_threadgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
template [[host_name("rowGeneral_{0}")]] [[kernel]] void
row_reduce_general<{1}, {2}, {3}<{2}>>(
const device {1}* in [[buffer(0)]],
device mlx_atomic<{2}>* out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& out_size [[buffer(3)]],
const constant size_t& non_row_reductions [[buffer(4)]],
const constant int* shape [[buffer(5)]],
const constant size_t* strides [[buffer(6)]],
const constant int& ndim [[buffer(7)]],
uint3 lid [[thread_position_in_threadgroup]],
uint3 lsize [[threads_per_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_per_group [[simdgroups_per_threadgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
)";
constexpr std::string_view reduce_non_atomic_kernels = R"(
template [[host_name("allNoAtomics_{0}")]] [[kernel]] void
all_reduce_no_atomics<{1}, {2}, {3}<{2}>>(
const device {1}* in [[buffer(0)]],
device {2}* out [[buffer(1)]],
const device size_t& in_size [[buffer(2)]],
uint gid [[thread_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
uint grid_size [[threads_per_grid]],
uint simd_per_group [[simdgroups_per_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint thread_group_id [[threadgroup_position_in_grid]]);
template [[host_name("colGeneralNoAtomics_{0}")]] [[kernel]] void
col_reduce_general_no_atomics<{1}, {2}, {3}<{2}>>(
const device {1}* in [[buffer(0)]],
device {2}* out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& reduction_stride [[buffer(3)]],
const constant size_t& out_size [[buffer(4)]],
const constant int* shape [[buffer(5)]],
const constant size_t* strides [[buffer(6)]],
const constant int& ndim [[buffer(7)]],
threadgroup {2}* local_data [[threadgroup(0)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint3 gid [[thread_position_in_grid]],
uint3 lsize [[threads_per_threadgroup]],
uint3 gsize [[threads_per_grid]]);
template [[host_name("colSmall_{0}")]] [[kernel]] void
col_reduce_small<{1}, {2}, {3}<{2}>>(
const device {1}* in [[buffer(0)]],
device {2}* out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& reduction_stride [[buffer(3)]],
const constant size_t& out_size [[buffer(4)]],
const constant int* shape [[buffer(5)]],
const constant size_t* strides [[buffer(6)]],
const constant int& ndim [[buffer(7)]],
const constant size_t& non_col_reductions [[buffer(8)]],
const constant int* non_col_shapes [[buffer(9)]],
const constant size_t* non_col_strides [[buffer(10)]],
const constant int& non_col_ndim [[buffer(11)]],
uint tid [[thread_position_in_grid]]);
template [[host_name("rowGeneralSmall_{0}")]] [[kernel]] void
row_reduce_general_small<{1}, {2}, {3}<{2}>>(
const device {1}* in [[buffer(0)]],
device {2}* out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& out_size [[buffer(3)]],
const constant size_t& non_row_reductions [[buffer(4)]],
const constant int* shape [[buffer(5)]],
const constant size_t* strides [[buffer(6)]],
const constant int& ndim [[buffer(7)]],
uint lid [[thread_position_in_grid]]);
template [[host_name("rowGeneralNoAtomics_{0}")]] [[kernel]] void
row_reduce_general_no_atomics<{1}, {2}, {3}<{2}>>(
const device {1}* in [[buffer(0)]],
device {2}* out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& out_size [[buffer(3)]],
const constant size_t& non_row_reductions [[buffer(4)]],
const constant int* shape [[buffer(5)]],
const constant size_t* strides [[buffer(6)]],
const constant int& ndim [[buffer(7)]],
uint3 lid [[thread_position_in_threadgroup]],
uint3 lsize [[threads_per_threadgroup]],
uint3 gsize [[threads_per_grid]],
uint3 tid [[threadgroup_position_in_grid]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_per_group [[simdgroups_per_threadgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
)";

View File

@@ -6,7 +6,6 @@
#include "mlx/backend/metal/jit/copy.h"
#include "mlx/backend/metal/jit/gemv_masked.h"
#include "mlx/backend/metal/jit/includes.h"
#include "mlx/backend/metal/jit/reduce.h"
#include "mlx/backend/metal/jit/scan.h"
#include "mlx/backend/metal/jit/softmax.h"
#include "mlx/backend/metal/jit/steel_conv.h"
@@ -323,12 +322,13 @@ MTL::ComputePipelineState* get_reduce_init_kernel(
auto lib = d.get_library(kernel_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::reduce_utils()
<< fmt::format(
reduce_init_kernels,
kernel_name,
get_type_string(out.dtype()),
op_name(out));
std::string op_type = op_name(out);
op_type[0] = std::toupper(op_name(out)[0]);
auto out_type = get_type_string(out.dtype());
std::string op = op_type + "<" + out_type + ">";
kernel_source << metal::utils() << metal::reduce_utils() << metal::reduce();
kernel_source << get_template_definition(
kernel_name, "init_reduce", out_type, op);
lib = d.get_library(kernel_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib);
@@ -347,14 +347,22 @@ MTL::ComputePipelineState* get_reduce_kernel(
op_type[0] = std::toupper(op_name[0]);
bool non_atomic = out.dtype() == int64 || out.dtype() == uint64;
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::reduce_utils() << metal::reduce()
<< fmt::format(
non_atomic ? reduce_non_atomic_kernels
: reduce_kernels,
lib_name,
get_type_string(in.dtype()),
get_type_string(out.dtype()),
op_type);
auto in_type = get_type_string(in.dtype());
auto out_type = get_type_string(out.dtype());
std::vector<std::pair<std::string, std::string>> reduce_kernels = {
{"all_reduce", "allReduce"},
{"col_reduce_small", "colReduceSmall"},
{"col_reduce_looped", "colReduceLooped"},
{"row_reduce_small", "rowReduceSmall"},
{"row_reduce_looped", "rowReduceLooped"},
{"row_reduce_simple", "rowReduceSimple"}};
std::string op = op_type + "<" + out_type + ">";
kernel_source << metal::utils() << metal::reduce_utils() << metal::reduce();
for (auto [func, name] : reduce_kernels) {
kernel_source << get_template_definition(
name + "_" + lib_name, func, in_type, out_type, op);
}
lib = d.get_library(lib_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib);

View File

@@ -37,13 +37,13 @@ struct mlx_atomic<T, enable_if_t<is_metal_atomic<T>>> {
template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
METAL_FUNC T
mlx_atomic_load_explicit(device mlx_atomic<T>* object, uint offset) {
mlx_atomic_load_explicit(device mlx_atomic<T>* object, size_t offset) {
return atomic_load_explicit(&(object[offset].val), memory_order_relaxed);
}
template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
METAL_FUNC void
mlx_atomic_store_explicit(device mlx_atomic<T>* object, T val, uint offset) {
mlx_atomic_store_explicit(device mlx_atomic<T>* object, T val, size_t offset) {
atomic_store_explicit(&(object[offset].val), val, memory_order_relaxed);
}
@@ -51,13 +51,15 @@ template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
METAL_FUNC void mlx_atomic_fetch_and_explicit(
device mlx_atomic<T>* object,
T val,
uint offset) {
size_t offset) {
atomic_fetch_and_explicit(&(object[offset].val), val, memory_order_relaxed);
}
template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
METAL_FUNC void
mlx_atomic_fetch_or_explicit(device mlx_atomic<T>* object, T val, uint offset) {
METAL_FUNC void mlx_atomic_fetch_or_explicit(
device mlx_atomic<T>* object,
T val,
size_t offset) {
atomic_fetch_or_explicit(&(object[offset].val), val, memory_order_relaxed);
}
@@ -65,7 +67,7 @@ template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
METAL_FUNC void mlx_atomic_fetch_min_explicit(
device mlx_atomic<T>* object,
T val,
uint offset) {
size_t offset) {
atomic_fetch_min_explicit(&(object[offset].val), val, memory_order_relaxed);
}
@@ -73,7 +75,7 @@ template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
METAL_FUNC void mlx_atomic_fetch_max_explicit(
device mlx_atomic<T>* object,
T val,
uint offset) {
size_t offset) {
atomic_fetch_max_explicit(&(object[offset].val), val, memory_order_relaxed);
}
@@ -81,7 +83,7 @@ template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
METAL_FUNC void mlx_atomic_fetch_add_explicit(
device mlx_atomic<T>* object,
T val,
uint offset) {
size_t offset) {
atomic_fetch_add_explicit(&(object[offset].val), val, memory_order_relaxed);
}
@@ -89,7 +91,7 @@ template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
METAL_FUNC void mlx_atomic_fetch_mul_explicit(
device mlx_atomic<T>* object,
T val,
uint offset) {
size_t offset) {
T expected = mlx_atomic_load_explicit(object, offset);
while (!mlx_atomic_compare_exchange_weak_explicit(
object, &expected, val * expected, offset)) {
@@ -101,7 +103,7 @@ METAL_FUNC bool mlx_atomic_compare_exchange_weak_explicit(
device mlx_atomic<T>* object,
thread T* expected,
T val,
uint offset) {
size_t offset) {
return atomic_compare_exchange_weak_explicit(
&(object[offset].val),
expected,
@@ -115,7 +117,7 @@ template <>
METAL_FUNC void mlx_atomic_fetch_min_explicit<float>(
device mlx_atomic<float>* object,
float val,
uint offset) {
size_t offset) {
float expected = mlx_atomic_load_explicit(object, offset);
while (val < expected) {
if (mlx_atomic_compare_exchange_weak_explicit(
@@ -130,7 +132,7 @@ template <>
METAL_FUNC void mlx_atomic_fetch_max_explicit<float>(
device mlx_atomic<float>* object,
float val,
uint offset) {
size_t offset) {
float expected = mlx_atomic_load_explicit(object, offset);
while (val > expected) {
if (mlx_atomic_compare_exchange_weak_explicit(
@@ -157,7 +159,7 @@ union uint_or_packed {
template <typename T, typename Op>
struct mlx_atomic_update_helper {
uint operator()(uint_or_packed<T> init, T update, uint elem_offset) {
uint operator()(uint_or_packed<T> init, T update, size_t elem_offset) {
Op op;
init.val[elem_offset] = op(update, init.val[elem_offset]);
return init.bits;
@@ -168,9 +170,9 @@ template <typename T, typename Op>
METAL_FUNC void mlx_atomic_update_and_store(
device mlx_atomic<T>* object,
T update,
uint offset) {
uint pack_offset = offset / packing_size<T>;
uint elem_offset = offset % packing_size<T>;
size_t offset) {
size_t pack_offset = offset / packing_size<T>;
size_t elem_offset = offset % packing_size<T>;
mlx_atomic_update_helper<T, Op> helper;
uint_or_packed<T> expected;
@@ -251,9 +253,9 @@ struct __Min {
template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
METAL_FUNC T
mlx_atomic_load_explicit(device mlx_atomic<T>* object, uint offset) {
uint pack_offset = offset / sizeof(T);
uint elem_offset = offset % sizeof(T);
mlx_atomic_load_explicit(device mlx_atomic<T>* object, size_t offset) {
size_t pack_offset = offset / sizeof(T);
size_t elem_offset = offset % sizeof(T);
uint_or_packed<T> packed_val;
packed_val.bits =
atomic_load_explicit(&(object[pack_offset].val), memory_order_relaxed);
@@ -262,7 +264,7 @@ mlx_atomic_load_explicit(device mlx_atomic<T>* object, uint offset) {
template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
METAL_FUNC void
mlx_atomic_store_explicit(device mlx_atomic<T>* object, T val, uint offset) {
mlx_atomic_store_explicit(device mlx_atomic<T>* object, T val, size_t offset) {
mlx_atomic_update_and_store<T, __None<T>>(object, val, offset);
}
@@ -270,9 +272,9 @@ template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
METAL_FUNC void mlx_atomic_fetch_and_explicit(
device mlx_atomic<T>* object,
T val,
uint offset) {
uint pack_offset = offset / packing_size<T>;
uint elem_offset = offset % packing_size<T>;
size_t offset) {
size_t pack_offset = offset / packing_size<T>;
size_t elem_offset = offset % packing_size<T>;
uint_or_packed<T> identity;
identity.bits = __UINT32_MAX__;
identity.val[elem_offset] = val;
@@ -282,10 +284,12 @@ METAL_FUNC void mlx_atomic_fetch_and_explicit(
}
template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
METAL_FUNC void
mlx_atomic_fetch_or_explicit(device mlx_atomic<T>* object, T val, uint offset) {
uint pack_offset = offset / packing_size<T>;
uint elem_offset = offset % packing_size<T>;
METAL_FUNC void mlx_atomic_fetch_or_explicit(
device mlx_atomic<T>* object,
T val,
size_t offset) {
size_t pack_offset = offset / packing_size<T>;
size_t elem_offset = offset % packing_size<T>;
uint_or_packed<T> identity;
identity.bits = 0;
identity.val[elem_offset] = val;
@@ -298,7 +302,7 @@ template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
METAL_FUNC void mlx_atomic_fetch_min_explicit(
device mlx_atomic<T>* object,
T val,
uint offset) {
size_t offset) {
mlx_atomic_update_and_store<T, __Min<T>>(object, val, offset);
}
@@ -306,7 +310,7 @@ template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
METAL_FUNC void mlx_atomic_fetch_max_explicit(
device mlx_atomic<T>* object,
T val,
uint offset) {
size_t offset) {
mlx_atomic_update_and_store<T, __Max<T>>(object, val, offset);
}
@@ -314,7 +318,7 @@ template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
METAL_FUNC void mlx_atomic_fetch_add_explicit(
device mlx_atomic<T>* object,
T val,
uint offset) {
size_t offset) {
mlx_atomic_update_and_store<T, __Add<T>>(object, val, offset);
}
@@ -322,7 +326,7 @@ template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
METAL_FUNC void mlx_atomic_fetch_mul_explicit(
device mlx_atomic<T>* object,
T val,
uint offset) {
size_t offset) {
mlx_atomic_update_and_store<T, __Mul<T>>(object, val, offset);
}
@@ -331,7 +335,7 @@ METAL_FUNC bool mlx_atomic_compare_exchange_weak_explicit(
device mlx_atomic<T>* object,
thread uint* expected,
uint val,
uint offset) {
size_t offset) {
return atomic_compare_exchange_weak_explicit(
&(object[offset].val),
expected,

View File

@@ -23,6 +23,8 @@ struct complex64_t {
// Constructors
constexpr complex64_t(float real, float imag) : real(real), imag(imag) {};
constexpr complex64_t() : real(0), imag(0) {};
constexpr complex64_t() threadgroup : real(0), imag(0) {};
// Conversions to complex64_t
template <

View File

@@ -9,7 +9,8 @@
#endif
static MTL_CONST constexpr int MAX_REDUCE_SPECIALIZED_DIMS = 4;
static MTL_CONST constexpr int REDUCE_N_READS = 16;
static MTL_CONST constexpr int REDUCE_N_READS = 4;
static MTL_CONST constexpr int REDUCE_N_WRITES = 4;
static MTL_CONST constexpr int SOFTMAX_N_READS = 4;
static MTL_CONST constexpr int RMS_N_READS = 4;
static MTL_CONST constexpr int RMS_LOOPED_LIMIT = 4096;

View File

@@ -14,32 +14,36 @@ METAL_FUNC void gather_impl(
const constant int* slice_sizes [[buffer(5)]],
const constant int* axes [[buffer(6)]],
const thread Indices<IdxT, NIDX>& indices,
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
auto ind_idx = index.x;
auto ind_offset = index.y;
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
size_t src_idx = 0;
for (int i = 0; i < NIDX; ++i) {
size_t idx_loc;
if (IDX_NDIM == 0) {
idx_loc = 0;
} else if (IDX_NDIM == 1) {
idx_loc = ind_idx * indices.strides[indices.ndim * i];
idx_loc = index.x * indices.strides[indices.ndim * i];
} else {
idx_loc = elem_to_loc(
ind_idx,
&indices.shapes[indices.ndim * i],
&indices.strides[indices.ndim * i],
indices.ndim);
idx_loc = index.x * indices.strides[indices.ndim * i];
idx_loc += elem_to_loc(
index.y,
&indices.shapes[indices.ndim * i + 1],
&indices.strides[indices.ndim * i + 1],
indices.ndim - 1);
}
auto ax = axes[i];
auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], src_shape[ax]);
src_idx += idx_val * src_strides[ax];
}
auto src_offset = elem_to_loc(ind_offset, slice_sizes, src_strides, src_ndim);
auto src_offset = elem_to_loc(index.z, slice_sizes, src_strides, src_ndim);
size_t out_idx = index.y + static_cast<size_t>(grid_dim.y) * index.x;
size_t out_idx = index.z;
if (IDX_NDIM == 1) {
out_idx += static_cast<size_t>(grid_dim.z) * index.x;
} else if (IDX_NDIM >= 2) {
out_idx +=
grid_dim.z * (index.x * static_cast<size_t>(grid_dim.y) + index.y);
}
out[out_idx] = src[src_offset + src_idx];
}

View File

@@ -1,4 +1,5 @@
#pragma once
#include "mlx/backend/metal/kernels/reduction/reduce_all.h"
#include "mlx/backend/metal/kernels/reduction/reduce_col.h"
#include "mlx/backend/metal/kernels/reduction/reduce_init.h"
#include "mlx/backend/metal/kernels/reduction/reduce_row.h"

View File

@@ -8,7 +8,6 @@
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/atomic.h"
#include "mlx/backend/metal/kernels/reduction/ops.h"
#include "mlx/backend/metal/kernels/reduction/reduce_init.h"
#include "mlx/backend/metal/kernels/reduce.h"
#define instantiate_reduce_helper_floats(inst_f, name, op) \
@@ -28,7 +27,8 @@
#define instantiate_reduce_helper_64b(inst_f, name, op) \
inst_f(name, int64, int64_t, op) \
inst_f(name, uint64, uint64_t, op)
inst_f(name, uint64, uint64_t, op) \
inst_f(name, complex64, complex64_t, op)
#define instantiate_reduce_helper_types(inst_f, name, op) \
instantiate_reduce_helper_floats(inst_f, name, op) \
@@ -83,9 +83,9 @@
op)
#define instantiate_init_reduce(name, otype, op) \
template [[host_name("i_reduce_" #name)]] [[kernel]] void \
init_reduce<otype, op>( \
device otype * out [[buffer(1)]], uint tid [[thread_position_in_grid]]);
instantiate_kernel("init_reduce_" #name, \
init_reduce, \
otype, op)
#define instantiate_init_reduce_helper(name, tname, type, op) \
instantiate_init_reduce(name##tname, type, op<type>)
@@ -97,40 +97,15 @@ instantiate_init_reduce(andbool_, bool, And<bool>)
instantiate_init_reduce(orbool_, bool, Or<bool>)
#define instantiate_all_reduce(name, itype, otype, op) \
template [[host_name("all_reduce_" #name)]] [[kernel]] void \
all_reduce<itype, otype, op>( \
const device itype* in [[buffer(0)]], \
device mlx_atomic<otype>* out [[buffer(1)]], \
const device size_t& in_size [[buffer(2)]], \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint grid_size [[threads_per_grid]], \
uint simd_per_group [[simdgroups_per_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_all_reduce_no_atomics(name, itype, otype, op) \
template [[host_name("allNoAtomics_reduce_" #name)]] [[kernel]] void \
all_reduce_no_atomics<itype, otype, op>( \
const device itype* in [[buffer(0)]], \
device otype* out [[buffer(1)]], \
const device size_t& in_size [[buffer(2)]], \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint grid_size [[threads_per_grid]], \
uint simd_per_group [[simdgroups_per_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
uint thread_group_id [[threadgroup_position_in_grid]]);
instantiate_kernel("allReduce_" #name, \
all_reduce, \
itype, otype, op)
#define instantiate_same_all_reduce_helper(name, tname, type, op) \
instantiate_all_reduce(name##tname, type, type, op<type>)
#define instantiate_same_all_reduce_na_helper(name, tname, type, op) \
instantiate_all_reduce_no_atomics(name##tname, type, type, op<type>)
instantiate_reduce_ops(instantiate_same_all_reduce_helper, instantiate_reduce_helper_types)
instantiate_reduce_ops(instantiate_same_all_reduce_na_helper, instantiate_reduce_helper_64b)
instantiate_reduce_ops(instantiate_same_all_reduce_helper, instantiate_reduce_helper_64b)
instantiate_reduce_from_types(instantiate_all_reduce, and, bool, And<bool>)
instantiate_reduce_from_types(instantiate_all_reduce, or, bool, Or<bool>)
@@ -138,153 +113,72 @@ instantiate_reduce_from_types(instantiate_all_reduce, or, bool, Or<bool>)
// special case bool with larger output type
instantiate_all_reduce(sumbool_, bool, uint32_t, Sum<uint32_t>)
#define instantiate_col_reduce_general(name, itype, otype, op) \
template [[host_name("colGeneral_reduce_" #name)]] [[kernel]] void \
col_reduce_general<itype, otype, op>( \
const device itype* in [[buffer(0)]], \
device mlx_atomic<otype>* out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& reduction_stride [[buffer(3)]], \
const constant size_t& out_size [[buffer(4)]], \
const constant int* shape [[buffer(5)]], \
const constant size_t* strides [[buffer(6)]], \
const constant int& ndim [[buffer(7)]], \
threadgroup otype* local_data [[threadgroup(0)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint3 lsize [[threads_per_threadgroup]]);
#define instantiate_col_reduce_small(name, itype, otype, op, dim) \
instantiate_kernel("colReduceSmall_" #dim "_reduce_" #name, \
col_reduce_small, \
itype, otype, op, dim)
#define instantiate_col_reduce_general_no_atomics(name, itype, otype, op) \
template \
[[host_name("colGeneralNoAtomics_reduce_" #name)]] [[kernel]] void \
col_reduce_general_no_atomics<itype, otype, op>( \
const device itype* in [[buffer(0)]], \
device otype* out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& reduction_stride [[buffer(3)]], \
const constant size_t& out_size [[buffer(4)]], \
const constant int* shape [[buffer(5)]], \
const constant size_t* strides [[buffer(6)]], \
const constant int& ndim [[buffer(7)]], \
threadgroup otype* local_data [[threadgroup(0)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint3 gid [[thread_position_in_grid]], \
uint3 lsize [[threads_per_threadgroup]], \
uint3 gsize [[threads_per_grid]]);
#define instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, bm, bn) \
instantiate_kernel("colReduceLooped_" #dim "_" #bm "_" #bn "_reduce_" #name, \
col_reduce_looped, \
itype, otype, op, dim, bm, bn)
#define instantiate_col_reduce_small(name, itype, otype, op) \
template [[host_name("colSmall_reduce_" #name)]] [[kernel]] void \
col_reduce_small<itype, otype, op>( \
const device itype* in [[buffer(0)]], \
device otype* out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& reduction_stride [[buffer(3)]], \
const constant size_t& out_size [[buffer(4)]], \
const constant int* shape [[buffer(5)]], \
const constant size_t* strides [[buffer(6)]], \
const constant int& ndim [[buffer(7)]], \
const constant size_t& non_col_reductions [[buffer(8)]], \
const constant int* non_col_shapes [[buffer(9)]], \
const constant size_t* non_col_strides [[buffer(10)]], \
const constant int& non_col_ndim [[buffer(11)]], \
uint tid [[thread_position_in_grid]]);
#define instantiate_col_reduce_looped(name, itype, otype, op, dim) \
instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, 8, 128) \
instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, 32, 32)
#define instantiate_same_col_reduce_helper(name, tname, type, op) \
instantiate_col_reduce_small(name ##tname, type, type, op<type>) \
instantiate_col_reduce_general(name ##tname, type, type, op<type>)
#define instantiate_col_reduce_general(name, itype, otype, op) \
instantiate_col_reduce_small(name, itype, otype, op, 0) \
instantiate_col_reduce_small(name, itype, otype, op, 1) \
instantiate_col_reduce_small(name, itype, otype, op, 2) \
instantiate_col_reduce_small(name, itype, otype, op, 3) \
instantiate_col_reduce_small(name, itype, otype, op, 4) \
instantiate_col_reduce_looped(name, itype, otype, op, 0) \
instantiate_col_reduce_looped(name, itype, otype, op, 1) \
instantiate_col_reduce_looped(name, itype, otype, op, 2) \
instantiate_col_reduce_looped(name, itype, otype, op, 3) \
instantiate_col_reduce_looped(name, itype, otype, op, 4)
#define instantiate_same_col_reduce_na_helper(name, tname, type, op) \
instantiate_col_reduce_small(name ##tname, type, type, op<type>) \
instantiate_col_reduce_general_no_atomics(name ##tname, type, type, op<type>)
#define instantiate_same_col_reduce_helper(name, tname, type, op) \
instantiate_col_reduce_general(name##tname, type, type, op<type>)
instantiate_reduce_ops(instantiate_same_col_reduce_helper, instantiate_reduce_helper_types)
instantiate_reduce_ops(instantiate_same_col_reduce_na_helper, instantiate_reduce_helper_64b)
instantiate_reduce_ops(instantiate_same_col_reduce_helper, instantiate_reduce_helper_64b)
instantiate_col_reduce_general(sumbool_, bool, uint32_t, Sum<uint32_t>)
instantiate_reduce_from_types(instantiate_col_reduce_general, and, bool, And<bool>)
instantiate_reduce_from_types(instantiate_col_reduce_general, or, bool, Or<bool>)
instantiate_col_reduce_small(sumbool_, bool, uint32_t, Sum<uint32_t>)
instantiate_reduce_from_types(instantiate_col_reduce_small, and, bool, And<bool>)
instantiate_reduce_from_types(instantiate_col_reduce_small, or, bool, Or<bool>)
#define instantiate_row_reduce_small(name, itype, otype, op, dim) \
instantiate_kernel("rowReduceSmall_" #dim "_reduce_" #name, \
row_reduce_small, \
itype, otype, op, dim)
#define instantiate_row_reduce_small(name, itype, otype, op) \
template [[host_name("rowGeneralSmall_reduce_" #name)]] [[kernel]] void \
row_reduce_general_small<itype, otype, op>( \
const device itype* in [[buffer(0)]], \
device otype* out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& out_size [[buffer(3)]], \
const constant size_t& non_row_reductions [[buffer(4)]], \
const constant int* shape [[buffer(5)]], \
const constant size_t* strides [[buffer(6)]], \
const constant int& ndim [[buffer(7)]], \
uint lid [[thread_position_in_grid]]); \
template [[host_name("rowGeneralMed_reduce_" #name)]] [[kernel]] void \
row_reduce_general_med<itype, otype, op>( \
const device itype* in [[buffer(0)]], \
device otype* out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& out_size [[buffer(3)]], \
const constant size_t& non_row_reductions [[buffer(4)]], \
const constant int* shape [[buffer(5)]], \
const constant size_t* strides [[buffer(6)]], \
const constant int& ndim [[buffer(7)]], \
uint tid [[threadgroup_position_in_grid]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_per_group [[dispatch_simdgroups_per_threadgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_row_reduce_looped(name, itype, otype, op, dim) \
instantiate_kernel("rowReduceLooped_" #dim "_reduce_" #name, \
row_reduce_looped, \
itype, otype, op, dim)
#define instantiate_row_reduce_general(name, itype, otype, op) \
instantiate_row_reduce_small(name, itype, otype, op) \
template \
[[host_name("rowGeneral_reduce_" #name)]] [[kernel]] void \
row_reduce_general<itype, otype, op>( \
const device itype* in [[buffer(0)]], \
device mlx_atomic<otype>* out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& out_size [[buffer(3)]], \
const constant size_t& non_row_reductions [[buffer(4)]], \
const constant int* shape [[buffer(5)]], \
const constant size_t* strides [[buffer(6)]], \
const constant int& ndim [[buffer(7)]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint3 lsize [[threads_per_threadgroup]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_per_group [[simdgroups_per_threadgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_row_reduce_general_no_atomics(name, itype, otype, op) \
instantiate_row_reduce_small(name, itype, otype, op) \
template \
[[host_name("rowGeneralNoAtomics_reduce_" #name)]] [[kernel]] void \
row_reduce_general_no_atomics<itype, otype, op>( \
const device itype* in [[buffer(0)]], \
device otype* out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& out_size [[buffer(3)]], \
const constant size_t& non_row_reductions [[buffer(4)]], \
const constant int* shape [[buffer(5)]], \
const constant size_t* strides [[buffer(6)]], \
const constant int& ndim [[buffer(7)]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint3 lsize [[threads_per_threadgroup]], \
uint3 gsize [[threads_per_grid]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_per_group [[simdgroups_per_threadgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
instantiate_row_reduce_small(name, itype, otype, op, 0) \
instantiate_row_reduce_small(name, itype, otype, op, 1) \
instantiate_row_reduce_small(name, itype, otype, op, 2) \
instantiate_row_reduce_small(name, itype, otype, op, 3) \
instantiate_row_reduce_small(name, itype, otype, op, 4) \
instantiate_row_reduce_looped(name, itype, otype, op, 0) \
instantiate_row_reduce_looped(name, itype, otype, op, 1) \
instantiate_row_reduce_looped(name, itype, otype, op, 2) \
instantiate_row_reduce_looped(name, itype, otype, op, 3) \
instantiate_row_reduce_looped(name, itype, otype, op, 4) \
instantiate_kernel("rowReduceSimple_" #name, \
row_reduce_simple, \
itype, otype, op)
#define instantiate_same_row_reduce_helper(name, tname, type, op) \
instantiate_row_reduce_general(name##tname, type, type, op<type>)
#define instantiate_same_row_reduce_na_helper(name, tname, type, op) \
instantiate_row_reduce_general_no_atomics(name##tname, type, type, op<type>)
instantiate_reduce_ops(instantiate_same_row_reduce_helper, instantiate_reduce_helper_types)
instantiate_reduce_ops(instantiate_same_row_reduce_na_helper, instantiate_reduce_helper_64b)
instantiate_reduce_ops(instantiate_same_row_reduce_helper, instantiate_reduce_helper_64b)
instantiate_reduce_from_types(instantiate_row_reduce_general, and, bool, And<bool>)
instantiate_reduce_from_types(instantiate_row_reduce_general, or, bool, Or<bool>)

View File

@@ -5,6 +5,20 @@
#include <metal_atomic>
#include <metal_simdgroup>
#define DEFINE_SIMD_REDUCE() \
template <typename T, metal::enable_if_t<sizeof(T) < 8, bool> = true> \
T simd_reduce(T val) { \
return simd_reduce_impl(val); \
} \
\
template <typename T, metal::enable_if_t<sizeof(T) == 8, bool> = true> \
T simd_reduce(T val) { \
for (short i = simd_size / 2; i > 0; i /= 2) { \
val = operator()(val, simd_shuffle_down(val, i)); \
} \
return val; \
}
static constant constexpr const uint8_t simd_size = 32;
union bool4_or_uint {
@@ -14,14 +28,16 @@ union bool4_or_uint {
struct None {
template <typename T>
void atomic_update(device mlx_atomic<T>* out, T val, uint offset = 0) {
void atomic_update(device mlx_atomic<T>* out, T val, size_t offset = 0) {
mlx_atomic_store_explicit(out, val, offset);
}
};
template <typename U = bool>
struct And {
bool simd_reduce(bool val) {
DEFINE_SIMD_REDUCE()
bool simd_reduce_impl(bool val) {
return simd_all(val);
}
@@ -31,7 +47,7 @@ struct And {
device mlx_atomic<unsigned int>* out,
bool val,
int elem_idx,
int offset = 0) {
size_t offset = 0) {
if (!val) {
bool4_or_uint update;
update.b = {true, true, true, true};
@@ -40,7 +56,8 @@ struct And {
}
}
void atomic_update(device mlx_atomic<bool>* out, bool val, uint offset = 0) {
void
atomic_update(device mlx_atomic<bool>* out, bool val, size_t offset = 0) {
if (!val) {
mlx_atomic_store_explicit(out, val, offset);
}
@@ -59,7 +76,9 @@ struct And {
template <typename U = bool>
struct Or {
bool simd_reduce(bool val) {
DEFINE_SIMD_REDUCE()
bool simd_reduce_impl(bool val) {
return simd_any(val);
}
@@ -68,8 +87,8 @@ struct Or {
void atomic_update(
device mlx_atomic<unsigned int>* out,
bool val,
uint elem_idx,
uint offset = 0) {
int elem_idx,
size_t offset = 0) {
if (val) {
bool4_or_uint update;
update.b = {false, false, false, false};
@@ -78,7 +97,8 @@ struct Or {
}
}
void atomic_update(device mlx_atomic<bool>* out, bool val, uint offset = 0) {
void
atomic_update(device mlx_atomic<bool>* out, bool val, size_t offset = 0) {
if (val) {
mlx_atomic_store_explicit(out, val, offset);
}
@@ -97,15 +117,17 @@ struct Or {
template <typename U>
struct Sum {
DEFINE_SIMD_REDUCE()
template <typename T>
T simd_reduce(T val) {
T simd_reduce_impl(T val) {
return simd_sum(val);
}
static constexpr constant U init = U(0);
template <typename T>
void atomic_update(device mlx_atomic<T>* out, T val, uint offset = 0) {
void atomic_update(device mlx_atomic<T>* out, T val, size_t offset = 0) {
mlx_atomic_fetch_add_explicit(out, val, offset);
}
@@ -117,15 +139,17 @@ struct Sum {
template <typename U>
struct Prod {
DEFINE_SIMD_REDUCE()
template <typename T>
T simd_reduce(T val) {
T simd_reduce_impl(T val) {
return simd_product(val);
}
static constexpr constant U init = U(1);
template <typename T>
void atomic_update(device mlx_atomic<T>* out, T val, uint offset = 0) {
void atomic_update(device mlx_atomic<T>* out, T val, size_t offset = 0) {
mlx_atomic_fetch_mul_explicit(out, val, offset);
}
@@ -137,15 +161,17 @@ struct Prod {
template <typename U>
struct Min {
DEFINE_SIMD_REDUCE()
template <typename T>
T simd_reduce(T val) {
T simd_reduce_impl(T val) {
return simd_min(val);
}
static constexpr constant U init = Limits<U>::max;
template <typename T>
void atomic_update(device mlx_atomic<T>* out, T val, uint offset = 0) {
void atomic_update(device mlx_atomic<T>* out, T val, size_t offset = 0) {
mlx_atomic_fetch_min_explicit(out, val, offset);
}
@@ -157,15 +183,17 @@ struct Min {
template <typename U>
struct Max {
DEFINE_SIMD_REDUCE()
template <typename T>
T simd_reduce(T val) {
T simd_reduce_impl(T val) {
return simd_max(val);
}
static constexpr constant U init = Limits<U>::min;
template <typename T>
void atomic_update(device mlx_atomic<T>* out, T val, uint offset = 0) {
void atomic_update(device mlx_atomic<T>* out, T val, size_t offset = 0) {
mlx_atomic_fetch_max_explicit(out, val, offset);
}

View File

@@ -1,135 +1,61 @@
// Copyright © 2023-2024 Apple Inc.
///////////////////////////////////////////////////////////////////////////////
// All reduce helper
///////////////////////////////////////////////////////////////////////////////
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
METAL_FUNC U per_thread_all_reduce(
const device T* in,
const device size_t& in_size,
uint gid,
uint grid_size) {
Op op;
U total_val = Op::init;
if (gid * N_READS < in_size) {
in += gid * N_READS;
int r = 0;
for (; r < (int)ceildiv(in_size, grid_size * N_READS) - 1; r++) {
U vals[N_READS] = {op.init};
for (int i = 0; i < N_READS; i++) {
vals[i] = static_cast<U>(in[i]);
}
for (int i = 0; i < N_READS; i++) {
total_val = op(vals[i], total_val);
}
in += grid_size * N_READS;
}
// Separate case for the last set as we close the reduction size
size_t curr_idx = (gid + r * (size_t)grid_size) * N_READS;
if (curr_idx < in_size) {
int max_reads = in_size - curr_idx;
T vals[N_READS];
for (int i = 0, idx = 0; i < N_READS; i++, idx++) {
idx = idx < max_reads ? idx : max_reads - 1;
vals[i] = in[idx];
}
for (int i = 0; i < N_READS; i++) {
U val = i < max_reads ? vals[i] : Op::init;
total_val = op(static_cast<U>(val), total_val);
}
}
}
return total_val;
}
///////////////////////////////////////////////////////////////////////////////
// All reduce kernel
///////////////////////////////////////////////////////////////////////////////
// NB: This kernel assumes threads_per_threadgroup is at most
// 1024. This way with a simd_size of 32, we are guaranteed to
// complete the reduction in two steps of simd-level reductions.
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
[[kernel]] void all_reduce(
const device T* in [[buffer(0)]],
device mlx_atomic<U>* out [[buffer(1)]],
const device size_t& in_size [[buffer(2)]],
uint gid [[thread_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
uint grid_size [[threads_per_grid]],
device U* out [[buffer(1)]],
const constant size_t& in_size [[buffer(2)]],
const constant size_t& row_size [[buffer(3)]],
uint3 gid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint3 lsize [[threads_per_threadgroup]],
uint simd_per_group [[simdgroups_per_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
Op op;
threadgroup U local_vals[simd_size];
threadgroup U shared_vals[simd_size];
U total_val =
per_thread_all_reduce<T, U, Op, N_READS>(in, in_size, gid, grid_size);
U total = Op::init;
int64_t start_idx = gid.y * row_size;
int64_t actual_row =
(start_idx + row_size <= in_size) ? row_size : in_size - start_idx;
int64_t blocks = actual_row / (lsize.x * N_READS);
int extra = actual_row - blocks * (lsize.x * N_READS);
extra -= lid.x * N_READS;
start_idx += lid.x * N_READS;
in += start_idx;
if (extra >= N_READS) {
blocks++;
extra = 0;
}
for (int64_t b = 0; b < blocks; b++) {
for (int i = 0; i < N_READS; i++) {
total = op(static_cast<U>(in[i]), total);
}
in += lsize.x * N_READS;
}
if (extra > 0) {
for (int i = 0; i < extra; i++) {
total = op(static_cast<U>(in[i]), total);
}
}
// Reduction within simd group
total_val = op.simd_reduce(total_val);
if (simd_lane_id == 0) {
local_vals[simd_group_id] = total_val;
total = op.simd_reduce(total);
if (simd_per_group > 1) {
if (simd_lane_id == 0) {
shared_vals[simd_group_id] = total;
}
// Reduction within thread group
threadgroup_barrier(mem_flags::mem_threadgroup);
total = lid.x < simd_per_group ? shared_vals[lid.x] : op.init;
total = op.simd_reduce(total);
}
// Reduction within thread group
threadgroup_barrier(mem_flags::mem_threadgroup);
total_val = lid < simd_per_group ? local_vals[lid] : op.init;
total_val = op.simd_reduce(total_val);
// Reduction across threadgroups
if (lid == 0) {
op.atomic_update(out, total_val);
}
}
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
[[kernel]] void all_reduce_no_atomics(
const device T* in [[buffer(0)]],
device U* out [[buffer(1)]],
const device size_t& in_size [[buffer(2)]],
uint gid [[thread_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
uint grid_size [[threads_per_grid]],
uint simd_per_group [[simdgroups_per_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint thread_group_id [[threadgroup_position_in_grid]]) {
Op op;
threadgroup U local_vals[simd_size];
U total_val =
per_thread_all_reduce<T, U, Op, N_READS>(in, in_size, gid, grid_size);
// Reduction within simd group (simd_add isn't supported for uint64/int64
// types)
for (uint16_t lane_offset = simd_size / 2; lane_offset > 0;
lane_offset /= 2) {
total_val = op(total_val, simd_shuffle_down(total_val, lane_offset));
}
// Write simd group reduction results to local memory
if (simd_lane_id == 0) {
local_vals[simd_group_id] = total_val;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Reduction of simdgroup reduction results within threadgroup.
total_val = lid < simd_per_group ? local_vals[lid] : op.init;
for (uint16_t lane_offset = simd_size / 2; lane_offset > 0;
lane_offset /= 2) {
total_val = op(total_val, simd_shuffle_down(total_val, lane_offset));
}
// Reduction across threadgroups
if (lid == 0) {
out[thread_group_id] = total_val;
if (lid.x == 0) {
out[gid.y] = total;
}
}

View File

@@ -1,165 +1,337 @@
// Copyright © 2023-2024 Apple Inc.
///////////////////////////////////////////////////////////////////////////////
// Small column reduce kernel
///////////////////////////////////////////////////////////////////////////////
template <typename T, typename U, typename Op>
template <
typename T,
typename U,
typename Op,
int NDIMS = 0,
int N_READS = REDUCE_N_READS>
[[kernel]] void col_reduce_small(
const device T* in [[buffer(0)]],
device U* out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& reduction_stride [[buffer(3)]],
const constant size_t& out_size [[buffer(4)]],
const constant int* shape [[buffer(5)]],
const constant size_t* strides [[buffer(6)]],
const constant int& ndim [[buffer(7)]],
const constant size_t& non_col_reductions [[buffer(8)]],
const constant int* non_col_shapes [[buffer(9)]],
const constant size_t* non_col_strides [[buffer(10)]],
const constant int& non_col_ndim [[buffer(11)]],
uint tid [[thread_position_in_grid]]) {
// Appease the compiler
(void)out_size;
const constant int* shape [[buffer(4)]],
const constant size_t* strides [[buffer(5)]],
const constant int& ndim [[buffer(6)]],
const constant int* reduce_shape [[buffer(7)]],
const constant size_t* reduce_strides [[buffer(8)]],
const constant int& reduce_ndim [[buffer(9)]],
const constant size_t& non_col_reductions [[buffer(10)]],
uint3 gid [[threadgroup_position_in_grid]],
uint3 gsize [[threadgroups_per_grid]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[thread_position_in_grid]],
uint3 tsize [[threads_per_grid]]) {
Op op;
U total_val = Op::init;
looped_elem_to_loc<NDIMS> loop;
const device T* row;
auto out_idx = tid;
// Case 1: Small row small column
if (reduction_size * non_col_reductions < 64 && reduction_stride < 32) {
U totals[31];
for (int i = 0; i < 31; i++) {
totals[i] = Op::init;
}
in += elem_to_loc(
out_idx,
shape + non_col_ndim,
strides + non_col_ndim,
ndim - non_col_ndim);
short stride = reduction_stride;
short size = reduction_size;
short blocks = stride / N_READS;
short extra = stride - blocks * N_READS;
for (uint i = 0; i < non_col_reductions; i++) {
size_t in_idx =
elem_to_loc(i, non_col_shapes, non_col_strides, non_col_ndim);
size_t out_idx = tid.x + tsize.y * size_t(tid.y);
in += elem_to_loc(out_idx, shape, strides, ndim);
for (uint j = 0; j < reduction_size; j++, in_idx += reduction_stride) {
U val = static_cast<U>(in[in_idx]);
total_val = op(total_val, val);
for (uint r = 0; r < non_col_reductions; r++) {
row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim);
for (short i = 0; i < size; i++) {
for (short j = 0; j < blocks; j++) {
for (short k = 0; k < N_READS; k++) {
totals[j * N_READS + k] =
op(totals[j * N_READS + k],
static_cast<U>(row[i * stride + j * N_READS + k]));
}
}
for (short k = 0; k < extra; k++) {
totals[blocks * N_READS + k] =
op(totals[blocks * N_READS + k],
static_cast<U>(row[i * stride + blocks * N_READS + k]));
}
}
loop.next(reduce_shape, reduce_strides);
}
out += out_idx * reduction_stride;
for (short j = 0; j < stride; j++) {
out[j] = totals[j];
}
}
out[out_idx] = total_val;
}
// Case 2: Long row small column
else if (reduction_size * non_col_reductions < 32) {
U totals[N_READS];
for (int i = 0; i < N_READS; i++) {
totals[i] = Op::init;
}
///////////////////////////////////////////////////////////////////////////////
// Column reduce helper
///////////////////////////////////////////////////////////////////////////////
short size = reduction_size;
size_t offset = size_t(tid.x) * N_READS;
bool safe = offset + N_READS <= reduction_stride;
short extra = reduction_stride - offset;
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
METAL_FUNC U _contiguous_strided_reduce(
const device T* in,
threadgroup U* local_data,
uint in_idx,
uint reduction_size,
uint reduction_stride,
uint2 tid,
uint2 lid,
uint2 lsize) {
Op op;
U total_val = Op::init;
size_t out_idx = tid.y + tsize.z * size_t(tid.z);
in += elem_to_loc(out_idx, shape, strides, ndim) + offset;
uint base_offset = (tid.y * lsize.y + lid.y) * N_READS;
for (uint r = 0; r < N_READS && (base_offset + r) < reduction_size; r++) {
uint offset = base_offset + r;
total_val =
op(static_cast<U>(total_val), in[in_idx + offset * reduction_stride]);
}
local_data[lsize.y * lid.x + lid.y] = total_val;
threadgroup_barrier(mem_flags::mem_threadgroup);
for (uint r = 0; r < non_col_reductions; r++) {
row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim);
U val = Op::init;
if (lid.y == 0) {
// Perform reduction across columns in thread group
for (uint i = 0; i < lsize.y; i++) {
val = op(val, local_data[lsize.y * lid.x + i]);
if (safe) {
for (short i = 0; i < size; i++) {
for (short j = 0; j < N_READS; j++) {
totals[j] =
op(static_cast<U>(row[i * reduction_stride + j]), totals[j]);
}
}
} else {
for (short i = 0; i < size; i++) {
for (short j = 0; j < extra; j++) {
totals[j] =
op(static_cast<U>(row[i * reduction_stride + j]), totals[j]);
}
}
}
loop.next(reduce_shape, reduce_strides);
}
out += out_idx * reduction_stride + offset;
if (safe) {
for (short i = 0; i < N_READS; i++) {
out[i] = totals[i];
}
} else {
for (short i = 0; i < extra; i++) {
out[i] = totals[i];
}
}
}
return val;
}
// Case 3: Long row medium column
else {
threadgroup U shared_vals[1024];
U totals[N_READS];
for (int i = 0; i < N_READS; i++) {
totals[i] = Op::init;
}
///////////////////////////////////////////////////////////////////////////////
// Column reduce kernel
///////////////////////////////////////////////////////////////////////////////
short stride = reduction_stride;
short lid = simd_group_id * simd_size + simd_lane_id;
short2 tile((stride + N_READS - 1) / N_READS, 32);
short2 offset((lid % tile.x) * N_READS, lid / tile.x);
short sm_stride = tile.x * N_READS;
bool safe = offset.x + N_READS <= stride;
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
[[kernel]] void col_reduce_general(
const device T* in [[buffer(0)]],
device mlx_atomic<U>* out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& reduction_stride [[buffer(3)]],
const constant size_t& out_size [[buffer(4)]],
const constant int* shape [[buffer(5)]],
const constant size_t* strides [[buffer(6)]],
const constant int& ndim [[buffer(7)]],
threadgroup U* local_data [[threadgroup(0)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint3 lsize [[threads_per_threadgroup]]) {
auto out_idx = tid.x * lsize.x + lid.x;
auto in_idx = elem_to_loc(out_idx + tid.z * out_size, shape, strides, ndim);
size_t out_idx = gid.y + gsize.y * size_t(gid.z);
in += elem_to_loc(out_idx, shape, strides, ndim) + offset.x;
Op op;
if (out_idx < out_size) {
U val = _contiguous_strided_reduce<T, U, Op, N_READS>(
in,
local_data,
in_idx,
reduction_size,
reduction_stride,
tid.xy,
lid.xy,
lsize.xy);
// Read cooperatively and contiguously and aggregate the partial results.
size_t total = non_col_reductions * reduction_size;
loop.next(offset.y, reduce_shape, reduce_strides);
for (size_t r = offset.y; r < total; r += simd_size) {
row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim);
// Write out reduction results generated by threadgroups working on specific
// output element, contiguously.
if (lid.y == 0) {
op.atomic_update(out, val, out_idx);
if (safe) {
for (int i = 0; i < N_READS; i++) {
totals[i] = op(static_cast<U>(row[i]), totals[i]);
}
} else {
U vals[N_READS];
for (int i = 0; i < N_READS; i++) {
vals[i] = (offset.x + i < stride) ? static_cast<U>(row[i]) : op.init;
}
for (int i = 0; i < N_READS; i++) {
totals[i] = op(vals[i], totals[i]);
}
}
loop.next(simd_size, reduce_shape, reduce_strides);
}
// Each thread holds N_READS partial results but the simdgroups are not
// aligned to do the reduction across the simdgroup so we write our results
// in the shared memory and read them back according to the simdgroup.
for (int i = 0; i < N_READS; i++) {
shared_vals[offset.y * sm_stride + offset.x + i] = totals[i];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
for (int i = 0; i < N_READS; i++) {
totals[i] = op.simd_reduce(
shared_vals[simd_lane_id * sm_stride + simd_group_id * N_READS + i]);
}
// Write the output.
if (simd_lane_id == 0) {
short column = simd_group_id * N_READS;
out += out_idx * reduction_stride + column;
if (column + N_READS <= stride) {
for (int i = 0; i < N_READS; i++) {
out[i] = totals[i];
}
} else {
for (int i = 0; column + i < stride; i++) {
out[i] = totals[i];
}
}
}
}
}
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
[[kernel]] void col_reduce_general_no_atomics(
/**
* Our approach is the following simple looped approach:
* 1. Each thread keeps running totals for BN / n_simdgroups outputs.
* 2. Load a tile BM, BN in registers and accumulate in the running totals
* 3. Move ahead by BM steps until the column axis and the non column
* reductions are exhausted.
* 6. If BM == 32 then transpose in SM and simd reduce the running totals.
* Otherwise write in shared memory and BN threads accumulate the running
* totals with a loop.
* 7. Write them to the output
*/
template <
typename T,
typename U,
typename Op,
int NDIMS = 0,
int BM = 8,
int BN = 128>
[[kernel]] void col_reduce_looped(
const device T* in [[buffer(0)]],
device U* out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& reduction_stride [[buffer(3)]],
const constant size_t& out_size [[buffer(4)]],
const constant int* shape [[buffer(5)]],
const constant size_t* strides [[buffer(6)]],
const constant int& ndim [[buffer(7)]],
threadgroup U* local_data [[threadgroup(0)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint3 gid [[thread_position_in_grid]],
uint3 lsize [[threads_per_threadgroup]],
uint3 gsize [[threads_per_grid]]) {
auto out_idx = tid.x * lsize.x + lid.x;
auto in_idx = elem_to_loc(out_idx + tid.z * out_size, shape, strides, ndim);
const constant int* shape [[buffer(4)]],
const constant size_t* strides [[buffer(5)]],
const constant int& ndim [[buffer(6)]],
const constant int* reduce_shape [[buffer(7)]],
const constant size_t* reduce_strides [[buffer(8)]],
const constant int& reduce_ndim [[buffer(9)]],
const constant size_t& non_col_reductions [[buffer(10)]],
uint3 gid [[threadgroup_position_in_grid]],
uint3 gsize [[threadgroups_per_grid]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
Op op;
constexpr int n_simdgroups = 4;
constexpr short tgp_size = n_simdgroups * simd_size;
constexpr short n_reads = (BM * BN) / tgp_size;
constexpr short n_read_blocks = BN / n_reads;
if (out_idx < out_size) {
U val = _contiguous_strided_reduce<T, U, Op, N_READS>(
in,
local_data,
in_idx,
reduction_size,
reduction_stride,
tid.xy,
lid.xy,
lsize.xy);
threadgroup U shared_vals[BN * BM];
U totals[n_reads];
looped_elem_to_loc<NDIMS> loop;
const device T* row;
// Write out reduction results generated by threadgroups working on specific
// output element, contiguously.
if (lid.y == 0) {
uint tgsize_y = ceildiv(gsize.y, lsize.y);
uint tgsize_z = ceildiv(gsize.z, lsize.z);
out[tgsize_y * tgsize_z * gid.x + tgsize_y * tid.z + tid.y] = val;
for (int i = 0; i < n_reads; i++) {
totals[i] = Op::init;
}
short lid = simd_group_id * simd_size + simd_lane_id;
short2 offset((lid % n_read_blocks) * n_reads, lid / n_read_blocks);
size_t column = BN * gid.x + offset.x;
bool safe = column + n_reads <= reduction_stride;
size_t out_idx = gid.y + gsize.y * size_t(gid.z);
size_t in_idx = elem_to_loc(out_idx, shape, strides, ndim);
in += in_idx + column;
size_t total = non_col_reductions * reduction_size;
loop.next(offset.y, reduce_shape, reduce_strides);
for (size_t r = offset.y; r < total; r += BM) {
row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim);
if (safe) {
for (int i = 0; i < n_reads; i++) {
totals[i] = op(static_cast<U>(row[i]), totals[i]);
}
} else {
U vals[n_reads];
for (int i = 0; i < n_reads; i++) {
vals[i] =
(column + i < reduction_stride) ? static_cast<U>(row[i]) : op.init;
}
for (int i = 0; i < n_reads; i++) {
totals[i] = op(vals[i], totals[i]);
}
}
loop.next(BM, reduce_shape, reduce_strides);
}
// We can use a simd reduction to accumulate across BM so each thread writes
// the partial output to SM and then each simdgroup does BN / n_simdgroups
// accumulations.
if (BM == 32) {
constexpr int n_outputs = BN / n_simdgroups;
static_assert(
BM != 32 || n_outputs == n_reads,
"The tile should be selected such that n_outputs == n_reads");
for (int i = 0; i < n_reads; i++) {
shared_vals[offset.y * BN + offset.x + i] = totals[i];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
short2 out_offset(simd_group_id * n_outputs, simd_lane_id);
for (int i = 0; i < n_outputs; i++) {
totals[i] =
op.simd_reduce(shared_vals[out_offset.y * BN + out_offset.x + i]);
}
// Write the output.
if (simd_lane_id == 0) {
size_t out_column = BN * gid.x + out_offset.x;
out += out_idx * reduction_stride + out_column;
if (out_column + n_outputs <= reduction_stride) {
for (int i = 0; i < n_outputs; i++) {
out[i] = totals[i];
}
} else {
for (int i = 0; out_column + i < reduction_stride; i++) {
out[i] = totals[i];
}
}
}
}
// Each thread holds n_reads partial results. We write them all out to shared
// memory and threads with offset.y == 0 aggregate the columns and write the
// outputs.
else {
short x_block = offset.x / n_reads;
for (int i = 0; i < n_reads; i++) {
shared_vals[x_block * BM * n_reads + i * BM + offset.y] = totals[i];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (offset.y == 0) {
for (int i = 0; i < n_reads; i++) {
for (int j = 1; j < BM; j++) {
totals[i] =
op(shared_vals[x_block * BM * n_reads + i * BM + j], totals[i]);
}
}
}
// Write the output.
if (offset.y == 0) {
out += out_idx * reduction_stride + column;
if (safe) {
for (int i = 0; i < n_reads; i++) {
out[i] = totals[i];
}
} else {
for (int i = 0; column + i < reduction_stride; i++) {
out[i] = totals[i];
}
}
}
}
}

View File

@@ -1,287 +1,366 @@
// Copyright © 2023-2024 Apple Inc.
///////////////////////////////////////////////////////////////////////////////
// Small row reductions
///////////////////////////////////////////////////////////////////////////////
// Row reduction utilities
// - `per_thread_row_reduce` collaborative partial reduction in the threadgroup
// - `threadgroup_reduce` collaborative reduction in the threadgroup such that
// lid.x == 0 holds the reduced value
// - `thread_reduce` simple loop and reduce the row
// Each thread reduces for one output
template <typename T, typename U, typename Op>
[[kernel]] void row_reduce_general_small(
const device T* in [[buffer(0)]],
device U* out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& out_size [[buffer(3)]],
const constant size_t& non_row_reductions [[buffer(4)]],
const constant int* shape [[buffer(5)]],
const constant size_t* strides [[buffer(6)]],
const constant int& ndim [[buffer(7)]],
uint lid [[thread_position_in_grid]]) {
/**
* The thread group collaboratively reduces across the rows with bounds
* checking. In the end each thread holds a part of the reduction.
*/
template <
typename T,
typename U,
typename Op,
int N_READS = REDUCE_N_READS,
int N_WRITES = REDUCE_N_WRITES>
METAL_FUNC void per_thread_row_reduce(
thread U totals[N_WRITES],
const device T* inputs[N_WRITES],
int blocks,
int extra,
uint lsize_x,
uint lid_x) {
Op op;
uint out_idx = lid;
if (out_idx >= out_size) {
return;
// Set up the accumulator registers
for (int i = 0; i < N_WRITES; i++) {
totals[i] = Op::init;
}
U total_val = Op::init;
// Loop over the reduction size within thread group
for (int i = 0; i < blocks; i++) {
for (int j = 0; j < N_WRITES; j++) {
for (int i = 0; i < N_READS; i++) {
totals[j] = op(static_cast<U>(inputs[j][i]), totals[j]);
}
for (short r = 0; r < short(non_row_reductions); r++) {
uint in_idx = elem_to_loc(out_idx + r * out_size, shape, strides, ndim);
const device T* in_row = in + in_idx;
for (short i = 0; i < short(reduction_size); i++) {
total_val = op(static_cast<U>(in_row[i]), total_val);
inputs[j] += lsize_x * N_READS;
}
}
out[out_idx] = total_val;
}
// Each simdgroup reduces for one output
template <typename T, typename U, typename Op>
[[kernel]] void row_reduce_general_med(
const device T* in [[buffer(0)]],
device U* out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& out_size [[buffer(3)]],
const constant size_t& non_row_reductions [[buffer(4)]],
const constant int* shape [[buffer(5)]],
const constant size_t* strides [[buffer(6)]],
const constant int& ndim [[buffer(7)]],
uint tid [[threadgroup_position_in_grid]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_per_group [[dispatch_simdgroups_per_threadgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
Op op;
uint out_idx = simd_per_group * tid + simd_group_id;
if (out_idx >= out_size) {
return;
}
U total_val = Op::init;
if (short(non_row_reductions) == 1) {
uint in_idx = elem_to_loc(out_idx, shape, strides, ndim);
const device T* in_row = in + in_idx;
for (short i = simd_lane_id; i < short(reduction_size); i += 32) {
total_val = op(static_cast<U>(in_row[i]), total_val);
}
}
else if (short(non_row_reductions) >= 32) {
for (short r = simd_lane_id; r < short(non_row_reductions); r += 32) {
uint in_idx = elem_to_loc(out_idx + r * out_size, shape, strides, ndim);
const device T* in_row = in + in_idx;
for (short i = 0; i < short(reduction_size); i++) {
total_val = op(static_cast<U>(in_row[i]), total_val);
// Separate case for the last set as we close the reduction size
int index = lid_x * N_READS;
if (index + N_READS <= extra) {
for (int j = 0; j < N_WRITES; j++) {
for (int i = 0; i < N_READS; i++) {
totals[j] = op(static_cast<U>(inputs[j][i]), totals[j]);
}
}
}
else {
const short n_reductions =
short(reduction_size) * short(non_row_reductions);
const short reductions_per_thread =
(n_reductions + simd_size - 1) / simd_size;
const short r_st = simd_lane_id / reductions_per_thread;
const short r_ed = short(non_row_reductions);
const short r_jump = simd_size / reductions_per_thread;
const short i_st = simd_lane_id % reductions_per_thread;
const short i_ed = short(reduction_size);
const short i_jump = reductions_per_thread;
if (r_st < r_jump) {
for (short r = r_st; r < r_ed; r += r_jump) {
uint in_idx = elem_to_loc(out_idx + r * out_size, shape, strides, ndim);
const device T* in_row = in + in_idx;
for (short i = i_st; i < i_ed; i += i_jump) {
total_val = op(static_cast<U>(in_row[i]), total_val);
}
} else {
for (int j = 0; j < N_WRITES; j++) {
for (int i = 0; index + i < extra; i++) {
totals[j] = op(static_cast<U>(inputs[j][i]), totals[j]);
}
}
}
total_val = op.simd_reduce(total_val);
if (simd_lane_id == 0) {
out[out_idx] = total_val;
}
}
///////////////////////////////////////////////////////////////////////////////
// Large row reductions
///////////////////////////////////////////////////////////////////////////////
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
METAL_FUNC U per_thread_row_reduce(
/**
* Consecutive rows in a contiguous array.
*/
template <
typename T,
typename U,
typename Op,
int N_READS = REDUCE_N_READS,
int N_WRITES = REDUCE_N_WRITES>
METAL_FUNC void per_thread_row_reduce(
thread U totals[N_WRITES],
const device T* in,
const constant size_t& reduction_size,
const constant size_t& out_size,
int blocks,
int extra,
uint lsize_x,
uint lid_x) {
// Set up the input pointers
const device T* inputs[N_WRITES];
inputs[0] = in + lid_x * N_READS;
for (int i = 1; i < N_READS; i++) {
inputs[i] = inputs[i - 1] + reduction_size;
}
per_thread_row_reduce<T, U, Op, N_READS, N_WRITES>(
totals, inputs, blocks, extra, lsize_x, lid_x);
}
/**
* Consecutive rows in an arbitrarily ordered array.
*/
template <
typename T,
typename U,
typename Op,
int N_READS = REDUCE_N_READS,
int N_WRITES = REDUCE_N_WRITES>
METAL_FUNC void per_thread_row_reduce(
thread U totals[N_WRITES],
const device T* in,
const size_t row_idx,
int blocks,
int extra,
const constant int* shape,
const constant size_t* strides,
const constant int& ndim,
uint lsize_x,
uint lid_x,
uint2 tid) {
Op op;
// Each threadgroup handles 1 reduction
// TODO: Specializing elem_to_loc would be slightly faster
int idx = tid.y * out_size + tid.x;
int extra_offset = elem_to_loc(idx, shape, strides, ndim);
in += extra_offset + lid_x * N_READS;
// The reduction is accumulated here
U total_val = Op::init;
// Loop over the reduction size within thread group
int r = 0;
for (; r < (int)ceildiv(reduction_size, N_READS * lsize_x) - 1; r++) {
T vals[N_READS];
for (int i = 0; i < N_READS; i++) {
vals[i] = in[i];
}
for (int i = 0; i < N_READS; i++) {
total_val = op(static_cast<U>(vals[i]), total_val);
}
in += lsize_x * N_READS;
uint lid_x) {
// Set up the input pointers
const device T* inputs[N_WRITES];
in += lid_x * N_READS;
for (int i = 0; i < N_READS; i++) {
inputs[i] = in + elem_to_loc(row_idx + i, shape, strides, ndim);
}
// Separate case for the last set as we close the reduction size
size_t reduction_index = (lid_x + (size_t)lsize_x * r) * N_READS;
if (reduction_index < reduction_size) {
int max_reads = reduction_size - reduction_index;
T vals[N_READS];
for (int i = 0; i < N_READS; i++) {
int idx = min(i, max_reads - 1);
vals[i] = static_cast<U>(in[idx]);
}
for (int i = 0; i < N_READS; i++) {
T val = i < max_reads ? vals[i] : Op::init;
total_val = op(static_cast<U>(val), total_val);
}
}
return total_val;
per_thread_row_reduce<T, U, Op, N_READS, N_WRITES>(
totals, inputs, blocks, extra, lsize_x, lid_x);
}
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
[[kernel]] void row_reduce_general(
const device T* in [[buffer(0)]],
device mlx_atomic<U>* out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& out_size [[buffer(3)]],
const constant size_t& non_row_reductions [[buffer(4)]],
const constant int* shape [[buffer(5)]],
const constant size_t* strides [[buffer(6)]],
const constant int& ndim [[buffer(7)]],
/**
* Reduce within the threadgroup.
*/
template <
typename T,
typename U,
typename Op,
int N_READS = REDUCE_N_READS,
int N_WRITES = REDUCE_N_WRITES>
METAL_FUNC void threadgroup_reduce(
thread U totals[N_WRITES],
threadgroup U* shared_vals,
uint3 lid [[thread_position_in_threadgroup]],
uint3 lsize [[threads_per_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_per_group [[simdgroups_per_threadgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
(void)non_row_reductions;
Op op;
threadgroup U local_vals[simd_size];
U total_val = per_thread_row_reduce<T, U, Op, N_READS>(
in,
reduction_size,
out_size,
shape,
strides,
ndim,
lsize.x,
lid.x,
tid.xy);
total_val = op.simd_reduce(total_val);
// Prepare next level
if (simd_lane_id == 0) {
local_vals[simd_group_id] = total_val;
// Simdgroup first
for (int i = 0; i < N_WRITES; i++) {
totals[i] = op.simd_reduce(totals[i]);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Reduction within thread group
// Only needed if multiple simd groups
if (reduction_size > simd_size) {
total_val = lid.x < simd_per_group ? local_vals[lid.x] : op.init;
total_val = op.simd_reduce(total_val);
}
// Update output
if (lid.x == 0) {
op.atomic_update(out, total_val, tid.x);
// Across simdgroups
if (simd_per_group > 1) {
if (simd_lane_id == 0) {
for (int i = 0; i < N_WRITES; i++) {
shared_vals[simd_group_id * N_WRITES + i] = totals[i];
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
U values[N_WRITES];
for (int i = 0; i < N_WRITES; i++) {
values[i] = (lid.x < simd_per_group) ? shared_vals[lid.x * N_WRITES + i]
: op.init;
}
for (int i = 0; i < N_WRITES; i++) {
totals[i] = op.simd_reduce(values[i]);
}
}
}
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
[[kernel]] void row_reduce_general_no_atomics(
METAL_FUNC void
thread_reduce(thread U& total, const device T* row, int blocks, int extra) {
Op op;
for (int i = 0; i < blocks; i++) {
U vals[N_READS];
for (int j = 0; j < N_READS; j++) {
vals[j] = row[j];
}
for (int j = 0; j < N_READS; j++) {
total = op(vals[j], total);
}
row += N_READS;
}
for (int i = 0; i < extra; i++) {
total = op(*row++, total);
}
}
// Reduction kernels
// - `row_reduce_small` depending on the non-row reductions and row size it
// either just loops over everything or a simd collaboratively reduces the
// non_row reductions. In the first case one thread is responsible for one
// output on the 2nd one simd is responsible for one output.
// - `row_reduce_simple` simple contiguous row reduction
// - `row_reduce_looped` simply loop and reduce each row for each non-row
// reduction. One threadgroup is responsible for one output.
template <
typename T,
typename U,
typename Op,
int NDIMS = 0,
int N_READS = REDUCE_N_READS>
[[kernel]] void row_reduce_small(
const device T* in [[buffer(0)]],
device U* out [[buffer(1)]],
const constant size_t& row_size [[buffer(2)]],
const constant size_t& non_row_reductions [[buffer(3)]],
const constant int* shape [[buffer(4)]],
const constant size_t* strides [[buffer(5)]],
const constant int& ndim [[buffer(6)]],
const constant int* reduce_shape [[buffer(7)]],
const constant size_t* reduce_strides [[buffer(8)]],
const constant int& reduce_ndim [[buffer(9)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint3 gid [[threadgroup_position_in_grid]],
uint3 gsize [[threadgroups_per_grid]],
uint3 tid [[thread_position_in_grid]],
uint3 tsize [[threads_per_grid]]) {
Op op;
U total_val = Op::init;
looped_elem_to_loc<NDIMS> loop;
// Precompute some row reduction numbers
const device T* row;
int blocks = row_size / N_READS;
int extra = row_size % N_READS;
if ((non_row_reductions < 32 && row_size <= 8) || non_row_reductions <= 8) {
// Simple loop over non_row_reductions and reduce the row in the thread.
size_t out_idx = tid.x + tsize.y * size_t(tid.y);
in += elem_to_loc(out_idx, shape, strides, ndim);
for (uint r = 0; r < non_row_reductions; r++) {
row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim);
thread_reduce<T, U, Op, N_READS>(total_val, row, blocks, extra);
loop.next(reduce_shape, reduce_strides);
}
out[out_idx] = total_val;
} else {
// Collaboratively reduce over non_row_reductions in the simdgroup. Each
// thread reduces every 32nd row and then a simple simd reduce.
size_t out_idx = gid.y + gsize.y * size_t(gid.z);
in += elem_to_loc(out_idx, shape, strides, ndim);
loop.next(simd_lane_id, reduce_shape, reduce_strides);
for (uint r = simd_lane_id; r < non_row_reductions; r += simd_size) {
row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim);
thread_reduce<T, U, Op, N_READS>(total_val, row, blocks, extra);
loop.next(simd_size, reduce_shape, reduce_strides);
}
total_val = op.simd_reduce(total_val);
if (simd_lane_id == 0) {
out[out_idx] = total_val;
}
}
}
template <
typename T,
typename U,
typename Op,
int N_READS = REDUCE_N_READS,
int N_WRITES = REDUCE_N_WRITES>
[[kernel]] void row_reduce_simple(
const device T* in [[buffer(0)]],
device U* out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& out_size [[buffer(3)]],
const constant size_t& non_row_reductions [[buffer(4)]],
const constant int* shape [[buffer(5)]],
const constant size_t* strides [[buffer(6)]],
const constant int& ndim [[buffer(7)]],
uint3 gid [[threadgroup_position_in_grid]],
uint3 gsize [[threadgroups_per_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint3 lsize [[threads_per_threadgroup]],
uint3 gsize [[threads_per_grid]],
uint3 tid [[threadgroup_position_in_grid]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_per_group [[simdgroups_per_threadgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
(void)non_row_reductions;
threadgroup U shared_vals[simd_size * N_WRITES];
U totals[N_WRITES];
Op op;
threadgroup U local_vals[simd_size];
U total_val = per_thread_row_reduce<T, U, Op, N_READS>(
in,
reduction_size,
out_size,
shape,
strides,
ndim,
lsize.x,
lid.x,
tid.xy);
// Reduction within simd group - simd_add isn't supported for int64 types
for (uint16_t i = simd_size / 2; i > 0; i /= 2) {
total_val = op(total_val, simd_shuffle_down(total_val, i));
// Move to the row
size_t out_idx = N_WRITES * (gid.y + gsize.y * size_t(gid.z));
if (out_idx + N_WRITES > out_size) {
out_idx = out_size - N_WRITES;
}
in += out_idx * reduction_size;
out += out_idx;
// Prepare next level
if (simd_lane_id == 0) {
local_vals[simd_group_id] = total_val;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Each thread reduces across the row
int blocks = reduction_size / (lsize.x * N_READS);
int extra = reduction_size - blocks * (lsize.x * N_READS);
per_thread_row_reduce<T, U, Op, N_READS, N_WRITES>(
totals, in, reduction_size, blocks, extra, lsize.x, lid.x);
// Reduction within thread group
// Only needed if thread group has multiple simd groups
if (ceildiv(reduction_size, N_READS) > simd_size) {
total_val = lid.x < simd_per_group ? local_vals[lid.x] : op.init;
for (uint16_t i = simd_size / 2; i > 0; i /= 2) {
total_val = op(total_val, simd_shuffle_down(total_val, i));
// Reduce across the threadgroup
threadgroup_reduce<T, U, Op, N_READS, N_WRITES>(
totals, shared_vals, lid, simd_lane_id, simd_per_group, simd_group_id);
// Write the output
if (lid.x == 0) {
for (int i = 0; i < N_WRITES; i++) {
out[i] = totals[i];
}
}
// Write row reduce output for threadgroup with 1st thread in thread group
}
template <
typename T,
typename U,
typename Op,
int NDIMS = 0,
int N_READS = REDUCE_N_READS>
[[kernel]] void row_reduce_looped(
const device T* in [[buffer(0)]],
device U* out [[buffer(1)]],
const constant size_t& row_size [[buffer(2)]],
const constant size_t& non_row_reductions [[buffer(3)]],
const constant int* shape [[buffer(4)]],
const constant size_t* strides [[buffer(5)]],
const constant int& ndim [[buffer(6)]],
const constant int* reduce_shape [[buffer(7)]],
const constant size_t* reduce_strides [[buffer(8)]],
const constant int& reduce_ndim [[buffer(9)]],
uint3 gid [[threadgroup_position_in_grid]],
uint3 gsize [[threadgroups_per_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint3 lsize [[threads_per_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_per_group [[simdgroups_per_threadgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
Op op;
threadgroup U shared_vals[simd_size];
U total = Op::init;
size_t out_idx = gid.y + gsize.y * size_t(gid.z);
// lid.x * N_READS breaks the per_thread_row_reduce interface a bit. Maybe it
// needs a small refactor.
in += elem_to_loc(out_idx, shape, strides, ndim) + lid.x * N_READS;
looped_elem_to_loc<NDIMS> loop;
const device T* row;
int blocks = row_size / (lsize.x * N_READS);
int extra = row_size - blocks * (lsize.x * N_READS);
for (size_t i = 0; i < non_row_reductions; i++) {
row = in + loop.location(i, reduce_shape, reduce_strides, reduce_ndim);
// Each thread reduces across the row
U row_total;
per_thread_row_reduce<T, U, Op, N_READS, 1>(
&row_total, &row, blocks, extra, lsize.x, lid.x);
// Aggregate across rows
total = op(total, row_total);
loop.next(reduce_shape, reduce_strides);
}
// Reduce across the threadgroup
threadgroup_reduce<T, U, Op, N_READS, 1>(
&total, shared_vals, lid, simd_lane_id, simd_per_group, simd_group_id);
// Write the output
if (lid.x == 0) {
out[(ceildiv(gsize.y, lsize.y) * tid.x) + tid.y] = total_val;
out[out_idx] = total;
}
}

View File

@@ -4,44 +4,36 @@
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/utils.h"
template <typename T, bool traditional, bool forward>
[[kernel]] void rope_single(
const device T* in [[buffer(0)]],
device T* out [[buffer(1)]],
void rope_single_impl(
const device T* in,
device T* out,
constant const int& offset,
constant const float& base,
const float inv_freq,
constant const float& scale,
constant const size_t& stride,
uint2 pos [[thread_position_in_grid]],
uint2 grid [[threads_per_grid]]) {
// Figure out L and d.
uint2 pos,
uint2 grid) {
float L = scale * static_cast<float>(offset);
float d = static_cast<float>(pos.x) / static_cast<float>(grid.x);
// Compute costheta, sintheta
float theta = L * metal::exp2(-d * base);
float theta = L * inv_freq;
float costheta = metal::fast::cos(theta);
float sintheta = metal::fast::sin(theta);
// Compute the input and output indices
uint in_index_1, in_index_2;
uint out_index_1, out_index_2;
uint index_1, index_2;
if (traditional) {
out_index_1 = 2 * pos.x + pos.y * stride;
out_index_2 = out_index_1 + 1;
in_index_1 = 2 * pos.x + pos.y * stride;
in_index_2 = in_index_1 + 1;
index_1 = 2 * pos.x + pos.y * stride;
index_2 = index_1 + 1;
} else {
out_index_1 = pos.x + pos.y * stride;
out_index_2 = out_index_1 + grid.x;
in_index_1 = pos.x + pos.y * stride;
in_index_2 = in_index_1 + grid.x;
index_1 = pos.x + pos.y * stride;
index_2 = index_1 + grid.x;
}
// Read and write the output
float x1 = static_cast<float>(in[in_index_1]);
float x2 = static_cast<float>(in[in_index_2]);
float x1 = static_cast<float>(in[index_1]);
float x2 = static_cast<float>(in[index_2]);
float rx1;
float rx2;
if (forward) {
@@ -51,28 +43,58 @@ template <typename T, bool traditional, bool forward>
rx1 = x2 * sintheta + x1 * costheta;
rx2 = x2 * costheta - x1 * sintheta;
}
out[out_index_1] = static_cast<T>(rx1);
out[out_index_2] = static_cast<T>(rx2);
out[index_1] = static_cast<T>(rx1);
out[index_2] = static_cast<T>(rx2);
}
template <typename T, bool traditional, bool forward, int N = 4>
[[kernel]] void rope(
template <typename T, bool traditional, bool forward>
[[kernel]] void rope_single(
const device T* in [[buffer(0)]],
device T* out [[buffer(1)]],
constant const int& offset,
constant const float& base,
constant const float& scale,
constant const size_t& stride,
constant const float& base [[buffer(10)]],
uint2 pos [[thread_position_in_grid]],
uint2 grid [[threads_per_grid]]) {
float d = static_cast<float>(pos.x) / static_cast<float>(grid.x);
float inv_freq = metal::exp2(-d * base);
rope_single_impl<T, traditional, forward>(
in, out, offset, inv_freq, scale, stride, pos, grid);
}
template <typename T, bool traditional, bool forward>
[[kernel]] void rope_single_freqs(
const device T* in [[buffer(0)]],
device T* out [[buffer(1)]],
constant const int& offset,
constant const float& scale,
constant const size_t& stride,
const device float* freqs [[buffer(10)]],
constant const size_t& freq_stride [[buffer(11)]],
uint2 pos [[thread_position_in_grid]],
uint2 grid [[threads_per_grid]]) {
float inv_freq = 1.0 / (freqs[freq_stride * pos.x]);
rope_single_impl<T, traditional, forward>(
in, out, offset, inv_freq, scale, stride, pos, grid);
}
template <typename T, bool traditional, bool forward, int N = 4>
void rope_impl(
const device T* in,
device T* out,
constant const int& offset,
const float inv_freq,
constant const float& scale,
constant const size_t strides[3],
constant const size_t out_strides[3],
constant const size_t& n_batch,
uint3 pos [[thread_position_in_grid]],
uint3 grid [[threads_per_grid]]) {
// Figure out L and d.
uint3 pos,
uint3 grid) {
float L = scale * static_cast<float>(pos.y + offset);
float d = static_cast<float>(pos.x) / static_cast<float>(grid.x);
// Compute costheta, sintheta
float theta = L * metal::exp2(-d * base);
float theta = L * inv_freq;
float costheta = metal::fast::cos(theta);
float sintheta = metal::fast::sin(theta);
@@ -116,37 +138,115 @@ template <typename T, bool traditional, bool forward, int N = 4>
}
}
template <typename T, bool traditional, bool forward, int N = 4>
[[kernel]] void rope(
const device T* in [[buffer(0)]],
device T* out [[buffer(1)]],
constant const int& offset,
constant const float& scale,
constant const size_t strides[3],
constant const size_t out_strides[3],
constant const size_t& n_batch,
constant const float& base [[buffer(10)]],
uint3 pos [[thread_position_in_grid]],
uint3 grid [[threads_per_grid]]) {
float d = static_cast<float>(pos.x) / static_cast<float>(grid.x);
float inv_freq = metal::exp2(-d * base);
rope_impl<T, traditional, forward, N>(
in,
out,
offset,
inv_freq,
scale,
strides,
out_strides,
n_batch,
pos,
grid);
}
template <typename T, bool traditional, bool forward, int N = 4>
[[kernel]] void rope_freqs(
const device T* in [[buffer(0)]],
device T* out [[buffer(1)]],
constant const int& offset,
constant const float& scale,
constant const size_t strides[3],
constant const size_t out_strides[3],
constant const size_t& n_batch,
const device float* freqs [[buffer(10)]],
constant const size_t& freq_stride [[buffer(11)]],
uint3 pos [[thread_position_in_grid]],
uint3 grid [[threads_per_grid]]) {
float inv_freq = 1.0 / (freqs[freq_stride * pos.x]);
rope_impl<T, traditional, forward, N>(
in,
out,
offset,
inv_freq,
scale,
strides,
out_strides,
n_batch,
pos,
grid);
}
// clang-format off
#define instantiate_rope_g(name, type, traditional, forward) \
template [[host_name("rope_" #name)]] [[kernel]] void \
rope<type, traditional, forward>( \
const device type* in [[buffer(0)]], \
device type* out [[buffer(1)]], \
constant const int& offset, \
constant const float& base, \
constant const float& scale, \
constant const size_t strides[3], \
constant const size_t out_strides[3], \
constant const size_t& n_batch, \
constant const float& base [[buffer(10)]], \
uint3 pos [[thread_position_in_grid]], \
uint3 grid [[threads_per_grid]]); \
template [[host_name("rope_freqs_" #name)]] \
[[kernel]] void rope_freqs<type, traditional, forward>( \
const device type* in [[buffer(0)]], \
device type* out [[buffer(1)]], \
constant const int& offset, \
constant const float& scale, \
constant const size_t strides[3], \
constant const size_t out_strides[3], \
constant const size_t& n_batch, \
const device float* freqs [[buffer(10)]], \
constant const size_t& freq_stride [[buffer(11)]], \
uint3 pos [[thread_position_in_grid]], \
uint3 grid [[threads_per_grid]]);
#define instantiate_rope_s(name, type, traditional, forward) \
template [[host_name("rope_single_" #name)]] [[kernel]] void \
rope_single<type, traditional, forward>( \
const device type* in [[buffer(0)]], \
device type* out [[buffer(1)]], \
constant const int& offset, \
constant const float& base, \
constant const float& scale, \
constant const size_t& stride, \
uint2 pos [[thread_position_in_grid]], \
#define instantiate_rope_s(name, type, traditional, forward) \
template [[host_name("rope_single_" #name)]] [[kernel]] void \
rope_single<type, traditional, forward>( \
const device type* in [[buffer(0)]], \
device type* out [[buffer(1)]], \
constant const int& offset, \
constant const float& scale, \
constant const size_t& stride, \
constant const float& base [[buffer(10)]], \
uint2 pos [[thread_position_in_grid]], \
uint2 grid [[threads_per_grid]]); \
template [[host_name("rope_single_freqs_" #name)]] \
[[kernel]] void rope_single_freqs<type, traditional, forward>( \
const device type* in [[buffer(0)]], \
device type* out [[buffer(1)]], \
constant const int& offset, \
constant const float& scale, \
constant const size_t& stride, \
const device float* freqs [[buffer(10)]], \
constant const size_t& freq_stride [[buffer(11)]], \
uint2 pos [[thread_position_in_grid]], \
uint2 grid [[threads_per_grid]]);
#define instantiate_rope(name, type, traditional, forward) \
instantiate_rope_s(name, type, traditional, forward) \
instantiate_rope_g(name, type, traditional, forward)
instantiate_rope_g(name, type, traditional, forward)
// clang-format off
instantiate_rope(traditional_float16, half, true, true)
instantiate_rope(traditional_bfloat16, bfloat16_t, true, true)
instantiate_rope(traditional_float32, float, true, true)

View File

@@ -18,7 +18,7 @@ METAL_FUNC void scatter_1d_index_impl(
uint2 gid [[thread_position_in_grid]]) {
Op op;
uint out_idx = 0;
size_t out_idx = 0;
for (int i = 0; i < NIDX; i++) {
auto idx_val = offset_neg_idx(idx_buffers[i][gid.y], out_shape[i]);
out_idx += idx_val * out_strides[i];

View File

@@ -394,7 +394,7 @@ struct Conv2DWeightBlockLoader {
const constant ImplicitGemmConv2DParams* gemm_params_,
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]])
: src_ld(params_ -> wt_strides[0]),
: src_ld(params_->wt_strides[0]),
thread_idx(simd_group_id * 32 + simd_lane_id),
bi(thread_idx / TCOLS),
bj(vec_size * (thread_idx % TCOLS)),

View File

@@ -244,7 +244,7 @@ struct Conv2DWeightBlockLoaderSmallChannels {
const constant ImplicitGemmConv2DParams* gemm_params_,
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]])
: src_ld(params_ -> wt_strides[0]),
: src_ld(params_->wt_strides[0]),
thread_idx(simd_group_id * 32 + simd_lane_id),
bi(thread_idx / TCOLS),
bj(vec_size * (thread_idx % TCOLS)),

View File

@@ -218,7 +218,7 @@ struct Conv2DWeightBlockLoaderGeneral {
const short base_ww_,
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]])
: src_ld(params_ -> wt_strides[0]),
: src_ld(params_->wt_strides[0]),
thread_idx(simd_group_id * 32 + simd_lane_id),
bi(thread_idx / TCOLS),
bj(vec_size * (thread_idx % TCOLS)),

View File

@@ -64,6 +64,7 @@ instantiate_unary_all(Cos, complex64, complex64_t)
instantiate_unary_all(Cosh, complex64, complex64_t)
instantiate_unary_all(Exp, complex64, complex64_t)
instantiate_unary_all(Negative, complex64, complex64_t)
instantiate_unary_all(Sign, complex64, complex64_t)
instantiate_unary_all(Sin, complex64, complex64_t)
instantiate_unary_all(Sinh, complex64, complex64_t)
instantiate_unary_all(Tan, complex64, complex64_t)

View File

@@ -308,6 +308,14 @@ struct Sign {
uint32_t operator()(uint32_t x) {
return x != 0;
};
template <>
complex64_t operator()(complex64_t x) {
if (x == complex64_t(0)) {
return x;
}
return x /
(complex64_t)metal::precise::sqrt(x.real * x.real + x.imag * x.imag);
};
};
struct Sin {

View File

@@ -64,6 +64,16 @@ struct Limits<bool> {
static constexpr constant bool min = false;
};
template <>
struct Limits<complex64_t> {
static constexpr constant complex64_t max = complex64_t(
metal::numeric_limits<float>::infinity(),
metal::numeric_limits<float>::infinity());
static constexpr constant complex64_t min = complex64_t(
-metal::numeric_limits<float>::infinity(),
-metal::numeric_limits<float>::infinity());
};
///////////////////////////////////////////////////////////////////////////////
// Indexing utils
///////////////////////////////////////////////////////////////////////////////
@@ -101,6 +111,34 @@ METAL_FUNC stride_t elem_to_loc(
return loc;
}
template <typename stride_t>
METAL_FUNC stride_t elem_to_loc(
stride_t elem,
device const int* shape,
device const stride_t* strides,
int ndim) {
stride_t loc = 0;
for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
loc += (elem % shape[i]) * strides[i];
elem /= shape[i];
}
return loc;
}
template <typename stride_t>
METAL_FUNC stride_t elem_to_loc(
stride_t elem,
constant const int* shape,
constant const stride_t* strides,
int ndim) {
stride_t loc = 0;
for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
loc += (elem % shape[i]) * strides[i];
elem /= shape[i];
}
return loc;
}
// Non templated version to handle arbitrary dims
template <typename stride_t>
METAL_FUNC stride_t elem_to_loc(
@@ -288,12 +326,87 @@ METAL_FUNC uint3 elem_to_loc_3_nd(
return loc;
}
///////////////////////////////////////////////////////////////////////////////
// Elem to loc in a loop utils
///////////////////////////////////////////////////////////////////////////////
template <int dim, typename offset_t = size_t>
struct looped_elem_to_loc {
looped_elem_to_loc<dim - 1, offset_t> inner_looper;
offset_t offset{0};
int index{0};
void next(const constant int* shape, const constant size_t* strides) {
index++;
offset += strides[dim - 1];
if (index >= shape[dim - 1]) {
index = 0;
inner_looper.next(shape, strides);
offset = inner_looper.offset;
}
}
void next(int n, const constant int* shape, const constant size_t* strides) {
index += n;
offset += n * strides[dim - 1];
if (index >= shape[dim - 1]) {
int extra = index - shape[dim - 1];
index = 0;
inner_looper.next(shape, strides);
offset = inner_looper.offset;
if (extra > 0) {
next(extra, shape, strides);
}
}
}
offset_t
location(offset_t, const constant int*, const constant size_t*, int) {
return offset;
}
};
template <typename offset_t>
struct looped_elem_to_loc<1, offset_t> {
offset_t offset{0};
void next(const constant int*, const constant size_t* strides) {
offset += strides[0];
}
void next(int n, const constant int*, const constant size_t* strides) {
offset += n * strides[0];
}
offset_t
location(offset_t, const constant int*, const constant size_t*, int) {
return offset;
}
};
template <typename offset_t>
struct looped_elem_to_loc<0, offset_t> {
void next(const constant int*, const constant size_t*) {}
void next(int, const constant int*, const constant size_t*) {}
offset_t location(
offset_t idx,
const constant int* shape,
const constant size_t* strides,
int ndim) {
return elem_to_loc(idx, shape, strides, ndim);
}
};
///////////////////////////////////////////////////////////////////////////////
// Calculation utils
///////////////////////////////////////////////////////////////////////////////
/** Compute ceil((float)N/(float)M) */
inline size_t ceildiv(size_t N, size_t M) {
template <typename T, typename U>
inline T ceildiv(T N, U M) {
return (N + M - 1) / M;
}
@@ -339,3 +452,8 @@ inline int64_t simd_shuffle_down(int64_t data, uint16_t delta) {
inline bool simd_shuffle_down(bool data, uint16_t delta) {
return simd_shuffle_down(static_cast<uint32_t>(data), delta);
}
inline complex64_t simd_shuffle_down(complex64_t data, uint16_t delta) {
return complex64_t(
simd_shuffle_down(data.real, delta), simd_shuffle_down(data.imag, delta));
}

View File

@@ -14,9 +14,9 @@ SRC_NAME=$(basename -- "${SRC_FILE}")
INPUT_FILE=${SRC_DIR}/mlx/backend/metal/kernels/${SRC_FILE}.h
OUTPUT_FILE=${OUTPUT_DIR}/${SRC_NAME}.cpp
mkdir -p $OUTPUT_DIR
mkdir -p "$OUTPUT_DIR"
CONTENT=$($CC -I $SRC_DIR -DMLX_METAL_JIT -E -P $INPUT_FILE $CFLAGS 2>/dev/null)
CONTENT=$($CC -I "$SRC_DIR" -DMLX_METAL_JIT -E -P "$INPUT_FILE" $CFLAGS 2>/dev/null)
cat << EOF > "$OUTPUT_FILE"
namespace mlx::core::metal {

View File

@@ -47,8 +47,6 @@ std::function<void()> make_task(array arr, bool signal) {
for (auto& input : arr.inputs()) {
if (input.event().valid() &&
input.event().stream() != arr.primitive().stream()) {
// TODO, consider committing the buffer and encoding a wait in the new
// buffer rather than on the task thread
input.event().wait();
}
}

View File

@@ -135,8 +135,8 @@ void RMSNormVJP::eval_gpu(
auto axis_size = static_cast<uint32_t>(x.shape().back());
int n_rows = x.data_size() / axis_size;
// Allocate a temporary to store the gradients for w and initialize the
// gradient accumulator to 0.
// Allocate the gradient accumulator gw and a temporary to store the
// gradients before they are accumulated.
array gw_temp({n_rows, x.shape().back()}, gw.dtype(), nullptr, {});
bool g_in_gw = false;
if (!g_in_gx && g.is_donatable()) {
@@ -146,11 +146,7 @@ void RMSNormVJP::eval_gpu(
gw_temp.set_data(allocator::malloc_or_wait(gw_temp.nbytes()));
}
copies.push_back(gw_temp);
{
array zero(0, gw.dtype());
copy_gpu(zero, gw, CopyType::Scalar, s);
copies.push_back(std::move(zero));
}
gw.set_data(allocator::malloc_or_wait(gw.nbytes()));
const int simd_size = 32;
const int n_reads = RMS_N_READS;
@@ -330,8 +326,8 @@ void LayerNormVJP::eval_gpu(
auto axis_size = static_cast<uint32_t>(x.shape().back());
int n_rows = x.data_size() / axis_size;
// Allocate a temporary to store the gradients for w and initialize the
// gradient accumulator to 0.
// Allocate a temporary to store the gradients for w and allocate the output
// gradient accumulators.
array gw_temp({n_rows, x.shape().back()}, gw.dtype(), nullptr, {});
bool g_in_gw = false;
if (!g_in_gx && g.is_donatable()) {
@@ -341,12 +337,8 @@ void LayerNormVJP::eval_gpu(
gw_temp.set_data(allocator::malloc_or_wait(gw_temp.nbytes()));
}
copies.push_back(gw_temp);
{
array zero(0, gw.dtype());
copy_gpu(zero, gw, CopyType::Scalar, s);
copy_gpu(zero, gb, CopyType::Scalar, s);
copies.push_back(std::move(zero));
}
gw.set_data(allocator::malloc_or_wait(gw.nbytes()));
gb.set_data(allocator::malloc_or_wait(gb.nbytes()));
// Finish with the gradient for b in case we had a b
auto& compute_encoder = d.get_command_encoder(s.index);

View File

@@ -4,12 +4,14 @@
#include <numeric>
#include <sstream>
#include "mlx/backend/common/load.h"
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/slicing.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/primitives.h"
#include "mlx/scheduler.h"
#include "mlx/utils.h"
namespace mlx::core {
@@ -197,7 +199,27 @@ void Full::eval_gpu(const std::vector<array>& inputs, array& out) {
}
void Load::eval_gpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
static Stream io_stream = new_stream(Device::cpu);
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto read_task = [out = out,
offset = offset_,
reader = reader_,
swap_endianness = swap_endianness_]() mutable {
load(out, offset, reader, swap_endianness);
};
auto fut = io::thread_pool().enqueue(std::move(read_task)).share();
auto signal_task = [out = out, fut = std::move(fut)]() {
fut.wait();
out.event().signal();
};
scheduler::enqueue(io_stream, std::move(signal_task));
auto& d = metal::device(stream().device);
d.end_encoding(stream().index);
auto command_buffer = d.get_command_buffer(stream().index);
command_buffer->encodeWait(
static_cast<MTL::Event*>(out.event().raw_event().get()),
out.event().value());
}
void NumberOfElements::eval_gpu(const std::vector<array>& inputs, array& out) {

File diff suppressed because it is too large Load Diff

View File

@@ -16,7 +16,8 @@ void all_reduce_dispatch(
const std::string& op_name,
CommandEncoder& compute_encoder,
metal::Device& d,
const Stream& s);
const Stream& s,
std::vector<array>& copies);
void row_reduce_general_dispatch(
const array& in,

View File

@@ -67,8 +67,10 @@ void RoPE::eval_gpu(
// Special case for inference (single time step and contiguous)
bool single = in.flags().row_contiguous && (mat_size == in.shape(-1));
bool with_freqs = inputs.size() == 2;
std::ostringstream kname;
kname << "rope_" << (single ? "single_" : "") << (forward_ ? "" : "vjp_")
kname << "rope_" << (single ? "single_" : "")
<< ((with_freqs) ? "freqs_" : "") << (forward_ ? "" : "vjp_")
<< (traditional_ ? "traditional_" : "") << type_to_name(in);
auto kernel = d.get_kernel(kname.str());
auto& compute_encoder = d.get_command_encoder(s.index);
@@ -78,27 +80,36 @@ void RoPE::eval_gpu(
compute_encoder.set_input_array(donated ? out : in, 0);
compute_encoder.set_output_array(out, 1);
compute_encoder->setBytes(&offset_, sizeof(int), 2);
compute_encoder->setBytes(&base, sizeof(float), 3);
compute_encoder->setBytes(&scale_, sizeof(float), 4);
compute_encoder->setBytes(&scale_, sizeof(float), 3);
size_t n_batch = in.size() / mat_size;
MTL::Size group_dims;
MTL::Size grid_dims;
if (single) {
compute_encoder->setBytes(&out_strides[1], sizeof(size_t), 5);
compute_encoder->setBytes(out_strides, sizeof(size_t), 4);
uint32_t dim0 = dims_ / 2;
auto group_dims = get_block_dims(dim0, n_batch, 1);
auto grid_dims = MTL::Size(dim0, n_batch, 1);
compute_encoder.dispatchThreads(grid_dims, group_dims);
group_dims = get_block_dims(dim0, n_batch, 1);
grid_dims = MTL::Size(dim0, n_batch, 1);
} else {
compute_encoder->setBytes(&strides, 3 * sizeof(size_t), 5);
compute_encoder->setBytes(&out_strides, 3 * sizeof(size_t), 6);
compute_encoder->setBytes(&n_batch, sizeof(size_t), 7);
compute_encoder->setBytes(&strides, 3 * sizeof(size_t), 4);
compute_encoder->setBytes(&out_strides, 3 * sizeof(size_t), 5);
compute_encoder->setBytes(&n_batch, sizeof(size_t), 6);
uint32_t dim0 = dims_ / 2;
uint32_t dim1 = in.shape(-2);
uint32_t dim2 = (n_batch + n_per_thread - 1) / n_per_thread;
auto group_dims = get_block_dims(dim0, dim1, dim2);
auto grid_dims = MTL::Size(dim0, dim1, dim2);
compute_encoder.dispatchThreads(grid_dims, group_dims);
group_dims = get_block_dims(dim0, dim1, dim2);
grid_dims = MTL::Size(dim0, dim1, dim2);
}
if (with_freqs) {
auto& freqs = inputs[1];
compute_encoder.set_input_array(freqs, 10);
auto freq_stride = freqs.strides()[0];
compute_encoder->setBytes(&freq_stride, sizeof(size_t), 11);
} else {
compute_encoder->setBytes(&base, sizeof(float), 10);
}
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
} // namespace mlx::core::fast

View File

@@ -56,9 +56,12 @@ void ternary_op_gpu_inplace(
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(a, 0);
compute_encoder.set_input_array(b, 1);
compute_encoder.set_input_array(c, 2);
bool donate_a = a.data_shared_ptr() == nullptr;
bool donate_b = b.data_shared_ptr() == nullptr;
bool donate_c = c.data_shared_ptr() == nullptr;
compute_encoder.set_input_array(donate_a ? out : a, 0);
compute_encoder.set_input_array(donate_b ? out : b, 1);
compute_encoder.set_input_array(donate_c ? out : c, 2);
compute_encoder.set_output_array(out, 3);
if (topt == TernaryOpType::General) {
@@ -91,9 +94,10 @@ void ternary_op_gpu_inplace(
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
compute_encoder.dispatchThreads(grid_dims, group_dims);
} else {
// Launch a 1D grid of threads
// Launch a 1D or 2D grid of threads
size_t nthreads = out.data_size();
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
: MTL::Size(nthreads, 1, 1);
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size > nthreads) {
thread_group_size = nthreads;

View File

@@ -1,6 +1,7 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/primitives.h"
#include "mlx/distributed/primitives.h"
#include "mlx/fast_primitives.h"
#define NO_GPU_MULTI(func) \
@@ -119,6 +120,14 @@ NO_GPU_MULTI(RMSNormVJP)
NO_GPU_MULTI(RoPE)
NO_GPU(ScaledDotProductAttention)
NO_GPU_MULTI(AffineQuantize)
NO_GPU_MULTI(CustomKernel)
} // namespace fast
namespace distributed {
NO_GPU_MULTI(AllReduce)
NO_GPU_MULTI(AllGather)
NO_GPU_MULTI(Send)
NO_GPU_MULTI(Recv)
} // namespace distributed
} // namespace mlx::core

View File

@@ -50,17 +50,4 @@ struct Group {
*/
Group init(bool strict = false);
namespace detail {
/* Return the communication stream. */
Stream communication_stream();
/* Perform an all reduce sum operation */
void all_sum(Group group, const array& input, array& output);
/* Perform an all reduce sum operation */
void all_gather(Group group, const array& input, array& output);
} // namespace detail
} // namespace mlx::core::distributed

View File

@@ -0,0 +1,24 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include "mlx/distributed/distributed.h"
namespace mlx::core::distributed::detail {
/* Return the communication stream. */
Stream communication_stream();
/* Perform an all reduce sum operation */
void all_sum(Group group, const array& input, array& output);
/* Perform an all gather operation */
void all_gather(Group group, const array& input, array& output);
/** Send an array to the dst rank */
void send(Group group, const array& input, int dst);
/** Recv an array from the src rank */
void recv(Group group, array& out, int src);
} // namespace mlx::core::distributed::detail

View File

@@ -5,6 +5,7 @@
#include "mlx/backend/common/copy.h"
#include "mlx/distributed/distributed.h"
#include "mlx/distributed/distributed_impl.h"
#include "mlx/scheduler.h"
#define LOAD_SYMBOL(symbol, variable) \
@@ -47,6 +48,8 @@ struct MPIWrapper {
LOAD_SYMBOL(MPI_Comm_free, comm_free);
LOAD_SYMBOL(MPI_Allreduce, all_reduce);
LOAD_SYMBOL(MPI_Allgather, all_gather);
LOAD_SYMBOL(MPI_Send, send);
LOAD_SYMBOL(MPI_Recv, recv);
// Objects
LOAD_SYMBOL(ompi_mpi_comm_world, comm_world_);
@@ -141,6 +144,8 @@ struct MPIWrapper {
MPI_Comm);
int (*comm_split)(MPI_Comm, int, int, MPI_Comm*);
int (*comm_free)(MPI_Comm*);
int (*send)(const void*, int, MPI_Datatype, int, int, MPI_Comm);
int (*recv)(void*, int, MPI_Datatype, int, int, MPI_Comm, MPI_Status*);
// Objects
MPI_Comm comm_world_;
@@ -284,6 +289,29 @@ void all_gather(Group group, const array& input_, array& output) {
to_comm(group));
}
void send(Group group, const array& input_, int dst) {
array input = ensure_row_contiguous(input_);
mpi().send(
input.data<void>(),
input.size(),
mpi().datatype(input),
dst,
0,
to_comm(group));
}
void recv(Group group, array& out, int src) {
MPI_Status status;
mpi().recv(
out.data<void>(),
out.size(),
mpi().datatype(out),
src,
MPI_ANY_TAG,
to_comm(group),
&status);
}
} // namespace detail
} // namespace mlx::core::distributed

View File

@@ -1,6 +1,7 @@
// Copyright © 2024 Apple Inc.
#include "mlx/distributed/distributed.h"
#include "mlx/distributed/distributed_impl.h"
namespace mlx::core::distributed {
@@ -33,6 +34,8 @@ Stream communication_stream() {
void all_sum(Group group, const array& input, array& output) {}
void all_gather(Group group, const array& input, array& output) {}
void send(Group group, const array& input, int dst) {}
void recv(Group group, array& out, int src) {}
} // namespace detail

View File

@@ -1,5 +1,7 @@
// Copyright © 2024 Apple Inc.
#include <sstream>
#include "mlx/distributed/ops.h"
#include "mlx/distributed/primitives.h"
@@ -17,7 +19,10 @@ Group to_group(std::optional<Group> group) {
} // namespace
array all_sum(const array& x, std::optional<Group> group_) {
array all_sum(
const array& x,
std::optional<Group> group_ /* = std::nullopt */,
StreamOrDevice s /* = {} */) {
auto group = to_group(group_);
if (group.size() == 1) {
@@ -27,11 +32,14 @@ array all_sum(const array& x, std::optional<Group> group_) {
return array(
x.shape(),
x.dtype(),
std::make_shared<AllReduce>(group, AllReduce::Sum),
std::make_shared<AllReduce>(to_stream(s), group, AllReduce::Sum),
{x});
}
array all_gather(const array& x, std::optional<Group> group_) {
array all_gather(
const array& x,
std::optional<Group> group_ /* = std::nullopt */,
StreamOrDevice s /* = {} */) {
auto group = to_group(group_);
if (group.size() == 1) {
@@ -47,8 +55,63 @@ array all_gather(const array& x, std::optional<Group> group_) {
return array(
std::move(result_shape),
x.dtype(),
std::make_shared<AllGather>(group),
std::make_shared<AllGather>(to_stream(s), group),
{x});
}
array send(
const array& x,
int dst,
std::optional<Group> group_ /* = std::nullopt */,
StreamOrDevice s /* = {} */) {
auto group = to_group(group_);
if (group.size() == 1) {
throw std::invalid_argument("Cannot send to a singleton group");
}
if (dst < 0 || dst >= group.size()) {
std::ostringstream msg;
msg << "Invalid destination=" << dst << " for a group of size "
<< group.size();
throw std::invalid_argument(msg.str());
}
return array(
{0}, int32, std::make_shared<Send>(to_stream(s), group, dst), {x});
}
array recv(
std::vector<int> shape,
Dtype dtype,
int src,
std::optional<Group> group_ /* = std::nullopt */,
StreamOrDevice s /* = {} */) {
auto group = to_group(group_);
if (group.size() == 1) {
throw std::invalid_argument("Cannot recv from a singleton group");
}
if (src < 0 || src >= group.size()) {
std::ostringstream msg;
msg << "Invalid source=" << src << " for a group of size " << group.size();
throw std::invalid_argument(msg.str());
}
return array(
std::move(shape),
std::move(dtype),
std::make_shared<Recv>(to_stream(s), group, src),
std::vector<array>{});
}
array recv_like(
const array& x,
int src,
std::optional<Group> group_ /* = std::nullopt */,
StreamOrDevice s /* = {} */) {
return recv(x.shape(), x.dtype(), src, group_, s);
}
} // namespace mlx::core::distributed

View File

@@ -5,10 +5,37 @@
#include <optional>
#include "mlx/distributed/distributed.h"
#include "mlx/utils.h"
namespace mlx::core::distributed {
array all_sum(const array& x, std::optional<Group> group = std::nullopt);
array all_gather(const array& x, std::optional<Group> group = std::nullopt);
array all_sum(
const array& x,
std::optional<Group> group = std::nullopt,
StreamOrDevice s = {});
array all_gather(
const array& x,
std::optional<Group> group = std::nullopt,
StreamOrDevice S = {});
array send(
const array& x,
int dst,
std::optional<Group> group = std::nullopt,
StreamOrDevice s = {});
array recv(
std::vector<int> shape,
Dtype dtype,
int src,
std::optional<Group> group = std::nullopt,
StreamOrDevice s = {});
array recv_like(
const array& x,
int src,
std::optional<Group> group = std::nullopt,
StreamOrDevice s = {});
} // namespace mlx::core::distributed

View File

@@ -3,7 +3,6 @@
#include <cassert>
#include "mlx/allocator.h"
#include "mlx/backend/common/copy.h"
#include "mlx/distributed/ops.h"
#include "mlx/distributed/primitives.h"
#include "mlx/ops.h"
@@ -36,7 +35,7 @@ std::pair<std::vector<array>, std::vector<int>> AllReduce::vmap(
const std::vector<int>& axes) {
switch (reduce_type_) {
case Sum:
return {{all_sum(inputs[0], group())}, axes};
return {{all_sum(inputs[0], group(), stream())}, axes};
default:
throw std::runtime_error("Only all reduce sum is supported for now");
}
@@ -48,7 +47,7 @@ std::vector<array> AllReduce::jvp(
const std::vector<int>& argnums) {
switch (reduce_type_) {
case Sum:
return {all_sum(tangents[0], group())};
return {all_sum(tangents[0], group(), stream())};
default:
throw std::runtime_error("Only all reduce sum is supported for now");
}
@@ -76,14 +75,14 @@ void AllGather::eval_cpu(
std::pair<std::vector<array>, std::vector<int>> AllGather::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
return {{all_gather(inputs[0], group())}, axes};
return {{all_gather(inputs[0], group(), stream())}, axes};
}
std::vector<array> AllGather::jvp(
const std::vector<array>& primals,
const std::vector<array>& tangents,
const std::vector<int>& argnums) {
return {all_gather(tangents[0], group())};
return {all_gather(tangents[0], group(), stream())};
}
std::vector<array> AllGather::vjp(
@@ -99,4 +98,29 @@ std::vector<array> AllGather::vjp(
return {slice(cotangents[0], starts, stops)};
}
void Send::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() == 1);
assert(outputs.size() == 1);
distributed::detail::send(group(), inputs[0], dst_);
}
std::pair<std::vector<array>, std::vector<int>> Send::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
return {{send(inputs[0], dst_, group(), stream())}, axes};
}
void Recv::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() == 0);
assert(outputs.size() == 1);
outputs[0].set_data(allocator::malloc_or_wait(outputs[0].nbytes()));
distributed::detail::recv(group(), outputs[0], src_);
}
} // namespace mlx::core::distributed

View File

@@ -3,20 +3,15 @@
#pragma once
#include "mlx/distributed/distributed.h"
#include "mlx/distributed/distributed_impl.h"
#include "mlx/primitives.h"
namespace mlx::core::distributed {
class DistPrimitive : public Primitive {
public:
DistPrimitive(Group group)
: Primitive(detail::communication_stream()), group_(group) {}
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override {
throw std::runtime_error(
"Communication primitives cannot be run on the GPU");
}
DistPrimitive(Stream stream, Group group)
: Primitive(stream), group_(group) {}
const Group& group() const {
return group_;
@@ -30,11 +25,13 @@ class AllReduce : public DistPrimitive {
public:
enum ReduceType { And, Or, Sum, Prod, Min, Max };
AllReduce(Group group, ReduceType reduce_type)
: DistPrimitive(group), reduce_type_(reduce_type) {}
AllReduce(Stream stream, Group group, ReduceType reduce_type)
: DistPrimitive(stream, group), reduce_type_(reduce_type) {}
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
std::pair<std::vector<array>, std::vector<int>> vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) override;
@@ -77,10 +74,13 @@ class AllReduce : public DistPrimitive {
class AllGather : public DistPrimitive {
public:
AllGather(Group group) : DistPrimitive(group) {}
AllGather(Stream stream, Group group) : DistPrimitive(stream, group) {}
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
std::pair<std::vector<array>, std::vector<int>> vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) override;
@@ -97,4 +97,39 @@ class AllGather : public DistPrimitive {
DEFINE_PRINT(AllGather);
};
class Send : public DistPrimitive {
public:
Send(Stream stream, Group group, int dst)
: DistPrimitive(stream, group), dst_(dst) {}
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
std::pair<std::vector<array>, std::vector<int>> vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) override;
DEFINE_PRINT(Send);
private:
int dst_;
};
class Recv : public DistPrimitive {
public:
Recv(Stream stream, Group group, int src)
: DistPrimitive(stream, group), src_(src) {}
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
DEFINE_PRINT(Recv);
private:
int src_;
};
} // namespace mlx::core::distributed

View File

@@ -1,8 +1,10 @@
// Copyright © 2023-2024 Apple Inc.
#include <cassert>
#include <iostream>
#include <numeric>
#include <regex>
#include "mlx/backend/common/compiled.h"
#include "mlx/fast.h"
#include "mlx/fast_primitives.h"
#include "mlx/ops.h"
@@ -323,7 +325,7 @@ bool LayerNormVJP::is_equivalent(const Primitive& other) const {
}
array rope(
const array& x,
std::vector<array> inputs,
int dims,
bool traditional,
float base,
@@ -331,15 +333,23 @@ array rope(
int offset,
bool forward,
StreamOrDevice s) {
auto& x = inputs[0];
if (x.ndim() < 3) {
std::ostringstream msg;
msg << "[rope] Input must have at least 3 dimensions but got input with "
<< x.ndim() << " dimensions.";
throw std::invalid_argument(msg.str());
}
if (inputs.size() == 2 &&
(inputs[1].ndim() != 1 || inputs[1].shape(0) != dims / 2)) {
std::ostringstream msg;
msg << "[rope] freqs must be one dimensional with size " << dims / 2
<< " but got shape " << inputs[1].shape() << ".";
throw std::invalid_argument(msg.str());
}
auto fallback = [dims, traditional, base, scale, offset, forward, s](
const std::vector<array>& inputs) {
std::vector<array> inputs) {
auto& shape = inputs[0].shape();
int ndim = shape.size();
auto x = reshape(inputs[0], {-1, shape[ndim - 2], shape[ndim - 1]}, s);
@@ -348,10 +358,20 @@ array rope(
// Compute sines and cosines
auto half_dims = dims / 2;
auto positions = multiply(arange(offset, N, t, s), array(scale, t), s);
auto freqs = negative(arange(0, half_dims, t, s), s);
freqs = exp(multiply(freqs, array(std::log(base) / half_dims, t), s), s);
auto default_inv_freqs = [&inputs, &s, &t, base, half_dims]() {
return exp(
multiply(
arange(0, -half_dims, -1, t, s),
array(std::log(base) / half_dims, t),
s),
s);
};
auto inv_freqs =
inputs.size() == 2 ? reciprocal(inputs[1], s) : default_inv_freqs();
auto theta =
multiply(expand_dims(positions, 1, s), expand_dims(freqs, 0, s), s);
multiply(expand_dims(positions, 1, s), expand_dims(inv_freqs, 0, s), s);
auto coss = cos(theta, s);
auto sins = sin(theta, s);
@@ -409,20 +429,39 @@ array rope(
x.dtype(),
std::make_shared<RoPE>(
stream, fallback, dims, traditional, base, scale, offset, forward),
{x});
std::move(inputs));
}
return fallback({x})[0];
return fallback(std::move(inputs))[0];
}
array rope(
const array& x,
int dims,
bool traditional,
float base,
std::optional<float> base,
float scale,
int offset,
const std::optional<array>& freqs /* = std::nullopt */,
StreamOrDevice s /* = {} */) {
return rope(x, dims, traditional, base, scale, offset, true, s);
std::vector<array> inputs = {x};
if (freqs) {
inputs.push_back(astype(*freqs, float32, s));
if (base) {
throw std::invalid_argument(
"[rope] Only one of base or freqs can have a value.");
}
} else if (!base) {
throw std::invalid_argument("[rope] Neither base nor freqs has a value.");
}
return rope(
std::move(inputs),
dims,
traditional,
base.has_value() ? *base : 1.0,
scale,
offset,
true,
s);
}
std::vector<array> RoPE::vjp(
@@ -438,16 +477,27 @@ std::vector<array> RoPE::vjp(
offset = offset_,
forward = forward_,
s](std::vector<array> inputs) {
return std::vector<array>{
rope(inputs[0], dims, traditional, base, scale, offset, !forward, s)};
return std::vector<array>{rope(
std::move(inputs),
dims,
traditional,
base,
scale,
offset,
!forward,
s)};
};
auto inputs = cotangents;
if (primals.size() == 2) {
inputs.push_back(primals[1]);
}
return {array(
cotangents[0].shape(),
cotangents[0].dtype(),
std::make_shared<RoPE>(
s, fallback, dims_, traditional_, base_, scale_, offset_, !forward_),
cotangents)};
std::move(inputs))};
}
bool RoPE::is_equivalent(const Primitive& other) const {
@@ -465,6 +515,7 @@ array scaled_dot_product_attention(
const array& values,
const float scale,
const std::optional<array>& mask,
const std::optional<int>& memory_efficient_threshold,
StreamOrDevice s) {
for (const auto& tensor : {queries, keys, values}) {
if (tensor.ndim() != 4) {
@@ -535,6 +586,11 @@ array scaled_dot_product_attention(
* * dtype is not fp32 or fp16
*/
int threshold = 1e6;
if (memory_efficient_threshold.has_value()) {
threshold = std::max(1, memory_efficient_threshold.value());
}
bool needs_mask = mask.has_value();
auto fallback = [scale, needs_mask, final_type, n_q_heads, n_kv_heads, &s](
const std::vector<array>& inputs) {
@@ -581,9 +637,10 @@ array scaled_dot_product_attention(
bool implementation_supports_use_case =
supports_sdpa || supports_full_self_attention;
// disabling full self attention until perf is tuned;
// likewise for sdpa
implementation_supports_use_case &= false;
// sdpa gpu shader is disabled except for memory efficient opt-in
const int seq_for_threshold = queries.shape(2);
bool use_memory_efficient_impl = seq_for_threshold >= threshold;
implementation_supports_use_case &= use_memory_efficient_impl;
if (implementation_supports_use_case) {
auto out_shape =
@@ -859,4 +916,285 @@ array affine_dequantize(
return fallback({w, scales, biases})[0];
}
void validate_output_shapes(
std::map<std::string, std::vector<int>> output_shapes,
std::map<std::string, Dtype> output_dtypes) {
// Make sure output shapes and dtypes have the same keys
bool validated = true;
if (output_shapes.size() == 0) {
throw std::invalid_argument(
"[metal_kernel] Must specify at least one output.");
}
if (output_shapes.size() != output_dtypes.size()) {
validated = false;
} else {
for (const auto& kv : output_shapes) {
if (output_dtypes.find(kv.first) == output_dtypes.end()) {
validated = false;
break;
}
}
}
if (!validated) {
throw std::invalid_argument(
"[metal_kernel] `output_shapes` and `output_dtypes` must have the same keys.");
}
}
void write_signature(
std::string func_name,
std::string& source,
std::map<std::string, array>& inputs,
std::map<std::string, std::vector<int>>& output_shapes,
std::map<std::string, Dtype>& output_dtypes,
std::optional<std::map<std::string, TemplateArg>> template_args,
std::vector<CustomKernelShapeInfo>& shape_infos,
bool atomic_outputs,
std::ostringstream& kernel_source) {
// Auto-generate a function signature based on `template_args`
// and the dtype/shape of the arrays passed as `inputs`.
if (template_args && template_args.value().size() > 0) {
kernel_source << "template <";
int i = 0;
for (const auto& [name, arg] : template_args.value()) {
std::string param_type;
if (std::holds_alternative<int>(arg)) {
param_type = "int";
} else if (std::holds_alternative<bool>(arg)) {
param_type = "bool";
} else if (std::holds_alternative<Dtype>(arg)) {
param_type = "typename";
}
if (i > 0) {
kernel_source << ", ";
}
kernel_source << param_type << " " << name;
i++;
}
kernel_source << ">" << std::endl;
}
kernel_source << "[[kernel]] void " << func_name << "(" << std::endl;
// Metal attributes are automatically added to the arguments if present
const std::vector<std::pair<std::string, std::string>> metal_attributes = {
{"dispatch_quadgroups_per_threadgroup", "uint"},
{"dispatch_simdgroups_per_threadgroup", "uint"},
{"dispatch_threads_per_threadgroup", "uint3"},
{"grid_origin", "uint3"},
{"grid_size", "uint3"},
{"quadgroup_index_in_threadgroup", "uint"},
{"quadgroups_per_threadgroup", "uint"},
{"simdgroup_index_in_threadgroup", "uint"},
{"simdgroups_per_threadgroup", "uint"},
{"thread_execution_width", "uint"},
{"thread_index_in_quadgroup", "uint"},
{"thread_index_in_simdgroup", "uint"},
{"thread_index_in_threadgroup", "uint"},
{"thread_position_in_grid", "uint3"},
{"thread_position_in_threadgroup", "uint3"},
{"threadgroup_position_in_grid", "uint3"},
{"threadgroups_per_grid", "uint3"},
{"threads_per_grid", "uint3"},
{"threads_per_simdgroup", "uint"},
{"thread_per_threadgroup", "uint3"},
};
std::vector<std::pair<std::string, std::string>> attrs;
for (const auto& [attr, dtype] : metal_attributes) {
if (source.find(attr) != std::string::npos) {
attrs.push_back({attr, dtype});
}
}
int index = 0;
constexpr int max_constant_array_size = 8;
// Add inputs
for (const auto& [name, arr] : inputs) {
auto dtype = get_type_string(arr.dtype());
bool is_constant =
arr.is_available() && arr.size() < max_constant_array_size;
std::string location = is_constant ? "constant" : "device";
std::string ref = arr.ndim() == 0 ? "&" : "*";
kernel_source << " const " << location << " " << dtype << ref << " "
<< name << " [[buffer(" << index << ")]]," << std::endl;
index++;
// Add input shape, strides and ndim if present in the source
CustomKernelShapeInfo shape_info;
if (arr.ndim() > 0) {
if (source.find(name + "_shape") != std::string::npos) {
kernel_source << " const constant int* " << name << "_shape [[buffer("
<< index << ")]]," << std::endl;
shape_info.shape = true;
index++;
}
if (source.find(name + "_strides") != std::string::npos) {
kernel_source << " const constant size_t* " << name
<< "_strides [[buffer(" << index << ")]]," << std::endl;
shape_info.strides = true;
index++;
}
if (source.find(name + "_ndim") != std::string::npos) {
kernel_source << " const constant int& " << name << "_ndim [[buffer("
<< index << ")]]," << std::endl;
shape_info.ndim = true;
index++;
}
}
shape_infos.push_back(shape_info);
}
// Add outputs
for (const auto& [name, dtype] : output_dtypes) {
kernel_source << " device ";
auto type_string = get_type_string(dtype);
if (atomic_outputs) {
kernel_source << "atomic<" << type_string << ">";
} else {
kernel_source << type_string;
}
kernel_source << "* " << name << " [[buffer(" << index << ")]]";
if (index < inputs.size() + output_shapes.size() - 1 || attrs.size() > 0) {
kernel_source << "," << std::endl;
} else {
kernel_source << ") {" << std::endl;
}
index++;
}
// Add metal attributes e.g. `threadgroup_index_in_grid`
index = 0;
for (const auto& [attr, dtype] : attrs) {
kernel_source << " " << dtype << " " << attr << " [[" << attr << "]]";
if (index < attrs.size() - 1) {
kernel_source << "," << std::endl;
} else {
kernel_source << ") {" << std::endl;
}
index++;
}
kernel_source << source << std::endl;
kernel_source << "}" << std::endl;
}
std::string write_template(std::map<std::string, TemplateArg>& template_args) {
std::ostringstream template_def;
template_def << "<";
int i = 0;
for (const auto& [name, arg] : template_args) {
if (i > 0) {
template_def << ", ";
}
if (std::holds_alternative<int>(arg)) {
template_def << std::get<int>(arg);
} else if (std::holds_alternative<bool>(arg)) {
template_def << std::get<bool>(arg);
} else if (std::holds_alternative<Dtype>(arg)) {
template_def << get_type_string(std::get<Dtype>(arg));
}
i++;
}
template_def << ">";
return template_def.str();
}
std::map<std::string, array> MetalKernel::operator()(
std::map<std::string, array>& inputs,
std::map<std::string, std::vector<int>> output_shapes,
std::map<std::string, Dtype> output_dtypes,
std::tuple<int, int, int> grid,
std::tuple<int, int, int> threadgroup,
std::optional<std::map<std::string, TemplateArg>> template_args,
std::optional<float> init_value,
bool verbose,
StreamOrDevice s_) {
validate_output_shapes(output_shapes, output_dtypes);
auto s = to_stream(s_);
if (s.device != Device::gpu) {
throw std::invalid_argument(
"[metal_kernel] MetalKernel only works on GPU.");
}
std::ostringstream func_name;
std::string template_def = "";
bool needs_template = template_args && template_args.value().size() > 0;
std::string hash_key = "";
if (needs_template) {
std::regex disallowed_chars("\\<|\\>|(, )");
template_def = write_template(template_args.value());
hash_key = std::regex_replace(template_def, disallowed_chars, "_");
hash_key.pop_back();
}
func_name << "custom_kernel_" << name_ << hash_key;
std::string kernel_name = func_name.str();
std::ostringstream kernel_source;
kernel_source << header_ << std::endl;
std::vector<CustomKernelShapeInfo> shape_infos;
write_signature(
func_name.str(),
source_,
inputs,
output_shapes,
output_dtypes,
template_args,
shape_infos,
atomic_outputs_,
kernel_source);
if (needs_template) {
template_def = func_name.str() + template_def;
kernel_source << std::endl
<< "template [[host_name(\"" << kernel_name
<< "\")]] [[kernel]] decltype(" << template_def << ") "
<< template_def << ";" << std::endl;
}
if (verbose) {
std::cout << "Generated source code for `" << name_ << "`:" << std::endl
<< "```" << std::endl
<< kernel_source.str() << std::endl
<< "```" << std::endl;
}
std::vector<array> in_arrs;
for (const auto& kv : inputs) {
in_arrs.push_back(kv.second);
}
std::vector<std::string> out_keys;
std::vector<std::vector<int>> out_shapes;
for (const auto& [name, shape] : output_shapes) {
out_keys.push_back(name);
out_shapes.push_back(shape);
}
std::vector<Dtype> out_dtypes;
for (const auto& kv : output_dtypes) {
out_dtypes.push_back(kv.second);
}
std::map<std::string, array> outputs;
auto outputs_vec = array::make_arrays(
out_shapes,
out_dtypes,
std::make_shared<CustomKernel>(
s,
kernel_name,
kernel_source.str(),
grid,
threadgroup,
shape_infos,
ensure_row_contiguous_,
init_value),
in_arrs);
int i = 0;
for (const auto& key : out_keys) {
outputs.insert({key, outputs_vec[i]});
i++;
}
return outputs;
}
} // namespace mlx::core::fast

View File

@@ -2,6 +2,7 @@
#pragma once
#include <map>
#include <optional>
#include "mlx/utils.h"
@@ -25,9 +26,10 @@ array rope(
const array& x,
int dims,
bool traditional,
float base,
std::optional<float> base,
float scale,
int offset,
const std::optional<array>& freqs = std::nullopt,
StreamOrDevice s = {});
/** Computes: O = softmax(Q @ K.T) @ V **/
@@ -37,6 +39,7 @@ array scaled_dot_product_attention(
const array& values,
const float scale,
const std::optional<array>& mask = std::nullopt,
const std::optional<int>& memory_efficient_threshold = std::nullopt,
StreamOrDevice s = {});
std::tuple<array, array, array> affine_quantize(
@@ -61,4 +64,39 @@ array affine_dequantize(
int bits = 4,
StreamOrDevice s = {});
typedef std::variant<int, bool, Dtype> TemplateArg;
class MetalKernel {
public:
MetalKernel(
const std::string& name,
const std::string& source,
const std::string& header = "",
bool ensure_row_contiguous = true,
bool atomic_outputs = false)
: name_(name),
source_(source),
header_(header),
ensure_row_contiguous_(ensure_row_contiguous),
atomic_outputs_(atomic_outputs) {}
std::map<std::string, array> operator()(
std::map<std::string, array>& inputs,
std::map<std::string, std::vector<int>> output_shapes,
std::map<std::string, Dtype> output_dtypes,
std::tuple<int, int, int> grid,
std::tuple<int, int, int> threadgroup,
std::optional<std::map<std::string, TemplateArg>> template_args =
std::nullopt,
std::optional<float> init_value = std::nullopt,
bool verbose = false,
StreamOrDevice s = {});
private:
std::string name_;
std::string source_;
std::string header_;
bool ensure_row_contiguous_;
bool atomic_outputs_;
};
} // namespace mlx::core::fast

View File

@@ -1,5 +1,7 @@
// Copyright © 2024 Apple Inc.
#include <optional>
#include "mlx/primitives.h"
namespace mlx::core::fast {
@@ -242,4 +244,50 @@ class AffineQuantize : public Custom {
bool dequantize_;
};
struct CustomKernelShapeInfo {
bool shape = false;
bool strides = false;
bool ndim = false;
};
class CustomKernel : public Primitive {
public:
CustomKernel(
Stream stream,
std::string name,
std::string source,
std::tuple<int, int, int> grid,
std::tuple<int, int, int> threadgroup,
std::vector<CustomKernelShapeInfo> shape_infos,
bool ensure_row_contiguous,
std::optional<float> init_value)
: Primitive(stream),
source_(source),
name_(name),
grid_(grid),
threadgroup_(threadgroup),
shape_infos_(shape_infos),
ensure_row_contiguous_(ensure_row_contiguous),
init_value_(init_value) {}
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override {
throw std::runtime_error("Custom Metal kernels only run on GPU.");
}
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
DEFINE_PRINT(CustomKernel);
private:
std::string source_;
std::string name_;
std::tuple<int, int, int> grid_;
std::tuple<int, int, int> threadgroup_;
std::vector<CustomKernelShapeInfo> shape_infos_;
bool ensure_row_contiguous_;
std::optional<float> init_value_;
};
} // namespace mlx::core::fast

View File

@@ -2,6 +2,7 @@
#include <cstdint>
#include <cstring>
#include <fstream>
#include <numeric>
#include "mlx/io/gguf.h"

View File

@@ -298,7 +298,65 @@ array load(std::shared_ptr<io::Reader> in_stream, StreamOrDevice s) {
/** Load array from file in .npy format */
array load(std::string file, StreamOrDevice s) {
return load(std::make_shared<io::FileReader>(std::move(file)), s);
return load(std::make_shared<io::ParallelFileReader>(std::move(file)), s);
}
namespace io {
ThreadPool& thread_pool() {
static ThreadPool pool_{4};
return pool_;
}
ThreadPool ParallelFileReader::thread_pool_{4};
void ParallelFileReader::read(char* data, size_t n) {
while (n != 0) {
auto m = ::read(fd_, data, std::min(n, static_cast<size_t>(INT32_MAX)));
if (m <= 0) {
std::ostringstream msg;
msg << "[read] Unable to read " << n << " bytes from file.";
throw std::runtime_error(msg.str());
}
data += m;
n -= m;
}
}
void ParallelFileReader::read(char* data, size_t n, size_t offset) {
auto readfn = [fd = fd_](size_t offset, size_t size, char* buffer) -> bool {
while (size != 0) {
auto m = pread(fd, buffer, size, offset);
if (m <= 0) {
return false;
}
buffer += m;
size -= m;
}
return true;
};
std::vector<std::future<bool>> futs;
while (n != 0) {
if (n < batch_size_) {
if (!readfn(offset, n, data)) {
throw std::runtime_error("[read] Unable to read from file.");
}
break;
} else {
size_t m = batch_size_;
futs.emplace_back(thread_pool_.enqueue(readfn, offset, m, data));
data += m;
n -= m;
offset += m;
}
}
for (auto& f : futs) {
if (!f.get()) {
throw std::runtime_error("[read] Unable to read from file.");
}
}
}
} // namespace io
} // namespace mlx::core

View File

@@ -2,14 +2,20 @@
#pragma once
#include <fstream>
#include <istream>
#include <fcntl.h>
#include <sys/stat.h>
#include <unistd.h>
#include <memory>
#include <sstream>
#include "mlx/io/threadpool.h"
namespace mlx::core {
namespace io {
ThreadPool& thread_pool();
class Reader {
public:
virtual bool is_open() const = 0;
@@ -19,7 +25,9 @@ class Reader {
int64_t off,
std::ios_base::seekdir way = std::ios_base::beg) = 0;
virtual void read(char* data, size_t n) = 0;
virtual void read(char* data, size_t n, size_t offset) = 0;
virtual std::string label() const = 0;
virtual ~Reader() = default;
};
class Writer {
@@ -32,73 +40,93 @@ class Writer {
std::ios_base::seekdir way = std::ios_base::beg) = 0;
virtual void write(const char* data, size_t n) = 0;
virtual std::string label() const = 0;
virtual ~Writer() = default;
};
class FileReader : public Reader {
class ParallelFileReader : public Reader {
public:
explicit FileReader(std::ifstream is)
: is_(std::move(is)), label_("stream") {}
explicit FileReader(std::string file_path)
: is_(std::ifstream(file_path, std::ios::binary)),
label_(std::move(file_path)) {}
explicit ParallelFileReader(std::string file_path)
: fd_(open(file_path.c_str(), O_RDONLY)), label_(std::move(file_path)) {}
~ParallelFileReader() override {
close(fd_);
}
bool is_open() const override {
return is_.is_open();
return fd_ > 0;
}
bool good() const override {
return is_.good();
return is_open();
}
size_t tell() override {
return is_.tellg();
return lseek(fd_, 0, SEEK_CUR);
}
void seek(int64_t off, std::ios_base::seekdir way = std::ios_base::beg)
override {
is_.seekg(off, way);
void seek(int64_t, std::ios_base::seekdir = std::ios_base::beg) override {
throw std::runtime_error("[ParallelFileReader::seek] Not allowed");
}
void read(char* data, size_t n) override {
is_.read(data, n);
}
// Warning: do not use this function from multiple threads as
// it advances the file descriptor
void read(char* data, size_t n) override;
void read(char* data, size_t n, size_t offset) override;
std::string label() const override {
return "file " + label_;
}
private:
std::ifstream is_;
static constexpr size_t batch_size_ = 1 << 25;
static ThreadPool thread_pool_;
int fd_;
std::string label_;
};
class FileWriter : public Writer {
public:
explicit FileWriter(std::ofstream os)
: os_(std::move(os)), label_("stream") {}
explicit FileWriter(std::string file_path)
: os_(std::ofstream(file_path, std::ios::binary)),
: fd_(open(file_path.c_str(), O_CREAT | O_WRONLY | O_TRUNC, 0644)),
label_(std::move(file_path)) {}
~FileWriter() override {
close(fd_);
}
bool is_open() const override {
return os_.is_open();
return fd_ >= 0;
}
bool good() const override {
return os_.good();
return is_open();
}
size_t tell() override {
return os_.tellp();
return lseek(fd_, 0, SEEK_CUR);
}
void seek(int64_t off, std::ios_base::seekdir way = std::ios_base::beg)
override {
os_.seekp(off, way);
if (way == std::ios_base::beg) {
lseek(fd_, off, 0);
} else {
lseek(fd_, off, SEEK_CUR);
}
}
void write(const char* data, size_t n) override {
os_.write(data, n);
while (n != 0) {
auto m = ::write(fd_, data, std::min(n, static_cast<size_t>(INT32_MAX)));
if (m <= 0) {
std::ostringstream msg;
msg << "[write] Unable to write " << n << " bytes to file.";
throw std::runtime_error(msg.str());
}
data += m;
n -= m;
}
}
std::string label() const override {
@@ -106,7 +134,7 @@ class FileWriter : public Writer {
}
private:
std::ofstream os_;
int fd_;
std::string label_;
};

View File

@@ -147,7 +147,7 @@ SafetensorsLoad load_safetensors(
}
SafetensorsLoad load_safetensors(const std::string& file, StreamOrDevice s) {
return load_safetensors(std::make_shared<io::FileReader>(file), s);
return load_safetensors(std::make_shared<io::ParallelFileReader>(file), s);
}
void save_safetensors(

103
mlx/io/threadpool.h Normal file
View File

@@ -0,0 +1,103 @@
// This code was modified from https://github.com/progschj/ThreadPool
// The original License is copied below:
//
// Copyright (c) 2012 Jakob Progsch, Václav Zeman
// This software is provided 'as-is', without any express or implied
// warranty. In no event will the authors be held liable for any damages
// arising from the use of this software.
//
// Permission is granted to anyone to use this software for any purpose,
// including commercial applications, and to alter it and redistribute it
// freely, subject to the following restrictions:
//
// 1. The origin of this software must not be misrepresented; you must not
// claim that you wrote the original software. If you use this software
// in a product, an acknowledgment in the product documentation would be
// appreciated but is not required.
//
// 2. Altered source versions must be plainly marked as such, and must not be
// misrepresented as being the original software.
//
// 3. This notice may not be removed or altered from any source
// distribution.
#pragma once
#include <condition_variable>
#include <functional>
#include <future>
#include <memory>
#include <mutex>
#include <queue>
#include <stdexcept>
#include <thread>
#include <vector>
class ThreadPool {
public:
ThreadPool(size_t);
template <class F, class... Args>
auto enqueue(F&& f, Args&&... args)
-> std::future<typename std::result_of_t<F(Args...)>>;
~ThreadPool();
private:
std::vector<std::thread> workers;
std::queue<std::function<void()>> tasks;
std::mutex queue_mutex;
std::condition_variable condition;
bool stop;
};
inline ThreadPool::ThreadPool(size_t threads) : stop(false) {
for (size_t i = 0; i < threads; ++i)
workers.emplace_back([this] {
for (;;) {
std::function<void()> task;
{
std::unique_lock<std::mutex> lock(this->queue_mutex);
this->condition.wait(
lock, [this] { return this->stop || !this->tasks.empty(); });
if (this->stop && this->tasks.empty())
return;
task = std::move(this->tasks.front());
this->tasks.pop();
}
task();
}
});
}
template <class F, class... Args>
auto ThreadPool::enqueue(F&& f, Args&&... args)
-> std::future<typename std::result_of_t<F(Args...)>> {
using return_type = typename std::result_of_t<F(Args...)>;
auto task = std::make_shared<std::packaged_task<return_type()>>(
std::bind(std::forward<F>(f), std::forward<Args>(args)...));
std::future<return_type> res = task->get_future();
{
std::unique_lock<std::mutex> lock(queue_mutex);
if (stop) {
throw std::runtime_error(
"[ThreadPool::enqueue] Not allowed on stopped ThreadPool");
}
tasks.emplace([task]() { (*task)(); });
}
condition.notify_one();
return res;
}
inline ThreadPool::~ThreadPool() {
{
std::unique_lock<std::mutex> lock(queue_mutex);
stop = true;
}
condition.notify_all();
for (std::thread& worker : workers)
worker.join();
}

View File

@@ -306,6 +306,49 @@ array cholesky(
{a});
}
array pinv(const array& a, StreamOrDevice s /* = {} */) {
if (a.dtype() != float32) {
std::ostringstream msg;
msg << "[linalg::pinv] Arrays must type float32. Received array "
<< "with type " << a.dtype() << ".";
throw std::invalid_argument(msg.str());
}
if (a.ndim() < 2) {
std::ostringstream msg;
msg << "[linalg::pinv] Arrays must have >= 2 dimensions. Received array "
<< "with " << a.ndim() << " dimensions.";
throw std::invalid_argument(msg.str());
}
int m = a.shape(-2);
int n = a.shape(-1);
int k = std::min(m, n);
auto outs = linalg::svd(a, s);
array U = outs[0];
array S = outs[1];
array V = outs[2];
std::vector<int> starts(a.ndim(), 0);
std::vector<int> ends = a.shape();
int i = a.ndim() - 2;
int j = a.ndim() - 1;
// Prepare U
ends[i] = m;
ends[j] = k;
U = swapaxes(slice(U, starts, ends, s), -1, -2, s);
// Prepare V
ends[i] = k;
ends[j] = n;
V = swapaxes(slice(V, starts, ends, s), -1, -2, s);
// Prepare S
S = expand_dims(S, -2, s);
return matmul(divide(V, S, s), U);
}
array cholesky_inv(
const array& L,
bool upper /* = false */,

View File

@@ -70,6 +70,8 @@ array tri_inv(const array& a, bool upper = false, StreamOrDevice s = {});
array cholesky(const array& a, bool upper = false, StreamOrDevice s = {});
array pinv(const array& a, StreamOrDevice s = {});
array cholesky_inv(const array& a, bool upper = false, StreamOrDevice s = {});
} // namespace mlx::core::linalg

View File

@@ -16,9 +16,11 @@ namespace mlx::core {
namespace {
std::pair<std::vector<int>, std::vector<int>> compute_reduce_shape(
std::tuple<std::vector<int>, std::vector<int>, std::vector<int>, bool>
compute_reduce_shape(
const std::vector<int>& axes,
const std::vector<int>& shape) {
bool is_noop = true;
std::set<int> axes_set;
auto ndim = shape.size();
for (auto ax : axes) {
@@ -35,15 +37,18 @@ std::pair<std::vector<int>, std::vector<int>> compute_reduce_shape(
throw std::invalid_argument("Duplicate axes detected in reduction.");
}
std::vector<int> out_shape;
std::vector<int> squeezed_shape;
for (int i = 0; i < ndim; ++i) {
if (axes_set.count(i) == 0) {
out_shape.push_back(shape[i]);
squeezed_shape.push_back(shape[i]);
} else {
out_shape.push_back(1);
}
is_noop &= (out_shape.back() == shape[i]);
}
std::vector<int> sorted_axes(axes_set.begin(), axes_set.end());
return {out_shape, sorted_axes};
return {out_shape, sorted_axes, squeezed_shape, is_noop};
}
Dtype at_least_float(const Dtype& d) {
@@ -1378,6 +1383,10 @@ array isinf(const array& a, StreamOrDevice s /* = {} */) {
return logical_or(isposinf(a, s), isneginf(a, s), s);
}
array isfinite(const array& a, StreamOrDevice s /* = {} */) {
return logical_not(logical_or(isinf(a, s), isnan(a, s), s), s);
}
array isposinf(const array& a, StreamOrDevice s /* = {} */) {
if (issubdtype(a.dtype(), integer) || a.dtype() == bool_) {
return full(a.shape(), false, bool_, s);
@@ -1498,17 +1507,17 @@ array all(
const std::vector<int>& axes,
bool keepdims /* = false */,
StreamOrDevice s /* = {}*/) {
if (axes.empty()) {
return astype(a, bool_, s);
}
auto [out_shape, sorted_axes] = compute_reduce_shape(axes, a.shape());
auto out = array(
out_shape,
bool_,
std::make_shared<Reduce>(to_stream(s), Reduce::And, sorted_axes),
{a});
auto [out_shape, sorted_axes, squeezed_shape, is_noop] =
compute_reduce_shape(axes, a.shape());
auto out = (is_noop)
? astype(a, bool_, s)
: array(
std::move(out_shape),
bool_,
std::make_shared<Reduce>(to_stream(s), Reduce::And, sorted_axes),
{a});
if (!keepdims) {
out = squeeze(out, sorted_axes, s);
out = reshape(out, std::move(squeezed_shape), s);
}
return out;
}
@@ -1532,17 +1541,17 @@ array any(
const std::vector<int>& axes,
bool keepdims /* = false */,
StreamOrDevice s /* = {}*/) {
if (axes.empty()) {
return astype(a, bool_, s);
}
auto [out_shape, sorted_axes] = compute_reduce_shape(axes, a.shape());
auto out = array(
out_shape,
bool_,
std::make_shared<Reduce>(to_stream(s), Reduce::Or, sorted_axes),
{a});
auto [out_shape, sorted_axes, squeezed_shape, is_noop] =
compute_reduce_shape(axes, a.shape());
auto out = (is_noop)
? astype(a, bool_, s)
: array(
std::move(out_shape),
bool_,
std::make_shared<Reduce>(to_stream(s), Reduce::Or, sorted_axes),
{a});
if (!keepdims) {
out = squeeze(out, sorted_axes, s);
out = reshape(out, std::move(squeezed_shape), s);
}
return out;
}
@@ -1569,15 +1578,18 @@ array sum(
if (axes.empty()) {
return a;
}
auto [out_shape, sorted_axes] = compute_reduce_shape(axes, a.shape());
auto [out_shape, sorted_axes, squeezed_shape, is_noop] =
compute_reduce_shape(axes, a.shape());
auto out_type = a.dtype() == bool_ ? int32 : a.dtype();
auto out = array(
out_shape,
out_type,
std::make_shared<Reduce>(to_stream(s), Reduce::Sum, sorted_axes),
{a});
auto out = (is_noop)
? astype(a, out_type, s)
: array(
std::move(out_shape),
out_type,
std::make_shared<Reduce>(to_stream(s), Reduce::Sum, sorted_axes),
{a});
if (!keepdims) {
out = squeeze(out, sorted_axes, s);
out = reshape(out, std::move(squeezed_shape), s);
}
return out;
}
@@ -1711,14 +1723,17 @@ array prod(
if (axes.empty()) {
return a;
}
auto [out_shape, sorted_axes] = compute_reduce_shape(axes, a.shape());
auto out = array(
out_shape,
a.dtype(),
std::make_shared<Reduce>(to_stream(s), Reduce::Prod, sorted_axes),
{a});
auto [out_shape, sorted_axes, squeezed_shape, is_noop] =
compute_reduce_shape(axes, a.shape());
auto out = (is_noop)
? a
: array(
std::move(out_shape),
a.dtype(),
std::make_shared<Reduce>(to_stream(s), Reduce::Prod, sorted_axes),
{a});
if (!keepdims) {
out = squeeze(out, sorted_axes, s);
out = reshape(out, std::move(squeezed_shape), s);
}
return out;
}
@@ -1745,17 +1760,17 @@ array max(
if (a.size() == 0) {
throw std::invalid_argument("[max] Cannot max reduce zero size array.");
}
if (axes.empty()) {
return a;
}
auto [out_shape, sorted_axes] = compute_reduce_shape(axes, a.shape());
auto out = array(
out_shape,
a.dtype(),
std::make_shared<Reduce>(to_stream(s), Reduce::Max, sorted_axes),
{a});
auto [out_shape, sorted_axes, squeezed_shape, is_noop] =
compute_reduce_shape(axes, a.shape());
auto out = (is_noop)
? a
: array(
std::move(out_shape),
a.dtype(),
std::make_shared<Reduce>(to_stream(s), Reduce::Max, sorted_axes),
{a});
if (!keepdims) {
out = squeeze(out, sorted_axes, s);
out = reshape(out, std::move(squeezed_shape), s);
}
return out;
}
@@ -1785,14 +1800,17 @@ array min(
if (axes.empty()) {
return a;
}
auto [out_shape, sorted_axes] = compute_reduce_shape(axes, a.shape());
auto out = array(
out_shape,
a.dtype(),
std::make_shared<Reduce>(to_stream(s), Reduce::Min, sorted_axes),
{a});
auto [out_shape, sorted_axes, squeezed_shape, is_noop] =
compute_reduce_shape(axes, a.shape());
auto out = (is_noop)
? a
: array(
std::move(out_shape),
a.dtype(),
std::make_shared<Reduce>(to_stream(s), Reduce::Min, sorted_axes),
{a});
if (!keepdims) {
out = squeeze(out, sorted_axes, s);
out = reshape(out, std::move(squeezed_shape), s);
}
return out;
}
@@ -1825,15 +1843,18 @@ array argmin(
throw std::invalid_argument(
"[argmin] Cannot argmin reduce zero size array.");
}
auto [out_shape, sorted_axes] = compute_reduce_shape({axis}, a.shape());
auto out = array(
out_shape,
uint32,
std::make_shared<ArgReduce>(
to_stream(s), ArgReduce::ArgMin, sorted_axes[0]),
{a});
auto [out_shape, sorted_axes, squeezed_shape, is_noop] =
compute_reduce_shape({axis}, a.shape());
auto out = (is_noop)
? zeros(out_shape, uint32, s)
: array(
std::move(out_shape),
uint32,
std::make_shared<ArgReduce>(
to_stream(s), ArgReduce::ArgMin, sorted_axes[0]),
{a});
if (!keepdims) {
out = squeeze(out, sorted_axes, s);
out = reshape(out, std::move(squeezed_shape), s);
}
return out;
}
@@ -1858,15 +1879,18 @@ array argmax(
throw std::invalid_argument(
"[argmax] Cannot argmax reduce zero size array.");
}
auto [out_shape, sorted_axes] = compute_reduce_shape({axis}, a.shape());
auto out = array(
out_shape,
uint32,
std::make_shared<ArgReduce>(
to_stream(s), ArgReduce::ArgMax, sorted_axes[0]),
{a});
auto [out_shape, sorted_axes, squeezed_shape, is_noop] =
compute_reduce_shape({axis}, a.shape());
auto out = (is_noop)
? zeros(out_shape, uint32, s)
: array(
std::move(out_shape),
uint32,
std::make_shared<ArgReduce>(
to_stream(s), ArgReduce::ArgMax, sorted_axes[0]),
{a});
if (!keepdims) {
out = squeeze(out, sorted_axes, s);
out = reshape(out, std::move(squeezed_shape), s);
}
return out;
}
@@ -2903,6 +2927,10 @@ array softmax(
const std::vector<int>& axes,
bool precise /* = false */,
StreamOrDevice s /* = {}*/) {
if (a.size() == 0) {
return a;
}
if (axes.size() == 1 && (a.ndim() == axes[0] + 1 || axes[0] == -1)) {
auto dtype = at_least_float(a.dtype());
return array(

View File

@@ -399,6 +399,8 @@ array isnan(const array& a, StreamOrDevice s = {});
array isinf(const array& a, StreamOrDevice s = {});
array isfinite(const array& a, StreamOrDevice s = {});
array isposinf(const array& a, StreamOrDevice s = {});
array isneginf(const array& a, StreamOrDevice s = {});

View File

@@ -186,10 +186,35 @@ array eval_impl(std::vector<array> outputs, bool async) {
}
void async_eval(std::vector<array> outputs) {
if (outputs.empty()) {
return;
}
if (std::none_of(outputs.begin(), outputs.end(), [](array& x) {
return x.status() == array::Status::unscheduled;
})) {
return;
}
eval_impl(std::move(outputs), true);
}
void eval(std::vector<array> outputs) {
if (outputs.empty()) {
return;
}
if (std::none_of(outputs.begin(), outputs.end(), [](array& x) {
return x.status() == array::Status::unscheduled;
})) {
for (auto& x : outputs) {
if (!x.is_available()) {
x.event().wait();
}
}
return;
}
eval_impl(std::move(outputs), false).event().wait();
}

View File

@@ -4,10 +4,10 @@
#include <variant>
#include "array.h"
#include "device.h"
#include "dtype.h"
#include "stream.h"
#include "mlx/array.h"
#include "mlx/device.h"
#include "mlx/dtype.h"
#include "mlx/stream.h"
namespace mlx::core {

View File

@@ -1,7 +1,7 @@
[build-system]
requires = [
"setuptools>=42",
"nanobind@git+https://github.com/wjakob/nanobind.git@2f04eac452a6d9142dedb957701bdb20125561e4",
"nanobind==2.1.0",
"cmake>=3.24",
]
build-backend = "setuptools.build_meta"

View File

@@ -234,7 +234,7 @@ def glorot_uniform(
def he_normal(
dtype: mx.Dtype = mx.float32,
) -> Callable[[mx.array, str, float], mx.array]:
) -> Callable[[mx.array, Literal["fan_in", "fan_out"], float], mx.array]:
r"""Build a He normal initializer.
This initializer samples from a normal distribution with a standard
@@ -292,7 +292,7 @@ def he_normal(
def he_uniform(
dtype: mx.Dtype = mx.float32,
) -> Callable[[mx.array, str, float], mx.array]:
) -> Callable[[mx.array, Literal["fan_in", "fan_out"], float], mx.array]:
r"""A He uniform (Kaiming uniform) initializer.
This initializer samples from a uniform distribution with a range

View File

@@ -72,6 +72,8 @@ from mlx.nn.layers.recurrent import GRU, LSTM, RNN
from mlx.nn.layers.transformer import (
MultiHeadAttention,
Transformer,
TransformerDecoder,
TransformerDecoderLayer,
TransformerEncoder,
TransformerEncoderLayer,
)

View File

@@ -1,5 +1,7 @@
# Copyright © 2023 Apple Inc.
from __future__ import annotations
import textwrap
from typing import Any, Callable, List, Optional, Tuple, Union
@@ -7,42 +9,6 @@ import mlx.core as mx
from mlx.utils import tree_flatten, tree_unflatten
def _unwrap(model, value_key, value, filter_fn, map_fn, is_leaf_fn):
if is_leaf_fn(model, value_key, value):
return map_fn(value)
elif isinstance(value, Module):
return {
k: _unwrap(value, k, v, filter_fn, map_fn, is_leaf_fn)
for k, v in value.items()
if filter_fn(value, k, v)
}
elif isinstance(value, dict):
nd = {}
for k, v in value.items():
tk = f"{value_key}.{k}"
nd[k] = (
_unwrap(model, tk, v, filter_fn, map_fn, is_leaf_fn)
if filter_fn(model, tk, v)
else {}
)
return nd
elif isinstance(value, list):
nl = []
for i, vi in enumerate(value):
tk = f"{value_key}.{i}"
nl.append(
_unwrap(model, tk, vi, filter_fn, map_fn, is_leaf_fn)
if filter_fn(model, tk, vi)
else {}
)
return nl
raise RuntimeError("Unexpected leaf found while traversing the module")
class Module(dict):
"""Base class for building neural networks with MLX.
@@ -151,7 +117,7 @@ class Module(dict):
self,
file_or_weights: Union[str, List[Tuple[str, mx.array]]],
strict: bool = True,
) -> "Module":
) -> Module:
"""
Update the model's weights from a ``.npz``, a ``.safetensors`` file, or a list.
@@ -266,9 +232,9 @@ class Module(dict):
def filter_and_map(
self,
filter_fn: Callable[["mlx.nn.Module", str, Any], bool],
filter_fn: Callable[[Module, str, Any], bool],
map_fn: Optional[Callable] = None,
is_leaf_fn: Optional[Callable[["mlx.nn.Module", str, Any], bool]] = None,
is_leaf_fn: Optional[Callable[[Module, str, Any], bool]] = None,
):
"""Recursively filter the contents of the module using ``filter_fn``,
namely only select keys and values where ``filter_fn`` returns true.
@@ -323,7 +289,7 @@ class Module(dict):
return self.filter_and_map(self.valid_child_filter, is_leaf_fn=_is_leaf_module)
def update(self, parameters: dict) -> "Module":
def update(self, parameters: dict) -> Module:
"""Replace the parameters of this Module with the provided ones in the
dict of dicts and lists.
@@ -371,8 +337,8 @@ class Module(dict):
def apply(
self,
map_fn: Callable[[mx.array], mx.array],
filter_fn: Optional[Callable[["mlx.nn.Module", str, Any], bool]] = None,
) -> "Module":
filter_fn: Optional[Callable[[Module, str, Any], bool]] = None,
) -> Module:
"""Map all the parameters using the provided ``map_fn`` and immediately
update the module with the mapped parameters.
@@ -391,7 +357,7 @@ class Module(dict):
self.update(self.filter_and_map(filter_fn, map_fn))
return self
def update_modules(self, modules: dict) -> "Module":
def update_modules(self, modules: dict) -> Module:
"""Replace the child modules of this :class:`Module` instance with the
provided ones in the dict of dicts and lists.
@@ -432,9 +398,7 @@ class Module(dict):
apply(self, modules)
return self
def apply_to_modules(
self, apply_fn: Callable[[str, "mlx.nn.Module"], Any]
) -> "Module":
def apply_to_modules(self, apply_fn: Callable[[str, Module], Any]) -> Module:
"""Apply a function to all the modules in this instance (including this
instance).
@@ -489,7 +453,7 @@ class Module(dict):
recurse: bool = True,
keys: Optional[Union[str, List[str]]] = None,
strict: bool = False,
) -> "Module":
) -> Module:
"""Freeze the Module's parameters or some of them. Freezing a parameter means not
computing gradients for it.
@@ -544,7 +508,7 @@ class Module(dict):
recurse: bool = True,
keys: Optional[Union[str, List[str]]] = None,
strict: bool = False,
) -> "Module":
) -> Module:
"""Unfreeze the Module's parameters or some of them.
This function is idempotent ie unfreezing a model that is not frozen is
@@ -588,7 +552,7 @@ class Module(dict):
_unfreeze_impl("", self)
return self
def train(self, mode: bool = True) -> "Module":
def train(self, mode: bool = True) -> Module:
"""Set the model in or out of training mode.
Training mode only applies to certain layers. For example
@@ -608,7 +572,7 @@ class Module(dict):
self.apply_to_modules(_set_train)
return self
def eval(self) -> "Module":
def eval(self) -> Module:
"""Set the model to evaluation mode.
See :func:`train`.
@@ -637,3 +601,39 @@ class Module(dict):
return True
self.apply(lambda x: x.astype(dtype) if predicate(x.dtype) else x)
def _unwrap(model, value_key, value, filter_fn, map_fn, is_leaf_fn):
if is_leaf_fn(model, value_key, value):
return map_fn(value)
elif isinstance(value, Module):
return {
k: _unwrap(value, k, v, filter_fn, map_fn, is_leaf_fn)
for k, v in value.items()
if filter_fn(value, k, v)
}
elif isinstance(value, dict):
nd = {}
for k, v in value.items():
tk = f"{value_key}.{k}"
nd[k] = (
_unwrap(model, tk, v, filter_fn, map_fn, is_leaf_fn)
if filter_fn(model, tk, v)
else {}
)
return nd
elif isinstance(value, list):
nl = []
for i, vi in enumerate(value):
tk = f"{value_key}.{i}"
nl.append(
_unwrap(model, tk, vi, filter_fn, map_fn, is_leaf_fn)
if filter_fn(model, tk, vi)
else {}
)
return nl
raise RuntimeError("Unexpected leaf found while traversing the module")

View File

@@ -32,7 +32,7 @@ class Dropout(Module):
mask = mx.random.bernoulli(self._p_1, x.shape)
return (1 / self._p_1) * mask * x
return (mask * x) * (1 / self._p_1)
class Dropout2d(Module):
@@ -85,7 +85,7 @@ class Dropout2d(Module):
mask_shape[-2] = mask_shape[-3] = 1
mask = mx.random.bernoulli(p=self._p_1, shape=mask_shape)
return (1 / self._p_1) * mask * x
return (mask * x) * (1 / self._p_1)
class Dropout3d(Module):
@@ -134,4 +134,4 @@ class Dropout3d(Module):
mask_shape[-2] = mask_shape[-3] = mask_shape[-4] = 1
mask = mx.random.bernoulli(p=self._p_1, shape=mask_shape)
return (1 / self._p_1) * mask * x
return (mask * x) * (1 / self._p_1)

View File

@@ -190,9 +190,9 @@ class MaxPool1d(_Pool1d):
def __init__(
self,
kernel_size: Union[int, Tuple[int, int]],
stride: Optional[Union[int, Tuple[int, int]]] = None,
padding: Optional[Union[int, Tuple[int, int]]] = 0,
kernel_size: Union[int, Tuple[int]],
stride: Optional[Union[int, Tuple[int]]] = None,
padding: Union[int, Tuple[int]] = 0,
):
super().__init__(mx.max, -float("inf"), kernel_size, stride, padding)
@@ -229,9 +229,9 @@ class AvgPool1d(_Pool1d):
def __init__(
self,
kernel_size: Union[int, Tuple[int, int]],
stride: Optional[Union[int, Tuple[int, int]]] = None,
padding: Optional[Union[int, Tuple[int, int]]] = 0,
kernel_size: Union[int, Tuple[int]],
stride: Optional[Union[int, Tuple[int]]] = None,
padding: Union[int, Tuple[int]] = 0,
):
super().__init__(mx.mean, 0, kernel_size, stride, padding)

View File

@@ -12,7 +12,7 @@ def quantize(
model: Module,
group_size: int = 64,
bits: int = 4,
class_predicate: Optional[callable] = None,
class_predicate: Optional[Callable] = None,
):
"""Quantize the sub-modules of a module according to a predicate.

View File

@@ -147,9 +147,9 @@ class TransformerEncoderLayer(Module):
else:
y = self.attention(x, x, x, mask)
y = self.dropout1(y)
y = self.ln1(x + y)
x = self.ln1(x + y)
y = self.linear1(y)
y = self.linear1(x)
y = self.activation(y)
y = self.dropout2(y)
y = self.linear2(y)

View File

@@ -1,7 +1,7 @@
# Copyright © 2023 Apple Inc.
import math
from typing import Literal
from typing import Literal, Optional
import mlx.core as mx
@@ -22,7 +22,7 @@ def _reduce(loss: mx.array, reduction: Reduction = "none"):
def cross_entropy(
logits: mx.array,
targets: mx.array,
weights: mx.array = None,
weights: Optional[mx.array] = None,
axis: int = -1,
label_smoothing: float = 0.0,
reduction: Reduction = "none",
@@ -117,7 +117,7 @@ def cross_entropy(
def binary_cross_entropy(
inputs: mx.array,
targets: mx.array,
weights: mx.array = None,
weights: Optional[mx.array] = None,
with_logits: bool = True,
reduction: Reduction = "mean",
) -> mx.array:

View File

@@ -1,7 +1,7 @@
# Copyright © 2023-2024 Apple Inc.
from functools import wraps
from typing import Callable
from typing import Callable, Optional
import mlx.core as mx
@@ -37,7 +37,7 @@ def value_and_grad(model: Module, fn: Callable):
return wrapped_value_grad_fn
def checkpoint(module: Module, fn: Callable = None):
def checkpoint(module: Module, fn: Optional[Callable] = None):
"""Transform the passed callable to one that performs gradient
checkpointing with respect to the trainable parameters of the module (and
the callable's inputs).

View File

@@ -4,6 +4,7 @@ import math
from typing import Callable, List, Optional, Tuple, Union
import mlx.core as mx
from mlx.nn import Module
from mlx.utils import tree_map, tree_reduce
@@ -17,7 +18,7 @@ class Optimizer:
self._state = {"step": mx.array(0, mx.uint64)}
self._schedulers = {k: v for k, v in (schedulers or {}).items()}
def update(self, model: "mlx.nn.Module", gradients: dict):
def update(self, model: Module, gradients: dict):
"""Apply the gradients to the parameters of the model and update the
model with the new parameters.
@@ -48,8 +49,28 @@ class Optimizer:
>>> optimizer.state.keys()
dict_keys(['step', 'learning_rate', 'weight', 'bias'])
"""
self._state.update(tree_map(lambda x: {}, parameters))
tree_map(self.init_single, parameters, self._state)
# Iniatilize the optimizer state to match the parameter state
def update_state(params, state):
if isinstance(params, (list, tuple)):
state = list(state)
for i in range(len(state)):
state[i] = update_state(params[i], state[i])
if len(state) != len(params):
state.extend(tree_map(lambda x: {}, params[len(state) :]))
return type(params)(state)
elif isinstance(params, dict):
for k, v in params.items():
if k not in state:
state[k] = tree_map(lambda x: {}, v)
else:
state[k] = update_state(v, state[k])
return state
else:
return state
update_state(parameters, self._state)
tree_map(lambda p, s: s or self.init_single(p, s), parameters, self._state)
self._initialized = True
def init_single(self, parameter: mx.array, state: dict):
@@ -104,7 +125,7 @@ class Optimizer:
@state.setter
def state(self, state: dict):
self._initialized = True
self._initialized = False
self._state = state
@property

View File

@@ -1,10 +1,10 @@
# Copyright © 2023 Apple Inc.
from collections import defaultdict
from typing import Any, Callable, Tuple
from typing import Any, Callable, List, Optional, Tuple
def tree_map(
fn: Callable, tree: Any, *rest: Tuple[Any], is_leaf: Callable = None
fn: Callable, tree: Any, *rest: Any, is_leaf: Optional[Callable] = None
) -> Any:
"""Applies ``fn`` to the leaves of the Python tree ``tree`` and
returns a new collection with the results.
@@ -59,8 +59,8 @@ def tree_map(
def tree_map_with_path(
fn: Callable,
tree: Any,
*rest: Tuple[Any],
is_leaf: Callable = None,
*rest: Any,
is_leaf: Optional[Callable] = None,
path: Any = None,
) -> Any:
"""Applies ``fn`` to the path and leaves of the Python tree ``tree`` and
@@ -111,7 +111,9 @@ def tree_map_with_path(
return fn(path, tree, *rest)
def tree_flatten(tree, prefix="", is_leaf=None):
def tree_flatten(
tree: Any, prefix: str = "", is_leaf: Optional[Callable] = None
) -> Any:
"""Flattens a Python tree to a list of key, value tuples.
The keys are using the dot notation to define trees of arbitrary depth and
@@ -155,7 +157,7 @@ def tree_flatten(tree, prefix="", is_leaf=None):
return [(prefix[1:], tree)]
def tree_unflatten(tree):
def tree_unflatten(tree: List[Tuple[str, Any]]) -> Any:
"""Recreate a Python tree from its flat representation.
.. code-block:: python

View File

@@ -9,6 +9,7 @@
#include <nanobind/stl/string.h>
#include <nanobind/stl/variant.h>
#include <nanobind/stl/vector.h>
#include <nanobind/typing.h>
#include "mlx/backend/metal/metal.h"
#include "python/src/buffer.h"
@@ -113,6 +114,7 @@ void init_array(nb::module_& m) {
.def("__hash__", [](const Dtype& t) {
return static_cast<int64_t>(t.val);
});
m.attr("bool_") = nb::cast(bool_);
m.attr("uint8") = nb::cast(uint8);
m.attr("uint16") = nb::cast(uint16);
@@ -177,7 +179,7 @@ void init_array(nb::module_& m) {
.export_values();
nb::class_<ArrayAt>(
m,
"_ArrayAt",
"ArrayAt",
R"pbdoc(
A helper object to apply updates at specific indices.
)pbdoc")
@@ -195,7 +197,7 @@ void init_array(nb::module_& m) {
nb::class_<ArrayPythonIterator>(
m,
"_ArrayIterator",
"ArrayIterator",
R"pbdoc(
A helper object to iterate over the 1st dimension of an array.
)pbdoc")
@@ -840,6 +842,8 @@ void init_array(nb::module_& m) {
},
"other"_a,
nb::rv_policy::none)
.def("__int__", [](array& a) { return nb::int_(to_scalar(a)); })
.def("__float__", [](array& a) { return nb::float_(to_scalar(a)); })
.def(
"flatten",
[](const array& a,

View File

@@ -160,6 +160,10 @@ nb::ndarray<> mlx_to_dlpack(const array& a) {
}
nb::object to_scalar(array& a) {
if (a.size() != 1) {
throw std::invalid_argument(
"[convert] Only length-1 arrays can be converted to Python scalars.");
}
{
nb::gil_scoped_release nogil;
a.eval();

View File

@@ -3,6 +3,8 @@
#include <nanobind/nanobind.h>
#include <nanobind/stl/optional.h>
#include <nanobind/stl/shared_ptr.h>
#include <nanobind/stl/variant.h>
#include <nanobind/stl/vector.h>
#include "mlx/distributed/distributed.h"
#include "mlx/distributed/ops.h"
@@ -74,8 +76,9 @@ void init_distributed(nb::module_& parent_module) {
"x"_a,
nb::kw_only(),
"group"_a = nb::none(),
"stream"_a = nb::none(),
nb::sig(
"def all_sum(x: array, *, group: Optional[Group] = None) -> array"),
"def all_sum(x: array, *, group: Optional[Group] = None, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
All reduce sum.
@@ -86,6 +89,8 @@ void init_distributed(nb::module_& parent_module) {
group (Group): The group of processes that will participate in the
reduction. If set to ``None`` the global group is used. Default:
``None``.
stream (Stream, optional): Stream or device. Defaults to ``None``
in which case the default stream of the default device is used.
Returns:
array: The sum of all ``x`` arrays.
@@ -97,8 +102,9 @@ void init_distributed(nb::module_& parent_module) {
"x"_a,
nb::kw_only(),
"group"_a = nb::none(),
"stream"_a = nb::none(),
nb::sig(
"def all_gather(x: array, *, group: Optional[Group] = None) -> array"),
"def all_gather(x: array, *, group: Optional[Group] = None, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Gather arrays from all processes.
@@ -110,8 +116,96 @@ void init_distributed(nb::module_& parent_module) {
group (Group): The group of processes that will participate in the
gather. If set to ``None`` the global group is used. Default:
``None``.
stream (Stream, optional): Stream or device. Defaults to ``None``
in which case the default stream of the default device is used.
Returns:
array: The concatenation of all ``x`` arrays.
)pbdoc");
m.def(
"send",
&distributed::send,
"x"_a,
"dst"_a,
nb::kw_only(),
"group"_a = nb::none(),
"stream"_a = nb::none(),
nb::sig(
"def send(x: array, dst: int, *, group: Optional[Group] = None, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Send an array from the current process to the process that has rank
``dst`` in the group.
Args:
x (array): Input array.
dst (int): Rank of the destination process in the group.
group (Group): The group of processes that will participate in the
sned. If set to ``None`` the global group is used. Default:
``None``.
stream (Stream, optional): Stream or device. Defaults to ``None``
in which case the default stream of the default device is used.
Returns:
array: An empty array which when evaluated the send is performed.
)pbdoc");
m.def(
"recv",
&distributed::recv,
"shape"_a,
"dtype"_a,
"src"_a,
nb::kw_only(),
"group"_a = nb::none(),
"stream"_a = nb::none(),
nb::sig(
"def recv(shape: Sequence[int], dtype: Dtype, src: int, *, group: Optional[Group] = None, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Recv an array with shape ``shape`` and dtype ``dtype`` from process
with rank ``src``.
Args:
shape (Tuple[int]): The shape of the array we are receiving.
dtype (Dtype): The data type of the array we are receiving.
src (int): Rank of the source process in the group.
group (Group): The group of processes that will participate in the
recv. If set to ``None`` the global group is used. Default:
``None``.
stream (Stream, optional): Stream or device. Defaults to ``None``
in which case the default stream of the default device is used.
Returns:
array: The array that was received from ``src``.
)pbdoc");
m.def(
"recv_like",
&distributed::recv_like,
"x"_a,
"src"_a,
nb::kw_only(),
"group"_a = nb::none(),
"stream"_a = nb::none(),
nb::sig(
"def recv_like(x: array, src: int, *, group: Optional[Group] = None, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Recv an array with shape and type like ``x`` from process with rank
``src``.
It is equivalent to calling ``mx.distributed.recv(x.shape, x.dtype, src)``.
Args:
x (array): An array defining the shape and dtype of the array we are
receiving.
src (int): Rank of the source process in the group.
group (Group): The group of processes that will participate in the
recv. If set to ``None`` the global group is used. Default:
``None``.
stream (Stream, optional): Stream or device. Defaults to ``None``
in which case the default stream of the default device is used.
Returns:
array: The array that was received from ``src``.
)pbdoc");
}

View File

@@ -1,9 +1,14 @@
// Copyright © 2023-2024 Apple Inc.
#include <nanobind/nanobind.h>
#include <nanobind/stl/map.h>
#include <nanobind/stl/optional.h>
#include <nanobind/stl/string.h>
#include <nanobind/stl/tuple.h>
#include <nanobind/stl/variant.h>
#include <nanobind/stl/vector.h>
#include "python/src/utils.h"
#include "mlx/fast.h"
#include "mlx/ops.h"
@@ -79,25 +84,29 @@ void init_fast(nb::module_& parent_module) {
"dims"_a,
nb::kw_only(),
"traditional"_a,
"base"_a,
"base"_a.none(),
"scale"_a,
"offset"_a,
"freqs"_a = nb::none(),
"stream"_a = nb::none(),
nb::sig(
"def rope(a: array, dims: int, *, traditional: bool, base: float, scale: float, offset: int, stream: Union[None, Stream, Device] = None) -> array"),
"def rope(a: array, dims: int, *, traditional: bool, base: Optional[float], scale: float, offset: int, freqs: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Apply rotary positional encoding to the input.
Args:
a (array): Input array.
dims (int): The feature dimensions to be rotated. If the input feature
is larger than dims then the rest is left unchanged.
is larger than dims then the rest is left unchanged.
traditional (bool): If set to ``True`` choose the traditional
implementation which rotates consecutive dimensions.
base (float): The base used to compute angular frequency for
each dimension in the positional encodings.
implementation which rotates consecutive dimensions.
base (float, optional): The base used to compute angular frequency for
each dimension in the positional encodings. Exactly one of ``base`` and
``freqs`` must be ``None``.
scale (float): The scale used to scale the positions.
offset (int): The position offset to start at.
freqs (array, optional): Optional frequencies to use with RoPE.
If set, the ``base`` parameter must be ``None``. Default: ``None``.
Returns:
array: The output array.
@@ -112,9 +121,10 @@ void init_fast(nb::module_& parent_module) {
nb::kw_only(),
"scale"_a,
"mask"_a = nb::none(),
"memory_efficient_threshold"_a = nb::none(),
"stream"_a = nb::none(),
nb::sig(
"def scaled_dot_product_attention(q: array, k: array, v: array, *, scale: float, mask: Union[None, array] = None, stream: Union[None, Stream, Device] = None) -> array"),
"def scaled_dot_product_attention(q: array, k: array, v: array, *, scale: float, mask: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
A fast implementation of multi-head attention: ``O = softmax(Q @ K.T, dim=-1) @ V``.
@@ -182,4 +192,153 @@ void init_fast(nb::module_& parent_module) {
Returns:
array: The quantized version of ``w``
)pbdoc");
nb::class_<fast::MetalKernel>(
m,
"metal_kernel",
R"pbdoc(
A jit-compiled custom Metal kernel defined from a source string.
)pbdoc")
.def(
nb::init<
const std::string&,
const std::string&,
const std::string&,
bool,
bool>(),
"name"_a,
"source"_a,
"header"_a = "",
"ensure_row_contiguous"_a = true,
"atomic_outputs"_a = false,
R"pbdoc(
Initialize a metal_kernel.
Args:
name (str): Name for the kernel.
source (str): Source code. This is the body of a function in Metal,
the function signature will be generated for you. The names of the inputs/outputs
are determined by the ``inputs`` and ``output_shapes``/``output_dtypes``
used when the kernel is called.
header (str): Header source code to include before the main function.
Useful for helper functions or includes that should live outside of the main function body.
ensure_row_contiguous (bool): Whether to ensure the inputs are row contiguous
before the kernel runs. Default: ``True``.
atomic_outputs (bool): Whether to use atomic outputs in the function signature
e.g. ``device atomic<float>``. Default: ``False``.
Returns:
Callable ``metal_kernel``.
Example:
.. code-block:: python
def exp_elementwise(a: mx.array):
source = '''
uint elem = thread_position_in_grid.x;
T tmp = inp[elem];
out[elem] = metal::exp(tmp);
'''
kernel = mx.fast.metal_kernel(
name="myexp",
source=source
)
outputs = kernel(
inputs={"inp": a},
template={"T": mx.float32},
grid=(a.size, 1, 1),
threadgroup=(256, 1, 1),
output_shapes={"out": a.shape},
output_dtypes={"out": a.dtype},
verbose=True,
)
return outputs["out"]
a = mx.random.normal(shape=(4, 16)).astype(mx.float16)
b = exp_elementwise(a)
assert mx.allclose(b, mx.exp(a))
)pbdoc")
.def(
"__call__",
[](fast::MetalKernel& kernel,
std::map<std::string, ScalarOrArray>& inputs_,
std::map<std::string, std::vector<int>>& output_shapes,
std::map<std::string, Dtype>& output_dtypes,
std::tuple<int, int, int> grid,
std::tuple<int, int, int> threadgroup,
std::optional<std::map<std::string, nb::handle>> template_args_,
std::optional<float> init_value,
bool verbose,
StreamOrDevice s) {
std::map<std::string, array> inputs;
for (const auto& [name, value] : inputs_) {
auto arr = to_array(value, std::nullopt);
inputs.insert({name, arr});
}
std::map<std::string, fast::TemplateArg> template_args;
if (template_args_) {
for (const auto& [name, value] : template_args_.value()) {
// Handle bool, int and dtype template args
if (nb::isinstance<bool>(value)) {
bool bool_val = nb::cast<bool>(value);
template_args.insert({name, bool_val});
} else if (nb::isinstance<int>(value)) {
int int_val = nb::cast<int>(value);
template_args.insert({name, int_val});
} else if (nb::isinstance<Dtype>(value)) {
Dtype dtype = nb::cast<Dtype>(value);
template_args.insert({name, dtype});
} else {
throw std::invalid_argument(
"[[metal_kernel]] Invalid template argument. Must be `mlx.core.Dtype`, `int` or `bool`.");
}
}
}
return kernel(
inputs,
output_shapes,
output_dtypes,
grid,
threadgroup,
template_args,
init_value,
verbose,
s);
},
nb::kw_only(),
"inputs"_a,
"output_shapes"_a,
"output_dtypes"_a,
"grid"_a,
"threadgroup"_a,
"template"_a = nb::none(),
"init_value"_a = nb::none(),
"verbose"_a = false,
"stream"_a = nb::none(),
nb::sig(
"def __call__(self, *, inputs: Mapping[str, Union[scalar, array]], output_shapes: Mapping[str, Sequence[int]], output_dtypes: Mapping[str, Dtype], grid: tuple[int, int, int], threadgroup: tuple[int, int, int], template: Optional[Mapping[str, Union[bool, int, Dtype]]] = None, init_value: Optional[float] = None, verbose: bool = false, stream: Union[None, Stream, Device] = None)"),
R"pbdoc(
Run the kernel.
Args:
inputs (Mapping[str, array]): Inputs. These will be added to the function signature and passed to the Metal kernel.
The keys will be the names of the arguments to the kernel.
output_shapes (Mapping[str, Sequence[int]]): Output shapes. A dict mapping
output variable names to shapes. These will be added to the function signature.
output_dtypes (Mapping[str, Dtype]): Output dtypes. A dict mapping output variable
names to dtypes. Must have the same keys as ``output_shapes``.
grid (tuple[int, int, int]): 3-tuple specifying the grid to launch the kernel with.
threadgroup (tuple[int, int, int]): 3-tuple specifying the threadgroup size to use.
template (Mapping[str, Union[bool, int, Dtype]], optional): Template arguments.
These will be added as template arguments to the kernel definition. Default: ``None``.
init_value (float, optional): Optional value to use to initialize all of the output arrays.
By default, output arrays are uninitialized. Default: ``None``.
verbose (bool, optional): Whether to print the full generated source code of the kernel
when it is run. Default: ``False``.
stream (mx.stream, optional): Stream to run the kernel on. Default: ``None``.
Returns:
dict[str, array]: Dictionary of output arrays based on ``output_shapes``/``output_dtypes``.
)pbdoc");
}

View File

@@ -63,7 +63,7 @@ void init_linalg(nb::module_& parent_module) {
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def norm(a: array, /, ord: Union[None, scalar, str] = None, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"),
"def norm(a: array, /, ord: Union[None, int, float, str] = None, axis: Union[None, int, list[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Matrix or vector norm.
@@ -74,7 +74,7 @@ void init_linalg(nb::module_& parent_module) {
a (array): Input array. If ``axis`` is ``None``, ``a`` must be 1-D or 2-D,
unless ``ord`` is ``None``. If both ``axis`` and ``ord`` are ``None``, the
2-norm of ``a.flatten`` will be returned.
ord (scalar or str, optional): Order of the norm (see table under ``Notes``).
ord (int, float or str, optional): Order of the norm (see table under ``Notes``).
If ``None``, the 2-norm (or Frobenius norm for matrices) will be computed
along the given ``axis``. Default: ``None``.
axis (int or list(int), optional): If ``axis`` is an integer, it specifies the
@@ -187,7 +187,7 @@ void init_linalg(nb::module_& parent_module) {
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def qr(a: array, *, stream: Union[None, Stream, Device] = None) -> (array, array)"),
"def qr(a: array, *, stream: Union[None, Stream, Device] = None) -> tuple(array, array)"),
R"pbdoc(
The QR factorization of the input matrix.
@@ -220,7 +220,7 @@ void init_linalg(nb::module_& parent_module) {
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def svd(a: array, *, stream: Union[None, Stream, Device] = None) -> (array, array, array)"),
"def svd(a: array, *, stream: Union[None, Stream, Device] = None) -> tuple(array, array, array)"),
R"pbdoc(
The Singular Value Decomposition (SVD) of the input matrix.
@@ -325,9 +325,9 @@ void init_linalg(nb::module_& parent_module) {
nb::sig(
"def cholesky_inv(L: array, upper: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Compute the inverse of a real symmetric positive semi-definite matrix using it's Cholesky decomposition L.
Compute the inverse of a real symmetric positive semi-definite matrix using it's Cholesky decomposition.
Let A be a real symmetric positive semi-definite matrix and L its Cholesky definition such that:
Let :math:`\mathbf{A}` be a real symmetric positive semi-definite matrix and :math:`\mathbf{L}` its Cholesky decomposition such that:
.. math::
@@ -339,7 +339,7 @@ void init_linalg(nb::module_& parent_module) {
This function supports arrays with at least 2 dimensions. When the input
has more than two dimensions, the Cholesky inverse is computed for each matrix
in the last two dimensions of ``L``.
in the last two dimensions of :math:`\mathbf{L}`.
If the input matrix is not a triangular matrix behaviour is undefined.
@@ -351,6 +351,30 @@ void init_linalg(nb::module_& parent_module) {
in which case the default stream of the default device is used.
Returns:
array: :math:`A^{-1}` where :math:`\mathbf{A} = \mathbf{L}\mathbf{L}^T`.
array: :math:`\mathbf{A^{-1}}` where :math:`\mathbf{A} = \mathbf{L}\mathbf{L}^T`.
)pbdoc");
m.def(
"pinv",
&pinv,
"a"_a,
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def pinv(a: array, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Compute the (Moore-Penrose) pseudo-inverse of a matrix.
This function calculates a generalized inverse of a matrix using its
singular-value decomposition. This function supports arrays with at least 2 dimensions.
When the input has more than two dimensions, the inverse is computed for each
matrix in the last two dimensions of ``a``.
Args:
a (array): Input array.
stream (Stream, optional): Stream or device. Defaults to ``None``
in which case the default stream of the default device is used.
Returns:
array: ``aplus`` such that ``a @ aplus @ a = a``
)pbdoc");
}

View File

@@ -138,6 +138,21 @@ class PyFileReader : public io::Reader {
void read(char* data, size_t n) override {
nb::gil_scoped_acquire gil;
_read(data, n);
}
void read(char* data, size_t n, size_t offset) override {
nb::gil_scoped_acquire gil;
seek_func_(offset, (int)std::ios_base::beg);
_read(data, n);
}
std::string label() const override {
return "python file object";
}
private:
void _read(char* data, size_t n) {
auto memview = PyMemoryView_FromMemory(data, n, PyBUF_WRITE);
nb::object bytes_read = readinto_func_(nb::handle(memview));
@@ -146,11 +161,6 @@ class PyFileReader : public io::Reader {
}
}
std::string label() const override {
return "python file object";
}
private:
nb::object pyistream_;
nb::object readinto_func_;
nb::object seek_func_;

Some files were not shown because too many files have changed in this diff Show More